Spaces:
				
			
			
	
			
			
		Running
		
			on 
			
			L40S
	
	
	
			
			
	
	
	
	
		
		
		Running
		
			on 
			
			L40S
	Upload folder using huggingface_hub
Browse files- .gitignore +42 -0
- README.md +43 -12
- README_zh_cn.md +242 -0
- app.py +231 -110
- env_install.sh +1 -1
- infer/gif_render.py +3 -3
- infer/image_to_views.py +9 -4
- infer/text_to_image.py +1 -2
- infer/utils.py +7 -1
- infer/views_to_mesh.py +7 -4
- main.py +60 -12
- requirements.txt +1 -0
- svrm/ldm/models/svrm.py +16 -19
- svrm/ldm/modules/attention.py +20 -11
- svrm/ldm/vis_util.py +14 -15
- svrm/predictor.py +1 -3
- third_party/check.py +25 -0
- third_party/dust3r_utils.py +366 -0
- third_party/gen_baking.py +288 -0
- third_party/mesh_baker.py +142 -0
- third_party/utils/camera_utils.py +90 -0
- third_party/utils/img_utils.py +211 -0
    	
        .gitignore
    ADDED
    
    | @@ -0,0 +1,42 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            **/*~
         | 
| 2 | 
            +
            **/*.bk
         | 
| 3 | 
            +
            **/*.xx
         | 
| 4 | 
            +
            **/*.so
         | 
| 5 | 
            +
            **/*.ipynb
         | 
| 6 | 
            +
            **/*.log
         | 
| 7 | 
            +
            **/*.swp
         | 
| 8 | 
            +
            **/*.zip
         | 
| 9 | 
            +
            **/*.look
         | 
| 10 | 
            +
            **/*.lock
         | 
| 11 | 
            +
            **/*.think
         | 
| 12 | 
            +
            **/dosth.sh
         | 
| 13 | 
            +
            **/nohup.out
         | 
| 14 | 
            +
            **/*polaris*
         | 
| 15 | 
            +
            **/*egg*/
         | 
| 16 | 
            +
            **/cl5/
         | 
| 17 | 
            +
            **/tmp/
         | 
| 18 | 
            +
            **/look/
         | 
| 19 | 
            +
            **/temp/
         | 
| 20 | 
            +
            **/build/
         | 
| 21 | 
            +
            **/model/
         | 
| 22 | 
            +
            **/log/
         | 
| 23 | 
            +
            **/backup/
         | 
| 24 | 
            +
            **/outputs/
         | 
| 25 | 
            +
            **/work_dir/
         | 
| 26 | 
            +
            **/work_dirs/
         | 
| 27 | 
            +
            **/__pycache__/
         | 
| 28 | 
            +
            **/.ipynb_checkpoints/
         | 
| 29 | 
            +
            *.jpg
         | 
| 30 | 
            +
            *.png
         | 
| 31 | 
            +
            *.gif
         | 
| 32 | 
            +
            ### PreCI ###
         | 
| 33 | 
            +
            .codecc
         | 
| 34 | 
            +
             | 
| 35 | 
            +
            app_hg.py
         | 
| 36 | 
            +
            outputs
         | 
| 37 | 
            +
            weights
         | 
| 38 | 
            +
            .vscode/
         | 
| 39 | 
            +
            baking
         | 
| 40 | 
            +
            inference.py
         | 
| 41 | 
            +
            third_party/weights
         | 
| 42 | 
            +
            third_party/dust3r
         | 
    	
        README.md
    CHANGED
    
    | @@ -1,14 +1,5 @@ | |
| 1 | 
            -
             | 
| 2 | 
            -
             | 
| 3 | 
            -
            emoji: 😻
         | 
| 4 | 
            -
            colorFrom: purple
         | 
| 5 | 
            -
            colorTo: red
         | 
| 6 | 
            -
            sdk: gradio
         | 
| 7 | 
            -
            sdk_version: 5.5.0
         | 
| 8 | 
            -
            app_file: app_hg.py
         | 
| 9 | 
            -
            pinned: false
         | 
| 10 | 
            -
            short_description: Text-to-3D and Image-to-3D Generation
         | 
| 11 | 
            -
            ---
         | 
| 12 | 
             
            <!-- ## **Hunyuan3D-1.0** -->
         | 
| 13 |  | 
| 14 | 
             
            <p align="center">
         | 
| @@ -19,7 +10,7 @@ short_description: Text-to-3D and Image-to-3D Generation | |
| 19 |  | 
| 20 | 
             
            <div align="center">
         | 
| 21 | 
             
              <a href="https://github.com/tencent/Hunyuan3D-1"><img src="https://img.shields.io/static/v1?label=Code&message=Github&color=blue&logo=github-pages"></a>  
         | 
| 22 | 
            -
              <a href="https://3d.hunyuan.tencent.com"><img src="https://img.shields.io/static/v1?label=Homepage&message=Tencent | 
| 23 | 
             
              <a href="https://arxiv.org/pdf/2411.02293"><img src="https://img.shields.io/static/v1?label=Tech Report&message=Arxiv&color=red&logo=arxiv"></a>  
         | 
| 24 | 
             
              <a href="https://huggingface.co/Tencent/Hunyuan3D-1"><img src="https://img.shields.io/static/v1?label=Checkpoints&message=HuggingFace&color=yellow"></a>  
         | 
| 25 | 
             
              <a href="https://huggingface.co/spaces/Tencent/Hunyuan3D-1"><img src="https://img.shields.io/static/v1?label=Demo&message=HuggingFace&color=yellow"></a>  
         | 
| @@ -101,6 +92,19 @@ pip install torch torchvision --index-url https://download.pytorch.org/whl/cu121 | |
| 101 | 
             
            # step 3. install other packages
         | 
| 102 | 
             
            bash env_install.sh
         | 
| 103 | 
             
            ```
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 104 | 
             
            <details>
         | 
| 105 | 
             
            <summary>💡Other tips for envrionment installation</summary>
         | 
| 106 |  | 
| @@ -204,6 +208,33 @@ bash scripts/image_to_3d_std_separately.sh ./demos/example_000.png ./outputs/tes | |
| 204 | 
             
            bash scripts/image_to_3d_lite_separately.sh ./demos/example_000.png ./outputs/test # >= 10G
         | 
| 205 | 
             
            ```
         | 
| 206 |  | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 207 | 
             
            #### Using Gradio
         | 
| 208 |  | 
| 209 | 
             
            We have prepared two versions of multi-view generation, std and lite.
         | 
|  | |
| 1 | 
            +
            [English](README.md) | [简体中文](README_zh_cn.md)
         | 
| 2 | 
            +
             | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 3 | 
             
            <!-- ## **Hunyuan3D-1.0** -->
         | 
| 4 |  | 
| 5 | 
             
            <p align="center">
         | 
|  | |
| 10 |  | 
| 11 | 
             
            <div align="center">
         | 
| 12 | 
             
              <a href="https://github.com/tencent/Hunyuan3D-1"><img src="https://img.shields.io/static/v1?label=Code&message=Github&color=blue&logo=github-pages"></a>  
         | 
| 13 | 
            +
              <a href="https://3d.hunyuan.tencent.com"><img src="https://img.shields.io/static/v1?label=Homepage&message=Tencent%20Hunyuan3D&color=blue&logo=github-pages"></a>  
         | 
| 14 | 
             
              <a href="https://arxiv.org/pdf/2411.02293"><img src="https://img.shields.io/static/v1?label=Tech Report&message=Arxiv&color=red&logo=arxiv"></a>  
         | 
| 15 | 
             
              <a href="https://huggingface.co/Tencent/Hunyuan3D-1"><img src="https://img.shields.io/static/v1?label=Checkpoints&message=HuggingFace&color=yellow"></a>  
         | 
| 16 | 
             
              <a href="https://huggingface.co/spaces/Tencent/Hunyuan3D-1"><img src="https://img.shields.io/static/v1?label=Demo&message=HuggingFace&color=yellow"></a>  
         | 
|  | |
| 92 | 
             
            # step 3. install other packages
         | 
| 93 | 
             
            bash env_install.sh
         | 
| 94 | 
             
            ```
         | 
| 95 | 
            +
             | 
| 96 | 
            +
            because of dust3r, we offer a guide:
         | 
| 97 | 
            +
             | 
| 98 | 
            +
            ```
         | 
| 99 | 
            +
            cd third_party
         | 
| 100 | 
            +
            git clone --recursive https://github.com/naver/dust3r.git
         | 
| 101 | 
            +
             | 
| 102 | 
            +
            cd ../third_party/weights
         | 
| 103 | 
            +
            wget https://download.europe.naverlabs.com/ComputerVision/DUSt3R/DUSt3R_ViTLarge_BaseDecoder_512_dpt.pth
         | 
| 104 | 
            +
             | 
| 105 | 
            +
            ```
         | 
| 106 | 
            +
             | 
| 107 | 
            +
             | 
| 108 | 
             
            <details>
         | 
| 109 | 
             
            <summary>💡Other tips for envrionment installation</summary>
         | 
| 110 |  | 
|  | |
| 208 | 
             
            bash scripts/image_to_3d_lite_separately.sh ./demos/example_000.png ./outputs/test # >= 10G
         | 
| 209 | 
             
            ```
         | 
| 210 |  | 
| 211 | 
            +
            #### Baking related
         | 
| 212 | 
            +
             | 
| 213 | 
            +
            ```bash
         | 
| 214 | 
            +
            cd ./third_party
         | 
| 215 | 
            +
            git clone --recursive https://github.com/naver/dust3r.git
         | 
| 216 | 
            +
             | 
| 217 | 
            +
            mkdir -p weights/DUSt3R_ViTLarge_BaseDecoder_512_dpt
         | 
| 218 | 
            +
            cd weights/DUSt3R_ViTLarge_BaseDecoder_512_dpt
         | 
| 219 | 
            +
             | 
| 220 | 
            +
            wget https://download.europe.naverlabs.com/ComputerVision/DUSt3R/DUSt3R_ViTLarge_BaseDecoder_512_dpt.pth
         | 
| 221 | 
            +
            cd ../../..
         | 
| 222 | 
            +
            ```
         | 
| 223 | 
            +
             | 
| 224 | 
            +
            If you download related code and weights, we list some additional arg:
         | 
| 225 | 
            +
             | 
| 226 | 
            +
            |    Argument        |  Default  |                     Description                     |
         | 
| 227 | 
            +
            |:------------------:|:---------:|:---------------------------------------------------:|
         | 
| 228 | 
            +
            |`--do_bake`  |   False   | baking multi-view into mesh   |
         | 
| 229 | 
            +
            |`--bake_align_times`  |   3   | the times of align image with mesh |
         | 
| 230 | 
            +
             | 
| 231 | 
            +
             | 
| 232 | 
            +
            Note: When running main.py, ensure that do_bake is set to True and do_texture_mapping is also set to True.
         | 
| 233 | 
            +
             | 
| 234 | 
            +
            ```bash
         | 
| 235 | 
            +
            python main.py ... --do_texture_mapping --do_bake (--do_render)
         | 
| 236 | 
            +
            ```
         | 
| 237 | 
            +
             | 
| 238 | 
             
            #### Using Gradio
         | 
| 239 |  | 
| 240 | 
             
            We have prepared two versions of multi-view generation, std and lite.
         | 
    	
        README_zh_cn.md
    ADDED
    
    | @@ -0,0 +1,242 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            [English](README.md) | [简体中文](README_zh_cn.md)
         | 
| 2 | 
            +
             | 
| 3 | 
            +
            <!-- ## **Hunyuan3D-1.0** -->
         | 
| 4 | 
            +
             | 
| 5 | 
            +
            <p align="center">
         | 
| 6 | 
            +
              <img src="./assets/logo.png"  height=200>
         | 
| 7 | 
            +
            </p>
         | 
| 8 | 
            +
             | 
| 9 | 
            +
            # Tencent Hunyuan3D-1.0: A Unified Framework for Text-to-3D and Image-to-3D Generation
         | 
| 10 | 
            +
             | 
| 11 | 
            +
            <div align="center">
         | 
| 12 | 
            +
              <a href="https://github.com/tencent/Hunyuan3D-1"><img src="https://img.shields.io/static/v1?label=Code&message=Github&color=blue&logo=github-pages"></a>  
         | 
| 13 | 
            +
              <a href="https://3d.hunyuan.tencent.com"><img src="https://img.shields.io/static/v1?label=Homepage&message=Tencent%20Hunyuan3D&color=blue&logo=github-pages"></a>  
         | 
| 14 | 
            +
              <a href="https://arxiv.org/pdf/2411.02293"><img src="https://img.shields.io/static/v1?label=Tech Report&message=Arxiv&color=red&logo=arxiv"></a>  
         | 
| 15 | 
            +
              <a href="https://huggingface.co/Tencent/Hunyuan3D-1"><img src="https://img.shields.io/static/v1?label=Checkpoints&message=HuggingFace&color=yellow"></a>  
         | 
| 16 | 
            +
              <a href="https://huggingface.co/spaces/Tencent/Hunyuan3D-1"><img src="https://img.shields.io/static/v1?label=Demo&message=HuggingFace&color=yellow"></a>  
         | 
| 17 | 
            +
            </div>
         | 
| 18 | 
            +
             | 
| 19 | 
            +
             | 
| 20 | 
            +
            ## 🔥🔥🔥 更新!!
         | 
| 21 | 
            +
             | 
| 22 | 
            +
            * Nov 5, 2024: 💬 已经支持图生3D。请在[script](#using-gradio)体验。
         | 
| 23 | 
            +
            * Nov 5, 2024: 💬 已经支持文生3D,请在[script](#using-gradio)体验。
         | 
| 24 | 
            +
             | 
| 25 | 
            +
             | 
| 26 | 
            +
            ## 📑 开源计划
         | 
| 27 | 
            +
             | 
| 28 | 
            +
            - [x] Inference 
         | 
| 29 | 
            +
            - [x] Checkpoints
         | 
| 30 | 
            +
            - [ ] Baking related
         | 
| 31 | 
            +
            - [ ] Training
         | 
| 32 | 
            +
            - [ ] ComfyUI
         | 
| 33 | 
            +
            - [ ] Distillation Version
         | 
| 34 | 
            +
            - [ ] TensorRT Version
         | 
| 35 | 
            +
             | 
| 36 | 
            +
             | 
| 37 | 
            +
             | 
| 38 | 
            +
            ## **概要**
         | 
| 39 | 
            +
            <p align="center">
         | 
| 40 | 
            +
              <img src="./assets/teaser.png"  height=450>
         | 
| 41 | 
            +
            </p>
         | 
| 42 | 
            +
             | 
| 43 | 
            +
            为了解决现有的3D生成模型在生成速度和泛化能力上存在不足,我们开源了混元3D-1.0模型,可以帮助3D创作者和艺术家自动化生产3D资产。我们的模型采用两阶段生成方法,在保证质量和可控的基础上,仅需10秒即可生成3D资产。在第一阶段,我们采用了一种多视角扩散模型,轻量版模型能够在大约4秒内高效生成多视角图像,这些多视角图像从不同的视角捕捉了3D资产的丰富的纹理和几何先验,将任务从单视角重建松弛到多视角重建。在第二阶段,我们引入了一种前馈重建模型,利用上一阶段生成的多视角图像。该模型能够在大约3秒内快速而准确地重建3D资产。重建模型学习处理多视角扩散引入的噪声和不一致性,并利用条件图像中的可用信息高效恢复3D结构。最终,该模型可以实现输入任意单视角实现三维生成。
         | 
| 44 | 
            +
             | 
| 45 | 
            +
             | 
| 46 | 
            +
            ## 🎉 **Hunyuan3D-1.0 模型架构**
         | 
| 47 | 
            +
             | 
| 48 | 
            +
            <p align="center">
         | 
| 49 | 
            +
              <img src="./assets/overview_3.png"  height=400>
         | 
| 50 | 
            +
            </p>
         | 
| 51 | 
            +
             | 
| 52 | 
            +
             | 
| 53 | 
            +
            ## 📈 比较
         | 
| 54 | 
            +
             | 
| 55 | 
            +
            通过和其他开源模型比较, 混元3D-1.0在5项指标都得到了最高用户评分。细节请查看以下用户研究结果。
         | 
| 56 | 
            +
             | 
| 57 | 
            +
            在A100显卡上,轻量版模型仅需10s即可完成单图生成3D,标准版则大约需要25s。以下散点图表明腾讯混元3D-1.0实现了质量和速度的合理平衡。
         | 
| 58 | 
            +
             | 
| 59 | 
            +
            <p align="center">
         | 
| 60 | 
            +
              <img src="./assets/radar.png"  height=300>
         | 
| 61 | 
            +
              <img src="./assets/runtime.png"  height=300>
         | 
| 62 | 
            +
            </p>
         | 
| 63 | 
            +
             | 
| 64 | 
            +
            ## 使用
         | 
| 65 | 
            +
             | 
| 66 | 
            +
            #### 复制代码仓库
         | 
| 67 | 
            +
             | 
| 68 | 
            +
            ```shell
         | 
| 69 | 
            +
            git clone https://github.com/tencent/Hunyuan3D-1
         | 
| 70 | 
            +
            cd Hunyuan3D-1
         | 
| 71 | 
            +
            ```
         | 
| 72 | 
            +
             | 
| 73 | 
            +
            #### Linux系统安装
         | 
| 74 | 
            +
             | 
| 75 | 
            +
            env_install.sh 脚本提供了如何安装环境:
         | 
| 76 | 
            +
             | 
| 77 | 
            +
            ```
         | 
| 78 | 
            +
            # 第一步:创建环境
         | 
| 79 | 
            +
            conda create -n hunyuan3d-1 python=3.9 or 3.10 or 3.11 or 3.12
         | 
| 80 | 
            +
            conda activate hunyuan3d-1
         | 
| 81 | 
            +
             | 
| 82 | 
            +
            # 第二部:安装torch和相关依赖包
         | 
| 83 | 
            +
            which pip # check pip corresponds to python
         | 
| 84 | 
            +
             | 
| 85 | 
            +
            # modify the cuda version according to your machine (recommended)
         | 
| 86 | 
            +
            pip install torch torchvision --index-url https://download.pytorch.org/whl/cu121
         | 
| 87 | 
            +
             | 
| 88 | 
            +
            # 第三步:安装其他相关依赖包
         | 
| 89 | 
            +
            bash env_install.sh
         | 
| 90 | 
            +
            ```
         | 
| 91 | 
            +
             | 
| 92 | 
            +
            由于dust3r的许可证限制, 我们仅提供其安装途径:
         | 
| 93 | 
            +
             | 
| 94 | 
            +
            ```
         | 
| 95 | 
            +
            cd third_party
         | 
| 96 | 
            +
            git clone --recursive https://github.com/naver/dust3r.git
         | 
| 97 | 
            +
             | 
| 98 | 
            +
            cd ../third_party/weights
         | 
| 99 | 
            +
            wget https://download.europe.naverlabs.com/ComputerVision/DUSt3R/DUSt3R_ViTLarge_BaseDecoder_512_dpt.pth
         | 
| 100 | 
            +
             | 
| 101 | 
            +
            ```
         | 
| 102 | 
            +
             | 
| 103 | 
            +
             | 
| 104 | 
            +
            <details>
         | 
| 105 | 
            +
            <summary>💡一些环境安装建议</summary>
         | 
| 106 | 
            +
                
         | 
| 107 | 
            +
            可以选择安装 xformers 或 flash_attn 进行加速:
         | 
| 108 | 
            +
             | 
| 109 | 
            +
            ```
         | 
| 110 | 
            +
            pip install xformers --index-url https://download.pytorch.org/whl/cu121
         | 
| 111 | 
            +
            ```
         | 
| 112 | 
            +
            ```
         | 
| 113 | 
            +
            pip install flash_attn
         | 
| 114 | 
            +
            ```
         | 
| 115 | 
            +
             | 
| 116 | 
            +
            Most environment errors are caused by a mismatch between machine and packages. You can try manually specifying the version, as shown in the following successful cases:
         | 
| 117 | 
            +
            ```
         | 
| 118 | 
            +
            # python3.9
         | 
| 119 | 
            +
            pip install torch==2.0.1 torchvision==0.15.2 --index-url https://download.pytorch.org/whl/cu118 
         | 
| 120 | 
            +
            ```
         | 
| 121 | 
            +
             | 
| 122 | 
            +
            when install pytorch3d, the gcc version is preferably greater than 9, and the gpu driver should not be too old.
         | 
| 123 | 
            +
             | 
| 124 | 
            +
            </details>
         | 
| 125 | 
            +
             | 
| 126 | 
            +
            #### 下载预训练模型
         | 
| 127 | 
            +
             | 
| 128 | 
            +
            模型下载链接 [https://huggingface.co/tencent/Hunyuan3D-1](https://huggingface.co/tencent/Hunyuan3D-1):
         | 
| 129 | 
            +
             | 
| 130 | 
            +
            + `Hunyuan3D-1/lite`, lite model for multi-view generation.
         | 
| 131 | 
            +
            + `Hunyuan3D-1/std`, standard model for multi-view generation.
         | 
| 132 | 
            +
            + `Hunyuan3D-1/svrm`, sparse-view reconstruction model.
         | 
| 133 | 
            +
             | 
| 134 | 
            +
             | 
| 135 | 
            +
            为了通过Hugging Face下载模型,请先下载 huggingface-cli. (安装细节可见 [here](https://huggingface.co/docs/huggingface_hub/guides/cli).)
         | 
| 136 | 
            +
             | 
| 137 | 
            +
            ```shell
         | 
| 138 | 
            +
            python3 -m pip install "huggingface_hub[cli]"
         | 
| 139 | 
            +
            ```
         | 
| 140 | 
            +
             | 
| 141 | 
            +
            请使用以下命令下载模型:
         | 
| 142 | 
            +
             | 
| 143 | 
            +
            ```shell
         | 
| 144 | 
            +
            mkdir weights
         | 
| 145 | 
            +
            huggingface-cli download tencent/Hunyuan3D-1 --local-dir ./weights
         | 
| 146 | 
            +
             | 
| 147 | 
            +
            mkdir weights/hunyuanDiT
         | 
| 148 | 
            +
            huggingface-cli download Tencent-Hunyuan/HunyuanDiT-v1.1-Diffusers-Distilled --local-dir ./weights/hunyuanDiT
         | 
| 149 | 
            +
            ```
         | 
| 150 | 
            +
             | 
| 151 | 
            +
            #### 推理 
         | 
| 152 | 
            +
            对于文生3D,我们支持中/英双语生成,请使用以下命令进行本地推理:
         | 
| 153 | 
            +
            ```python
         | 
| 154 | 
            +
            python3 main.py \
         | 
| 155 | 
            +
                --text_prompt "a lovely rabbit" \
         | 
| 156 | 
            +
                --save_folder ./outputs/test/ \
         | 
| 157 | 
            +
                --max_faces_num 90000 \
         | 
| 158 | 
            +
                --do_texture_mapping \
         | 
| 159 | 
            +
                --do_render
         | 
| 160 | 
            +
            ```
         | 
| 161 | 
            +
             | 
| 162 | 
            +
            对于图生3D,请使用以下命令进行本地推理:
         | 
| 163 | 
            +
            ```python
         | 
| 164 | 
            +
            python3 main.py \
         | 
| 165 | 
            +
                --image_prompt "/path/to/your/image" \
         | 
| 166 | 
            +
                --save_folder ./outputs/test/ \
         | 
| 167 | 
            +
                --max_faces_num 90000 \
         | 
| 168 | 
            +
                --do_texture_mapping \
         | 
| 169 | 
            +
                --do_render
         | 
| 170 | 
            +
            ```
         | 
| 171 | 
            +
            更多参数详解:
         | 
| 172 | 
            +
             | 
| 173 | 
            +
            |    Argument        |  Default  |                     Description                     |
         | 
| 174 | 
            +
            |:------------------:|:---------:|:---------------------------------------------------:|
         | 
| 175 | 
            +
            |`--text_prompt`  |   None    |The text prompt for 3D generation         |
         | 
| 176 | 
            +
            |`--image_prompt` |   None    |The image prompt for 3D generation         |
         | 
| 177 | 
            +
            |`--t2i_seed`     |    0      |The random seed for generating images        |
         | 
| 178 | 
            +
            |`--t2i_steps`    |    25     |The number of steps for sampling of text to image  |
         | 
| 179 | 
            +
            |`--gen_seed`     |    0      |The random seed for generating 3d generation        |
         | 
| 180 | 
            +
            |`--gen_steps`    |    50     |The number of steps for sampling of 3d generation  |
         | 
| 181 | 
            +
            |`--max_faces_numm` | 90000  |The limit number of faces of 3d mesh |
         | 
| 182 | 
            +
            |`--save_memory`   | False   |module will move to cpu automatically|
         | 
| 183 | 
            +
            |`--do_texture_mapping` |   False    |Change vertex shadding to texture shading  |
         | 
| 184 | 
            +
            |`--do_render`  |   False   |render gif   |
         | 
| 185 | 
            +
             | 
| 186 | 
            +
             | 
| 187 | 
            +
            如果显卡内存有限,可以使用`--save_memory`命令,最低显卡内存要求如下:
         | 
| 188 | 
            +
            - Inference Std-pipeline requires 30GB VRAM (24G VRAM with --save_memory).
         | 
| 189 | 
            +
            - Inference Lite-pipeline requires 22GB VRAM (18G VRAM with --save_memory).
         | 
| 190 | 
            +
            - Note: --save_memory will increase inference time
         | 
| 191 | 
            +
             | 
| 192 | 
            +
            ```bash
         | 
| 193 | 
            +
            bash scripts/text_to_3d_std.sh 
         | 
| 194 | 
            +
            bash scripts/text_to_3d_lite.sh 
         | 
| 195 | 
            +
            bash scripts/image_to_3d_std.sh 
         | 
| 196 | 
            +
            bash scripts/image_to_3d_lite.sh 
         | 
| 197 | 
            +
            ```
         | 
| 198 | 
            +
             | 
| 199 | 
            +
            如果你的显卡内存为16G,可以分别加载模型到显卡:
         | 
| 200 | 
            +
            ```bash
         | 
| 201 | 
            +
            bash scripts/text_to_3d_std_separately.sh 'a lovely rabbit' ./outputs/test # >= 16G
         | 
| 202 | 
            +
            bash scripts/text_to_3d_lite_separately.sh 'a lovely rabbit' ./outputs/test # >= 14G
         | 
| 203 | 
            +
            bash scripts/image_to_3d_std_separately.sh ./demos/example_000.png ./outputs/test  # >= 16G
         | 
| 204 | 
            +
            bash scripts/image_to_3d_lite_separately.sh ./demos/example_000.png ./outputs/test # >= 10G
         | 
| 205 | 
            +
            ```
         | 
| 206 | 
            +
             | 
| 207 | 
            +
            #### Gradio界面部署
         | 
| 208 | 
            +
             | 
| 209 | 
            +
            我们分别提供轻量版和标准版界面:
         | 
| 210 | 
            +
             | 
| 211 | 
            +
            ```shell
         | 
| 212 | 
            +
            # std 
         | 
| 213 | 
            +
            python3 app.py
         | 
| 214 | 
            +
            python3 app.py --save_memory
         | 
| 215 | 
            +
             | 
| 216 | 
            +
            # lite
         | 
| 217 | 
            +
            python3 app.py --use_lite
         | 
| 218 | 
            +
            python3 app.py --use_lite --save_memory
         | 
| 219 | 
            +
            ```
         | 
| 220 | 
            +
             | 
| 221 | 
            +
            Gradio界面体验地址为 http://0.0.0.0:8080. 这里 0.0.0.0 应当填写运行模型的机器IP地址。
         | 
| 222 | 
            +
             | 
| 223 | 
            +
            ## 相机参数
         | 
| 224 | 
            +
             | 
| 225 | 
            +
            生成多视图视角固定为
         | 
| 226 | 
            +
             | 
| 227 | 
            +
            + Azimuth (relative to input view): `+0, +60, +120, +180, +240, +300`.
         | 
| 228 | 
            +
             | 
| 229 | 
            +
             | 
| 230 | 
            +
            ## 引用
         | 
| 231 | 
            +
             | 
| 232 | 
            +
            如果我们的仓库对您有帮助,请引用我们的工作
         | 
| 233 | 
            +
            ```bibtex
         | 
| 234 | 
            +
            @misc{yang2024tencent,
         | 
| 235 | 
            +
                title={Tencent Hunyuan3D-1.0: A Unified Framework for Text-to-3D and Image-to-3D Generation},
         | 
| 236 | 
            +
                author={Xianghui Yang and Huiwen Shi and Bowen Zhang and Fan Yang and Jiacheng Wang and Hongxu Zhao and Xinhai Liu and Xinzhou Wang and Qingxiang Lin and Jiaao Yu and Lifu Wang and Zhuo Chen and Sicong Liu and Yuhong Liu and Yong Yang and Di Wang and Jie Jiang and Chunchao Guo},
         | 
| 237 | 
            +
                year={2024},
         | 
| 238 | 
            +
                eprint={2411.02293},
         | 
| 239 | 
            +
                archivePrefix={arXiv},
         | 
| 240 | 
            +
                primaryClass={cs.CV}
         | 
| 241 | 
            +
            }
         | 
| 242 | 
            +
            ```
         | 
    	
        app.py
    CHANGED
    
    | @@ -32,9 +32,21 @@ import torch | |
| 32 | 
             
            import numpy as np
         | 
| 33 | 
             
            from PIL import Image
         | 
| 34 | 
             
            from einops import rearrange
         | 
|  | |
| 35 |  | 
| 36 | 
             
            from infer import seed_everything, save_gif
         | 
| 37 | 
             
            from infer import Text2Image, Removebg, Image2Views, Views2Mesh, GifRenderer
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 38 |  | 
| 39 | 
             
            warnings.simplefilter('ignore', category=UserWarning)
         | 
| 40 | 
             
            warnings.simplefilter('ignore', category=FutureWarning)
         | 
| @@ -58,33 +70,19 @@ CONST_MAX_QUEUE = 1 | |
| 58 | 
             
            CONST_SERVER = '0.0.0.0'
         | 
| 59 |  | 
| 60 | 
             
            CONST_HEADER = '''
         | 
| 61 | 
            -
            <h2>< | 
| 62 | 
            -
             | 
| 63 | 
            -
            Code: <a href='https://github.com/tencent/Hunyuan3D-1' target='_blank'>GitHub</a>. Techenical report: <a href='https://arxiv.org/abs/placeholder' target='_blank'>ArXiv</a>.
         | 
| 64 | 
            -
             | 
| 65 | 
            -
            ❗️❗️❗️**Important Notes:**
         | 
| 66 | 
            -
            - By default, our demo can export a .obj mesh with vertex colors or a .glb mesh.
         | 
| 67 | 
            -
            - If you select "texture mapping," it will export a .obj mesh with a texture map or a .glb mesh.
         | 
| 68 | 
            -
            - If you select "render GIF," it will export a GIF image rendering of the .glb file.
         | 
| 69 | 
            -
            - If the result is unsatisfactory, please try a different seed value (Default: 0).
         | 
| 70 | 
             
            '''
         | 
| 71 |  | 
| 72 | 
            -
             | 
| 73 | 
            -
             | 
| 74 | 
            -
             | 
| 75 | 
            -
             | 
| 76 | 
            -
             | 
| 77 | 
            -
             | 
| 78 | 
            -
             | 
| 79 | 
            -
             | 
| 80 | 
            -
             | 
| 81 | 
            -
                year={2024},
         | 
| 82 | 
            -
                eprint={2411.02293},
         | 
| 83 | 
            -
                archivePrefix={arXiv},
         | 
| 84 | 
            -
                primaryClass={cs.CV}
         | 
| 85 | 
            -
            }
         | 
| 86 | 
            -
            ```
         | 
| 87 | 
            -
            """
         | 
| 88 |  | 
| 89 | 
             
            ################################################################
         | 
| 90 | 
             
            # prepare text examples and image examples
         | 
| @@ -129,6 +127,13 @@ worker_v23 = Views2Mesh( | |
| 129 | 
             
            )
         | 
| 130 | 
             
            worker_gif = GifRenderer(args.device)
         | 
| 131 |  | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 132 | 
             
            def stage_0_t2i(text, image, seed, step):
         | 
| 133 | 
             
                os.makedirs('./outputs/app_output', exist_ok=True)
         | 
| 134 | 
             
                exists = set(int(_) for _ in os.listdir('./outputs/app_output') if not _.startswith("."))
         | 
| @@ -153,11 +158,11 @@ def stage_0_t2i(text, image, seed, step): | |
| 153 | 
             
                dst = worker_xbg(image, save_folder)
         | 
| 154 | 
             
                return dst, save_folder
         | 
| 155 |  | 
| 156 | 
            -
            def stage_1_xbg(image, save_folder): 
         | 
| 157 | 
             
                if isinstance(image, str):
         | 
| 158 | 
             
                    image = Image.open(image)
         | 
| 159 | 
             
                dst =  save_folder + '/img_nobg.png'
         | 
| 160 | 
            -
                rgba = worker_xbg(image)
         | 
| 161 | 
             
                rgba.save(dst)
         | 
| 162 | 
             
                return dst
         | 
| 163 |  | 
| @@ -181,12 +186,9 @@ def stage_3_v23( | |
| 181 | 
             
                seed, 
         | 
| 182 | 
             
                save_folder,
         | 
| 183 | 
             
                target_face_count = 30000,
         | 
| 184 | 
            -
                 | 
| 185 | 
            -
                do_render =True
         | 
| 186 | 
             
            ): 
         | 
| 187 | 
            -
                do_texture_mapping =  | 
| 188 | 
            -
                obj_dst = save_folder + '/mesh_with_colors.obj'
         | 
| 189 | 
            -
                glb_dst = save_folder + '/mesh.glb'
         | 
| 190 | 
             
                worker_v23(
         | 
| 191 | 
             
                    views_pil, 
         | 
| 192 | 
             
                    cond_pil, 
         | 
| @@ -195,149 +197,268 @@ def stage_3_v23( | |
| 195 | 
             
                    target_face_count = target_face_count,
         | 
| 196 | 
             
                    do_texture_mapping = do_texture_mapping
         | 
| 197 | 
             
                )
         | 
|  | |
|  | |
|  | |
| 198 | 
             
                return obj_dst, glb_dst
         | 
| 199 |  | 
| 200 | 
            -
            def  | 
| 201 | 
            -
                if  | 
| 202 | 
            -
             | 
| 203 | 
            -
             | 
| 204 | 
            -
                     | 
| 205 | 
            -
             | 
| 206 | 
            -
             | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 207 | 
             
                return gif_dst
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 208 | 
             
            # ===============================================================
         | 
| 209 | 
             
            # gradio display
         | 
| 210 | 
             
            # ===============================================================
         | 
|  | |
| 211 | 
             
            with gr.Blocks() as demo:
         | 
| 212 | 
             
                gr.Markdown(CONST_HEADER)
         | 
| 213 | 
             
                with gr.Row(variant="panel"):
         | 
|  | |
|  | |
|  | |
| 214 | 
             
                    with gr.Column(scale=2):
         | 
|  | |
|  | |
|  | |
| 215 | 
             
                        with gr.Tab("Text to 3D"):
         | 
| 216 | 
             
                            with gr.Column():
         | 
| 217 | 
            -
                                text = gr.TextArea('一只黑白相间的熊猫在白色背景上居中坐着,呈现出卡通风格和可爱氛围。',  | 
|  | |
| 218 | 
             
                                with gr.Row():
         | 
| 219 | 
            -
                                     | 
| 220 | 
            -
             | 
| 221 | 
            -
                                     | 
| 222 | 
            -
                                     | 
| 223 | 
            -
             | 
| 224 | 
            -
                                    
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 225 | 
             
                                with gr.Row():
         | 
| 226 | 
            -
                                    textgen_do_texture_mapping = gr.Checkbox(label="texture mapping", value=False, interactive=True)
         | 
| 227 | 
            -
                                    textgen_do_render_gif = gr.Checkbox(label="Render gif", value=False, interactive=True)
         | 
| 228 | 
             
                                    textgen_submit = gr.Button("Generate", variant="primary")
         | 
| 229 |  | 
| 230 | 
             
                                with gr.Row():
         | 
| 231 | 
            -
                                    gr.Examples(examples=example_ts, inputs=[text], label=" | 
| 232 |  | 
|  | |
|  | |
| 233 | 
             
                        with gr.Tab("Image to 3D"):
         | 
| 234 | 
            -
                            with gr. | 
| 235 | 
            -
                                input_image = gr.Image(label="Input image",
         | 
| 236 | 
            -
                                                        | 
| 237 | 
            -
             | 
| 238 | 
            -
             | 
| 239 | 
            -
             | 
| 240 | 
            -
             | 
| 241 | 
            -
             | 
| 242 | 
            -
             | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 243 |  | 
| 244 | 
            -
                                with gr.Row():
         | 
| 245 | 
            -
                                    imggen_do_texture_mapping = gr.Checkbox(label="texture mapping", value=False, interactive=True)
         | 
| 246 | 
            -
                                    imggen_do_render_gif = gr.Checkbox(label="Render gif", value=False, interactive=True)
         | 
| 247 | 
            -
                                    imggen_submit = gr.Button("Generate", variant="primary")       
         | 
| 248 | 
            -
                                with gr.Row():
         | 
| 249 | 
            -
                                    gr.Examples(
         | 
| 250 | 
            -
                                        examples=example_is, 
         | 
| 251 | 
            -
                                        inputs=[input_image], 
         | 
| 252 | 
            -
                                        label="Img examples",
         | 
| 253 | 
            -
                                        examples_per_page=10
         | 
| 254 | 
            -
                                    )
         | 
| 255 | 
            -
                       
         | 
| 256 | 
             
                    with gr.Column(scale=3):
         | 
| 257 | 
             
                        with gr.Row():
         | 
| 258 | 
             
                            with gr.Column(scale=2):
         | 
| 259 | 
            -
                                rem_bg_image = gr.Image( | 
| 260 | 
            -
             | 
|  | |
|  | |
|  | |
|  | |
| 261 | 
             
                            with gr.Column(scale=3):
         | 
| 262 | 
            -
                                result_image = gr.Image( | 
| 263 | 
            -
             | 
| 264 | 
            -
             | 
|  | |
|  | |
|  | |
|  | |
| 265 | 
             
                            result_3dobj = gr.Model3D(
         | 
| 266 | 
             
                                clear_color=[0.0, 0.0, 0.0, 0.0],
         | 
| 267 | 
            -
                                label=" | 
| 268 | 
             
                                show_label=True,
         | 
| 269 | 
             
                                visible=True,
         | 
| 270 | 
             
                                camera_position=[90, 90, None],
         | 
| 271 | 
             
                                interactive=False
         | 
| 272 | 
             
                            )
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 273 |  | 
| 274 | 
            -
                             | 
| 275 | 
             
                                clear_color=[0.0, 0.0, 0.0, 0.0],
         | 
| 276 | 
            -
                                label=" | 
| 277 | 
             
                                show_label=True,
         | 
| 278 | 
             
                                visible=True,
         | 
| 279 | 
             
                                camera_position=[90, 90, None],
         | 
| 280 | 
            -
                                interactive=False
         | 
| 281 | 
            -
                            )
         | 
| 282 | 
            -
                            result_gif = gr.Image(label="Rendered GIF", interactive=False)
         | 
| 283 |  | 
| 284 | 
            -
                        with gr.Row(): | 
| 285 | 
            -
                            gr.Markdown( | 
| 286 | 
            -
             | 
| 287 | 
            -
             | 
| 288 | 
            -
             | 
| 289 | 
            -
             | 
| 290 | 
            -
             | 
| 291 | 
            -
            #===============================================================
         | 
| 292 | 
            -
            # gradio running code
         | 
| 293 | 
            -
            #===============================================================
         | 
| 294 |  | 
|  | |
|  | |
|  | |
|  | |
| 295 | 
             
                none = gr.State(None)
         | 
| 296 | 
             
                save_folder = gr.State()
         | 
| 297 | 
             
                cond_image = gr.State()
         | 
| 298 | 
             
                views_image = gr.State()
         | 
| 299 | 
             
                text_image = gr.State()
         | 
| 300 |  | 
|  | |
| 301 | 
             
                textgen_submit.click(
         | 
| 302 | 
            -
                    fn=stage_0_t2i,  | 
|  | |
| 303 | 
             
                    outputs=[rem_bg_image, save_folder],
         | 
| 304 | 
             
                ).success(
         | 
| 305 | 
            -
                    fn=stage_2_i2v,  | 
|  | |
| 306 | 
             
                    outputs=[views_image, cond_image, result_image],
         | 
| 307 | 
             
                ).success(
         | 
| 308 | 
            -
                    fn=stage_3_v23,  | 
| 309 | 
            -
             | 
| 310 | 
            -
             | 
| 311 | 
            -
                    outputs=[result_3dobj, result_3dglb],
         | 
| 312 | 
             
                ).success(
         | 
| 313 | 
            -
                    fn= | 
|  | |
|  | |
|  | |
|  | |
|  | |
| 314 | 
             
                    outputs=[result_gif],
         | 
| 315 | 
             
                ).success(lambda: print('Text_to_3D Done ...'))
         | 
| 316 |  | 
|  | |
| 317 | 
             
                imggen_submit.click(
         | 
| 318 | 
            -
                    fn=stage_0_t2i,  | 
|  | |
| 319 | 
             
                    outputs=[text_image, save_folder],
         | 
| 320 | 
             
                ).success(
         | 
| 321 | 
            -
                    fn=stage_1_xbg,  | 
|  | |
| 322 | 
             
                    outputs=[rem_bg_image],
         | 
| 323 | 
             
                ).success(
         | 
| 324 | 
            -
                    fn=stage_2_i2v,  | 
|  | |
| 325 | 
             
                    outputs=[views_image, cond_image, result_image],
         | 
| 326 | 
             
                ).success(
         | 
| 327 | 
            -
                    fn=stage_3_v23,  | 
| 328 | 
            -
             | 
| 329 | 
            -
             | 
| 330 | 
            -
             | 
|  | |
|  | |
|  | |
| 331 | 
             
                ).success(
         | 
| 332 | 
            -
                    fn=stage_4_gif,  | 
|  | |
| 333 | 
             
                    outputs=[result_gif],
         | 
| 334 | 
             
                ).success(lambda: print('Image_to_3D Done ...'))
         | 
| 335 |  | 
| 336 | 
            -
            #===============================================================
         | 
| 337 | 
            -
            # start gradio server
         | 
| 338 | 
            -
            #===============================================================
         | 
| 339 |  | 
| 340 | 
            -
                gr.Markdown(CONST_CITATION)
         | 
| 341 | 
             
                demo.queue(max_size=CONST_MAX_QUEUE)
         | 
| 342 | 
             
                demo.launch(server_name=CONST_SERVER, server_port=CONST_PORT)
         | 
| 343 |  | 
|  | |
| 32 | 
             
            import numpy as np
         | 
| 33 | 
             
            from PIL import Image
         | 
| 34 | 
             
            from einops import rearrange
         | 
| 35 | 
            +
            import pandas as pd
         | 
| 36 |  | 
| 37 | 
             
            from infer import seed_everything, save_gif
         | 
| 38 | 
             
            from infer import Text2Image, Removebg, Image2Views, Views2Mesh, GifRenderer
         | 
| 39 | 
            +
            from third_party.check import check_bake_available
         | 
| 40 | 
            +
             | 
| 41 | 
            +
            try:
         | 
| 42 | 
            +
                from third_party.mesh_baker import MeshBaker
         | 
| 43 | 
            +
                BAKE_AVAILEBLE = True
         | 
| 44 | 
            +
            except Exception as err:
         | 
| 45 | 
            +
                print(err)
         | 
| 46 | 
            +
                print("import baking related fail, run without baking")
         | 
| 47 | 
            +
                check_bake_available()
         | 
| 48 | 
            +
                BAKE_AVAILEBLE = False
         | 
| 49 | 
            +
             | 
| 50 |  | 
| 51 | 
             
            warnings.simplefilter('ignore', category=UserWarning)
         | 
| 52 | 
             
            warnings.simplefilter('ignore', category=FutureWarning)
         | 
|  | |
| 70 | 
             
            CONST_SERVER = '0.0.0.0'
         | 
| 71 |  | 
| 72 | 
             
            CONST_HEADER = '''
         | 
| 73 | 
            +
            <h2><a href='https://github.com/tencent/Hunyuan3D-1' target='_blank'><b>Tencent Hunyuan3D-1.0: A Unified Framework for Text-to-3D and Image-to-3D Generation</b></a></h2>
         | 
| 74 | 
            +
            ⭐️Technical report: <a href='https://arxiv.org/pdf/2411.02293' target='_blank'>ArXiv</a>. ⭐️Code: <a href='https://github.com/tencent/Hunyuan3D-1' target='_blank'>GitHub</a>.
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 75 | 
             
            '''
         | 
| 76 |  | 
| 77 | 
            +
            CONST_NOTE = '''
         | 
| 78 | 
            +
            ❗️❗️❗️Usage❗️❗️❗️<br>
         | 
| 79 | 
            +
             | 
| 80 | 
            +
            Limited by format, the model can only export *.obj mesh with vertex colors. The "texture" mod can only work on *.glb.<br>
         | 
| 81 | 
            +
            Please click "Do Rendering" to export a GIF.<br>
         | 
| 82 | 
            +
            You can click "Do Baking" to bake multi-view imgaes onto the shape.<br>
         | 
| 83 | 
            +
             | 
| 84 | 
            +
            If the results aren't satisfactory, please try a different radnom seed (default is 0).
         | 
| 85 | 
            +
            '''
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 86 |  | 
| 87 | 
             
            ################################################################
         | 
| 88 | 
             
            # prepare text examples and image examples
         | 
|  | |
| 127 | 
             
            )
         | 
| 128 | 
             
            worker_gif = GifRenderer(args.device)
         | 
| 129 |  | 
| 130 | 
            +
             | 
| 131 | 
            +
            if BAKE_AVAILEBLE:
         | 
| 132 | 
            +
                worker_baker = MeshBaker()
         | 
| 133 | 
            +
             | 
| 134 | 
            +
             | 
| 135 | 
            +
            ### functional modules    
         | 
| 136 | 
            +
             | 
| 137 | 
             
            def stage_0_t2i(text, image, seed, step):
         | 
| 138 | 
             
                os.makedirs('./outputs/app_output', exist_ok=True)
         | 
| 139 | 
             
                exists = set(int(_) for _ in os.listdir('./outputs/app_output') if not _.startswith("."))
         | 
|  | |
| 158 | 
             
                dst = worker_xbg(image, save_folder)
         | 
| 159 | 
             
                return dst, save_folder
         | 
| 160 |  | 
| 161 | 
            +
            def stage_1_xbg(image, save_folder, force_remove): 
         | 
| 162 | 
             
                if isinstance(image, str):
         | 
| 163 | 
             
                    image = Image.open(image)
         | 
| 164 | 
             
                dst =  save_folder + '/img_nobg.png'
         | 
| 165 | 
            +
                rgba = worker_xbg(image, force=force_remove)
         | 
| 166 | 
             
                rgba.save(dst)
         | 
| 167 | 
             
                return dst
         | 
| 168 |  | 
|  | |
| 186 | 
             
                seed, 
         | 
| 187 | 
             
                save_folder,
         | 
| 188 | 
             
                target_face_count = 30000,
         | 
| 189 | 
            +
                texture_color = 'texture'
         | 
|  | |
| 190 | 
             
            ): 
         | 
| 191 | 
            +
                do_texture_mapping = texture_color == 'texture'
         | 
|  | |
|  | |
| 192 | 
             
                worker_v23(
         | 
| 193 | 
             
                    views_pil, 
         | 
| 194 | 
             
                    cond_pil, 
         | 
|  | |
| 197 | 
             
                    target_face_count = target_face_count,
         | 
| 198 | 
             
                    do_texture_mapping = do_texture_mapping
         | 
| 199 | 
             
                )
         | 
| 200 | 
            +
                glb_dst = save_folder + '/mesh.glb' if do_texture_mapping else None
         | 
| 201 | 
            +
                obj_dst =  save_folder + '/mesh.obj'
         | 
| 202 | 
            +
                obj_dst = save_folder + '/mesh_vertex_colors.obj' # gradio just only can show vertex shading
         | 
| 203 | 
             
                return obj_dst, glb_dst
         | 
| 204 |  | 
| 205 | 
            +
            def stage_3p_baking(save_folder, color, bake):
         | 
| 206 | 
            +
                if color == "texture" and bake:
         | 
| 207 | 
            +
                    obj_dst = worker_baker(save_folder)
         | 
| 208 | 
            +
                    glb_dst = obj_dst.replace(".obj", ".glb")
         | 
| 209 | 
            +
                    return glb_dst
         | 
| 210 | 
            +
                else:
         | 
| 211 | 
            +
                    return None
         | 
| 212 | 
            +
             | 
| 213 | 
            +
            def stage_4_gif(save_folder, color, bake, render):
         | 
| 214 | 
            +
                if not render: return None
         | 
| 215 | 
            +
                if os.path.exists(save_folder + '/view_1/bake/mesh.obj'):
         | 
| 216 | 
            +
                    obj_dst = save_folder + '/view_1/bake/mesh.obj'
         | 
| 217 | 
            +
                elif os.path.exists(save_folder + '/view_0/bake/mesh.obj'):
         | 
| 218 | 
            +
                    obj_dst = save_folder + '/view_0/bake/mesh.obj'
         | 
| 219 | 
            +
                elif os.path.exists(save_folder + '/mesh.obj'):
         | 
| 220 | 
            +
                    obj_dst = save_folder + '/mesh.obj'
         | 
| 221 | 
            +
                else:
         | 
| 222 | 
            +
                    print(save_folder)
         | 
| 223 | 
            +
                    raise FileNotFoundError("mesh obj file not found")
         | 
| 224 | 
            +
                gif_dst = obj_dst.replace(".obj", ".gif")
         | 
| 225 | 
            +
                worker_gif(obj_dst, gif_dst_path=gif_dst)
         | 
| 226 | 
             
                return gif_dst
         | 
| 227 | 
            +
             | 
| 228 | 
            +
             | 
| 229 | 
            +
            def check_image_available(image):
         | 
| 230 | 
            +
                if image.mode == "RGBA":
         | 
| 231 | 
            +
                    data = np.array(image)
         | 
| 232 | 
            +
                    alpha_channel = data[:, :, 3]
         | 
| 233 | 
            +
                    unique_alpha_values = np.unique(alpha_channel)
         | 
| 234 | 
            +
                    if len(unique_alpha_values) == 1:
         | 
| 235 | 
            +
                        msg = "The alpha channel is missing or invalid. The background removal option is selected for you."
         | 
| 236 | 
            +
                        return msg, gr.update(value=True, interactive=False)
         | 
| 237 | 
            +
                    else:
         | 
| 238 | 
            +
                        msg = "The image has four channels, and you can choose to remove the background or not."
         | 
| 239 | 
            +
                        return msg, gr.update(value=False, interactive=True)
         | 
| 240 | 
            +
                elif image.mode == "RGB":
         | 
| 241 | 
            +
                    msg = "The alpha channel is missing or invalid. The background removal option is selected for you."
         | 
| 242 | 
            +
                    return msg, gr.update(value=True, interactive=False)
         | 
| 243 | 
            +
                else:
         | 
| 244 | 
            +
                    raise Exception("Image Error")
         | 
| 245 | 
            +
                
         | 
| 246 | 
            +
            def update_bake_render(color):
         | 
| 247 | 
            +
                if color == "vertex":
         | 
| 248 | 
            +
                    return gr.update(value=False, interactive=False), gr.update(value=False, interactive=False)
         | 
| 249 | 
            +
                else:
         | 
| 250 | 
            +
                    return gr.update(interactive=True), gr.update(interactive=True)
         | 
| 251 | 
            +
                
         | 
| 252 | 
             
            # ===============================================================
         | 
| 253 | 
             
            # gradio display
         | 
| 254 | 
             
            # ===============================================================
         | 
| 255 | 
            +
             | 
| 256 | 
             
            with gr.Blocks() as demo:
         | 
| 257 | 
             
                gr.Markdown(CONST_HEADER)
         | 
| 258 | 
             
                with gr.Row(variant="panel"):
         | 
| 259 | 
            +
                    
         | 
| 260 | 
            +
                    ###### Input region
         | 
| 261 | 
            +
                    
         | 
| 262 | 
             
                    with gr.Column(scale=2):
         | 
| 263 | 
            +
                        
         | 
| 264 | 
            +
                        ### Text iutput region
         | 
| 265 | 
            +
                        
         | 
| 266 | 
             
                        with gr.Tab("Text to 3D"):
         | 
| 267 | 
             
                            with gr.Column():
         | 
| 268 | 
            +
                                text = gr.TextArea('一只黑白相间的熊猫在白色背景上居中坐着,呈现出卡通风格和可爱氛围。', 
         | 
| 269 | 
            +
                                                   lines=3, max_lines=20, label='Input text')
         | 
| 270 | 
             
                                with gr.Row():
         | 
| 271 | 
            +
                                    textgen_color = gr.Radio(choices=["vertex", "texture"], label="Color", value="texture")
         | 
| 272 | 
            +
                                with gr.Row():
         | 
| 273 | 
            +
                                    textgen_render = gr.Checkbox(label="Do Rendering", value=True, interactive=True)
         | 
| 274 | 
            +
                                    if BAKE_AVAILEBLE:
         | 
| 275 | 
            +
                                        textgen_bake = gr.Checkbox(label="Do Baking", value=True, interactive=True)
         | 
| 276 | 
            +
                                    else:
         | 
| 277 | 
            +
                                        textgen_bake = gr.Checkbox(label="Do Baking", value=False, interactive=False)
         | 
| 278 | 
            +
                                
         | 
| 279 | 
            +
                                textgen_color.change(
         | 
| 280 | 
            +
                                    fn=update_bake_render, 
         | 
| 281 | 
            +
                                    inputs=textgen_color, 
         | 
| 282 | 
            +
                                    outputs=[textgen_bake, textgen_render]
         | 
| 283 | 
            +
                                )
         | 
| 284 | 
            +
             | 
| 285 | 
            +
                                with gr.Row():
         | 
| 286 | 
            +
                                    textgen_seed = gr.Number(value=0, label="T2I seed", precision=0, interactive=True)
         | 
| 287 | 
            +
                                    textgen_step = gr.Number(value=25, label="T2I steps", precision=0, 
         | 
| 288 | 
            +
                                                             minimum=10, maximum=50, interactive=True)
         | 
| 289 | 
            +
                                    textgen_SEED = gr.Number(value=0, label="Gen seed", precision=0, interactive=True)
         | 
| 290 | 
            +
                                    textgen_STEP = gr.Number(value=50, label="Gen steps", precision=0, 
         | 
| 291 | 
            +
                                                             minimum=40, maximum=100, interactive=True)
         | 
| 292 | 
            +
                                    textgen_max_faces = gr.Number(value=90000, label="Face number", precision=0, 
         | 
| 293 | 
            +
                                                                  minimum=5000, maximum=1000000, interactive=True)
         | 
| 294 | 
             
                                with gr.Row():
         | 
|  | |
|  | |
| 295 | 
             
                                    textgen_submit = gr.Button("Generate", variant="primary")
         | 
| 296 |  | 
| 297 | 
             
                                with gr.Row():
         | 
| 298 | 
            +
                                    gr.Examples(examples=example_ts, inputs=[text], label="Text examples", examples_per_page=10)
         | 
| 299 |  | 
| 300 | 
            +
                        ### Image iutput region
         | 
| 301 | 
            +
                        
         | 
| 302 | 
             
                        with gr.Tab("Image to 3D"):
         | 
| 303 | 
            +
                            with gr.Row():
         | 
| 304 | 
            +
                                input_image = gr.Image(label="Input image", width=256, height=256, type="pil",
         | 
| 305 | 
            +
                                                       image_mode="RGBA", sources="upload", interactive=True)
         | 
| 306 | 
            +
                            with gr.Row():
         | 
| 307 | 
            +
                                alert_message = gr.Markdown("")  # for warning 
         | 
| 308 | 
            +
                            with gr.Row():
         | 
| 309 | 
            +
                                imggen_color = gr.Radio(choices=["vertex", "texture"], label="Color", value="texture")
         | 
| 310 | 
            +
                            with gr.Row():
         | 
| 311 | 
            +
                                imggen_removebg = gr.Checkbox(label="Remove Background", value=True, interactive=True)
         | 
| 312 | 
            +
                                imggen_render = gr.Checkbox(label="Do Rendering", value=True, interactive=True)
         | 
| 313 | 
            +
                                if BAKE_AVAILEBLE:
         | 
| 314 | 
            +
                                    imggen_bake = gr.Checkbox(label="Do Baking", value=True, interactive=True)
         | 
| 315 | 
            +
                                else:
         | 
| 316 | 
            +
                                    imggen_bake = gr.Checkbox(label="Do Baking", value=False, interactive=False)
         | 
| 317 | 
            +
                                
         | 
| 318 | 
            +
                            input_image.change(
         | 
| 319 | 
            +
                                fn=check_image_available, 
         | 
| 320 | 
            +
                                inputs=input_image, 
         | 
| 321 | 
            +
                                outputs=[alert_message, imggen_removebg]
         | 
| 322 | 
            +
                            )
         | 
| 323 | 
            +
                            imggen_color.change(
         | 
| 324 | 
            +
                                fn=update_bake_render, 
         | 
| 325 | 
            +
                                inputs=imggen_color, 
         | 
| 326 | 
            +
                                outputs=[imggen_bake, imggen_render]
         | 
| 327 | 
            +
                            )
         | 
| 328 | 
            +
                                    
         | 
| 329 | 
            +
                            with gr.Row():
         | 
| 330 | 
            +
                                imggen_SEED = gr.Number(value=0, label="Gen seed", precision=0, interactive=True)
         | 
| 331 | 
            +
                                imggen_STEP = gr.Number(value=50, label="Gen steps", precision=0, 
         | 
| 332 | 
            +
                                                        minimum=40, maximum=100, interactive=True)
         | 
| 333 | 
            +
                                imggen_max_faces = gr.Number(value=90000, label="Face number", precision=0, 
         | 
| 334 | 
            +
                                                                 minimum=5000, maximum=1000000, interactive=True)
         | 
| 335 | 
            +
                            with gr.Row():
         | 
| 336 | 
            +
                                imggen_submit = gr.Button("Generate", variant="primary")      
         | 
| 337 | 
            +
             | 
| 338 | 
            +
                            with gr.Row():
         | 
| 339 | 
            +
                                gr.Examples(examples=example_is,  inputs=[input_image], 
         | 
| 340 | 
            +
                                    label="Img examples", examples_per_page=10)
         | 
| 341 | 
            +
                                
         | 
| 342 | 
            +
                        gr.Markdown(CONST_NOTE)
         | 
| 343 | 
            +
                                
         | 
| 344 | 
            +
                    ###### Output region
         | 
| 345 |  | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 346 | 
             
                    with gr.Column(scale=3):
         | 
| 347 | 
             
                        with gr.Row():
         | 
| 348 | 
             
                            with gr.Column(scale=2):
         | 
| 349 | 
            +
                                rem_bg_image = gr.Image(
         | 
| 350 | 
            +
                                    label="Image without background", 
         | 
| 351 | 
            +
                                    type="pil",
         | 
| 352 | 
            +
                                    image_mode="RGBA", 
         | 
| 353 | 
            +
                                    interactive=False
         | 
| 354 | 
            +
                                )
         | 
| 355 | 
             
                            with gr.Column(scale=3):
         | 
| 356 | 
            +
                                result_image = gr.Image(
         | 
| 357 | 
            +
                                    label="Multi-view images", 
         | 
| 358 | 
            +
                                    type="pil", 
         | 
| 359 | 
            +
                                    interactive=False
         | 
| 360 | 
            +
                                )
         | 
| 361 | 
            +
                        
         | 
| 362 | 
            +
                        with gr.Row():          
         | 
| 363 | 
             
                            result_3dobj = gr.Model3D(
         | 
| 364 | 
             
                                clear_color=[0.0, 0.0, 0.0, 0.0],
         | 
| 365 | 
            +
                                label="OBJ vertex color",
         | 
| 366 | 
             
                                show_label=True,
         | 
| 367 | 
             
                                visible=True,
         | 
| 368 | 
             
                                camera_position=[90, 90, None],
         | 
| 369 | 
             
                                interactive=False
         | 
| 370 | 
             
                            )
         | 
| 371 | 
            +
                            result_gif = gr.Image(label="GIF", interactive=False)
         | 
| 372 | 
            +
                            
         | 
| 373 | 
            +
                        with gr.Row():
         | 
| 374 | 
            +
                            result_3dglb_texture = gr.Model3D(
         | 
| 375 | 
            +
                                clear_color=[0.0, 0.0, 0.0, 0.0],
         | 
| 376 | 
            +
                                label="GLB texture color",
         | 
| 377 | 
            +
                                show_label=True,
         | 
| 378 | 
            +
                                visible=True,
         | 
| 379 | 
            +
                                camera_position=[90, 90, None],
         | 
| 380 | 
            +
                                interactive=False)
         | 
| 381 |  | 
| 382 | 
            +
                            result_3dglb_baked = gr.Model3D(
         | 
| 383 | 
             
                                clear_color=[0.0, 0.0, 0.0, 0.0],
         | 
| 384 | 
            +
                                label="GLB baked color",
         | 
| 385 | 
             
                                show_label=True,
         | 
| 386 | 
             
                                visible=True,
         | 
| 387 | 
             
                                camera_position=[90, 90, None],
         | 
| 388 | 
            +
                                interactive=False)
         | 
|  | |
|  | |
| 389 |  | 
| 390 | 
            +
                        with gr.Row():
         | 
| 391 | 
            +
                            gr.Markdown(
         | 
| 392 | 
            +
                                "Due to Gradio limitations, OBJ files are displayed with vertex shading only, "
         | 
| 393 | 
            +
                                "while GLB files can be viewed with texture shading. <br>For the best experience, "
         | 
| 394 | 
            +
                                "we recommend downloading the GLB files and opening them with 3D software "
         | 
| 395 | 
            +
                                "like Blender or MeshLab."
         | 
| 396 | 
            +
                            )
         | 
|  | |
|  | |
|  | |
| 397 |  | 
| 398 | 
            +
                #===============================================================
         | 
| 399 | 
            +
                # gradio running code
         | 
| 400 | 
            +
                #===============================================================
         | 
| 401 | 
            +
                
         | 
| 402 | 
             
                none = gr.State(None)
         | 
| 403 | 
             
                save_folder = gr.State()
         | 
| 404 | 
             
                cond_image = gr.State()
         | 
| 405 | 
             
                views_image = gr.State()
         | 
| 406 | 
             
                text_image = gr.State()
         | 
| 407 |  | 
| 408 | 
            +
                
         | 
| 409 | 
             
                textgen_submit.click(
         | 
| 410 | 
            +
                    fn=stage_0_t2i, 
         | 
| 411 | 
            +
                    inputs=[text, none, textgen_seed, textgen_step], 
         | 
| 412 | 
             
                    outputs=[rem_bg_image, save_folder],
         | 
| 413 | 
             
                ).success(
         | 
| 414 | 
            +
                    fn=stage_2_i2v, 
         | 
| 415 | 
            +
                    inputs=[rem_bg_image, textgen_SEED, textgen_STEP, save_folder], 
         | 
| 416 | 
             
                    outputs=[views_image, cond_image, result_image],
         | 
| 417 | 
             
                ).success(
         | 
| 418 | 
            +
                    fn=stage_3_v23, 
         | 
| 419 | 
            +
                    inputs=[views_image, cond_image, textgen_SEED, save_folder, textgen_max_faces, textgen_color], 
         | 
| 420 | 
            +
                    outputs=[result_3dobj, result_3dglb_texture],
         | 
|  | |
| 421 | 
             
                ).success(
         | 
| 422 | 
            +
                    fn=stage_3p_baking, 
         | 
| 423 | 
            +
                    inputs=[save_folder, textgen_color, textgen_bake], 
         | 
| 424 | 
            +
                    outputs=[result_3dglb_baked],
         | 
| 425 | 
            +
                ).success(
         | 
| 426 | 
            +
                    fn=stage_4_gif, 
         | 
| 427 | 
            +
                    inputs=[save_folder, textgen_color, textgen_bake, textgen_render], 
         | 
| 428 | 
             
                    outputs=[result_gif],
         | 
| 429 | 
             
                ).success(lambda: print('Text_to_3D Done ...'))
         | 
| 430 |  | 
| 431 | 
            +
                
         | 
| 432 | 
             
                imggen_submit.click(
         | 
| 433 | 
            +
                    fn=stage_0_t2i, 
         | 
| 434 | 
            +
                    inputs=[none, input_image, textgen_seed, textgen_step], 
         | 
| 435 | 
             
                    outputs=[text_image, save_folder],
         | 
| 436 | 
             
                ).success(
         | 
| 437 | 
            +
                    fn=stage_1_xbg, 
         | 
| 438 | 
            +
                    inputs=[text_image, save_folder, imggen_removebg], 
         | 
| 439 | 
             
                    outputs=[rem_bg_image],
         | 
| 440 | 
             
                ).success(
         | 
| 441 | 
            +
                    fn=stage_2_i2v, 
         | 
| 442 | 
            +
                    inputs=[rem_bg_image, imggen_SEED, imggen_STEP, save_folder], 
         | 
| 443 | 
             
                    outputs=[views_image, cond_image, result_image],
         | 
| 444 | 
             
                ).success(
         | 
| 445 | 
            +
                    fn=stage_3_v23, 
         | 
| 446 | 
            +
                    inputs=[views_image, cond_image, imggen_SEED, save_folder, imggen_max_faces, imggen_color],
         | 
| 447 | 
            +
                    outputs=[result_3dobj, result_3dglb_texture],
         | 
| 448 | 
            +
                ).success(
         | 
| 449 | 
            +
                    fn=stage_3p_baking, 
         | 
| 450 | 
            +
                    inputs=[save_folder, imggen_color, imggen_bake], 
         | 
| 451 | 
            +
                    outputs=[result_3dglb_baked],
         | 
| 452 | 
             
                ).success(
         | 
| 453 | 
            +
                    fn=stage_4_gif, 
         | 
| 454 | 
            +
                    inputs=[save_folder, imggen_color, imggen_bake, imggen_render], 
         | 
| 455 | 
             
                    outputs=[result_gif],
         | 
| 456 | 
             
                ).success(lambda: print('Image_to_3D Done ...'))
         | 
| 457 |  | 
| 458 | 
            +
                #===============================================================
         | 
| 459 | 
            +
                # start gradio server
         | 
| 460 | 
            +
                #===============================================================
         | 
| 461 |  | 
|  | |
| 462 | 
             
                demo.queue(max_size=CONST_MAX_QUEUE)
         | 
| 463 | 
             
                demo.launch(server_name=CONST_SERVER, server_port=CONST_PORT)
         | 
| 464 |  | 
    	
        env_install.sh
    CHANGED
    
    | @@ -1,6 +1,6 @@ | |
| 1 | 
             
            pip3 install diffusers transformers
         | 
| 2 | 
             
            pip3 install rembg tqdm omegaconf matplotlib opencv-python imageio jaxtyping einops 
         | 
| 3 | 
            -
            pip3 install SentencePiece accelerate trimesh PyMCubes xatlas libigl ninja gradio
         | 
| 4 | 
             
            pip3 install git+https://github.com/facebookresearch/pytorch3d@stable
         | 
| 5 | 
             
            pip3 install git+https://github.com/NVlabs/nvdiffrast
         | 
| 6 | 
             
            pip3 install open3d
         | 
|  | |
| 1 | 
             
            pip3 install diffusers transformers
         | 
| 2 | 
             
            pip3 install rembg tqdm omegaconf matplotlib opencv-python imageio jaxtyping einops 
         | 
| 3 | 
            +
            pip3 install SentencePiece accelerate trimesh PyMCubes xatlas libigl ninja gradio roma
         | 
| 4 | 
             
            pip3 install git+https://github.com/facebookresearch/pytorch3d@stable
         | 
| 5 | 
             
            pip3 install git+https://github.com/NVlabs/nvdiffrast
         | 
| 6 | 
             
            pip3 install open3d
         | 
    	
        infer/gif_render.py
    CHANGED
    
    | @@ -25,7 +25,7 @@ | |
| 25 | 
             
            import os, sys
         | 
| 26 | 
             
            sys.path.insert(0, f"{os.path.dirname(os.path.dirname(os.path.abspath(__file__)))}")
         | 
| 27 |  | 
| 28 | 
            -
            from svrm.ldm.vis_util import  | 
| 29 | 
             
            from infer.utils import seed_everything, timing_decorator
         | 
| 30 |  | 
| 31 | 
             
            class GifRenderer():
         | 
| @@ -40,14 +40,14 @@ class GifRenderer(): | |
| 40 | 
             
                    self, 
         | 
| 41 | 
             
                    obj_filename, 
         | 
| 42 | 
             
                    elev=0, 
         | 
| 43 | 
            -
                    azim= | 
| 44 | 
             
                    resolution=512, 
         | 
| 45 | 
             
                    gif_dst_path='', 
         | 
| 46 | 
             
                    n_views=120, 
         | 
| 47 | 
             
                    fps=30, 
         | 
| 48 | 
             
                    rgb=True
         | 
| 49 | 
             
                ):
         | 
| 50 | 
            -
                     | 
| 51 | 
             
                        obj_filename,
         | 
| 52 | 
             
                        elev=elev, 
         | 
| 53 | 
             
                        azim=azim, 
         | 
|  | |
| 25 | 
             
            import os, sys
         | 
| 26 | 
             
            sys.path.insert(0, f"{os.path.dirname(os.path.dirname(os.path.abspath(__file__)))}")
         | 
| 27 |  | 
| 28 | 
            +
            from svrm.ldm.vis_util import render_func
         | 
| 29 | 
             
            from infer.utils import seed_everything, timing_decorator
         | 
| 30 |  | 
| 31 | 
             
            class GifRenderer():
         | 
|  | |
| 40 | 
             
                    self, 
         | 
| 41 | 
             
                    obj_filename, 
         | 
| 42 | 
             
                    elev=0, 
         | 
| 43 | 
            +
                    azim=None, 
         | 
| 44 | 
             
                    resolution=512, 
         | 
| 45 | 
             
                    gif_dst_path='', 
         | 
| 46 | 
             
                    n_views=120, 
         | 
| 47 | 
             
                    fps=30, 
         | 
| 48 | 
             
                    rgb=True
         | 
| 49 | 
             
                ):
         | 
| 50 | 
            +
                    render_func(
         | 
| 51 | 
             
                        obj_filename,
         | 
| 52 | 
             
                        elev=elev, 
         | 
| 53 | 
             
                        azim=azim, 
         | 
    	
        infer/image_to_views.py
    CHANGED
    
    | @@ -48,21 +48,26 @@ def save_gif(pils, save_path, df=False): | |
| 48 |  | 
| 49 |  | 
| 50 | 
             
            class Image2Views():
         | 
| 51 | 
            -
                def __init__(self,  | 
|  | |
|  | |
|  | |
| 52 | 
             
                    self.device = device
         | 
| 53 | 
             
                    if use_lite:
         | 
|  | |
| 54 | 
             
                        self.pipe = Hunyuan3d_MVD_Lite_Pipeline.from_pretrained(
         | 
| 55 | 
            -
                             | 
| 56 | 
             
                            torch_dtype = torch.float16,
         | 
| 57 | 
             
                            use_safetensors = True,
         | 
| 58 | 
             
                        )
         | 
| 59 | 
             
                    else:
         | 
|  | |
| 60 | 
             
                        self.pipe = HunYuan3D_MVD_Std_Pipeline.from_pretrained(
         | 
| 61 | 
            -
                             | 
| 62 | 
             
                            torch_dtype = torch.float16,
         | 
| 63 | 
             
                            use_safetensors = True,
         | 
| 64 | 
             
                        )
         | 
| 65 | 
            -
                    self.pipe = self.pipe.to(device)
         | 
| 66 | 
             
                    self.order = [0, 1, 2, 3, 4, 5] if use_lite else [0, 2, 4, 5, 3, 1]
         | 
| 67 | 
             
                    self.save_memory = save_memory
         | 
| 68 | 
             
                    set_parameter_grad_false(self.pipe.unet)
         | 
|  | |
| 48 |  | 
| 49 |  | 
| 50 | 
             
            class Image2Views():
         | 
| 51 | 
            +
                def __init__(self, 
         | 
| 52 | 
            +
                        device="cuda:0", use_lite=False, save_memory=False,
         | 
| 53 | 
            +
                        std_pretrain='./weights/mvd_std', lite_pretrain='./weights/mvd_lite'
         | 
| 54 | 
            +
                    ):
         | 
| 55 | 
             
                    self.device = device
         | 
| 56 | 
             
                    if use_lite:
         | 
| 57 | 
            +
                        print("loading", lite_pretrain)
         | 
| 58 | 
             
                        self.pipe = Hunyuan3d_MVD_Lite_Pipeline.from_pretrained(
         | 
| 59 | 
            +
                            lite_pretrain,
         | 
| 60 | 
             
                            torch_dtype = torch.float16,
         | 
| 61 | 
             
                            use_safetensors = True,
         | 
| 62 | 
             
                        )
         | 
| 63 | 
             
                    else:
         | 
| 64 | 
            +
                        print("loadding", std_pretrain)
         | 
| 65 | 
             
                        self.pipe = HunYuan3D_MVD_Std_Pipeline.from_pretrained(
         | 
| 66 | 
            +
                            std_pretrain,
         | 
| 67 | 
             
                            torch_dtype = torch.float16,
         | 
| 68 | 
             
                            use_safetensors = True,
         | 
| 69 | 
             
                        )
         | 
| 70 | 
            +
                    self.pipe = self.pipe if save_memory else self.pipe.to(device)
         | 
| 71 | 
             
                    self.order = [0, 1, 2, 3, 4, 5] if use_lite else [0, 2, 4, 5, 3, 1]
         | 
| 72 | 
             
                    self.save_memory = save_memory
         | 
| 73 | 
             
                    set_parameter_grad_false(self.pipe.unet)
         | 
    	
        infer/text_to_image.py
    CHANGED
    
    | @@ -46,8 +46,7 @@ class Text2Image(): | |
| 46 | 
             
                    )
         | 
| 47 | 
             
                    set_parameter_grad_false(self.pipe.transformer)
         | 
| 48 | 
             
                    print('text2image transformer model', get_parameter_number(self.pipe.transformer))
         | 
| 49 | 
            -
                    if  | 
| 50 | 
            -
                        self.pipe = self.pipe.to(device)
         | 
| 51 | 
             
                    self.neg_txt = "文本,特写,裁剪,出框,最差质量,低质量,JPEG伪影,PGLY,重复,病态,残缺,多余的手指,变异的手," \
         | 
| 52 | 
             
                                   "画得不好的手,画得不好的脸,变异,畸形,模糊,脱水,糟糕的解剖学,糟糕的比例,多余的肢体,克隆的脸," \
         | 
| 53 | 
             
                                   "毁容,恶心的比例,畸形的肢体,缺失的手臂,缺失的腿,额外的手臂,额外的腿,融合的手指,手指太多,长脖子"
         | 
|  | |
| 46 | 
             
                    )
         | 
| 47 | 
             
                    set_parameter_grad_false(self.pipe.transformer)
         | 
| 48 | 
             
                    print('text2image transformer model', get_parameter_number(self.pipe.transformer))
         | 
| 49 | 
            +
                    self.pipe = self.pipe if save_memory else self.pipe.to(device)
         | 
|  | |
| 50 | 
             
                    self.neg_txt = "文本,特写,裁剪,出框,最差质量,低质量,JPEG伪影,PGLY,重复,病态,残缺,多余的手指,变异的手," \
         | 
| 51 | 
             
                                   "画得不好的手,画得不好的脸,变异,畸形,模糊,脱水,糟糕的解剖学,糟糕的比例,多余的肢体,克隆的脸," \
         | 
| 52 | 
             
                                   "毁容,恶心的比例,畸形的肢体,缺失的手臂,缺失的腿,额外的手臂,额外的腿,融合的手指,手指太多,长脖子"
         | 
    	
        infer/utils.py
    CHANGED
    
    | @@ -21,7 +21,8 @@ | |
| 21 | 
             
            # optimizer states), machine-learning model code, inference-enabling code, training-enabling code, 
         | 
| 22 | 
             
            # fine-tuning enabling code and other elements of the foregoing made publicly available 
         | 
| 23 | 
             
            # by Tencent in accordance with TENCENT HUNYUAN COMMUNITY LICENSE AGREEMENT.
         | 
| 24 | 
            -
             | 
|  | |
| 25 | 
             
            import os
         | 
| 26 | 
             
            import time
         | 
| 27 | 
             
            import random
         | 
| @@ -30,6 +31,7 @@ import torch | |
| 30 | 
             
            from torch.cuda.amp import autocast, GradScaler
         | 
| 31 | 
             
            from functools import wraps
         | 
| 32 |  | 
|  | |
| 33 | 
             
            def seed_everything(seed):
         | 
| 34 | 
             
                '''
         | 
| 35 | 
             
                    seed everthing
         | 
| @@ -39,6 +41,7 @@ def seed_everything(seed): | |
| 39 | 
             
                torch.manual_seed(seed)
         | 
| 40 | 
             
                os.environ["PL_GLOBAL_SEED"] = str(seed)
         | 
| 41 |  | 
|  | |
| 42 | 
             
            def timing_decorator(category: str):
         | 
| 43 | 
             
                '''
         | 
| 44 | 
             
                    timing_decorator: record time
         | 
| @@ -57,6 +60,7 @@ def timing_decorator(category: str): | |
| 57 | 
             
                    return wrapper
         | 
| 58 | 
             
                return decorator
         | 
| 59 |  | 
|  | |
| 60 | 
             
            def auto_amp_inference(func):
         | 
| 61 | 
             
                '''
         | 
| 62 | 
             
                    with torch.cuda.amp.autocast()"
         | 
| @@ -69,11 +73,13 @@ def auto_amp_inference(func): | |
| 69 | 
             
                    return output
         | 
| 70 | 
             
                return wrapper
         | 
| 71 |  | 
|  | |
| 72 | 
             
            def get_parameter_number(model):
         | 
| 73 | 
             
                total_num = sum(p.numel() for p in model.parameters())
         | 
| 74 | 
             
                trainable_num = sum(p.numel() for p in model.parameters() if p.requires_grad)
         | 
| 75 | 
             
                return {'Total': total_num, 'Trainable': trainable_num}
         | 
| 76 |  | 
|  | |
| 77 | 
             
            def set_parameter_grad_false(model):
         | 
| 78 | 
             
                for p in model.parameters():
         | 
| 79 | 
             
                    p.requires_grad = False
         | 
|  | |
| 21 | 
             
            # optimizer states), machine-learning model code, inference-enabling code, training-enabling code, 
         | 
| 22 | 
             
            # fine-tuning enabling code and other elements of the foregoing made publicly available 
         | 
| 23 | 
             
            # by Tencent in accordance with TENCENT HUNYUAN COMMUNITY LICENSE AGREEMENT.
         | 
| 24 | 
            +
            import sys
         | 
| 25 | 
            +
            import io
         | 
| 26 | 
             
            import os
         | 
| 27 | 
             
            import time
         | 
| 28 | 
             
            import random
         | 
|  | |
| 31 | 
             
            from torch.cuda.amp import autocast, GradScaler
         | 
| 32 | 
             
            from functools import wraps
         | 
| 33 |  | 
| 34 | 
            +
             | 
| 35 | 
             
            def seed_everything(seed):
         | 
| 36 | 
             
                '''
         | 
| 37 | 
             
                    seed everthing
         | 
|  | |
| 41 | 
             
                torch.manual_seed(seed)
         | 
| 42 | 
             
                os.environ["PL_GLOBAL_SEED"] = str(seed)
         | 
| 43 |  | 
| 44 | 
            +
                
         | 
| 45 | 
             
            def timing_decorator(category: str):
         | 
| 46 | 
             
                '''
         | 
| 47 | 
             
                    timing_decorator: record time
         | 
|  | |
| 60 | 
             
                    return wrapper
         | 
| 61 | 
             
                return decorator
         | 
| 62 |  | 
| 63 | 
            +
             | 
| 64 | 
             
            def auto_amp_inference(func):
         | 
| 65 | 
             
                '''
         | 
| 66 | 
             
                    with torch.cuda.amp.autocast()"
         | 
|  | |
| 73 | 
             
                    return output
         | 
| 74 | 
             
                return wrapper
         | 
| 75 |  | 
| 76 | 
            +
             | 
| 77 | 
             
            def get_parameter_number(model):
         | 
| 78 | 
             
                total_num = sum(p.numel() for p in model.parameters())
         | 
| 79 | 
             
                trainable_num = sum(p.numel() for p in model.parameters() if p.requires_grad)
         | 
| 80 | 
             
                return {'Total': total_num, 'Trainable': trainable_num}
         | 
| 81 |  | 
| 82 | 
            +
             | 
| 83 | 
             
            def set_parameter_grad_false(model):
         | 
| 84 | 
             
                for p in model.parameters():
         | 
| 85 | 
             
                    p.requires_grad = False
         | 
    	
        infer/views_to_mesh.py
    CHANGED
    
    | @@ -47,11 +47,15 @@ class Views2Mesh(): | |
| 47 | 
             
                        use_lite: lite version
         | 
| 48 | 
             
                        save_memory: cpu auto
         | 
| 49 | 
             
                    '''
         | 
| 50 | 
            -
                    self.mv23d_predictor = MV23DPredictor(mv23d_ckt_path, mv23d_cfg_path, device=device)  
         | 
| 51 | 
            -
                    self.mv23d_predictor.model.eval()
         | 
| 52 | 
            -
                    self.order = [0, 1, 2, 3, 4, 5] if use_lite else [0, 2, 4, 5, 3, 1]
         | 
| 53 | 
             
                    self.device = device
         | 
| 54 | 
             
                    self.save_memory = save_memory
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 55 | 
             
                    set_parameter_grad_false(self.mv23d_predictor.model)
         | 
| 56 | 
             
                    print('view2mesh model', get_parameter_number(self.mv23d_predictor.model))
         | 
| 57 |  | 
| @@ -109,7 +113,6 @@ class Views2Mesh(): | |
| 109 | 
             
                        do_texture_mapping = do_texture_mapping
         | 
| 110 | 
             
                    )
         | 
| 111 | 
             
                    torch.cuda.empty_cache()
         | 
| 112 | 
            -
                    return save_dir
         | 
| 113 |  | 
| 114 |  | 
| 115 | 
             
            if __name__ == "__main__":
         | 
|  | |
| 47 | 
             
                        use_lite: lite version
         | 
| 48 | 
             
                        save_memory: cpu auto
         | 
| 49 | 
             
                    '''
         | 
|  | |
|  | |
|  | |
| 50 | 
             
                    self.device = device
         | 
| 51 | 
             
                    self.save_memory = save_memory
         | 
| 52 | 
            +
                    self.mv23d_predictor = MV23DPredictor(
         | 
| 53 | 
            +
                        mv23d_ckt_path, 
         | 
| 54 | 
            +
                        mv23d_cfg_path, 
         | 
| 55 | 
            +
                        device = "cpu" if save_memory else device
         | 
| 56 | 
            +
                    )  
         | 
| 57 | 
            +
                    self.mv23d_predictor.model.eval()
         | 
| 58 | 
            +
                    self.order = [0, 1, 2, 3, 4, 5] if use_lite else [0, 2, 4, 5, 3, 1]
         | 
| 59 | 
             
                    set_parameter_grad_false(self.mv23d_predictor.model)
         | 
| 60 | 
             
                    print('view2mesh model', get_parameter_number(self.mv23d_predictor.model))
         | 
| 61 |  | 
|  | |
| 113 | 
             
                        do_texture_mapping = do_texture_mapping
         | 
| 114 | 
             
                    )
         | 
| 115 | 
             
                    torch.cuda.empty_cache()
         | 
|  | |
| 116 |  | 
| 117 |  | 
| 118 | 
             
            if __name__ == "__main__":
         | 
    	
        main.py
    CHANGED
    
    | @@ -24,16 +24,28 @@ | |
| 24 |  | 
| 25 | 
             
            import os
         | 
| 26 | 
             
            import warnings
         | 
| 27 | 
            -
            import torch
         | 
| 28 | 
            -
            from PIL import Image
         | 
| 29 | 
             
            import argparse
         | 
| 30 | 
            -
             | 
| 31 | 
            -
            from  | 
|  | |
| 32 |  | 
| 33 | 
             
            warnings.simplefilter('ignore', category=UserWarning)
         | 
| 34 | 
             
            warnings.simplefilter('ignore', category=FutureWarning)
         | 
| 35 | 
             
            warnings.simplefilter('ignore', category=DeprecationWarning)
         | 
| 36 |  | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 37 | 
             
            def get_args():
         | 
| 38 | 
             
                parser = argparse.ArgumentParser()
         | 
| 39 | 
             
                parser.add_argument(
         | 
| @@ -73,8 +85,8 @@ def get_args(): | |
| 73 | 
             
                    "--gen_steps", default=50, type=int
         | 
| 74 | 
             
                )
         | 
| 75 | 
             
                parser.add_argument(
         | 
| 76 | 
            -
                    "--max_faces_num", default= | 
| 77 | 
            -
                    help="max num of face, suggest  | 
| 78 | 
             
                )
         | 
| 79 | 
             
                parser.add_argument(
         | 
| 80 | 
             
                    "--save_memory", default=False, action="store_true"
         | 
| @@ -85,6 +97,13 @@ def get_args(): | |
| 85 | 
             
                parser.add_argument(
         | 
| 86 | 
             
                    "--do_render", default=False, action="store_true"
         | 
| 87 | 
             
                )
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 88 | 
             
                return parser.parse_args()
         | 
| 89 |  | 
| 90 |  | 
| @@ -95,6 +114,7 @@ if __name__ == "__main__": | |
| 95 | 
             
                assert args.text_prompt or args.image_prompt,        "Text and image can only be given to one"
         | 
| 96 |  | 
| 97 | 
             
                # init model
         | 
|  | |
| 98 | 
             
                rembg_model = Removebg()
         | 
| 99 | 
             
                image_to_views_model = Image2Views(
         | 
| 100 | 
             
                    device=args.device, 
         | 
| @@ -116,9 +136,18 @@ if __name__ == "__main__": | |
| 116 | 
             
                        device = args.device, 
         | 
| 117 | 
             
                        save_memory = args.save_memory
         | 
| 118 | 
             
                    )
         | 
| 119 | 
            -
             | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 120 | 
             
                    gif_renderer = GifRenderer(device=args.device)
         | 
| 121 | 
            -
             | 
|  | |
|  | |
| 122 | 
             
                # ---- ----- ---- ---- ---- ----
         | 
| 123 |  | 
| 124 | 
             
                os.makedirs(args.save_folder, exist_ok=True)
         | 
| @@ -136,7 +165,7 @@ if __name__ == "__main__": | |
| 136 |  | 
| 137 | 
             
                # stage 2, remove back ground
         | 
| 138 | 
             
                res_rgba_pil = rembg_model(res_rgb_pil)
         | 
| 139 | 
            -
                 | 
| 140 |  | 
| 141 | 
             
                # stage 3, image to views
         | 
| 142 | 
             
                (views_grid_pil, cond_img), view_pil_list = image_to_views_model(
         | 
| @@ -155,10 +184,29 @@ if __name__ == "__main__": | |
| 155 | 
             
                    save_folder = args.save_folder,
         | 
| 156 | 
             
                    do_texture_mapping = args.do_texture_mapping
         | 
| 157 | 
             
                )
         | 
| 158 | 
            -
             | 
| 159 | 
            -
                # | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 160 | 
             
                if args.do_render:
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 161 | 
             
                    gif_renderer(
         | 
| 162 | 
            -
                         | 
| 163 | 
             
                        gif_dst_path = os.path.join(args.save_folder, 'output.gif'),
         | 
| 164 | 
             
                    )
         | 
|  | |
| 24 |  | 
| 25 | 
             
            import os
         | 
| 26 | 
             
            import warnings
         | 
|  | |
|  | |
| 27 | 
             
            import argparse
         | 
| 28 | 
            +
            import time
         | 
| 29 | 
            +
            from PIL import Image
         | 
| 30 | 
            +
            import torch
         | 
| 31 |  | 
| 32 | 
             
            warnings.simplefilter('ignore', category=UserWarning)
         | 
| 33 | 
             
            warnings.simplefilter('ignore', category=FutureWarning)
         | 
| 34 | 
             
            warnings.simplefilter('ignore', category=DeprecationWarning)
         | 
| 35 |  | 
| 36 | 
            +
            from infer import Text2Image, Removebg, Image2Views, Views2Mesh, GifRenderer
         | 
| 37 | 
            +
            from third_party.mesh_baker import MeshBaker
         | 
| 38 | 
            +
            from third_party.check import check_bake_available
         | 
| 39 | 
            +
             | 
| 40 | 
            +
            try:
         | 
| 41 | 
            +
                from third_party.mesh_baker import MeshBaker
         | 
| 42 | 
            +
                assert check_bake_available()
         | 
| 43 | 
            +
                BAKE_AVAILEBLE = True
         | 
| 44 | 
            +
            except Exception as err:
         | 
| 45 | 
            +
                print(err)
         | 
| 46 | 
            +
                print("import baking related fail, run without baking")
         | 
| 47 | 
            +
                BAKE_AVAILEBLE = False
         | 
| 48 | 
            +
             | 
| 49 | 
             
            def get_args():
         | 
| 50 | 
             
                parser = argparse.ArgumentParser()
         | 
| 51 | 
             
                parser.add_argument(
         | 
|  | |
| 85 | 
             
                    "--gen_steps", default=50, type=int
         | 
| 86 | 
             
                )
         | 
| 87 | 
             
                parser.add_argument(
         | 
| 88 | 
            +
                    "--max_faces_num", default=90000, type=int, 
         | 
| 89 | 
            +
                    help="max num of face, suggest 90000 for effect, 10000 for speed"
         | 
| 90 | 
             
                )
         | 
| 91 | 
             
                parser.add_argument(
         | 
| 92 | 
             
                    "--save_memory", default=False, action="store_true"
         | 
|  | |
| 97 | 
             
                parser.add_argument(
         | 
| 98 | 
             
                    "--do_render", default=False, action="store_true"
         | 
| 99 | 
             
                )
         | 
| 100 | 
            +
                parser.add_argument(
         | 
| 101 | 
            +
                    "--do_bake", default=False, action="store_true"
         | 
| 102 | 
            +
                )
         | 
| 103 | 
            +
                parser.add_argument(
         | 
| 104 | 
            +
                    "--bake_align_times", default=3, type=int,
         | 
| 105 | 
            +
                    help="align times between view image and mesh, suggest 1~6"
         | 
| 106 | 
            +
                )
         | 
| 107 | 
             
                return parser.parse_args()
         | 
| 108 |  | 
| 109 |  | 
|  | |
| 114 | 
             
                assert args.text_prompt or args.image_prompt,        "Text and image can only be given to one"
         | 
| 115 |  | 
| 116 | 
             
                # init model
         | 
| 117 | 
            +
                st = time.time()
         | 
| 118 | 
             
                rembg_model = Removebg()
         | 
| 119 | 
             
                image_to_views_model = Image2Views(
         | 
| 120 | 
             
                    device=args.device, 
         | 
|  | |
| 136 | 
             
                        device = args.device, 
         | 
| 137 | 
             
                        save_memory = args.save_memory
         | 
| 138 | 
             
                    )
         | 
| 139 | 
            +
                    
         | 
| 140 | 
            +
                if args.do_bake and BAKE_AVAILEBLE:
         | 
| 141 | 
            +
                    mesh_baker = MeshBaker(
         | 
| 142 | 
            +
                        device = args.device,
         | 
| 143 | 
            +
                        align_times = args.bake_align_times
         | 
| 144 | 
            +
                    )
         | 
| 145 | 
            +
                        
         | 
| 146 | 
            +
                if check_bake_available():
         | 
| 147 | 
             
                    gif_renderer = GifRenderer(device=args.device)
         | 
| 148 | 
            +
                    
         | 
| 149 | 
            +
                print(f"Init Models cost {time.time()-st}s")
         | 
| 150 | 
            +
                
         | 
| 151 | 
             
                # ---- ----- ---- ---- ---- ----
         | 
| 152 |  | 
| 153 | 
             
                os.makedirs(args.save_folder, exist_ok=True)
         | 
|  | |
| 165 |  | 
| 166 | 
             
                # stage 2, remove back ground
         | 
| 167 | 
             
                res_rgba_pil = rembg_model(res_rgb_pil)
         | 
| 168 | 
            +
                res_rgba_pil.save(os.path.join(args.save_folder, "img_nobg.png"))
         | 
| 169 |  | 
| 170 | 
             
                # stage 3, image to views
         | 
| 171 | 
             
                (views_grid_pil, cond_img), view_pil_list = image_to_views_model(
         | 
|  | |
| 184 | 
             
                    save_folder = args.save_folder,
         | 
| 185 | 
             
                    do_texture_mapping = args.do_texture_mapping
         | 
| 186 | 
             
                )
         | 
| 187 | 
            +
                
         | 
| 188 | 
            +
                # stage 5, baking
         | 
| 189 | 
            +
                mesh_file_for_render = None
         | 
| 190 | 
            +
                if args.do_bake and BAKE_AVAILEBLE:
         | 
| 191 | 
            +
                    mesh_file_for_render = mesh_baker(args.save_folder)
         | 
| 192 | 
            +
                    
         | 
| 193 | 
            +
                # stage 6, render gif
         | 
| 194 | 
            +
                # todo fix: if init folder unclear, it maybe mistake rendering
         | 
| 195 | 
             
                if args.do_render:
         | 
| 196 | 
            +
                    if mesh_file_for_render and os.path.exists(mesh_file_for_render):
         | 
| 197 | 
            +
                        mesh_file_for_render = mesh_file_for_render
         | 
| 198 | 
            +
                    elif os.path.exists(os.path.join(args.save_folder, 'view_1/bake/mesh.obj')):
         | 
| 199 | 
            +
                        mesh_file_for_render = os.path.join(args.save_folder, 'view_1/bake/mesh.obj')
         | 
| 200 | 
            +
                    elif os.path.exists(os.path.join(args.save_folder, 'view_0/bake/mesh.obj')):
         | 
| 201 | 
            +
                        mesh_file_for_render = os.path.join(args.save_folder, 'view_0/bake/mesh.obj')
         | 
| 202 | 
            +
                    elif os.path.exists(os.path.join(args.save_folder, 'mesh.obj')):
         | 
| 203 | 
            +
                        mesh_file_for_render = os.path.join(args.save_folder, 'mesh.obj')
         | 
| 204 | 
            +
                    else:
         | 
| 205 | 
            +
                        raise FileNotFoundError("mesh_file_for_render not found")
         | 
| 206 | 
            +
                        
         | 
| 207 | 
            +
                    print("Rendering 3d file:", mesh_file_for_render)
         | 
| 208 | 
            +
                    
         | 
| 209 | 
             
                    gif_renderer(
         | 
| 210 | 
            +
                        mesh_file_for_render,
         | 
| 211 | 
             
                        gif_dst_path = os.path.join(args.save_folder, 'output.gif'),
         | 
| 212 | 
             
                    )
         | 
    	
        requirements.txt
    CHANGED
    
    | @@ -22,3 +22,4 @@ git+https://github.com/facebookresearch/pytorch3d@stable | |
| 22 | 
             
            git+https://github.com/NVlabs/nvdiffrast
         | 
| 23 | 
             
            open3d
         | 
| 24 | 
             
            ninja
         | 
|  | 
|  | |
| 22 | 
             
            git+https://github.com/NVlabs/nvdiffrast
         | 
| 23 | 
             
            open3d
         | 
| 24 | 
             
            ninja
         | 
| 25 | 
            +
            roma
         | 
    	
        svrm/ldm/models/svrm.py
    CHANGED
    
    | @@ -46,7 +46,7 @@ from ..modules.rendering_neus.rasterize import NVDiffRasterizerContext | |
| 46 |  | 
| 47 | 
             
            from ..utils.ops import scale_tensor
         | 
| 48 | 
             
            from ..util import count_params, instantiate_from_config
         | 
| 49 | 
            -
            from ..vis_util import  | 
| 50 |  | 
| 51 |  | 
| 52 | 
             
            def unwrap_uv(v_pos, t_pos_idx):
         | 
| @@ -58,7 +58,6 @@ def unwrap_uv(v_pos, t_pos_idx): | |
| 58 | 
             
                indices = indices.astype(np.int64, casting="same_kind")
         | 
| 59 | 
             
                return uvs, indices
         | 
| 60 |  | 
| 61 | 
            -
             | 
| 62 | 
             
            def uv_padding(image, hole_mask, uv_padding_size = 2):
         | 
| 63 | 
             
                return cv2.inpaint(
         | 
| 64 | 
             
                    (image.detach().cpu().numpy() * 255).astype(np.uint8),
         | 
| @@ -120,14 +119,16 @@ class SVRMModel(torch.nn.Module): | |
| 120 | 
             
                    out_dir = 'outputs/test'
         | 
| 121 | 
             
                ):
         | 
| 122 | 
             
                    """
         | 
| 123 | 
            -
                     | 
| 124 | 
             
                    """
         | 
| 125 |  | 
| 126 | 
            -
                    obj_vertext_path = os.path.join(out_dir, ' | 
| 127 | 
            -
                     | 
| 128 | 
            -
                     | 
| 129 | 
            -
             | 
| 130 | 
            -
             | 
|  | |
|  | |
| 131 |  | 
| 132 | 
             
                    st = time.time()
         | 
| 133 |  | 
| @@ -204,15 +205,13 @@ class SVRMModel(torch.nn.Module): | |
| 204 | 
             
                    mesh = trimesh.load_mesh(obj_vertext_path)
         | 
| 205 | 
             
                    print(f"=====> generate mesh with vertex shading time: {time.time() - st}")
         | 
| 206 | 
             
                    st = time.time()
         | 
| 207 | 
            -
             | 
| 208 | 
             
                    if not do_texture_mapping:
         | 
| 209 | 
            -
                         | 
| 210 | 
            -
                        mesh.export(glb_path, file_type='glb')
         | 
| 211 | 
            -
                        return None
         | 
| 212 |  | 
| 213 | 
            -
             | 
| 214 | 
            -
                     | 
| 215 | 
            -
                    
         | 
| 216 |  | 
| 217 | 
             
                    st = time.time()
         | 
| 218 |  | 
| @@ -238,12 +237,9 @@ class SVRMModel(torch.nn.Module): | |
| 238 |  | 
| 239 | 
             
                    # Interpolate world space position
         | 
| 240 | 
             
                    gb_pos = ctx.interpolate_one(vtx_refine, rast[None, ...], faces_refine)[0][0]
         | 
| 241 | 
            -
                    
         | 
| 242 | 
             
                    with torch.no_grad():
         | 
| 243 | 
             
                        gb_mask_pos_scale = scale_tensor(gb_pos.unsqueeze(0).view(1, -1, 3), (-1, 1), (-1, 1))
         | 
| 244 | 
            -
                        
         | 
| 245 | 
             
                        tex_map = self.render.forward_points(cur_triplane, gb_mask_pos_scale)['rgb']
         | 
| 246 | 
            -
                        
         | 
| 247 | 
             
                        tex_map = tex_map.float().squeeze(0)  # (0, 1)
         | 
| 248 | 
             
                        tex_map = tex_map.view((texture_res, texture_res, 3)) 
         | 
| 249 | 
             
                        img = uv_padding(tex_map, hole_mask)
         | 
| @@ -257,7 +253,7 @@ class SVRMModel(torch.nn.Module): | |
| 257 | 
             
                        fid.write('newmtl material_0\n')
         | 
| 258 | 
             
                        fid.write("Ka 1.000 1.000 1.000\n")
         | 
| 259 | 
             
                        fid.write("Kd 1.000 1.000 1.000\n")
         | 
| 260 | 
            -
                        fid.write("Ks 0. | 
| 261 | 
             
                        fid.write("d 1.0\n")
         | 
| 262 | 
             
                        fid.write("illum 2\n")
         | 
| 263 | 
             
                        fid.write(f'map_Kd texture.png\n')
         | 
| @@ -278,4 +274,5 @@ class SVRMModel(torch.nn.Module): | |
| 278 | 
             
                    mesh = trimesh.load_mesh(obj_path)
         | 
| 279 | 
             
                    mesh.export(glb_path, file_type='glb')
         | 
| 280 | 
             
                    print(f"=====> generate mesh with texture shading time: {time.time() - st}")
         | 
|  | |
| 281 |  | 
|  | |
| 46 |  | 
| 47 | 
             
            from ..utils.ops import scale_tensor
         | 
| 48 | 
             
            from ..util import count_params, instantiate_from_config
         | 
| 49 | 
            +
            from ..vis_util import render_func
         | 
| 50 |  | 
| 51 |  | 
| 52 | 
             
            def unwrap_uv(v_pos, t_pos_idx):
         | 
|  | |
| 58 | 
             
                indices = indices.astype(np.int64, casting="same_kind")
         | 
| 59 | 
             
                return uvs, indices
         | 
| 60 |  | 
|  | |
| 61 | 
             
            def uv_padding(image, hole_mask, uv_padding_size = 2):
         | 
| 62 | 
             
                return cv2.inpaint(
         | 
| 63 | 
             
                    (image.detach().cpu().numpy() * 255).astype(np.uint8),
         | 
|  | |
| 119 | 
             
                    out_dir = 'outputs/test'
         | 
| 120 | 
             
                ):
         | 
| 121 | 
             
                    """
         | 
| 122 | 
            +
                    do_texture_mapping: True for ray texture, False for vertices texture
         | 
| 123 | 
             
                    """
         | 
| 124 |  | 
| 125 | 
            +
                    obj_vertext_path = os.path.join(out_dir, 'mesh_vertex_colors.obj')
         | 
| 126 | 
            +
                    
         | 
| 127 | 
            +
                    if do_texture_mapping:
         | 
| 128 | 
            +
                        obj_path = os.path.join(out_dir, 'mesh.obj')
         | 
| 129 | 
            +
                        obj_texture_path = os.path.join(out_dir, 'texture.png')
         | 
| 130 | 
            +
                        obj_mtl_path = os.path.join(out_dir, 'texture.mtl')
         | 
| 131 | 
            +
                        glb_path = os.path.join(out_dir, 'mesh.glb')
         | 
| 132 |  | 
| 133 | 
             
                    st = time.time()
         | 
| 134 |  | 
|  | |
| 205 | 
             
                    mesh = trimesh.load_mesh(obj_vertext_path)
         | 
| 206 | 
             
                    print(f"=====> generate mesh with vertex shading time: {time.time() - st}")
         | 
| 207 | 
             
                    st = time.time()
         | 
| 208 | 
            +
                    
         | 
| 209 | 
             
                    if not do_texture_mapping:
         | 
| 210 | 
            +
                        return obj_vertext_path, None
         | 
|  | |
|  | |
| 211 |  | 
| 212 | 
            +
                    ###########################################################
         | 
| 213 | 
            +
                    #-------------    export texture    -----------------------
         | 
| 214 | 
            +
                    ###########################################################
         | 
| 215 |  | 
| 216 | 
             
                    st = time.time()
         | 
| 217 |  | 
|  | |
| 237 |  | 
| 238 | 
             
                    # Interpolate world space position
         | 
| 239 | 
             
                    gb_pos = ctx.interpolate_one(vtx_refine, rast[None, ...], faces_refine)[0][0]
         | 
|  | |
| 240 | 
             
                    with torch.no_grad():
         | 
| 241 | 
             
                        gb_mask_pos_scale = scale_tensor(gb_pos.unsqueeze(0).view(1, -1, 3), (-1, 1), (-1, 1))
         | 
|  | |
| 242 | 
             
                        tex_map = self.render.forward_points(cur_triplane, gb_mask_pos_scale)['rgb']
         | 
|  | |
| 243 | 
             
                        tex_map = tex_map.float().squeeze(0)  # (0, 1)
         | 
| 244 | 
             
                        tex_map = tex_map.view((texture_res, texture_res, 3)) 
         | 
| 245 | 
             
                        img = uv_padding(tex_map, hole_mask)
         | 
|  | |
| 253 | 
             
                        fid.write('newmtl material_0\n')
         | 
| 254 | 
             
                        fid.write("Ka 1.000 1.000 1.000\n")
         | 
| 255 | 
             
                        fid.write("Kd 1.000 1.000 1.000\n")
         | 
| 256 | 
            +
                        fid.write("Ks 0.500 0.500 0.500\n")
         | 
| 257 | 
             
                        fid.write("d 1.0\n")
         | 
| 258 | 
             
                        fid.write("illum 2\n")
         | 
| 259 | 
             
                        fid.write(f'map_Kd texture.png\n')
         | 
|  | |
| 274 | 
             
                    mesh = trimesh.load_mesh(obj_path)
         | 
| 275 | 
             
                    mesh.export(glb_path, file_type='glb')
         | 
| 276 | 
             
                    print(f"=====> generate mesh with texture shading time: {time.time() - st}")
         | 
| 277 | 
            +
                    return obj_path, glb_path
         | 
| 278 |  | 
    	
        svrm/ldm/modules/attention.py
    CHANGED
    
    | @@ -246,8 +246,11 @@ class CrossAttention(nn.Module): | |
| 246 | 
             
            class FlashAttention(nn.Module):
         | 
| 247 | 
             
                def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.):
         | 
| 248 | 
             
                    super().__init__()
         | 
| 249 | 
            -
                    print( | 
| 250 | 
            -
             | 
|  | |
|  | |
|  | |
| 251 | 
             
                    inner_dim = dim_head * heads
         | 
| 252 | 
             
                    context_dim = default(context_dim, query_dim)
         | 
| 253 | 
             
                    self.scale = dim_head ** -0.5
         | 
| @@ -269,7 +272,12 @@ class FlashAttention(nn.Module): | |
| 269 | 
             
                    k = self.to_k(context).to(dtype)
         | 
| 270 | 
             
                    v = self.to_v(context).to(dtype)
         | 
| 271 | 
             
                    q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b n h d', h=h), (q, k, v)) # q is [b, 3079, 16, 64]
         | 
| 272 | 
            -
                    out = flash_attn_func(q, k, v,  | 
|  | |
|  | |
|  | |
|  | |
|  | |
| 273 | 
             
                    out = rearrange(out, 'b n h d -> b n (h d)', h=h)
         | 
| 274 | 
             
                    return self.to_out(out.float())
         | 
| 275 |  | 
| @@ -277,8 +285,11 @@ class MemoryEfficientCrossAttention(nn.Module): | |
| 277 | 
             
                # https://github.com/MatthieuTPHR/diffusers/blob/d80b531ff8060ec1ea982b65a1b8df70f73aa67c/src/diffusers/models/attention.py#L223
         | 
| 278 | 
             
                def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.0):
         | 
| 279 | 
             
                    super().__init__()
         | 
| 280 | 
            -
                    print( | 
| 281 | 
            -
             | 
|  | |
|  | |
|  | |
| 282 | 
             
                    inner_dim = dim_head * heads
         | 
| 283 | 
             
                    context_dim = default(context_dim, query_dim)
         | 
| 284 |  | 
| @@ -327,10 +338,12 @@ class BasicTransformerBlock(nn.Module): | |
| 327 | 
             
                    super().__init__()
         | 
| 328 | 
             
                    self.disable_self_attn = disable_self_attn
         | 
| 329 | 
             
                    self.attn1 = CrossAttention(query_dim=dim, heads=n_heads, dim_head=d_head, dropout=dropout,
         | 
| 330 | 
            -
                                                context_dim=context_dim if self.disable_self_attn else None) | 
|  | |
| 331 | 
             
                    self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff)
         | 
| 332 | 
             
                    self.attn2 = CrossAttention(query_dim=dim, context_dim=context_dim,
         | 
| 333 | 
            -
                                                heads=n_heads, dim_head=d_head, dropout=dropout)   | 
|  | |
| 334 | 
             
                    self.norm1 = Fp32LayerNorm(dim)
         | 
| 335 | 
             
                    self.norm2 = Fp32LayerNorm(dim)
         | 
| 336 | 
             
                    self.norm3 = Fp32LayerNorm(dim)
         | 
| @@ -451,7 +464,3 @@ class ImgToTriplaneTransformer(nn.Module): | |
| 451 | 
             
                    x = self.norm(x)
         | 
| 452 | 
             
                    return x
         | 
| 453 |  | 
| 454 | 
            -
             | 
| 455 | 
            -
             | 
| 456 | 
            -
             | 
| 457 | 
            -
             | 
|  | |
| 246 | 
             
            class FlashAttention(nn.Module):
         | 
| 247 | 
             
                def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.):
         | 
| 248 | 
             
                    super().__init__()
         | 
| 249 | 
            +
                    # print(
         | 
| 250 | 
            +
                    #     f"Setting up {self.__class__.__name__}. Query dim is {query_dim}, "
         | 
| 251 | 
            +
                    #     "context_dim is {context_dim} and using "
         | 
| 252 | 
            +
                    #     f"{heads} heads."
         | 
| 253 | 
            +
                    # )
         | 
| 254 | 
             
                    inner_dim = dim_head * heads
         | 
| 255 | 
             
                    context_dim = default(context_dim, query_dim)
         | 
| 256 | 
             
                    self.scale = dim_head ** -0.5
         | 
|  | |
| 272 | 
             
                    k = self.to_k(context).to(dtype)
         | 
| 273 | 
             
                    v = self.to_v(context).to(dtype)
         | 
| 274 | 
             
                    q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b n h d', h=h), (q, k, v)) # q is [b, 3079, 16, 64]
         | 
| 275 | 
            +
                    out = flash_attn_func(q, k, v, 
         | 
| 276 | 
            +
                                          dropout_p=self.dropout, 
         | 
| 277 | 
            +
                                          softmax_scale=None, 
         | 
| 278 | 
            +
                                          causal=False, 
         | 
| 279 | 
            +
                                          window_size=(-1, -1)
         | 
| 280 | 
            +
                                         ) # out is same shape to q
         | 
| 281 | 
             
                    out = rearrange(out, 'b n h d -> b n (h d)', h=h)
         | 
| 282 | 
             
                    return self.to_out(out.float())
         | 
| 283 |  | 
|  | |
| 285 | 
             
                # https://github.com/MatthieuTPHR/diffusers/blob/d80b531ff8060ec1ea982b65a1b8df70f73aa67c/src/diffusers/models/attention.py#L223
         | 
| 286 | 
             
                def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.0):
         | 
| 287 | 
             
                    super().__init__()
         | 
| 288 | 
            +
                    # print(
         | 
| 289 | 
            +
                    #     f"Setting up {self.__class__.__name__}. Query dim is {query_dim}, "
         | 
| 290 | 
            +
                    #     "context_dim is {context_dim} and using "
         | 
| 291 | 
            +
                    #     f"{heads} heads."
         | 
| 292 | 
            +
                    # )
         | 
| 293 | 
             
                    inner_dim = dim_head * heads
         | 
| 294 | 
             
                    context_dim = default(context_dim, query_dim)
         | 
| 295 |  | 
|  | |
| 338 | 
             
                    super().__init__()
         | 
| 339 | 
             
                    self.disable_self_attn = disable_self_attn
         | 
| 340 | 
             
                    self.attn1 = CrossAttention(query_dim=dim, heads=n_heads, dim_head=d_head, dropout=dropout,
         | 
| 341 | 
            +
                                                context_dim=context_dim if self.disable_self_attn else None)
         | 
| 342 | 
            +
                    # is a self-attention if not self.disable_self_attn
         | 
| 343 | 
             
                    self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff)
         | 
| 344 | 
             
                    self.attn2 = CrossAttention(query_dim=dim, context_dim=context_dim,
         | 
| 345 | 
            +
                                                heads=n_heads, dim_head=d_head, dropout=dropout)  
         | 
| 346 | 
            +
                    # is self-attn if context is none
         | 
| 347 | 
             
                    self.norm1 = Fp32LayerNorm(dim)
         | 
| 348 | 
             
                    self.norm2 = Fp32LayerNorm(dim)
         | 
| 349 | 
             
                    self.norm3 = Fp32LayerNorm(dim)
         | 
|  | |
| 464 | 
             
                    x = self.norm(x)
         | 
| 465 | 
             
                    return x
         | 
| 466 |  | 
|  | |
|  | |
|  | |
|  | 
    	
        svrm/ldm/vis_util.py
    CHANGED
    
    | @@ -27,10 +27,10 @@ from pytorch3d.renderer import ( | |
| 27 | 
             
            )
         | 
| 28 |  | 
| 29 |  | 
| 30 | 
            -
            def  | 
| 31 | 
             
                obj_filename, 
         | 
| 32 | 
             
                elev=0, 
         | 
| 33 | 
            -
                azim= | 
| 34 | 
             
                resolution=512, 
         | 
| 35 | 
             
                gif_dst_path='', 
         | 
| 36 | 
             
                n_views=120, 
         | 
| @@ -49,7 +49,7 @@ def render( | |
| 49 | 
             
                mesh = load_objs_as_meshes([obj_filename], device=device)
         | 
| 50 | 
             
                meshes = mesh.extend(n_views)
         | 
| 51 |  | 
| 52 | 
            -
                if  | 
| 53 | 
             
                    elev = torch.linspace(elev, elev, n_views+1)[:-1]
         | 
| 54 | 
             
                    azim = torch.linspace(0, 360, n_views+1)[:-1]
         | 
| 55 |  | 
| @@ -76,16 +76,15 @@ def render( | |
| 76 | 
             
                )
         | 
| 77 | 
             
                images = renderer(meshes)
         | 
| 78 |  | 
| 79 | 
            -
                 | 
| 80 | 
            -
             | 
| 81 | 
            -
             | 
| 82 | 
            -
             | 
| 83 | 
            -
             | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 84 |  | 
| 85 | 
            -
                # orbit frames rendering
         | 
| 86 | 
            -
                with imageio.get_writer(uri=gif_dst_path, mode='I', duration=1. / fps * 1000, loop=0) as writer:
         | 
| 87 | 
            -
                    for i in range(n_views):
         | 
| 88 | 
            -
                        frame = images[i, ..., :3] if rgb else images[i, ...]
         | 
| 89 | 
            -
                        frame = Image.fromarray((frame.cpu().squeeze(0) * 255).numpy().astype("uint8"))
         | 
| 90 | 
            -
                        writer.append_data(frame)
         | 
| 91 | 
            -
                    return gif_dst_path
         | 
|  | |
| 27 | 
             
            )
         | 
| 28 |  | 
| 29 |  | 
| 30 | 
            +
            def render_func(
         | 
| 31 | 
             
                obj_filename, 
         | 
| 32 | 
             
                elev=0, 
         | 
| 33 | 
            +
                azim=None, 
         | 
| 34 | 
             
                resolution=512, 
         | 
| 35 | 
             
                gif_dst_path='', 
         | 
| 36 | 
             
                n_views=120, 
         | 
|  | |
| 49 | 
             
                mesh = load_objs_as_meshes([obj_filename], device=device)
         | 
| 50 | 
             
                meshes = mesh.extend(n_views)
         | 
| 51 |  | 
| 52 | 
            +
                if azim is None:
         | 
| 53 | 
             
                    elev = torch.linspace(elev, elev, n_views+1)[:-1]
         | 
| 54 | 
             
                    azim = torch.linspace(0, 360, n_views+1)[:-1]
         | 
| 55 |  | 
|  | |
| 76 | 
             
                )
         | 
| 77 | 
             
                images = renderer(meshes)
         | 
| 78 |  | 
| 79 | 
            +
                if gif_dst_path != '': 
         | 
| 80 | 
            +
                    with imageio.get_writer(uri=gif_dst_path, mode='I', duration=1. / fps * 1000, loop=0) as writer:
         | 
| 81 | 
            +
                        for i in range(n_views):
         | 
| 82 | 
            +
                            frame = images[i, ..., :3] if rgb else images[i, ...]
         | 
| 83 | 
            +
                            frame = Image.fromarray((frame.cpu().squeeze(0) * 255).numpy().astype("uint8"))
         | 
| 84 | 
            +
                            writer.append_data(frame)
         | 
| 85 | 
            +
             | 
| 86 | 
            +
                frame = images[..., :3] if rgb else images
         | 
| 87 | 
            +
                frames = [Image.fromarray((fra.cpu().squeeze(0) * 255).numpy().astype("uint8")) for fra in frame]
         | 
| 88 | 
            +
                return frames
         | 
| 89 | 
            +
             | 
| 90 |  | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
    	
        svrm/predictor.py
    CHANGED
    
    | @@ -33,7 +33,7 @@ from omegaconf import OmegaConf | |
| 33 | 
             
            from torchvision import transforms
         | 
| 34 | 
             
            from safetensors.torch import save_file, load_file
         | 
| 35 | 
             
            from .ldm.util import instantiate_from_config
         | 
| 36 | 
            -
            from .ldm.vis_util import  | 
| 37 |  | 
| 38 | 
             
            class MV23DPredictor(object):
         | 
| 39 | 
             
                def __init__(self, ckpt_path, cfg_path, elevation=15, number_view=60, 
         | 
| @@ -46,9 +46,7 @@ class MV23DPredictor(object): | |
| 46 | 
             
                    self.elevation_list = [0, 0, 0, 0, 0, 0, 0]
         | 
| 47 | 
             
                    self.azimuth_list = [0, 60, 120, 180, 240, 300, 0]
         | 
| 48 |  | 
| 49 | 
            -
                    st = time.time()
         | 
| 50 | 
             
                    self.model = self.init_model(ckpt_path, cfg_path)
         | 
| 51 | 
            -
                    print(f"=====> mv23d model init time: {time.time() - st}")
         | 
| 52 |  | 
| 53 | 
             
                    self.input_view_transform = transforms.Compose([
         | 
| 54 | 
             
                        transforms.Resize(504, interpolation=Image.BICUBIC),
         | 
|  | |
| 33 | 
             
            from torchvision import transforms
         | 
| 34 | 
             
            from safetensors.torch import save_file, load_file
         | 
| 35 | 
             
            from .ldm.util import instantiate_from_config
         | 
| 36 | 
            +
            from .ldm.vis_util import render_func
         | 
| 37 |  | 
| 38 | 
             
            class MV23DPredictor(object):
         | 
| 39 | 
             
                def __init__(self, ckpt_path, cfg_path, elevation=15, number_view=60, 
         | 
|  | |
| 46 | 
             
                    self.elevation_list = [0, 0, 0, 0, 0, 0, 0]
         | 
| 47 | 
             
                    self.azimuth_list = [0, 60, 120, 180, 240, 300, 0]
         | 
| 48 |  | 
|  | |
| 49 | 
             
                    self.model = self.init_model(ckpt_path, cfg_path)
         | 
|  | |
| 50 |  | 
| 51 | 
             
                    self.input_view_transform = transforms.Compose([
         | 
| 52 | 
             
                        transforms.Resize(504, interpolation=Image.BICUBIC),
         | 
    	
        third_party/check.py
    ADDED
    
    | @@ -0,0 +1,25 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import os
         | 
| 2 | 
            +
            import sys
         | 
| 3 | 
            +
            import io
         | 
| 4 | 
            +
             | 
| 5 | 
            +
            def check_bake_available():
         | 
| 6 | 
            +
                is_ok = os.path.exists("./third_party/weights/DUSt3R_ViTLarge_BaseDecoder_512_dpt/model.safetensors")
         | 
| 7 | 
            +
                is_ok = is_ok and os.path.exists("./third_party/dust3r")
         | 
| 8 | 
            +
                is_ok = is_ok and os.path.exists("./third_party/dust3r/dust3r")
         | 
| 9 | 
            +
                is_ok = is_ok and os.path.exists("./third_party/dust3r/croco/models")
         | 
| 10 | 
            +
                if is_ok:
         | 
| 11 | 
            +
                    print("Baking is avaliable")
         | 
| 12 | 
            +
                    print("Baking is avaliable")
         | 
| 13 | 
            +
                    print("Baking is avaliable")
         | 
| 14 | 
            +
                else:
         | 
| 15 | 
            +
                    print("Baking is unavailable, please download related files in README")
         | 
| 16 | 
            +
                    print("Baking is unavailable, please download related files in README")
         | 
| 17 | 
            +
                    print("Baking is unavailable, please download related files in README")
         | 
| 18 | 
            +
                return is_ok
         | 
| 19 | 
            +
             | 
| 20 | 
            +
             | 
| 21 | 
            +
             | 
| 22 | 
            +
            if __name__ == "__main__":
         | 
| 23 | 
            +
                
         | 
| 24 | 
            +
                check_bake_available()
         | 
| 25 | 
            +
                
         | 
    	
        third_party/dust3r_utils.py
    ADDED
    
    | @@ -0,0 +1,366 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import sys
         | 
| 2 | 
            +
            import io
         | 
| 3 | 
            +
            import os
         | 
| 4 | 
            +
            import cv2
         | 
| 5 | 
            +
            import math
         | 
| 6 | 
            +
            import numpy as np
         | 
| 7 | 
            +
            from scipy.signal import medfilt
         | 
| 8 | 
            +
            from scipy.spatial import KDTree
         | 
| 9 | 
            +
            from matplotlib import pyplot as plt
         | 
| 10 | 
            +
            from PIL import Image
         | 
| 11 | 
            +
             | 
| 12 | 
            +
            from dust3r.inference import inference
         | 
| 13 | 
            +
             | 
| 14 | 
            +
            from dust3r.utils.image import load_images# , resize_images
         | 
| 15 | 
            +
            from dust3r.image_pairs import make_pairs
         | 
| 16 | 
            +
            from dust3r.cloud_opt import global_aligner, GlobalAlignerMode
         | 
| 17 | 
            +
            from dust3r.utils.geometry import find_reciprocal_matches, xy_grid
         | 
| 18 | 
            +
             | 
| 19 | 
            +
            from third_party.utils.camera_utils import remap_points
         | 
| 20 | 
            +
            from third_party.utils.img_utils import rgba_to_rgb, resize_with_aspect_ratio
         | 
| 21 | 
            +
            from third_party.utils.img_utils import compute_img_diff
         | 
| 22 | 
            +
             | 
| 23 | 
            +
            from PIL.ImageOps import exif_transpose
         | 
| 24 | 
            +
            import torchvision.transforms as tvf
         | 
| 25 | 
            +
            ImgNorm = tvf.Compose([tvf.ToTensor(), tvf.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
         | 
| 26 | 
            +
             | 
| 27 | 
            +
             | 
| 28 | 
            +
            def suppress_output(func):
         | 
| 29 | 
            +
                def wrapper(*args, **kwargs):
         | 
| 30 | 
            +
                    original_stdout = sys.stdout
         | 
| 31 | 
            +
                    original_stderr = sys.stderr
         | 
| 32 | 
            +
                    sys.stdout = io.StringIO()
         | 
| 33 | 
            +
                    sys.stderr = io.StringIO()
         | 
| 34 | 
            +
                    try:
         | 
| 35 | 
            +
                        return func(*args, **kwargs)
         | 
| 36 | 
            +
                    finally:
         | 
| 37 | 
            +
                        sys.stdout = original_stdout
         | 
| 38 | 
            +
                        sys.stderr = original_stderr
         | 
| 39 | 
            +
                return wrapper
         | 
| 40 | 
            +
             | 
| 41 | 
            +
            def _resize_pil_image(img, long_edge_size):
         | 
| 42 | 
            +
                S = max(img.size)
         | 
| 43 | 
            +
                if S > long_edge_size:
         | 
| 44 | 
            +
                    interp = Image.LANCZOS
         | 
| 45 | 
            +
                elif S <= long_edge_size:
         | 
| 46 | 
            +
                    interp = Image.BICUBIC
         | 
| 47 | 
            +
                new_size = tuple(int(round(x*long_edge_size/S)) for x in img.size)
         | 
| 48 | 
            +
                return img.resize(new_size, interp)
         | 
| 49 | 
            +
             | 
| 50 | 
            +
            def resize_images(imgs_list, size, square_ok=False):
         | 
| 51 | 
            +
                """ open and convert all images in a list or folder to proper input format for DUSt3R
         | 
| 52 | 
            +
                """
         | 
| 53 | 
            +
                imgs = []
         | 
| 54 | 
            +
                for img in imgs_list:
         | 
| 55 | 
            +
                    img = exif_transpose(Image.fromarray(img)).convert('RGB')
         | 
| 56 | 
            +
                    W1, H1 = img.size
         | 
| 57 | 
            +
                    if size == 224:
         | 
| 58 | 
            +
                        # resize short side to 224 (then crop)
         | 
| 59 | 
            +
                        img = _resize_pil_image(img, round(size * max(W1/H1, H1/W1)))
         | 
| 60 | 
            +
                    else:
         | 
| 61 | 
            +
                        # resize long side to 512
         | 
| 62 | 
            +
                        img = _resize_pil_image(img, size)
         | 
| 63 | 
            +
                    W, H = img.size
         | 
| 64 | 
            +
                    cx, cy = W//2, H//2
         | 
| 65 | 
            +
                    if size == 224:
         | 
| 66 | 
            +
                        half = min(cx, cy)
         | 
| 67 | 
            +
                        img = img.crop((cx-half, cy-half, cx+half, cy+half))
         | 
| 68 | 
            +
                    else:
         | 
| 69 | 
            +
                        halfw, halfh = ((2*cx)//16)*8, ((2*cy)//16)*8
         | 
| 70 | 
            +
                        if not (square_ok) and W == H:
         | 
| 71 | 
            +
                            halfh = 3*halfw/4
         | 
| 72 | 
            +
                        img = img.crop((cx-halfw, cy-halfh, cx+halfw, cy+halfh))
         | 
| 73 | 
            +
             | 
| 74 | 
            +
                    W2, H2 = img.size
         | 
| 75 | 
            +
                    imgs.append(dict(img=ImgNorm(img)[None], true_shape=np.int32(
         | 
| 76 | 
            +
                        [img.size[::-1]]), idx=len(imgs), instance=str(len(imgs))))
         | 
| 77 | 
            +
             | 
| 78 | 
            +
                return imgs
         | 
| 79 | 
            +
             | 
| 80 | 
            +
            @suppress_output
         | 
| 81 | 
            +
            def infer_match(images, model, vis=False, niter=300, lr=0.01, schedule='cosine', device="cuda:0"):
         | 
| 82 | 
            +
                batch_size = 1
         | 
| 83 | 
            +
                schedule = 'cosine'
         | 
| 84 | 
            +
                lr = 0.01
         | 
| 85 | 
            +
                niter = 300
         | 
| 86 | 
            +
                
         | 
| 87 | 
            +
                images_packed = resize_images(images, size=512, square_ok=True)
         | 
| 88 | 
            +
                # images_packed = images
         | 
| 89 | 
            +
             | 
| 90 | 
            +
                pairs = make_pairs(images_packed, scene_graph='complete', prefilter=None, symmetrize=True)
         | 
| 91 | 
            +
                output = inference(pairs, model, device, batch_size=batch_size, verbose=False)
         | 
| 92 | 
            +
             | 
| 93 | 
            +
                scene = global_aligner(output, device=device, mode=GlobalAlignerMode.PointCloudOptimizer)
         | 
| 94 | 
            +
                loss = scene.compute_global_alignment(init="mst", niter=niter, schedule=schedule, lr=lr)
         | 
| 95 | 
            +
             | 
| 96 | 
            +
                # retrieve useful values from scene:
         | 
| 97 | 
            +
                imgs = scene.imgs
         | 
| 98 | 
            +
                # focals = scene.get_focals()
         | 
| 99 | 
            +
                # poses = scene.get_im_poses()
         | 
| 100 | 
            +
                pts3d = scene.get_pts3d()
         | 
| 101 | 
            +
                confidence_masks = scene.get_masks()
         | 
| 102 | 
            +
             | 
| 103 | 
            +
                # visualize reconstruction
         | 
| 104 | 
            +
                # scene.show()
         | 
| 105 | 
            +
             | 
| 106 | 
            +
                # find 2D-2D matches between the two images
         | 
| 107 | 
            +
                pts2d_list, pts3d_list = [], []
         | 
| 108 | 
            +
                for i in range(2):
         | 
| 109 | 
            +
                    conf_i = confidence_masks[i].cpu().numpy()
         | 
| 110 | 
            +
                    pts2d_list.append(xy_grid(*imgs[i].shape[:2][::-1])[conf_i])  # imgs[i].shape[:2] = (H, W)
         | 
| 111 | 
            +
                    pts3d_list.append(pts3d[i].detach().cpu().numpy()[conf_i])
         | 
| 112 | 
            +
                    if pts3d_list[-1].shape[0] == 0:
         | 
| 113 | 
            +
                        return np.zeros((0, 2)), np.zeros((0, 2))
         | 
| 114 | 
            +
                reciprocal_in_P2, nn2_in_P1, num_matches = find_reciprocal_matches(*pts3d_list)
         | 
| 115 | 
            +
             | 
| 116 | 
            +
                matches_im1 = pts2d_list[1][reciprocal_in_P2]
         | 
| 117 | 
            +
                matches_im0 = pts2d_list[0][nn2_in_P1][reciprocal_in_P2]
         | 
| 118 | 
            +
             | 
| 119 | 
            +
                # visualize a few matches
         | 
| 120 | 
            +
                if vis == True:
         | 
| 121 | 
            +
                    print(f'found {num_matches} matches')
         | 
| 122 | 
            +
                    n_viz = 20
         | 
| 123 | 
            +
                    match_idx_to_viz = np.round(np.linspace(0, num_matches - 1, n_viz)).astype(int)
         | 
| 124 | 
            +
                    viz_matches_im0, viz_matches_im1 = matches_im0[match_idx_to_viz], matches_im1[match_idx_to_viz]
         | 
| 125 | 
            +
             | 
| 126 | 
            +
                    H0, W0, H1, W1 = *imgs[0].shape[:2], *imgs[1].shape[:2]
         | 
| 127 | 
            +
                    img0 = np.pad(imgs[0], ((0, max(H1 - H0, 0)), (0, 0), (0, 0)), 'constant', constant_values=0)
         | 
| 128 | 
            +
                    img1 = np.pad(imgs[1], ((0, max(H0 - H1, 0)), (0, 0), (0, 0)), 'constant', constant_values=0)
         | 
| 129 | 
            +
                    img = np.concatenate((img0, img1), axis=1)
         | 
| 130 | 
            +
                    plt.figure()
         | 
| 131 | 
            +
                    plt.imshow(img)
         | 
| 132 | 
            +
                    cmap = plt.get_cmap('jet')
         | 
| 133 | 
            +
                    for i in range(n_viz):
         | 
| 134 | 
            +
                        (x0, y0), (x1, y1) = viz_matches_im0[i].T, viz_matches_im1[i].T
         | 
| 135 | 
            +
                        plt.plot([x0, x1 + W0], [y0, y1], '-+', color=cmap(i / (n_viz - 1)), scalex=False, scaley=False)
         | 
| 136 | 
            +
                    plt.show(block=True)
         | 
| 137 | 
            +
             | 
| 138 | 
            +
                matches_im0 = remap_points(images[0].shape, matches_im0)
         | 
| 139 | 
            +
                matches_im1 = remap_points(images[1].shape, matches_im1)
         | 
| 140 | 
            +
                return matches_im0, matches_im1
         | 
| 141 | 
            +
             | 
| 142 | 
            +
             | 
| 143 | 
            +
            def point_transform(H, pt):
         | 
| 144 | 
            +
                """
         | 
| 145 | 
            +
                @param: H is homography matrix of dimension (3x3)
         | 
| 146 | 
            +
                @param: pt is the (x, y) point to be transformed
         | 
| 147 | 
            +
             | 
| 148 | 
            +
                Return:
         | 
| 149 | 
            +
                        returns a transformed point ptrans = H*pt.
         | 
| 150 | 
            +
                """
         | 
| 151 | 
            +
                a = H[0, 0] * pt[0] + H[0, 1] * pt[1] + H[0, 2]
         | 
| 152 | 
            +
                b = H[1, 0] * pt[0] + H[1, 1] * pt[1] + H[1, 2]
         | 
| 153 | 
            +
                c = H[2, 0] * pt[0] + H[2, 1] * pt[1] + H[2, 2]
         | 
| 154 | 
            +
                return [a / c, b / c]
         | 
| 155 | 
            +
             | 
| 156 | 
            +
             | 
| 157 | 
            +
            def points_transform(H, pt_x, pt_y):
         | 
| 158 | 
            +
                """
         | 
| 159 | 
            +
                @param: H is homography matrix of dimension (3x3)
         | 
| 160 | 
            +
                @param: pt is the (x, y) point to be transformed
         | 
| 161 | 
            +
             | 
| 162 | 
            +
                Return:
         | 
| 163 | 
            +
                        returns a transformed point ptrans = H*pt.
         | 
| 164 | 
            +
                """
         | 
| 165 | 
            +
                a = H[0, 0] * pt_x + H[0, 1] * pt_y + H[0, 2]
         | 
| 166 | 
            +
                b = H[1, 0] * pt_x + H[1, 1] * pt_y + H[1, 2]
         | 
| 167 | 
            +
                c = H[2, 0] * pt_x + H[2, 1] * pt_y + H[2, 2]
         | 
| 168 | 
            +
                return (a / c, b / c)
         | 
| 169 | 
            +
             | 
| 170 | 
            +
             | 
| 171 | 
            +
            def motion_propagate(old_points, new_points, old_size, new_size, H_size=(21, 21)):
         | 
| 172 | 
            +
                """
         | 
| 173 | 
            +
                @param: old_points are points in old_frame that are
         | 
| 174 | 
            +
                        matched feature points with new_frame
         | 
| 175 | 
            +
                @param: new_points are points in new_frame that are
         | 
| 176 | 
            +
                        matched feature points with old_frame
         | 
| 177 | 
            +
                @param: old_frame is the frame to which
         | 
| 178 | 
            +
                        motion mesh needs to be obtained
         | 
| 179 | 
            +
                @param: H is the homography between old and new points
         | 
| 180 | 
            +
             | 
| 181 | 
            +
                Return:
         | 
| 182 | 
            +
                        returns a motion mesh in x-direction
         | 
| 183 | 
            +
                        and y-direction for old_frame
         | 
| 184 | 
            +
                """
         | 
| 185 | 
            +
                # spreads motion over the mesh for the old_frame
         | 
| 186 | 
            +
                x_motion = np.zeros(H_size)
         | 
| 187 | 
            +
                y_motion = np.zeros(H_size)
         | 
| 188 | 
            +
                mesh_x_num, mesh_y_num = H_size[0], H_size[1]
         | 
| 189 | 
            +
                pixels_x, pixels_y = (old_size[1]) / (mesh_x_num - 1), (old_size[0]) / (mesh_y_num - 1)
         | 
| 190 | 
            +
                radius = max(pixels_x, pixels_y) * 5
         | 
| 191 | 
            +
                sigma = radius / 3.0
         | 
| 192 | 
            +
             | 
| 193 | 
            +
                H_global = None
         | 
| 194 | 
            +
                if old_points.shape[0] > 3:
         | 
| 195 | 
            +
                    # pre-warping with global homography
         | 
| 196 | 
            +
                    H_global, _ = cv2.findHomography(old_points, new_points, cv2.RANSAC)
         | 
| 197 | 
            +
                if H_global is None:
         | 
| 198 | 
            +
                    old_tmp = np.array([[0, 0], [0, old_size[0]], [old_size[1], 0], [old_size[1], old_size[0]]])
         | 
| 199 | 
            +
                    new_tmp = np.array([[0, 0], [0, new_size[0]], [new_size[1], 0], [new_size[1], new_size[0]]])
         | 
| 200 | 
            +
                    H_global, _ = cv2.findHomography(old_tmp, new_tmp, cv2.RANSAC)
         | 
| 201 | 
            +
             | 
| 202 | 
            +
                for i in range(mesh_x_num):
         | 
| 203 | 
            +
                    for j in range(mesh_y_num):
         | 
| 204 | 
            +
                        pt = [pixels_x * i, pixels_y * j]
         | 
| 205 | 
            +
                        ptrans = point_transform(H_global, pt)
         | 
| 206 | 
            +
                        x_motion[i, j] = ptrans[0]
         | 
| 207 | 
            +
                        y_motion[i, j] = ptrans[1]
         | 
| 208 | 
            +
             | 
| 209 | 
            +
                # disturbute feature motion vectors
         | 
| 210 | 
            +
                weighted_move_x = np.zeros(H_size)
         | 
| 211 | 
            +
                weighted_move_y = np.zeros(H_size)
         | 
| 212 | 
            +
                # 构建 KDTree
         | 
| 213 | 
            +
                tree = KDTree(old_points)
         | 
| 214 | 
            +
                # 计算权重和移动值
         | 
| 215 | 
            +
                for i in range(mesh_x_num):
         | 
| 216 | 
            +
                    for j in range(mesh_y_num):
         | 
| 217 | 
            +
                        vertex = [pixels_x * i, pixels_y * j]
         | 
| 218 | 
            +
                        neighbor_indices = tree.query_ball_point(vertex, radius, workers=-1)
         | 
| 219 | 
            +
                        if len(neighbor_indices) > 0:
         | 
| 220 | 
            +
                            pts = old_points[neighbor_indices]
         | 
| 221 | 
            +
                            sts = new_points[neighbor_indices]
         | 
| 222 | 
            +
                            ptrans_x, ptrans_y = points_transform(H_global, pts[:, 0], pts[:, 1])
         | 
| 223 | 
            +
                            moves_x = sts[:, 0] - ptrans_x
         | 
| 224 | 
            +
                            moves_y = sts[:, 1] - ptrans_y
         | 
| 225 | 
            +
             | 
| 226 | 
            +
                            dists = np.sqrt((vertex[0] - pts[:, 0]) ** 2 + (vertex[1] - pts[:, 1]) ** 2)
         | 
| 227 | 
            +
                            weights_x = np.exp(-(dists ** 2) / (2 * sigma ** 2))
         | 
| 228 | 
            +
                            weights_y = np.exp(-(dists ** 2) / (2 * sigma ** 2))
         | 
| 229 | 
            +
             | 
| 230 | 
            +
                            weighted_move_x[i, j] = np.sum(weights_x * moves_x) / (np.sum(weights_x) + 0.1)
         | 
| 231 | 
            +
                            weighted_move_y[i, j] = np.sum(weights_y * moves_y) / (np.sum(weights_y) + 0.1)
         | 
| 232 | 
            +
             | 
| 233 | 
            +
                x_motion_mesh = x_motion + weighted_move_x
         | 
| 234 | 
            +
                y_motion_mesh = y_motion + weighted_move_y
         | 
| 235 | 
            +
                '''
         | 
| 236 | 
            +
                # apply median filter (f-1) on obtained motion for each vertex
         | 
| 237 | 
            +
                x_motion_mesh = np.zeros((mesh_x_num, mesh_y_num), dtype=float)
         | 
| 238 | 
            +
                y_motion_mesh = np.zeros((mesh_x_num, mesh_y_num), dtype=float)
         | 
| 239 | 
            +
             | 
| 240 | 
            +
                for key in x_motion.keys():
         | 
| 241 | 
            +
                    try:
         | 
| 242 | 
            +
                        temp_x_motion[key].sort()
         | 
| 243 | 
            +
                        x_motion_mesh[key] = x_motion[key]+temp_x_motion[key][len(temp_x_motion[key])//2]
         | 
| 244 | 
            +
                    except KeyError:
         | 
| 245 | 
            +
                        x_motion_mesh[key] = x_motion[key]
         | 
| 246 | 
            +
                    try:
         | 
| 247 | 
            +
                        temp_y_motion[key].sort()
         | 
| 248 | 
            +
                        y_motion_mesh[key] = y_motion[key]+temp_y_motion[key][len(temp_y_motion[key])//2]
         | 
| 249 | 
            +
                    except KeyError:
         | 
| 250 | 
            +
                        y_motion_mesh[key] = y_motion[key]
         | 
| 251 | 
            +
             | 
| 252 | 
            +
                # apply second median filter (f-2) over the motion mesh for outliers
         | 
| 253 | 
            +
                #x_motion_mesh = medfilt(x_motion_mesh, kernel_size=[3, 3])
         | 
| 254 | 
            +
                #y_motion_mesh = medfilt(y_motion_mesh, kernel_size=[3, 3])
         | 
| 255 | 
            +
                '''
         | 
| 256 | 
            +
                return x_motion_mesh, y_motion_mesh
         | 
| 257 | 
            +
             | 
| 258 | 
            +
             | 
| 259 | 
            +
            def mesh_warp_points(points, x_motion_mesh, y_motion_mesh, img_size):
         | 
| 260 | 
            +
                ptrans = []
         | 
| 261 | 
            +
                mesh_x_num, mesh_y_num = x_motion_mesh.shape
         | 
| 262 | 
            +
                pixels_x, pixels_y = (img_size[1]) / (mesh_x_num - 1), (img_size[0]) / (mesh_y_num - 1)
         | 
| 263 | 
            +
                for pt in points:
         | 
| 264 | 
            +
                    i = int(pt[0] // pixels_x)
         | 
| 265 | 
            +
                    j = int(pt[1] // pixels_y)
         | 
| 266 | 
            +
                    src = [[i * pixels_x, j * pixels_y],
         | 
| 267 | 
            +
                           [(i + 1) * pixels_x, j * pixels_y],
         | 
| 268 | 
            +
                           [i * pixels_x, (j + 1) * pixels_y],
         | 
| 269 | 
            +
                           [(i + 1) * pixels_x, (j + 1) * pixels_y]]
         | 
| 270 | 
            +
                    src = np.asarray(src)
         | 
| 271 | 
            +
                    dst = [[x_motion_mesh[i, j], y_motion_mesh[i, j]],
         | 
| 272 | 
            +
                           [x_motion_mesh[i + 1, j], y_motion_mesh[i + 1, j]],
         | 
| 273 | 
            +
                           [x_motion_mesh[i, j + 1], y_motion_mesh[i, j + 1]],
         | 
| 274 | 
            +
                           [x_motion_mesh[i + 1, j + 1], y_motion_mesh[i + 1, j + 1]]]
         | 
| 275 | 
            +
                    dst = np.asarray(dst)
         | 
| 276 | 
            +
                    H, _ = cv2.findHomography(src, dst, cv2.RANSAC)
         | 
| 277 | 
            +
                    x, y = points_transform(H, pt[0], pt[1])
         | 
| 278 | 
            +
                    ptrans.append([x, y])
         | 
| 279 | 
            +
             | 
| 280 | 
            +
                return np.array(ptrans)
         | 
| 281 | 
            +
             | 
| 282 | 
            +
             | 
| 283 | 
            +
            def mesh_warp_frame(frame, x_motion_mesh, y_motion_mesh, resize):
         | 
| 284 | 
            +
                """
         | 
| 285 | 
            +
                @param: frame is the current frame
         | 
| 286 | 
            +
                @param: x_motion_mesh is the motion_mesh to
         | 
| 287 | 
            +
                        be warped on frame along x-direction
         | 
| 288 | 
            +
                @param: y_motion_mesh is the motion mesh to
         | 
| 289 | 
            +
                        be warped on frame along y-direction
         | 
| 290 | 
            +
                @param: resize is the desired output size (tuple of width, height)
         | 
| 291 | 
            +
             | 
| 292 | 
            +
                Returns:
         | 
| 293 | 
            +
                        returns a mesh warped frame according
         | 
| 294 | 
            +
                        to given motion meshes x_motion_mesh,
         | 
| 295 | 
            +
                        y_motion_mesh, resized to the specified size
         | 
| 296 | 
            +
                """
         | 
| 297 | 
            +
             | 
| 298 | 
            +
                map_x = np.zeros(resize, np.float32)
         | 
| 299 | 
            +
                map_y = np.zeros(resize, np.float32)
         | 
| 300 | 
            +
             | 
| 301 | 
            +
                mesh_x_num, mesh_y_num = x_motion_mesh.shape
         | 
| 302 | 
            +
                pixels_x, pixels_y = (resize[1]) / (mesh_x_num - 1), (resize[0]) / (mesh_y_num - 1)
         | 
| 303 | 
            +
             | 
| 304 | 
            +
                for i in range(mesh_x_num - 1):
         | 
| 305 | 
            +
                    for j in range(mesh_y_num - 1):
         | 
| 306 | 
            +
                        src = [[i * pixels_x, j * pixels_y],
         | 
| 307 | 
            +
                               [(i + 1) * pixels_x, j * pixels_y],
         | 
| 308 | 
            +
                               [i * pixels_x, (j + 1) * pixels_y],
         | 
| 309 | 
            +
                               [(i + 1) * pixels_x, (j + 1) * pixels_y]]
         | 
| 310 | 
            +
                        src = np.asarray(src)
         | 
| 311 | 
            +
             | 
| 312 | 
            +
                        dst = [[x_motion_mesh[i, j], y_motion_mesh[i, j]],
         | 
| 313 | 
            +
                               [x_motion_mesh[i + 1, j], y_motion_mesh[i + 1, j]],
         | 
| 314 | 
            +
                               [x_motion_mesh[i, j + 1], y_motion_mesh[i, j + 1]],
         | 
| 315 | 
            +
                               [x_motion_mesh[i + 1, j + 1], y_motion_mesh[i + 1, j + 1]]]
         | 
| 316 | 
            +
                        dst = np.asarray(dst)
         | 
| 317 | 
            +
                        H, _ = cv2.findHomography(src, dst, cv2.RANSAC)
         | 
| 318 | 
            +
             | 
| 319 | 
            +
                        start_x = math.ceil(pixels_x * i)
         | 
| 320 | 
            +
                        end_x = math.ceil(pixels_x * (i + 1))
         | 
| 321 | 
            +
                        start_y = math.ceil(pixels_y * j)
         | 
| 322 | 
            +
                        end_y = math.ceil(pixels_y * (j + 1))
         | 
| 323 | 
            +
             | 
| 324 | 
            +
                        x, y = np.meshgrid(range(start_x, end_x), range(start_y, end_y), indexing='ij')
         | 
| 325 | 
            +
             | 
| 326 | 
            +
                        map_x[y, x], map_y[y, x] = points_transform(H, x, y)
         | 
| 327 | 
            +
             | 
| 328 | 
            +
                # deforms mesh and directly outputs the resized frame
         | 
| 329 | 
            +
                resized_frame = cv2.remap(frame, map_x, map_y,
         | 
| 330 | 
            +
                                          interpolation=cv2.INTER_LINEAR, borderMode=cv2.BORDER_CONSTANT,
         | 
| 331 | 
            +
                                          borderValue=(255, 255, 255))
         | 
| 332 | 
            +
                return resized_frame
         | 
| 333 | 
            +
             | 
| 334 | 
            +
             | 
| 335 | 
            +
            def infer_warp_mesh_img(src, dst, model, vis=False):
         | 
| 336 | 
            +
                if isinstance(src, str):
         | 
| 337 | 
            +
                    image1 = cv2.imread(src,   cv2.IMREAD_UNCHANGED)
         | 
| 338 | 
            +
                    image2 = cv2.imread(dst, cv2.IMREAD_UNCHANGED)
         | 
| 339 | 
            +
                    image1 = cv2.cvtColor(image1, cv2.COLOR_BGR2RGB)
         | 
| 340 | 
            +
                    image2 = cv2.cvtColor(image2, cv2.COLOR_BGR2RGB)
         | 
| 341 | 
            +
                elif isinstance(src, Image.Image):
         | 
| 342 | 
            +
                    image1 = np.array(src)
         | 
| 343 | 
            +
                    image2 = np.array(dst)
         | 
| 344 | 
            +
                else:
         | 
| 345 | 
            +
                    assert isinstance(src, np.ndarray)
         | 
| 346 | 
            +
             | 
| 347 | 
            +
                image1 = rgba_to_rgb(image1)
         | 
| 348 | 
            +
                image2 = rgba_to_rgb(image2)
         | 
| 349 | 
            +
             | 
| 350 | 
            +
                image1_padded = resize_with_aspect_ratio(image1, image2)
         | 
| 351 | 
            +
                resized_image1 = cv2.resize(image1_padded, (image2.shape[1], image2.shape[0]), interpolation=cv2.INTER_AREA)
         | 
| 352 | 
            +
             | 
| 353 | 
            +
                matches_im0, matches_im1 = infer_match([resized_image1, image2], model, vis=vis)
         | 
| 354 | 
            +
                matches_im0 = matches_im0 * image1_padded.shape[0] / resized_image1.shape[0]
         | 
| 355 | 
            +
             | 
| 356 | 
            +
                # print('Estimate Mesh Grid')
         | 
| 357 | 
            +
                mesh_x, mesh_y = motion_propagate(matches_im1, matches_im0, image2.shape[:2], image1_padded.shape[:2])
         | 
| 358 | 
            +
             | 
| 359 | 
            +
                aligned_image = mesh_warp_frame(image1_padded, mesh_x, mesh_y, (image2.shape[0], image2.shape[1]))
         | 
| 360 | 
            +
             | 
| 361 | 
            +
                matches_im0_from_im1 = mesh_warp_points(matches_im1, mesh_x, mesh_y, (image2.shape[1], image2.shape[0]))
         | 
| 362 | 
            +
                
         | 
| 363 | 
            +
                info = compute_img_diff(aligned_image, image2, matches_im0, matches_im0_from_im1, vis=vis)
         | 
| 364 | 
            +
                
         | 
| 365 | 
            +
                return aligned_image, info
         | 
| 366 | 
            +
             | 
    	
        third_party/gen_baking.py
    ADDED
    
    | @@ -0,0 +1,288 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import os, sys, time
         | 
| 2 | 
            +
            from typing import List, Optional
         | 
| 3 | 
            +
            from iopath.common.file_io import PathManager
         | 
| 4 | 
            +
             | 
| 5 | 
            +
            import cv2
         | 
| 6 | 
            +
            import imageio
         | 
| 7 | 
            +
            import numpy as np
         | 
| 8 | 
            +
            from PIL import Image
         | 
| 9 | 
            +
            import matplotlib.pyplot as plt
         | 
| 10 | 
            +
             | 
| 11 | 
            +
            import torch
         | 
| 12 | 
            +
            import torch.nn.functional as F
         | 
| 13 | 
            +
            from torchvision import transforms
         | 
| 14 | 
            +
             | 
| 15 | 
            +
            import trimesh
         | 
| 16 | 
            +
            from pytorch3d.io import load_objs_as_meshes, load_obj, save_obj
         | 
| 17 | 
            +
            from pytorch3d.ops import interpolate_face_attributes
         | 
| 18 | 
            +
            from pytorch3d.common.datatypes import Device
         | 
| 19 | 
            +
            from pytorch3d.structures import Meshes
         | 
| 20 | 
            +
            from pytorch3d.renderer import (
         | 
| 21 | 
            +
                look_at_view_transform,
         | 
| 22 | 
            +
                FoVPerspectiveCameras,
         | 
| 23 | 
            +
                PointLights,
         | 
| 24 | 
            +
                DirectionalLights,
         | 
| 25 | 
            +
                AmbientLights,
         | 
| 26 | 
            +
                Materials,
         | 
| 27 | 
            +
                RasterizationSettings,
         | 
| 28 | 
            +
                MeshRenderer,
         | 
| 29 | 
            +
                MeshRasterizer,
         | 
| 30 | 
            +
                SoftPhongShader,
         | 
| 31 | 
            +
                TexturesUV,
         | 
| 32 | 
            +
                TexturesVertex,
         | 
| 33 | 
            +
                camera_position_from_spherical_angles,
         | 
| 34 | 
            +
                BlendParams,
         | 
| 35 | 
            +
            )
         | 
| 36 | 
            +
             | 
| 37 | 
            +
             | 
| 38 | 
            +
            def erode_mask(src_mask, p=1 / 20.0):
         | 
| 39 | 
            +
                monoMaskImage = cv2.split(src_mask)[0]
         | 
| 40 | 
            +
                br = cv2.boundingRect(monoMaskImage)
         | 
| 41 | 
            +
                k = int(min(br[2], br[3]) * p)
         | 
| 42 | 
            +
                kernel = np.ones((k, k), dtype=np.uint8)
         | 
| 43 | 
            +
                dst_mask = cv2.erode(src_mask, kernel, 1)
         | 
| 44 | 
            +
                return dst_mask
         | 
| 45 | 
            +
             | 
| 46 | 
            +
            def load_objs_as_meshes_fast(
         | 
| 47 | 
            +
                verts,
         | 
| 48 | 
            +
                faces,
         | 
| 49 | 
            +
                aux,
         | 
| 50 | 
            +
                device: Optional[Device] = None,
         | 
| 51 | 
            +
                load_textures: bool = True,
         | 
| 52 | 
            +
                create_texture_atlas: bool = False,
         | 
| 53 | 
            +
                texture_atlas_size: int = 4,
         | 
| 54 | 
            +
                texture_wrap: Optional[str] = "repeat",
         | 
| 55 | 
            +
                path_manager: Optional[PathManager] = None,
         | 
| 56 | 
            +
            ):
         | 
| 57 | 
            +
                tex = None
         | 
| 58 | 
            +
                if create_texture_atlas:
         | 
| 59 | 
            +
                    # TexturesAtlas type
         | 
| 60 | 
            +
                    tex = TexturesAtlas(atlas=[aux.texture_atlas.to(device)])
         | 
| 61 | 
            +
                else:
         | 
| 62 | 
            +
                    # TexturesUV type
         | 
| 63 | 
            +
                    tex_maps = aux.texture_images
         | 
| 64 | 
            +
                    if tex_maps is not None and len(tex_maps) > 0:
         | 
| 65 | 
            +
                        verts_uvs = aux.verts_uvs.to(device)  # (V, 2)
         | 
| 66 | 
            +
                        faces_uvs = faces.textures_idx.to(device)  # (F, 3)
         | 
| 67 | 
            +
                        image = list(tex_maps.values())[0].to(device)[None]
         | 
| 68 | 
            +
                        tex = TexturesUV(verts_uvs=[verts_uvs], faces_uvs=[faces_uvs], maps=image)
         | 
| 69 | 
            +
                mesh = Meshes( verts=[verts.to(device)], faces=[faces.verts_idx.to(device)], textures=tex)
         | 
| 70 | 
            +
                return mesh
         | 
| 71 | 
            +
             | 
| 72 | 
            +
             | 
| 73 | 
            +
            def get_triangle_to_triangle(tri_1, tri_2, img_refined):
         | 
| 74 | 
            +
                '''
         | 
| 75 | 
            +
                    args:
         | 
| 76 | 
            +
                        tri_1: 
         | 
| 77 | 
            +
                        tri_2: 
         | 
| 78 | 
            +
                '''
         | 
| 79 | 
            +
                r1 = cv2.boundingRect(tri_1)
         | 
| 80 | 
            +
                r2 = cv2.boundingRect(tri_2)
         | 
| 81 | 
            +
                
         | 
| 82 | 
            +
                tri_1_cropped = []
         | 
| 83 | 
            +
                tri_2_cropped = []
         | 
| 84 | 
            +
                for i in range(0, 3):
         | 
| 85 | 
            +
                    tri_1_cropped.append(((tri_1[i][1] - r1[1]), (tri_1[i][0] - r1[0])))
         | 
| 86 | 
            +
                    tri_2_cropped.append(((tri_2[i][1] - r2[1]), (tri_2[i][0] - r2[0])))
         | 
| 87 | 
            +
             | 
| 88 | 
            +
                trans = cv2.getAffineTransform(np.float32(tri_1_cropped), np.float32(tri_2_cropped))
         | 
| 89 | 
            +
             | 
| 90 | 
            +
                img_1_cropped = np.float32(img_refined[r1[0]:r1[0] + r1[2], r1[1]:r1[1] + r1[3]])
         | 
| 91 | 
            +
                
         | 
| 92 | 
            +
                mask = np.zeros((r2[2], r2[3], 3), dtype=np.float32)
         | 
| 93 | 
            +
                
         | 
| 94 | 
            +
                cv2.fillConvexPoly(mask, np.int32(tri_2_cropped), (1.0, 1.0, 1.0), 16, 0)
         | 
| 95 | 
            +
                
         | 
| 96 | 
            +
                img_2_cropped = cv2.warpAffine(
         | 
| 97 | 
            +
                    img_1_cropped, trans, (r2[3], r2[2]), None, 
         | 
| 98 | 
            +
                    flags = cv2.INTER_LINEAR,
         | 
| 99 | 
            +
                    borderMode = cv2.BORDER_REFLECT_101
         | 
| 100 | 
            +
                )
         | 
| 101 | 
            +
                return mask, img_2_cropped, r2
         | 
| 102 | 
            +
             | 
| 103 | 
            +
             | 
| 104 | 
            +
            def back_projection(
         | 
| 105 | 
            +
                obj_file, 
         | 
| 106 | 
            +
                init_texture_file, 
         | 
| 107 | 
            +
                front_view_file, 
         | 
| 108 | 
            +
                dst_dir, 
         | 
| 109 | 
            +
                render_resolution=512, 
         | 
| 110 | 
            +
                uv_resolution=600, 
         | 
| 111 | 
            +
                normalThreshold=0.3, # 0.3 
         | 
| 112 | 
            +
                rgb_thresh=820, # 520
         | 
| 113 | 
            +
                views=None, 
         | 
| 114 | 
            +
                camera_dist=1.5, 
         | 
| 115 | 
            +
                erode_scale=1/100.0, 
         | 
| 116 | 
            +
                device="cuda:0"
         | 
| 117 | 
            +
            ):
         | 
| 118 | 
            +
                # obj_file: 带有uv的obj
         | 
| 119 | 
            +
                # init_texture_file: 初始展开的uv贴图
         | 
| 120 | 
            +
                # render_resolution 正面视角渲染分辨率
         | 
| 121 | 
            +
                # uv_resolution 贴图分辨率
         | 
| 122 | 
            +
                # thres:normal threshold
         | 
| 123 | 
            +
             | 
| 124 | 
            +
                os.makedirs(dst_dir, exist_ok=True)
         | 
| 125 | 
            +
             | 
| 126 | 
            +
                if isinstance(front_view_file, str):
         | 
| 127 | 
            +
                    src = np.array(Image.open(front_view_file).convert("RGB"))
         | 
| 128 | 
            +
                elif isinstance(front_view_file, Image.Image):
         | 
| 129 | 
            +
                    src = np.array(front_view_file.convert("RGB"))
         | 
| 130 | 
            +
                else:
         | 
| 131 | 
            +
                    raise "need file_path or pil"
         | 
| 132 | 
            +
                
         | 
| 133 | 
            +
                image_size = (render_resolution, render_resolution)
         | 
| 134 | 
            +
             | 
| 135 | 
            +
                init_texture = Image.open(init_texture_file)
         | 
| 136 | 
            +
                init_texture = init_texture.convert("RGB")
         | 
| 137 | 
            +
                # init_texture = init_texture.resize((uv_resolution, uv_resolution))
         | 
| 138 | 
            +
                init_texture = np.array(init_texture).astype(np.float32)  
         | 
| 139 | 
            +
                
         | 
| 140 | 
            +
                print("load obj", obj_file)
         | 
| 141 | 
            +
                verts, faces, aux = load_obj(obj_file, device=device)
         | 
| 142 | 
            +
                mesh = load_objs_as_meshes_fast(verts, faces, aux, device=device)
         | 
| 143 | 
            +
             | 
| 144 | 
            +
             | 
| 145 | 
            +
                t0 = time.time()
         | 
| 146 | 
            +
                verts_uvs = aux.verts_uvs
         | 
| 147 | 
            +
                triangle_uvs = verts_uvs[faces.textures_idx]
         | 
| 148 | 
            +
                triangle_uvs = torch.cat([
         | 
| 149 | 
            +
                    ((1 - triangle_uvs[..., 1]) * uv_resolution).unsqueeze(2),
         | 
| 150 | 
            +
                    (triangle_uvs[..., 0] * uv_resolution).unsqueeze(2),
         | 
| 151 | 
            +
                ], dim=-1)
         | 
| 152 | 
            +
                triangle_uvs = np.clip(np.round(np.float32(triangle_uvs.cpu())).astype(np.int64), 0, uv_resolution-1)
         | 
| 153 | 
            +
             | 
| 154 | 
            +
                # import ipdb;ipdb.set_trace()
         | 
| 155 | 
            +
             | 
| 156 | 
            +
                
         | 
| 157 | 
            +
                R0, T0 = look_at_view_transform(camera_dist, views[0][0], views[0][1])
         | 
| 158 | 
            +
             | 
| 159 | 
            +
                cameras = FoVPerspectiveCameras(device=device, R=R0, T=T0, fov=49.1)
         | 
| 160 | 
            +
                
         | 
| 161 | 
            +
                camera_normal = camera_position_from_spherical_angles(1, views[0][0], views[0][1]).to(device)
         | 
| 162 | 
            +
                screen_coords = cameras.transform_points_screen(verts, image_size=image_size)[:, :2]
         | 
| 163 | 
            +
                screen_coords = torch.cat([screen_coords[..., 1, None], screen_coords[..., 0, None]], dim=-1)
         | 
| 164 | 
            +
                triangle_screen_coords = np.round(np.float32(screen_coords[faces.verts_idx].cpu())) # numpy.ndarray (90000, 3, 2)
         | 
| 165 | 
            +
                triangle_screen_coords = np.clip(triangle_screen_coords.astype(np.int64), 0, render_resolution-1)
         | 
| 166 | 
            +
             | 
| 167 | 
            +
                renderer = MeshRenderer(
         | 
| 168 | 
            +
                    rasterizer=MeshRasterizer(
         | 
| 169 | 
            +
                        cameras=cameras,
         | 
| 170 | 
            +
                        raster_settings= RasterizationSettings(
         | 
| 171 | 
            +
                            image_size=image_size,
         | 
| 172 | 
            +
                            blur_radius=0.0,
         | 
| 173 | 
            +
                            faces_per_pixel=1,
         | 
| 174 | 
            +
                        ),
         | 
| 175 | 
            +
                    ),
         | 
| 176 | 
            +
                    shader=SoftPhongShader(
         | 
| 177 | 
            +
                        device=device,
         | 
| 178 | 
            +
                        cameras=cameras,
         | 
| 179 | 
            +
                        lights= AmbientLights(device=device),
         | 
| 180 | 
            +
                        blend_params=BlendParams(background_color=(1.0, 1.0, 1.0)),
         | 
| 181 | 
            +
                    )
         | 
| 182 | 
            +
                )
         | 
| 183 | 
            +
             | 
| 184 | 
            +
                dst = renderer(mesh)
         | 
| 185 | 
            +
                dst = (dst[..., :3] * 255).squeeze(0).cpu().numpy().astype(np.uint8)
         | 
| 186 | 
            +
             | 
| 187 | 
            +
                src_mask = np.ones((src.shape[0], src.shape[1]), dst.dtype)
         | 
| 188 | 
            +
                ids = np.where(dst.sum(-1) > 253 * 3)
         | 
| 189 | 
            +
                ids2 = np.where(src.sum(-1) > 250 * 3)
         | 
| 190 | 
            +
                src_mask[ids[0], ids[1]] = 0
         | 
| 191 | 
            +
                src_mask[ids2[0], ids2[1]] = 0
         | 
| 192 | 
            +
                src_mask = (src_mask > 0).astype(np.uint8) * 255
         | 
| 193 | 
            +
                
         | 
| 194 | 
            +
                monoMaskImage = cv2.split(src_mask)[0] # reducing the mask to a monochrome
         | 
| 195 | 
            +
                br = cv2.boundingRect(monoMaskImage) # bounding rect (x,y,width,height)
         | 
| 196 | 
            +
                center = (br[0] + br[2] // 2, br[1] + br[3] // 2)
         | 
| 197 | 
            +
             
         | 
| 198 | 
            +
                # seamlessClone
         | 
| 199 | 
            +
                try:
         | 
| 200 | 
            +
                    images = cv2.seamlessClone(src, dst, src_mask, center, cv2.NORMAL_CLONE) # more qingxi
         | 
| 201 | 
            +
                    # images = cv2.seamlessClone(src, dst, src_mask, center, cv2.MIXED_CLONE)
         | 
| 202 | 
            +
                except Exception as err:
         | 
| 203 | 
            +
                    print(f"\n\n Warning seamlessClone error: {err} \n\n")
         | 
| 204 | 
            +
                    images = src
         | 
| 205 | 
            +
             | 
| 206 | 
            +
                Image.fromarray(src_mask).save(os.path.join(dst_dir, 'mask.jpeg'))
         | 
| 207 | 
            +
                Image.fromarray(src).save(os.path.join(dst_dir, 'src.jpeg'))
         | 
| 208 | 
            +
                Image.fromarray(dst).save(os.path.join(dst_dir, 'dst.jpeg'))
         | 
| 209 | 
            +
                Image.fromarray(images).save(os.path.join(dst_dir, 'blend.jpeg'))
         | 
| 210 | 
            +
             | 
| 211 | 
            +
                fragments_scaled = renderer.rasterizer(mesh)  # pytorch3d.renderer.mesh.rasterizer.Fragments
         | 
| 212 | 
            +
                faces_covered = fragments_scaled.pix_to_face.unique()[1:] # torch.Tensor torch.Size([30025])
         | 
| 213 | 
            +
                face_normals = mesh.faces_normals_packed().to(device) # torch.Tensor torch.Size([90000, 3]) cuda:0
         | 
| 214 | 
            +
             | 
| 215 | 
            +
                # faces:              pytorch3d.io.obj_io.Faces
         | 
| 216 | 
            +
                # faces.textures_idx: torch.Tensor torch.Size([90000, 3])
         | 
| 217 | 
            +
                # verts_uvs:          torch.Tensor torch.Size([49554, 2])
         | 
| 218 | 
            +
                triangle_uvs = verts_uvs[faces.textures_idx]
         | 
| 219 | 
            +
                triangle_uvs = [
         | 
| 220 | 
            +
                    ((1 - triangle_uvs[..., 1]) * uv_resolution).unsqueeze(2),
         | 
| 221 | 
            +
                    (triangle_uvs[..., 0] * uv_resolution).unsqueeze(2),
         | 
| 222 | 
            +
                ]
         | 
| 223 | 
            +
                triangle_uvs = torch.cat(triangle_uvs, dim=-1) # numpy.ndarray (90000, 3, 2)
         | 
| 224 | 
            +
                triangle_uvs = np.clip(np.round(np.float32(triangle_uvs.cpu())).astype(np.int64), 0, uv_resolution-1)
         | 
| 225 | 
            +
                
         | 
| 226 | 
            +
                t0 = time.time()
         | 
| 227 | 
            +
                
         | 
| 228 | 
            +
                
         | 
| 229 | 
            +
                SOFT_NORM = True # process big angle-diff face, true:flase? coeff:skip
         | 
| 230 | 
            +
                
         | 
| 231 | 
            +
                for k in faces_covered:
         | 
| 232 | 
            +
                    # todo: accelerate this for-loop
         | 
| 233 | 
            +
                    # if cosine between face-camera is too low, skip current face baking
         | 
| 234 | 
            +
                    face_normal = face_normals[k]
         | 
| 235 | 
            +
                    cosine = torch.sum((face_normal * camera_normal) ** 2)
         | 
| 236 | 
            +
                    if not SOFT_NORM and cosine < normalThreshold: continue
         | 
| 237 | 
            +
             | 
| 238 | 
            +
                    # if coord in screen out of subject, skip current face baking
         | 
| 239 | 
            +
                    out_of_subject = src_mask[triangle_screen_coords[k][0][0], triangle_screen_coords[k][0][1]]==0
         | 
| 240 | 
            +
                    if out_of_subject: continue
         | 
| 241 | 
            +
                        
         | 
| 242 | 
            +
                    coeff, img_2_cropped, r2 = get_triangle_to_triangle(triangle_screen_coords[k], triangle_uvs[k], images)
         | 
| 243 | 
            +
                    
         | 
| 244 | 
            +
                    # if color difference between new-old, skip current face baking
         | 
| 245 | 
            +
                    err = np.abs(init_texture[r2[0]:r2[0]+r2[2], r2[1]:r2[1]+r2[3]]- img_2_cropped)
         | 
| 246 | 
            +
                    err = (err * coeff).sum(-1)
         | 
| 247 | 
            +
                    
         | 
| 248 | 
            +
                    # print(err.shape, np.max(err))
         | 
| 249 | 
            +
                    if (np.max(err) > rgb_thresh): continue
         | 
| 250 | 
            +
                    
         | 
| 251 | 
            +
                    color_for_debug = None
         | 
| 252 | 
            +
                    # if (np.max(err) > 400): color_for_debug = [255, 0, 0]
         | 
| 253 | 
            +
                    # if (np.max(err) > 450): color_for_debug = [0, 255, 0]
         | 
| 254 | 
            +
                    # if (np.max(err) > 500): color_for_debug = [0, 0, 255]
         | 
| 255 | 
            +
             | 
| 256 | 
            +
                    coeff = coeff.clip(0, 1)
         | 
| 257 | 
            +
                    
         | 
| 258 | 
            +
                    if SOFT_NORM:
         | 
| 259 | 
            +
                        coeff *= ((cosine.detach().cpu().numpy() - normalThreshold) / normalThreshold).clip(0,1)
         | 
| 260 | 
            +
             | 
| 261 | 
            +
                    coeff *= (((rgb_thresh - err[...,None]) / rgb_thresh)**0.4).clip(0,1)
         | 
| 262 | 
            +
             | 
| 263 | 
            +
                    if color_for_debug is None:
         | 
| 264 | 
            +
                        init_texture[r2[0]:r2[0]+r2[2], r2[1]:r2[1]+r2[3]] = \
         | 
| 265 | 
            +
                            init_texture[r2[0]:r2[0]+r2[2], r2[1]:r2[1]+r2[3]] * ((1.0,1.0,1.0)-coeff) + img_2_cropped * coeff
         | 
| 266 | 
            +
                    else:
         | 
| 267 | 
            +
                        init_texture[r2[0]:r2[0]+r2[2], r2[1]:r2[1]+r2[3]] = color_for_debug
         | 
| 268 | 
            +
             | 
| 269 | 
            +
                print(f'View baking time: {time.time() - t0}')
         | 
| 270 | 
            +
             | 
| 271 | 
            +
                bake_dir = os.path.join(dst_dir, 'bake')
         | 
| 272 | 
            +
                os.makedirs(bake_dir, exist_ok=True)
         | 
| 273 | 
            +
                os.system(f'cp {obj_file} {bake_dir}')
         | 
| 274 | 
            +
                
         | 
| 275 | 
            +
                textute_img = Image.fromarray(init_texture.astype(np.uint8))
         | 
| 276 | 
            +
                textute_img.save(os.path.join(bake_dir, init_texture_file.split("/")[-1]))
         | 
| 277 | 
            +
                
         | 
| 278 | 
            +
                mtl_dir = obj_file.replace('.obj', '.mtl')
         | 
| 279 | 
            +
                if not os.path.exists(mtl_dir): mtl_dir = obj_file.replace("mesh.obj" ,"material.mtl")
         | 
| 280 | 
            +
                if not os.path.exists(mtl_dir): mtl_dir = obj_file.replace("mesh.obj" ,"texture.mtl")
         | 
| 281 | 
            +
                if not os.path.exists(mtl_dir): import ipdb;ipdb.set_trace()
         | 
| 282 | 
            +
                os.system(f'cp {mtl_dir} {bake_dir}')
         | 
| 283 | 
            +
             | 
| 284 | 
            +
                # convert .obj to .glb file
         | 
| 285 | 
            +
                new_obj_pth = os.path.join(bake_dir, obj_file.split('/')[-1])
         | 
| 286 | 
            +
                new_glb_path = new_obj_pth.replace('.obj', '.glb')
         | 
| 287 | 
            +
                mesh = trimesh.load_mesh(new_obj_pth)
         | 
| 288 | 
            +
                mesh.export(new_glb_path, file_type='glb')
         | 
    	
        third_party/mesh_baker.py
    ADDED
    
    | @@ -0,0 +1,142 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import os, sys, time, traceback
         | 
| 2 | 
            +
            print("sys path insert", os.path.join(os.path.dirname(__file__), "dust3r"))
         | 
| 3 | 
            +
            sys.path.insert(0, os.path.join(os.path.dirname(__file__), "dust3r"))
         | 
| 4 | 
            +
             | 
| 5 | 
            +
            import cv2
         | 
| 6 | 
            +
            import numpy as np
         | 
| 7 | 
            +
            from PIL import Image, ImageSequence
         | 
| 8 | 
            +
            from einops import rearrange
         | 
| 9 | 
            +
            import torch
         | 
| 10 | 
            +
             | 
| 11 | 
            +
            from infer.utils import seed_everything, timing_decorator
         | 
| 12 | 
            +
            from infer.utils import get_parameter_number, set_parameter_grad_false
         | 
| 13 | 
            +
             | 
| 14 | 
            +
            from dust3r.inference import inference
         | 
| 15 | 
            +
            from dust3r.model import AsymmetricCroCo3DStereo
         | 
| 16 | 
            +
             | 
| 17 | 
            +
            from third_party.gen_baking import back_projection
         | 
| 18 | 
            +
            from third_party.dust3r_utils import infer_warp_mesh_img
         | 
| 19 | 
            +
            from svrm.ldm.vis_util import render_func
         | 
| 20 | 
            +
             | 
| 21 | 
            +
             | 
| 22 | 
            +
            class MeshBaker:
         | 
| 23 | 
            +
                def __init__(
         | 
| 24 | 
            +
                    self, 
         | 
| 25 | 
            +
                    align_model = "third_party/weights/DUSt3R_ViTLarge_BaseDecoder_512_dpt",
         | 
| 26 | 
            +
                    device = "cuda:0", 
         | 
| 27 | 
            +
                    align_times = 1,
         | 
| 28 | 
            +
                    iou_thresh = 0.8, 
         | 
| 29 | 
            +
                    force_baking_ele_list = None,
         | 
| 30 | 
            +
                    save_memory = False
         | 
| 31 | 
            +
                ):
         | 
| 32 | 
            +
                    self.device = device
         | 
| 33 | 
            +
                    self.save_memory = save_memory
         | 
| 34 | 
            +
                    self.align_model = AsymmetricCroCo3DStereo.from_pretrained(align_model)
         | 
| 35 | 
            +
                    self.align_model = self.align_model if save_memory else self.align_model.to(device)
         | 
| 36 | 
            +
                    self.align_times = align_times
         | 
| 37 | 
            +
                    self.align_model.eval()
         | 
| 38 | 
            +
                    self.iou_thresh = iou_thresh
         | 
| 39 | 
            +
                    self.force_baking_ele_list = [] if force_baking_ele_list is None else force_baking_ele_list
         | 
| 40 | 
            +
                    self.force_baking_ele_list = [int(_) for _ in self.force_baking_ele_list]
         | 
| 41 | 
            +
                    set_parameter_grad_false(self.align_model)
         | 
| 42 | 
            +
                    print('baking align model', get_parameter_number(self.align_model))
         | 
| 43 | 
            +
                
         | 
| 44 | 
            +
                def align_and_check(self, src, dst, align_times=3):
         | 
| 45 | 
            +
                    try:
         | 
| 46 | 
            +
                        st = time.time()
         | 
| 47 | 
            +
                        best_baking_flag = False
         | 
| 48 | 
            +
                        best_aligned_image = aligned_image = src
         | 
| 49 | 
            +
                        best_info = {'match_num': 1000, "mask_iou": self.iou_thresh-0.1}
         | 
| 50 | 
            +
                        for i in range(align_times):
         | 
| 51 | 
            +
                            aligned_image, info = infer_warp_mesh_img(aligned_image, dst, self.align_model, vis=False)
         | 
| 52 | 
            +
                            aligned_image = Image.fromarray(aligned_image)
         | 
| 53 | 
            +
                            print(f"{i}-th time align process, mask-iou is {info['mask_iou']}")
         | 
| 54 | 
            +
                            if info['mask_iou'] > best_info['mask_iou']:
         | 
| 55 | 
            +
                                best_aligned_image, best_info = aligned_image, info
         | 
| 56 | 
            +
                            if info['mask_iou'] < self.iou_thresh:
         | 
| 57 | 
            +
                                break
         | 
| 58 | 
            +
                        print(f"Best Baking Info:{best_info['mask_iou']}")
         | 
| 59 | 
            +
                        best_baking_flag = best_info['mask_iou'] > self.iou_thresh
         | 
| 60 | 
            +
                        return best_aligned_image, best_info, best_baking_flag
         | 
| 61 | 
            +
                    except Exception as e:
         | 
| 62 | 
            +
                        print(f"Error processing image: {e}")
         | 
| 63 | 
            +
                        traceback.print_exc()
         | 
| 64 | 
            +
                        return None, None, None
         | 
| 65 | 
            +
                    
         | 
| 66 | 
            +
                @timing_decorator("baking mesh")
         | 
| 67 | 
            +
                def __call__(self, *args, **kwargs):
         | 
| 68 | 
            +
                    if self.save_memory:
         | 
| 69 | 
            +
                        self.align_model = self.align_model.to(self.device)
         | 
| 70 | 
            +
                        torch.cuda.empty_cache()
         | 
| 71 | 
            +
                        res = self.call(*args, **kwargs)
         | 
| 72 | 
            +
                        self.align_model = self.align_model.to("cpu")
         | 
| 73 | 
            +
                    else:
         | 
| 74 | 
            +
                        res = self.call(*args, **kwargs)
         | 
| 75 | 
            +
                    torch.cuda.empty_cache()
         | 
| 76 | 
            +
                    return res
         | 
| 77 | 
            +
                
         | 
| 78 | 
            +
                def call(self, save_folder):
         | 
| 79 | 
            +
                    obj_path         = os.path.join(save_folder, "mesh.obj")
         | 
| 80 | 
            +
                    raw_texture_path = os.path.join(save_folder, "texture.png")
         | 
| 81 | 
            +
                    views_pil        = os.path.join(save_folder, "views.jpg")
         | 
| 82 | 
            +
                    views_gif        = os.path.join(save_folder, "views.gif")
         | 
| 83 | 
            +
                    cond_pil         = os.path.join(save_folder, "img_nobg.png")
         | 
| 84 | 
            +
             | 
| 85 | 
            +
                    if os.path.exists(views_pil):
         | 
| 86 | 
            +
                        views_pil = Image.open(views_pil)
         | 
| 87 | 
            +
                        views = rearrange(np.asarray(views_pil, dtype=np.uint8), '(n h) (m w) c -> (n m) h w c', n=3, m=2)
         | 
| 88 | 
            +
                        views = [Image.fromarray(views[idx]).convert('RGB') for idx in [0,2,4,5,3,1]] 
         | 
| 89 | 
            +
                        cond_pil = Image.open(cond_pil).resize((512,512))
         | 
| 90 | 
            +
                    elif os.path.exists(views_gif):
         | 
| 91 | 
            +
                        views_gif_pil = Image.open(views_gif)
         | 
| 92 | 
            +
                        views = [img.convert('RGB') for img in ImageSequence.Iterator(views_gif_pil)]
         | 
| 93 | 
            +
                        cond_pil, views = views[0], views[1:]
         | 
| 94 | 
            +
                    else:
         | 
| 95 | 
            +
                        raise FileNotFoundError("views file not found")
         | 
| 96 | 
            +
                            
         | 
| 97 | 
            +
                    rendered_views = render_func(obj_path, elev=0, n_views=2)
         | 
| 98 | 
            +
                    
         | 
| 99 | 
            +
                    for ele_idx, ele in enumerate([0, 180]):
         | 
| 100 | 
            +
                        
         | 
| 101 | 
            +
                        if ele == 0:
         | 
| 102 | 
            +
                            aligned_cond, cond_info, _ = self.align_and_check(cond_pil, rendered_views[0], align_times=self.align_times)
         | 
| 103 | 
            +
                            aligned_cond.save(save_folder + f'/aligned_cond.jpg')
         | 
| 104 | 
            +
                    
         | 
| 105 | 
            +
                            aligned_img, info, _ = self.align_and_check(views[0], rendered_views[0], align_times=self.align_times)
         | 
| 106 | 
            +
                            aligned_img.save(save_folder + f'/aligned_{ele}.jpg')
         | 
| 107 | 
            +
                            
         | 
| 108 | 
            +
                            if info['mask_iou'] < cond_info['mask_iou']:
         | 
| 109 | 
            +
                                print("Using Cond Image to bake front view")
         | 
| 110 | 
            +
                                aligned_img = aligned_cond
         | 
| 111 | 
            +
                                info = cond_info
         | 
| 112 | 
            +
                            need_baking = info['mask_iou'] > self.iou_thresh
         | 
| 113 | 
            +
                        else:
         | 
| 114 | 
            +
                            aligned_img, info, need_baking = self.align_and_check(views[ele//60], rendered_views[ele_idx])
         | 
| 115 | 
            +
                            aligned_img.save(save_folder + f'/aligned_{ele}.jpg')
         | 
| 116 | 
            +
             | 
| 117 | 
            +
                        if need_baking or (ele in self.force_baking_ele_list):
         | 
| 118 | 
            +
                            st = time.time()
         | 
| 119 | 
            +
                            view1_res = back_projection(
         | 
| 120 | 
            +
                                obj_file = obj_path,
         | 
| 121 | 
            +
                                init_texture_file = raw_texture_path,
         | 
| 122 | 
            +
                                front_view_file = aligned_img,
         | 
| 123 | 
            +
                                dst_dir = os.path.join(save_folder, f"view_{ele_idx}"),
         | 
| 124 | 
            +
                                render_resolution = aligned_img.size[0], 
         | 
| 125 | 
            +
                                uv_resolution = 1024,
         | 
| 126 | 
            +
                                views = [[0, ele]],
         | 
| 127 | 
            +
                                device = self.device
         | 
| 128 | 
            +
                            )
         | 
| 129 | 
            +
                            print(f"view_{ele_idx} elevation_{ele} baking finished at {time.time() - st}")
         | 
| 130 | 
            +
                            obj_path = os.path.join(save_folder, f"view_{ele_idx}/bake/mesh.obj")
         | 
| 131 | 
            +
                            raw_texture_path = os.path.join(save_folder, f"view_{ele_idx}/bake/texture.png")
         | 
| 132 | 
            +
                        else:
         | 
| 133 | 
            +
                            print(f"Skip view_{ele_idx} elevation_{ele} baking")
         | 
| 134 | 
            +
             | 
| 135 | 
            +
                    print("Baking Finished")
         | 
| 136 | 
            +
                    return obj_path
         | 
| 137 | 
            +
                
         | 
| 138 | 
            +
             | 
| 139 | 
            +
            if __name__ == "__main__":
         | 
| 140 | 
            +
                baker = MeshBaker()
         | 
| 141 | 
            +
                obj_path = baker("./outputs/test")
         | 
| 142 | 
            +
                print(obj_path)
         | 
    	
        third_party/utils/camera_utils.py
    ADDED
    
    | @@ -0,0 +1,90 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import math
         | 
| 2 | 
            +
            import numpy as np
         | 
| 3 | 
            +
             | 
| 4 | 
            +
            def compute_extrinsic_matrix(elevation, azimuth, camera_distance):
         | 
| 5 | 
            +
                # Convert angles to radians
         | 
| 6 | 
            +
                elevation_rad = np.radians(elevation)
         | 
| 7 | 
            +
                azimuth_rad = np.radians(azimuth)
         | 
| 8 | 
            +
             | 
| 9 | 
            +
                R = np.array([
         | 
| 10 | 
            +
                    [np.cos(azimuth_rad), 0, -np.sin(azimuth_rad)],
         | 
| 11 | 
            +
                    [0, 1, 0],
         | 
| 12 | 
            +
                    [np.sin(azimuth_rad), 0, np.cos(azimuth_rad)],
         | 
| 13 | 
            +
                ], dtype=np.float32)
         | 
| 14 | 
            +
             | 
| 15 | 
            +
                R = R @ np.array([
         | 
| 16 | 
            +
                    [1, 0, 0],
         | 
| 17 | 
            +
                    [0, np.cos(elevation_rad), -np.sin(elevation_rad)],
         | 
| 18 | 
            +
                    [0, np.sin(elevation_rad), np.cos(elevation_rad)]
         | 
| 19 | 
            +
                ], dtype=np.float32)
         | 
| 20 | 
            +
             | 
| 21 | 
            +
                # Construct translation matrix T (3x1)
         | 
| 22 | 
            +
                T = np.array([[camera_distance], [0], [0]], dtype=np.float32)
         | 
| 23 | 
            +
                T = R @ T
         | 
| 24 | 
            +
             | 
| 25 | 
            +
                # Combined into a 4x4 transformation matrix
         | 
| 26 | 
            +
                extrinsic_matrix = np.vstack((np.hstack((R, T)), np.array([[0, 0, 0, 1]], dtype=np.float32)))
         | 
| 27 | 
            +
             | 
| 28 | 
            +
                return extrinsic_matrix
         | 
| 29 | 
            +
             | 
| 30 | 
            +
             | 
| 31 | 
            +
            def transform_camera_pose(im_pose, ori_pose, new_pose):
         | 
| 32 | 
            +
                T = new_pose @ ori_pose.T
         | 
| 33 | 
            +
                transformed_poses = []
         | 
| 34 | 
            +
             | 
| 35 | 
            +
                for pose in im_pose:
         | 
| 36 | 
            +
                    transformed_pose = T @ pose
         | 
| 37 | 
            +
                    transformed_poses.append(transformed_pose)
         | 
| 38 | 
            +
             | 
| 39 | 
            +
                return transformed_poses
         | 
| 40 | 
            +
             | 
| 41 | 
            +
            def compute_fov(intrinsic_matrix):
         | 
| 42 | 
            +
                # Get the focal length value in the internal parameter matrix
         | 
| 43 | 
            +
                fx = intrinsic_matrix[0, 0]
         | 
| 44 | 
            +
                fy = intrinsic_matrix[1, 1]
         | 
| 45 | 
            +
             | 
| 46 | 
            +
                h, w = intrinsic_matrix[0,2]*2, intrinsic_matrix[1,2]*2
         | 
| 47 | 
            +
             | 
| 48 | 
            +
                # Calculate horizontal and vertical FOV values
         | 
| 49 | 
            +
                fov_x = 2 * math.atan(w / (2 * fx)) * 180 / math.pi
         | 
| 50 | 
            +
                fov_y = 2 * math.atan(h / (2 * fy)) * 180 / math.pi
         | 
| 51 | 
            +
             | 
| 52 | 
            +
                return fov_x, fov_y
         | 
| 53 | 
            +
             | 
| 54 | 
            +
             | 
| 55 | 
            +
             | 
| 56 | 
            +
            def rotation_matrix_to_quaternion(rotation_matrix):
         | 
| 57 | 
            +
                rot = Rotation.from_matrix(rotation_matrix)
         | 
| 58 | 
            +
                quaternion = rot.as_quat()
         | 
| 59 | 
            +
                return quaternion
         | 
| 60 | 
            +
             | 
| 61 | 
            +
            def quaternion_to_rotation_matrix(quaternion):
         | 
| 62 | 
            +
                rot = Rotation.from_quat(quaternion)
         | 
| 63 | 
            +
                rotation_matrix = rot.as_matrix()
         | 
| 64 | 
            +
                return rotation_matrix
         | 
| 65 | 
            +
             | 
| 66 | 
            +
            def remap_points(img_size, match, size=512):
         | 
| 67 | 
            +
                H, W, _ = img_size
         | 
| 68 | 
            +
             | 
| 69 | 
            +
                S = max(W, H)
         | 
| 70 | 
            +
                new_W = int(round(W * size / S))
         | 
| 71 | 
            +
                new_H = int(round(H * size / S))
         | 
| 72 | 
            +
                cx, cy = new_W // 2, new_H // 2
         | 
| 73 | 
            +
             | 
| 74 | 
            +
                # Calculate the coordinates of the transformed image center point
         | 
| 75 | 
            +
                halfw, halfh = ((2 * cx) // 16) * 8, ((2 * cy) // 16) * 8
         | 
| 76 | 
            +
             | 
| 77 | 
            +
                dw, dh = cx - halfw, cy - halfh
         | 
| 78 | 
            +
             | 
| 79 | 
            +
                # store point coordinates mapped back to the original image
         | 
| 80 | 
            +
                new_match = np.zeros_like(match)
         | 
| 81 | 
            +
             | 
| 82 | 
            +
                # Map the transformed point coordinates back to the original image
         | 
| 83 | 
            +
                new_match[:, 0] = (match[:, 0] + dw) / new_W * W
         | 
| 84 | 
            +
                new_match[:, 1] = (match[:, 1] + dh) / new_H * H
         | 
| 85 | 
            +
             | 
| 86 | 
            +
                #print(dw,new_W,W,dh,new_H,H)
         | 
| 87 | 
            +
             | 
| 88 | 
            +
                return new_match
         | 
| 89 | 
            +
             | 
| 90 | 
            +
                
         | 
    	
        third_party/utils/img_utils.py
    ADDED
    
    | @@ -0,0 +1,211 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import os
         | 
| 2 | 
            +
            import cv2
         | 
| 3 | 
            +
            import numpy as np
         | 
| 4 | 
            +
            from skimage.metrics import hausdorff_distance
         | 
| 5 | 
            +
            from matplotlib import pyplot as plt
         | 
| 6 | 
            +
             | 
| 7 | 
            +
             | 
| 8 | 
            +
            def get_input_imgs_path(input_data_dir):
         | 
| 9 | 
            +
                path = {}
         | 
| 10 | 
            +
                names = ['000', 'ori_000']
         | 
| 11 | 
            +
                for name in names:
         | 
| 12 | 
            +
                    jpg_path = os.path.join(input_data_dir, f"{name}.jpg")
         | 
| 13 | 
            +
                    png_path = os.path.join(input_data_dir, f"{name}.png")
         | 
| 14 | 
            +
                    if os.path.exists(jpg_path):
         | 
| 15 | 
            +
                        path[name] = jpg_path
         | 
| 16 | 
            +
                    elif os.path.exists(png_path):
         | 
| 17 | 
            +
                        path[name] = png_path
         | 
| 18 | 
            +
                return path
         | 
| 19 | 
            +
             | 
| 20 | 
            +
             | 
| 21 | 
            +
            def rgba_to_rgb(image, bg_color=[255, 255, 255]):
         | 
| 22 | 
            +
                if image.shape[-1] == 3: return image
         | 
| 23 | 
            +
                    
         | 
| 24 | 
            +
                rgba = image.astype(float)
         | 
| 25 | 
            +
                rgb = rgba[:, :, :3].copy()
         | 
| 26 | 
            +
                alpha = rgba[:, :, 3] / 255.0
         | 
| 27 | 
            +
             | 
| 28 | 
            +
                bg = np.ones((image.shape[0], image.shape[1], 3), dtype=np.float32) 
         | 
| 29 | 
            +
                bg = bg * np.array(bg_color, dtype=np.float32)
         | 
| 30 | 
            +
             | 
| 31 | 
            +
                rgb = rgb * alpha[:, :, np.newaxis] + bg * (1 - alpha[:, :, np.newaxis])
         | 
| 32 | 
            +
                rgb = rgb.astype(np.uint8)
         | 
| 33 | 
            +
                
         | 
| 34 | 
            +
                return rgb
         | 
| 35 | 
            +
             | 
| 36 | 
            +
             | 
| 37 | 
            +
            def resize_with_aspect_ratio(image1, image2, pad_value=[255, 255, 255]):
         | 
| 38 | 
            +
                aspect_ratio1 = float(image1.shape[1]) / float(image1.shape[0])
         | 
| 39 | 
            +
                aspect_ratio2 = float(image2.shape[1]) / float(image2.shape[0])
         | 
| 40 | 
            +
             | 
| 41 | 
            +
                top_pad, bottom_pad, left_pad, right_pad = 0, 0, 0, 0
         | 
| 42 | 
            +
                
         | 
| 43 | 
            +
                if aspect_ratio1 < aspect_ratio2:
         | 
| 44 | 
            +
                    new_width = (aspect_ratio2 * image1.shape[0])
         | 
| 45 | 
            +
                    right_pad = left_pad = int((new_width - image1.shape[1]) / 2)
         | 
| 46 | 
            +
                else:
         | 
| 47 | 
            +
                    new_height = (image1.shape[1] / aspect_ratio2)
         | 
| 48 | 
            +
                    bottom_pad = top_pad = int((new_height - image1.shape[0]) / 2)
         | 
| 49 | 
            +
             | 
| 50 | 
            +
                image1_padded = cv2.copyMakeBorder(
         | 
| 51 | 
            +
                    image1, top_pad, bottom_pad, left_pad, right_pad, cv2.BORDER_CONSTANT, value=pad_value
         | 
| 52 | 
            +
                )
         | 
| 53 | 
            +
                return image1_padded
         | 
| 54 | 
            +
             | 
| 55 | 
            +
             | 
| 56 | 
            +
            def estimate_img_mask(image):
         | 
| 57 | 
            +
                # to gray
         | 
| 58 | 
            +
                gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
         | 
| 59 | 
            +
             | 
| 60 | 
            +
                # segment
         | 
| 61 | 
            +
                # _, thresh = cv2.threshold(gray, 0, 255, cv2.THRESH_BINARY_INV + cv2.THRESH_OTSU)
         | 
| 62 | 
            +
                # mask_otsu = thresh.astype(bool)
         | 
| 63 | 
            +
                # thresh_gray = 240
         | 
| 64 | 
            +
             | 
| 65 | 
            +
                edges = cv2.Canny(gray, 20, 50)
         | 
| 66 | 
            +
             | 
| 67 | 
            +
                kernel = np.ones((3, 3), np.uint8)
         | 
| 68 | 
            +
                edges_dilated = cv2.dilate(edges, kernel, iterations=1)
         | 
| 69 | 
            +
             | 
| 70 | 
            +
                contours, _ = cv2.findContours(edges_dilated, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
         | 
| 71 | 
            +
             | 
| 72 | 
            +
                mask = np.zeros_like(gray, dtype=np.uint8)
         | 
| 73 | 
            +
             | 
| 74 | 
            +
                cv2.drawContours(mask, contours, -1, 255, thickness=cv2.FILLED)
         | 
| 75 | 
            +
                mask = mask.astype(bool)
         | 
| 76 | 
            +
             | 
| 77 | 
            +
                return mask
         | 
| 78 | 
            +
             | 
| 79 | 
            +
             | 
| 80 | 
            +
            def compute_img_diff(img1, img2, matches1, matches1_from_2, vis=False):
         | 
| 81 | 
            +
                scale = 0.125
         | 
| 82 | 
            +
                gray_trunc_thres = 25 / 255.0
         | 
| 83 | 
            +
             | 
| 84 | 
            +
                # Match
         | 
| 85 | 
            +
                if matches1.shape[0] > 0:
         | 
| 86 | 
            +
                    match_scale = np.max(np.ptp(matches1, axis=-1))
         | 
| 87 | 
            +
                    match_dists = np.sqrt(np.sum((matches1 - matches1_from_2) ** 2, axis=-1))
         | 
| 88 | 
            +
                    dist_threshold = match_scale * 0.01
         | 
| 89 | 
            +
                    match_num = np.sum(match_dists <= dist_threshold)
         | 
| 90 | 
            +
                    match_rate = np.mean(match_dists <= dist_threshold)
         | 
| 91 | 
            +
                else:
         | 
| 92 | 
            +
                    match_num = 0
         | 
| 93 | 
            +
                    match_rate = 0
         | 
| 94 | 
            +
             | 
| 95 | 
            +
                # IOU
         | 
| 96 | 
            +
                img1_mask = estimate_img_mask(img1)
         | 
| 97 | 
            +
                img2_mask = estimate_img_mask(img2)
         | 
| 98 | 
            +
                img_intersection = (img1_mask == 1) & (img2_mask == 1)
         | 
| 99 | 
            +
                img_union = (img1_mask == 1) | (img2_mask == 1)
         | 
| 100 | 
            +
                intersection = np.sum(img_intersection == 1)
         | 
| 101 | 
            +
                union = np.sum(img_union == 1)
         | 
| 102 | 
            +
                mask_iou = intersection / union if union != 0 else 0
         | 
| 103 | 
            +
             | 
| 104 | 
            +
                # Gray
         | 
| 105 | 
            +
                height, width = img1.shape[:2]
         | 
| 106 | 
            +
                img1_gray = cv2.cvtColor(img1, cv2.COLOR_BGR2GRAY)
         | 
| 107 | 
            +
                img2_gray = cv2.cvtColor(img2, cv2.COLOR_BGR2GRAY)
         | 
| 108 | 
            +
                img1_gray = cv2.GaussianBlur(img1_gray, (7, 7), 0)
         | 
| 109 | 
            +
                img2_gray = cv2.GaussianBlur(img2_gray, (7, 7), 0)
         | 
| 110 | 
            +
             | 
| 111 | 
            +
                # Gray Diff
         | 
| 112 | 
            +
                img1_gray_small = cv2.resize(img1_gray, (int(width * scale), int(height * scale)),
         | 
| 113 | 
            +
                                             interpolation=cv2.INTER_LINEAR) / 255.0
         | 
| 114 | 
            +
                img2_gray_small = cv2.resize(img2_gray, (int(width * scale), int(height * scale)),
         | 
| 115 | 
            +
                                             interpolation=cv2.INTER_LINEAR) / 255.0
         | 
| 116 | 
            +
                img_gray_small_diff = np.abs(img1_gray_small - img2_gray_small)
         | 
| 117 | 
            +
                gray_diff = img_gray_small_diff.sum() / (union * scale) if union != 0 else 1
         | 
| 118 | 
            +
             | 
| 119 | 
            +
                img_gray_small_diff_trunc = img_gray_small_diff.copy()
         | 
| 120 | 
            +
                img_gray_small_diff_trunc[img_gray_small_diff < gray_trunc_thres] = 0
         | 
| 121 | 
            +
                gray_diff_trunc = img_gray_small_diff_trunc.sum() / (union * scale) if union != 0 else 1
         | 
| 122 | 
            +
             | 
| 123 | 
            +
                # Edge
         | 
| 124 | 
            +
                img1_edge = cv2.Canny(img1_gray, 100, 200)
         | 
| 125 | 
            +
                img2_edge = cv2.Canny(img2_gray, 100, 200)
         | 
| 126 | 
            +
                bw_edges1 = (img1_edge > 0).astype(bool)
         | 
| 127 | 
            +
                bw_edges2 = (img2_edge > 0).astype(bool)
         | 
| 128 | 
            +
                hausdorff_dist = hausdorff_distance(bw_edges1, bw_edges2)
         | 
| 129 | 
            +
                if vis == True:
         | 
| 130 | 
            +
                    fig, axs = plt.subplots(1, 4, figsize=(15, 5))
         | 
| 131 | 
            +
                    axs[0].imshow(img1_gray, cmap='gray')
         | 
| 132 | 
            +
                    axs[0].set_title('Img1')
         | 
| 133 | 
            +
                    axs[1].imshow(img2_gray, cmap='gray')
         | 
| 134 | 
            +
                    axs[1].set_title('Img2')
         | 
| 135 | 
            +
                    axs[2].imshow(img1_mask)
         | 
| 136 | 
            +
                    axs[2].set_title('Mask1')
         | 
| 137 | 
            +
                    axs[3].imshow(img2_mask)
         | 
| 138 | 
            +
                    axs[3].set_title('Mask2')
         | 
| 139 | 
            +
                    plt.show()
         | 
| 140 | 
            +
                    plt.figure()
         | 
| 141 | 
            +
                    mask_cmp = np.zeros((height, width, 3))
         | 
| 142 | 
            +
                    mask_cmp[img_intersection, 1] = 1
         | 
| 143 | 
            +
                    mask_cmp[img_union, 0] = 1
         | 
| 144 | 
            +
                    plt.imshow(mask_cmp)
         | 
| 145 | 
            +
                    plt.show()
         | 
| 146 | 
            +
                    fig, axs = plt.subplots(1, 4, figsize=(15, 5))
         | 
| 147 | 
            +
                    axs[0].imshow(img1_gray_small, cmap='gray')
         | 
| 148 | 
            +
                    axs[0].set_title('Img1 Gray')
         | 
| 149 | 
            +
                    axs[1].imshow(img2_gray_small, cmap='gray')
         | 
| 150 | 
            +
                    axs[1].set_title('Img2 Gary')
         | 
| 151 | 
            +
                    axs[2].imshow(img_gray_small_diff, cmap='gray')
         | 
| 152 | 
            +
                    axs[2].set_title('diff')
         | 
| 153 | 
            +
                    axs[3].imshow(img_gray_small_diff_trunc, cmap='gray')
         | 
| 154 | 
            +
                    axs[3].set_title('diff_trunct')
         | 
| 155 | 
            +
                    plt.show()
         | 
| 156 | 
            +
                    fig, axs = plt.subplots(1, 2, figsize=(15, 5))
         | 
| 157 | 
            +
                    axs[0].imshow(img1_edge, cmap='gray')
         | 
| 158 | 
            +
                    axs[0].set_title('img1_edge')
         | 
| 159 | 
            +
                    axs[1].imshow(img2_edge, cmap='gray')
         | 
| 160 | 
            +
                    axs[1].set_title('img2_edge')
         | 
| 161 | 
            +
                    plt.show()
         | 
| 162 | 
            +
             | 
| 163 | 
            +
                info = {}
         | 
| 164 | 
            +
                info['match_num'] = match_num
         | 
| 165 | 
            +
                info['match_rate'] = match_rate
         | 
| 166 | 
            +
                info['mask_iou'] = mask_iou
         | 
| 167 | 
            +
                info['gray_diff'] = gray_diff
         | 
| 168 | 
            +
                info['gray_diff_trunc'] = gray_diff_trunc
         | 
| 169 | 
            +
                info['hausdorff_dist'] = hausdorff_dist
         | 
| 170 | 
            +
                return info
         | 
| 171 | 
            +
             | 
| 172 | 
            +
             | 
| 173 | 
            +
            def predict_match_success_human(info):
         | 
| 174 | 
            +
                match_num = info['match_num']
         | 
| 175 | 
            +
                match_rate = info['match_rate']
         | 
| 176 | 
            +
                mask_iou = info['mask_iou']
         | 
| 177 | 
            +
                gray_diff = info['gray_diff']
         | 
| 178 | 
            +
                gray_diff_trunc = info['gray_diff_trunc']
         | 
| 179 | 
            +
                hausdorff_dist = info['hausdorff_dist']
         | 
| 180 | 
            +
             | 
| 181 | 
            +
                if mask_iou > 0.95:
         | 
| 182 | 
            +
                    return True
         | 
| 183 | 
            +
             | 
| 184 | 
            +
                if match_num < 20 or match_rate < 0.7:
         | 
| 185 | 
            +
                    return False
         | 
| 186 | 
            +
             | 
| 187 | 
            +
                if mask_iou > 0.80 and gray_diff < 0.040 and gray_diff_trunc < 0.010:
         | 
| 188 | 
            +
                    return True
         | 
| 189 | 
            +
             | 
| 190 | 
            +
                if mask_iou > 0.70 and gray_diff < 0.050 and gray_diff_trunc < 0.008:
         | 
| 191 | 
            +
                    return True
         | 
| 192 | 
            +
             | 
| 193 | 
            +
                '''
         | 
| 194 | 
            +
                if match_rate<0.70 or match_num<3000:
         | 
| 195 | 
            +
                    return False
         | 
| 196 | 
            +
             | 
| 197 | 
            +
                if (mask_iou>0.85 and hausdorff_dist<20)or (gray_diff<0.015 and gray_diff_trunc<0.01) or match_rate>=0.90:
         | 
| 198 | 
            +
                    return True
         | 
| 199 | 
            +
                '''
         | 
| 200 | 
            +
             | 
| 201 | 
            +
                return False
         | 
| 202 | 
            +
             | 
| 203 | 
            +
             | 
| 204 | 
            +
            def predict_match_success(info, model=None):
         | 
| 205 | 
            +
                if model == None:
         | 
| 206 | 
            +
                    return predict_match_success_human(info)
         | 
| 207 | 
            +
                else:
         | 
| 208 | 
            +
                    feat_name = ['match_num', 'match_rate', 'mask_iou', 'gray_diff', 'gray_diff_trunc', 'hausdorff_dist']
         | 
| 209 | 
            +
                    features = [info[f] for f in feat_name]
         | 
| 210 | 
            +
                    pred = model.predict([features])[0]
         | 
| 211 | 
            +
                    return pred >= 0.5
         | 

