Spaces:
				
			
			
	
			
			
		Paused
		
	
	
	
			
			
	
	
	
	
		
		
		Paused
		
	Commit 
							
							·
						
						2f4febc
	
1
								Parent(s):
							
							f17a2ad
								
full_files
Browse filesThis view is limited to 50 files because it contains too many changes.  
							See raw diff
- app.py +146 -201
 - configs/inference/controlnet_c_3b_canny.yaml +14 -0
 - configs/inference/controlnet_c_3b_identity.yaml +17 -0
 - configs/inference/controlnet_c_3b_inpainting.yaml +15 -0
 - configs/inference/controlnet_c_3b_sr.yaml +15 -0
 - configs/inference/lora_c_3b.yaml +15 -0
 - configs/inference/stage_b_1b.yaml +13 -0
 - configs/inference/stage_b_3b.yaml +13 -0
 - configs/inference/stage_c_1b.yaml +7 -0
 - configs/inference/stage_c_3b.yaml +7 -0
 - configs/training/cfg_control_lr.yaml +47 -0
 - configs/training/lora_personalization.yaml +37 -0
 - configs/training/t2i.yaml +29 -0
 - core/__init__.py +372 -0
 - core/data/__init__.py +69 -0
 - core/data/bucketeer.py +88 -0
 - core/data/bucketeer_deg.py +91 -0
 - core/data/deg_kair_utils/utils_alignfaces.py +263 -0
 - core/data/deg_kair_utils/utils_blindsr.py +631 -0
 - core/data/deg_kair_utils/utils_bnorm.py +91 -0
 - core/data/deg_kair_utils/utils_deblur.py +655 -0
 - core/data/deg_kair_utils/utils_dist.py +201 -0
 - core/data/deg_kair_utils/utils_googledownload.py +93 -0
 - core/data/deg_kair_utils/utils_image.py +1016 -0
 - core/data/deg_kair_utils/utils_lmdb.py +205 -0
 - core/data/deg_kair_utils/utils_logger.py +66 -0
 - core/data/deg_kair_utils/utils_mat.py +88 -0
 - core/data/deg_kair_utils/utils_matconvnet.py +197 -0
 - core/data/deg_kair_utils/utils_model.py +330 -0
 - core/data/deg_kair_utils/utils_modelsummary.py +485 -0
 - core/data/deg_kair_utils/utils_option.py +255 -0
 - core/data/deg_kair_utils/utils_params.py +135 -0
 - core/data/deg_kair_utils/utils_receptivefield.py +62 -0
 - core/data/deg_kair_utils/utils_regularizers.py +104 -0
 - core/data/deg_kair_utils/utils_sisr.py +848 -0
 - core/data/deg_kair_utils/utils_video.py +493 -0
 - core/data/deg_kair_utils/utils_videoio.py +555 -0
 - core/scripts/__init__.py +0 -0
 - core/scripts/cli.py +41 -0
 - core/templates/__init__.py +1 -0
 - core/templates/diffusion.py +236 -0
 - core/utils/__init__.py +9 -0
 - core/utils/__pycache__/__init__.cpython-310.pyc +0 -0
 - core/utils/__pycache__/__init__.cpython-39.pyc +0 -0
 - core/utils/__pycache__/base_dto.cpython-310.pyc +0 -0
 - core/utils/__pycache__/base_dto.cpython-39.pyc +0 -0
 - core/utils/__pycache__/save_and_load.cpython-310.pyc +0 -0
 - core/utils/__pycache__/save_and_load.cpython-39.pyc +0 -0
 - core/utils/base_dto.py +56 -0
 - core/utils/save_and_load.py +59 -0
 
    	
        app.py
    CHANGED
    
    | 
         @@ -1,213 +1,158 @@ 
     | 
|
| 1 | 
         
             
            import spaces
         
     | 
| 2 | 
         
            -
            import json
         
     | 
| 3 | 
         
            -
            import subprocess
         
     | 
| 4 | 
         
             
            import os
         
     | 
| 5 | 
         
            -
            import  
     | 
| 6 | 
         
            -
             
     | 
| 7 | 
         
            -
             
     | 
| 8 | 
         
            -
                process = subprocess.Popen(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, shell=True)
         
     | 
| 9 | 
         
            -
                output, error = process.communicate()
         
     | 
| 10 | 
         
            -
                if process.returncode != 0:
         
     | 
| 11 | 
         
            -
                    print(f"Error executing command: {command}")
         
     | 
| 12 | 
         
            -
                    print(error.decode('utf-8'))
         
     | 
| 13 | 
         
            -
                    exit(1)
         
     | 
| 14 | 
         
            -
                return output.decode('utf-8')
         
     | 
| 15 | 
         
            -
             
     | 
| 16 | 
         
            -
            # Download CUDA installer
         
     | 
| 17 | 
         
            -
            download_command = "wget https://developer.download.nvidia.com/compute/cuda/12.2.0/local_installers/cuda_12.2.0_535.54.03_linux.run"
         
     | 
| 18 | 
         
            -
            result = run_command(download_command)
         
     | 
| 19 | 
         
            -
            if result is None:
         
     | 
| 20 | 
         
            -
                print("Failed to download CUDA installer.")
         
     | 
| 21 | 
         
            -
                exit(1)
         
     | 
| 22 | 
         
            -
             
     | 
| 23 | 
         
            -
            # Run CUDA installer in silent mode
         
     | 
| 24 | 
         
            -
            install_command = "sh cuda_12.2.0_535.54.03_linux.run --silent --toolkit --samples --override"
         
     | 
| 25 | 
         
            -
            result = run_command(install_command)
         
     | 
| 26 | 
         
            -
            if result is None:
         
     | 
| 27 | 
         
            -
                print("Failed to run CUDA installer.")
         
     | 
| 28 | 
         
            -
                exit(1)
         
     | 
| 29 | 
         
            -
             
     | 
| 30 | 
         
            -
            print("CUDA installation process completed.")
         
     | 
| 31 | 
         
            -
             
     | 
| 32 | 
         
            -
            def install_packages():
         
     | 
| 33 | 
         
            -
                
         
     | 
| 34 | 
         
            -
                # Clone the repository with submodules
         
     | 
| 35 | 
         
            -
                run_command("git clone --recurse-submodules https://github.com/abetlen/llama-cpp-python.git")
         
     | 
| 36 | 
         
            -
                
         
     | 
| 37 | 
         
            -
                # Change to the cloned directory
         
     | 
| 38 | 
         
            -
                os.chdir("llama-cpp-python")
         
     | 
| 39 | 
         
            -
                
         
     | 
| 40 | 
         
            -
                # Checkout the specific commit in the llama.cpp submodule
         
     | 
| 41 | 
         
            -
                os.chdir("vendor/llama.cpp")
         
     | 
| 42 | 
         
            -
                run_command("git checkout 50e0535")
         
     | 
| 43 | 
         
            -
                os.chdir("../..")
         
     | 
| 44 | 
         
            -
                
         
     | 
| 45 | 
         
            -
                # Upgrade pip
         
     | 
| 46 | 
         
            -
                run_command("pip install --upgrade pip")
         
     | 
| 47 | 
         
            -
             
     | 
| 48 | 
         
            -
             
     | 
| 49 | 
         
            -
             
     | 
| 50 | 
         
            -
                # Install all optional dependencies with CUDA support
         
     | 
| 51 | 
         
            -
                run_command('CMAKE_ARGS="-DGGML_BLAS=ON -DGGML_BLAS_VENDOR=OpenBLAS -DCUDA_PATH=/usr/local/cuda-12.2 -DCUDAToolkit_ROOT=/usr/local/cuda-12.2 -DCUDAToolkit_INCLUDE_DIR=/usr/local/cuda-12/include -DCUDAToolkit_LIBRARY_DIR=/usr/local/cuda-12.2/lib64" FORCE_CMAKE=1 pip install -e .')
         
     | 
| 52 | 
         
            -
                
         
     | 
| 53 | 
         
            -
                run_command("make clean && GGML_OPENBLAS=1 make -j")
         
     | 
| 54 | 
         
            -
                    
         
     | 
| 55 | 
         
            -
                # Reinstall the package with CUDA support
         
     | 
| 56 | 
         
            -
                run_command('CMAKE_ARGS="-DGGML_BLAS=ON -DGGML_BLAS_VENDOR=OpenBLAS -DCUDA_PATH=/usr/local/cuda-12.2 -DCUDAToolkit_ROOT=/usr/local/cuda-12.2 -DCUDAToolkit_INCLUDE_DIR=/usr/local/cuda-12/include -DCUDAToolkit_LIBRARY_DIR=/usr/local/cuda-12.2/lib64" FORCE_CMAKE=1 pip install -e .')
         
     | 
| 57 | 
         
            -
             
     | 
| 58 | 
         
            -
                # Install llama-cpp-agent
         
     | 
| 59 | 
         
            -
                run_command("pip install llama-cpp-agent")
         
     | 
| 60 | 
         
            -
                
         
     | 
| 61 | 
         
            -
                run_command("export PYTHONPATH=$PYTHONPATH:$(pwd)")
         
     | 
| 62 | 
         
            -
                
         
     | 
| 63 | 
         
            -
                print("Installation complete!")
         
     | 
| 64 | 
         
            -
             
     | 
| 65 | 
         
            -
            try:
         
     | 
| 66 | 
         
            -
                install_packages()
         
     | 
| 67 | 
         
            -
                
         
     | 
| 68 | 
         
            -
                # Add a delay to allow for package registration
         
     | 
| 69 | 
         
            -
                import time
         
     | 
| 70 | 
         
            -
                time.sleep(5)
         
     | 
| 71 | 
         
            -
                
         
     | 
| 72 | 
         
            -
                # Force Python to reload the site packages
         
     | 
| 73 | 
         
            -
                import site
         
     | 
| 74 | 
         
            -
                import importlib
         
     | 
| 75 | 
         
            -
                importlib.reload(site)
         
     | 
| 76 | 
         
            -
                
         
     | 
| 77 | 
         
            -
                # Now try to import the libraries
         
     | 
| 78 | 
         
            -
                from llama_cpp import Llama
         
     | 
| 79 | 
         
            -
                from llama_cpp_agent import LlamaCppAgent, MessagesFormatterType
         
     | 
| 80 | 
         
            -
                from llama_cpp_agent.providers import LlamaCppPythonProvider
         
     | 
| 81 | 
         
            -
                from llama_cpp_agent.chat_history import BasicChatHistory
         
     | 
| 82 | 
         
            -
                from llama_cpp_agent.chat_history.messages import Roles
         
     | 
| 83 | 
         
            -
                
         
     | 
| 84 | 
         
            -
                print("Libraries imported successfully!")
         
     | 
| 85 | 
         
            -
            except Exception as e:
         
     | 
| 86 | 
         
            -
                print(f"Installation failed or libraries couldn't be imported: {str(e)}")
         
     | 
| 87 | 
         
            -
                sys.exit(1)
         
     | 
| 88 | 
         
            -
                
         
     | 
| 89 | 
         
             
            import gradio as gr
         
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 90 | 
         
             
            from huggingface_hub import hf_hub_download
         
     | 
| 91 | 
         | 
| 92 | 
         
            -
            hf_hub_download(
         
     | 
| 93 | 
         
            -
                repo_id="MaziyarPanahi/Mistral-Nemo-Instruct-2407-GGUF",
         
     | 
| 94 | 
         
            -
                filename="Mistral-Nemo-Instruct-2407.Q5_K_M.gguf",
         
     | 
| 95 | 
         
            -
                local_dir="./models"
         
     | 
| 96 | 
         
            -
            )
         
     | 
| 97 | 
         | 
| 98 | 
         
            -
             
     | 
| 99 | 
         
            -
             
     | 
| 100 | 
         
            -
             
     | 
| 101 | 
         
            -
                 
     | 
| 102 | 
         
            -
             
     | 
| 103 | 
         
            -
                 
     | 
| 104 | 
         
            -
             
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 105 | 
         
             
            )
         
     | 
| 106 | 
         
            -
             
     | 
| 107 | 
         
            -
             
     | 
| 108 | 
         
            -
             
     | 
| 109 | 
         
            -
             
     | 
| 110 | 
         
            -
             
     | 
| 111 | 
         
            -
             
     | 
| 112 | 
         
            -
             
     | 
| 113 | 
         
            -
             
     | 
| 114 | 
         
            -
             
     | 
| 115 | 
         
            -
             
     | 
| 116 | 
         
            -
             
     | 
| 117 | 
         
            -
             
     | 
| 118 | 
         
            -
                 
     | 
| 119 | 
         
            -
             
     | 
| 120 | 
         
            -
                 
     | 
| 121 | 
         
            -
             
     | 
| 122 | 
         
            -
                 
     | 
| 123 | 
         
            -
             
     | 
| 124 | 
         
            -
             
     | 
| 125 | 
         
            -
             
     | 
| 126 | 
         
            -
             
     | 
| 127 | 
         
            -
                 
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 128 | 
         | 
| 129 | 
         
            -
                 
     | 
| 130 | 
         
            -
                 
     | 
| 131 | 
         
            -
                settings.top_k = top_k
         
     | 
| 132 | 
         
            -
                settings.top_p = top_p
         
     | 
| 133 | 
         
            -
                settings.max_tokens = max_tokens
         
     | 
| 134 | 
         
            -
                settings.repeat_penalty = repeat_penalty
         
     | 
| 135 | 
         
            -
                settings.stream = True
         
     | 
| 136 | 
         
            -
             
     | 
| 137 | 
         
            -
                messages = BasicChatHistory()
         
     | 
| 138 | 
         
            -
             
     | 
| 139 | 
         
            -
                for msn in history:
         
     | 
| 140 | 
         
            -
                    user = {
         
     | 
| 141 | 
         
            -
                        'role': Roles.user,
         
     | 
| 142 | 
         
            -
                        'content': msn[0]
         
     | 
| 143 | 
         
            -
                    }
         
     | 
| 144 | 
         
            -
                    assistant = {
         
     | 
| 145 | 
         
            -
                        'role': Roles.assistant,
         
     | 
| 146 | 
         
            -
                        'content': msn[1]
         
     | 
| 147 | 
         
            -
                    }
         
     | 
| 148 | 
         
            -
                    messages.add_message(user)
         
     | 
| 149 | 
         
            -
                    messages.add_message(assistant)
         
     | 
| 150 | 
         | 
| 151 | 
         
            -
                 
     | 
| 152 | 
         
            -
                     
     | 
| 153 | 
         
            -
                     
     | 
| 154 | 
         
            -
             
     | 
| 155 | 
         
            -
                     
     | 
| 156 | 
         
            -
                     
     | 
| 157 | 
         
            -
             
     | 
| 158 | 
         
            -
             
     | 
| 159 | 
         
            -
             
     | 
| 160 | 
         
            -
             
     | 
| 161 | 
         
            -
                     
     | 
| 162 | 
         
            -
                     
     | 
| 163 | 
         
            -
             
     | 
| 164 | 
         
            -
             
     | 
| 165 | 
         
            -
             
     | 
| 166 | 
         
            -
             
     | 
| 167 | 
         
            -
             
     | 
| 168 | 
         
            -
             
     | 
| 169 | 
         
            -
             
     | 
| 170 | 
         
            -
             
     | 
| 171 | 
         
            -
             
     | 
| 172 | 
         
            -
                 
     | 
| 173 | 
         
            -
                 
     | 
| 174 | 
         
            -
                    gr.Textbox( 
     | 
| 175 | 
         
            -
                    gr.Slider(minimum= 
     | 
| 176 | 
         
            -
                    gr.Slider(minimum= 
     | 
| 177 | 
         
            -
                    gr. 
     | 
| 178 | 
         
            -
                        minimum=0.1,
         
     | 
| 179 | 
         
            -
                        maximum=1.0,
         
     | 
| 180 | 
         
            -
                        value=0.95,
         
     | 
| 181 | 
         
            -
                        step=0.05,
         
     | 
| 182 | 
         
            -
                        label="Top-p",
         
     | 
| 183 | 
         
            -
                    ),
         
     | 
| 184 | 
         
            -
                    gr.Slider(
         
     | 
| 185 | 
         
            -
                        minimum=0,
         
     | 
| 186 | 
         
            -
                        maximum=100,
         
     | 
| 187 | 
         
            -
                        value=40,
         
     | 
| 188 | 
         
            -
                        step=1,
         
     | 
| 189 | 
         
            -
                        label="Top-k",
         
     | 
| 190 | 
         
            -
                    ),
         
     | 
| 191 | 
         
            -
                    gr.Slider(
         
     | 
| 192 | 
         
            -
                        minimum=0.0,
         
     | 
| 193 | 
         
            -
                        maximum=2.0,
         
     | 
| 194 | 
         
            -
                        value=1.1,
         
     | 
| 195 | 
         
            -
                        step=0.1,
         
     | 
| 196 | 
         
            -
                        label="Repetition penalty",
         
     | 
| 197 | 
         
            -
                    ),
         
     | 
| 198 | 
         
             
                ],
         
     | 
| 199 | 
         
            -
                 
     | 
| 200 | 
         
            -
                 
     | 
| 201 | 
         
            -
                 
     | 
| 202 | 
         
            -
                 
     | 
| 203 | 
         
            -
                title="Chat with Mistral-NeMo using llama.cpp", 
         
     | 
| 204 | 
         
            -
                description=description,
         
     | 
| 205 | 
         
            -
                chatbot=gr.Chatbot(
         
     | 
| 206 | 
         
            -
                    scale=1, 
         
     | 
| 207 | 
         
            -
                    likeable=False,
         
     | 
| 208 | 
         
            -
                    show_copy_button=True
         
     | 
| 209 | 
         
            -
                )
         
     | 
| 210 | 
         
             
            )
         
     | 
| 211 | 
         | 
| 212 | 
         
            -
             
     | 
| 213 | 
         
            -
                demo.launch(debug=True)
         
     | 
| 
         | 
|
| 1 | 
         
             
            import spaces
         
     | 
| 
         | 
|
| 
         | 
|
| 2 | 
         
             
            import os
         
     | 
| 3 | 
         
            +
            import requests
         
     | 
| 4 | 
         
            +
            import yaml
         
     | 
| 5 | 
         
            +
            import torch
         
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 6 | 
         
             
            import gradio as gr
         
     | 
| 7 | 
         
            +
            from PIL import Image
         
     | 
| 8 | 
         
            +
            import sys
         
     | 
| 9 | 
         
            +
            sys.path.append(os.path.abspath('./'))
         
     | 
| 10 | 
         
            +
            from inference.utils import *
         
     | 
| 11 | 
         
            +
            from core.utils import load_or_fail
         
     | 
| 12 | 
         
            +
            from train import WurstCoreB
         
     | 
| 13 | 
         
            +
            from gdf import VPScaler, CosineTNoiseCond, DDPMSampler, P2LossWeight, AdaptiveLossWeight
         
     | 
| 14 | 
         
            +
            from train import WurstCore_t2i as WurstCoreC
         
     | 
| 15 | 
         
            +
            import torch.nn.functional as F
         
     | 
| 16 | 
         
            +
            from core.utils import load_or_fail
         
     | 
| 17 | 
         
            +
            import numpy as np
         
     | 
| 18 | 
         
            +
            import random
         
     | 
| 19 | 
         
            +
            import math
         
     | 
| 20 | 
         
            +
            from einops import rearrange
         
     | 
| 21 | 
         
             
            from huggingface_hub import hf_hub_download
         
     | 
| 22 | 
         | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 23 | 
         | 
| 24 | 
         
            +
            def download_file(url, folder_path, filename):
         
     | 
| 25 | 
         
            +
                if not os.path.exists(folder_path):
         
     | 
| 26 | 
         
            +
                    os.makedirs(folder_path)
         
     | 
| 27 | 
         
            +
                file_path = os.path.join(folder_path, filename)
         
     | 
| 28 | 
         
            +
             
     | 
| 29 | 
         
            +
                if os.path.isfile(file_path):
         
     | 
| 30 | 
         
            +
                    print(f"File already exists: {file_path}")
         
     | 
| 31 | 
         
            +
                else:
         
     | 
| 32 | 
         
            +
                    response = requests.get(url, stream=True)
         
     | 
| 33 | 
         
            +
                    if response.status_code == 200:
         
     | 
| 34 | 
         
            +
                        with open(file_path, 'wb') as file:
         
     | 
| 35 | 
         
            +
                            for chunk in response.iter_content(chunk_size=1024):
         
     | 
| 36 | 
         
            +
                                file.write(chunk)
         
     | 
| 37 | 
         
            +
                        print(f"File successfully downloaded and saved: {file_path}")
         
     | 
| 38 | 
         
            +
                    else:
         
     | 
| 39 | 
         
            +
                        print(f"Error downloading the file. Status code: {response.status_code}")
         
     | 
| 40 | 
         
            +
             
     | 
| 41 | 
         
            +
            def download_models():
         
     | 
| 42 | 
         
            +
                models = {
         
     | 
| 43 | 
         
            +
                    "STABLEWURST_A": ("https://huggingface.co/stabilityai/StableWurst/resolve/main/stage_a.safetensors?download=true", "models/StableWurst", "stage_a.safetensors"),
         
     | 
| 44 | 
         
            +
                    "STABLEWURST_PREVIEWER": ("https://huggingface.co/stabilityai/StableWurst/resolve/main/previewer.safetensors?download=true", "models/StableWurst", "previewer.safetensors"),
         
     | 
| 45 | 
         
            +
                    "STABLEWURST_EFFNET": ("https://huggingface.co/stabilityai/StableWurst/resolve/main/effnet_encoder.safetensors?download=true", "models/StableWurst", "effnet_encoder.safetensors"),
         
     | 
| 46 | 
         
            +
                    "STABLEWURST_B_LITE": ("https://huggingface.co/stabilityai/StableWurst/resolve/main/stage_b_lite_bf16.safetensors?download=true", "models/StableWurst", "stage_b_lite_bf16.safetensors"),
         
     | 
| 47 | 
         
            +
                    "STABLEWURST_C": ("https://huggingface.co/stabilityai/StableWurst/resolve/main/stage_c_bf16.safetensors?download=true", "models/StableWurst", "stage_c_bf16.safetensors"),
         
     | 
| 48 | 
         
            +
                    "ULTRAPIXEL_T2I": ("https://huggingface.co/roubaofeipi/UltraPixel/resolve/main/ultrapixel_t2i.safetensors?download=true", "models/UltraPixel", "ultrapixel_t2i.safetensors"),
         
     | 
| 49 | 
         
            +
                    "ULTRAPIXEL_LORA_CAT": ("https://huggingface.co/roubaofeipi/UltraPixel/resolve/main/lora_cat.safetensors?download=true", "models/UltraPixel", "lora_cat.safetensors"),
         
     | 
| 50 | 
         
            +
                }
         
     | 
| 51 | 
         
            +
             
     | 
| 52 | 
         
            +
                for model, (url, folder, filename) in models.items():
         
     | 
| 53 | 
         
            +
                    download_file(url, folder, filename)
         
     | 
| 54 | 
         
            +
             
     | 
| 55 | 
         
            +
            download_models()
         
     | 
| 56 | 
         
            +
             
     | 
| 57 | 
         
            +
            # Global variables
         
     | 
| 58 | 
         
            +
            device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
         
     | 
| 59 | 
         
            +
            dtype = torch.bfloat16
         
     | 
| 60 | 
         
            +
             
     | 
| 61 | 
         
            +
            # Load configs and setup models
         
     | 
| 62 | 
         
            +
            with open("configs/training/t2i.yaml", "r", encoding="utf-8") as file:
         
     | 
| 63 | 
         
            +
                config_c = yaml.safe_load(file)
         
     | 
| 64 | 
         
            +
             
     | 
| 65 | 
         
            +
            with open("configs/inference/stage_b_1b.yaml", "r", encoding="utf-8") as file:
         
     | 
| 66 | 
         
            +
                config_b = yaml.safe_load(file)
         
     | 
| 67 | 
         
            +
             
     | 
| 68 | 
         
            +
            core = WurstCoreC(config_dict=config_c, device=device, training=False)
         
     | 
| 69 | 
         
            +
            core_b = WurstCoreB(config_dict=config_b, device=device, training=False)
         
     | 
| 70 | 
         
            +
             
     | 
| 71 | 
         
            +
            extras = core.setup_extras_pre()
         
     | 
| 72 | 
         
            +
            models = core.setup_models(extras)
         
     | 
| 73 | 
         
            +
            models.generator.eval().requires_grad_(False)
         
     | 
| 74 | 
         
            +
             
     | 
| 75 | 
         
            +
            extras_b = core_b.setup_extras_pre()
         
     | 
| 76 | 
         
            +
            models_b = core_b.setup_models(extras_b, skip_clip=True)
         
     | 
| 77 | 
         
            +
            models_b = WurstCoreB.Models(
         
     | 
| 78 | 
         
            +
               **{**models_b.to_dict(), 'tokenizer': models.tokenizer, 'text_model': models.text_model}
         
     | 
| 79 | 
         
             
            )
         
     | 
| 80 | 
         
            +
            models_b.generator.bfloat16().eval().requires_grad_(False)
         
     | 
| 81 | 
         
            +
             
     | 
| 82 | 
         
            +
            # Load pretrained model
         
     | 
| 83 | 
         
            +
            pretrained_path = "models/ultrapixel_t2i.safetensors"
         
     | 
| 84 | 
         
            +
            sdd = torch.load(pretrained_path, map_location='cpu')
         
     | 
| 85 | 
         
            +
            collect_sd = {k[7:]: v for k, v in sdd.items()}
         
     | 
| 86 | 
         
            +
            models.train_norm.load_state_dict(collect_sd)
         
     | 
| 87 | 
         
            +
            models.generator.eval()
         
     | 
| 88 | 
         
            +
            models.train_norm.eval()
         
     | 
| 89 | 
         
            +
             
     | 
| 90 | 
         
            +
            # Set up sampling configurations
         
     | 
| 91 | 
         
            +
            extras.sampling_configs.update({
         
     | 
| 92 | 
         
            +
                'cfg': 4,
         
     | 
| 93 | 
         
            +
                'shift': 1,
         
     | 
| 94 | 
         
            +
                'timesteps': 20,
         
     | 
| 95 | 
         
            +
                't_start': 1.0,
         
     | 
| 96 | 
         
            +
                'sampler': DDPMSampler(extras.gdf)
         
     | 
| 97 | 
         
            +
            })
         
     | 
| 98 | 
         
            +
             
     | 
| 99 | 
         
            +
            extras_b.sampling_configs.update({
         
     | 
| 100 | 
         
            +
                'cfg': 1.1,
         
     | 
| 101 | 
         
            +
                'shift': 1,
         
     | 
| 102 | 
         
            +
                'timesteps': 10,
         
     | 
| 103 | 
         
            +
                't_start': 1.0
         
     | 
| 104 | 
         
            +
            })
         
     | 
| 105 | 
         
            +
             
     | 
| 106 | 
         
            +
            @spaces.GPU
         
     | 
| 107 | 
         
            +
            def generate_image(prompt, height, width, seed):
         
     | 
| 108 | 
         
            +
                torch.manual_seed(seed)
         
     | 
| 109 | 
         
            +
                random.seed(seed)
         
     | 
| 110 | 
         
            +
                np.random.seed(seed)
         
     | 
| 111 | 
         
            +
             
     | 
| 112 | 
         
            +
                batch_size = 1
         
     | 
| 113 | 
         
            +
                height_lr, width_lr = get_target_lr_size(height / width, std_size=32)
         
     | 
| 114 | 
         
            +
                stage_c_latent_shape, stage_b_latent_shape = calculate_latent_sizes(height, width, batch_size=batch_size)
         
     | 
| 115 | 
         
            +
                stage_c_latent_shape_lr, stage_b_latent_shape_lr = calculate_latent_sizes(height_lr, width_lr, batch_size=batch_size)
         
     | 
| 116 | 
         
            +
             
     | 
| 117 | 
         
            +
                batch = {'captions': [prompt] * batch_size}
         
     | 
| 118 | 
         
            +
                conditions = core.get_conditions(batch, models, extras, is_eval=True, is_unconditional=False, eval_image_embeds=False)
         
     | 
| 119 | 
         
            +
                unconditions = core.get_conditions(batch, models, extras, is_eval=True, is_unconditional=True, eval_image_embeds=False)    
         
     | 
| 120 | 
         | 
| 121 | 
         
            +
                conditions_b = core_b.get_conditions(batch, models_b, extras_b, is_eval=True, is_unconditional=False)
         
     | 
| 122 | 
         
            +
                unconditions_b = core_b.get_conditions(batch, models_b, extras_b, is_eval=True, is_unconditional=True)
         
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 123 | 
         | 
| 124 | 
         
            +
                with torch.no_grad():
         
     | 
| 125 | 
         
            +
                    models.generator.cuda()
         
     | 
| 126 | 
         
            +
                    with torch.cuda.amp.autocast(dtype=dtype):
         
     | 
| 127 | 
         
            +
                        sampled_c = generation_c(batch, models, extras, core, stage_c_latent_shape, stage_c_latent_shape_lr, device)
         
     | 
| 128 | 
         
            +
                    
         
     | 
| 129 | 
         
            +
                    models.generator.cpu()
         
     | 
| 130 | 
         
            +
                    torch.cuda.empty_cache()
         
     | 
| 131 | 
         
            +
                    
         
     | 
| 132 | 
         
            +
                    conditions_b = core_b.get_conditions(batch, models_b, extras_b, is_eval=True, is_unconditional=False)
         
     | 
| 133 | 
         
            +
                    unconditions_b = core_b.get_conditions(batch, models_b, extras_b, is_eval=True, is_unconditional=True)
         
     | 
| 134 | 
         
            +
                    conditions_b['effnet'] = sampled_c
         
     | 
| 135 | 
         
            +
                    unconditions_b['effnet'] = torch.zeros_like(sampled_c)
         
     | 
| 136 | 
         
            +
                    
         
     | 
| 137 | 
         
            +
                    with torch.cuda.amp.autocast(dtype=dtype):
         
     | 
| 138 | 
         
            +
                        sampled = decode_b(conditions_b, unconditions_b, models_b, stage_b_latent_shape, extras_b, device, stage_a_tiled=True)
         
     | 
| 139 | 
         
            +
                    
         
     | 
| 140 | 
         
            +
                    torch.cuda.empty_cache()
         
     | 
| 141 | 
         
            +
                    imgs = show_images(sampled)
         
     | 
| 142 | 
         
            +
                    return imgs[0]
         
     | 
| 143 | 
         
            +
             
     | 
| 144 | 
         
            +
            iface = gr.Interface(
         
     | 
| 145 | 
         
            +
                fn=generate_image,
         
     | 
| 146 | 
         
            +
                inputs=[
         
     | 
| 147 | 
         
            +
                    gr.Textbox(label="Prompt"),
         
     | 
| 148 | 
         
            +
                    gr.Slider(minimum=256, maximum=2560, step=32, label="Height", value=1024),
         
     | 
| 149 | 
         
            +
                    gr.Slider(minimum=256, maximum=5120, step=32, label="Width", value=1024),
         
     | 
| 150 | 
         
            +
                    gr.Number(label="Seed", value=42)
         
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 151 | 
         
             
                ],
         
     | 
| 152 | 
         
            +
                outputs=gr.Image(type="pil"),
         
     | 
| 153 | 
         
            +
                title="UltraPixel Image Generation",
         
     | 
| 154 | 
         
            +
                description="Generate high-resolution images using UltraPixel model.",
         
     | 
| 155 | 
         
            +
                theme='bethecloud/storj_theme'
         
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 156 | 
         
             
            )
         
     | 
| 157 | 
         | 
| 158 | 
         
            +
            iface.launch()
         
     | 
| 
         | 
    	
        configs/inference/controlnet_c_3b_canny.yaml
    ADDED
    
    | 
         @@ -0,0 +1,14 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            # GLOBAL STUFF
         
     | 
| 2 | 
         
            +
            model_version: 3.6B
         
     | 
| 3 | 
         
            +
            dtype: bfloat16
         
     | 
| 4 | 
         
            +
             
     | 
| 5 | 
         
            +
            # ControlNet specific
         
     | 
| 6 | 
         
            +
            controlnet_blocks: [0, 4, 8, 12, 51, 55, 59, 63]
         
     | 
| 7 | 
         
            +
            controlnet_filter: CannyFilter
         
     | 
| 8 | 
         
            +
            controlnet_filter_params: 
         
     | 
| 9 | 
         
            +
              resize: 224
         
     | 
| 10 | 
         
            +
             
     | 
| 11 | 
         
            +
            effnet_checkpoint_path: models/effnet_encoder.safetensors
         
     | 
| 12 | 
         
            +
            previewer_checkpoint_path: models/previewer.safetensors
         
     | 
| 13 | 
         
            +
            generator_checkpoint_path: models/stage_c_bf16.safetensors
         
     | 
| 14 | 
         
            +
            controlnet_checkpoint_path: models/canny.safetensors
         
     | 
    	
        configs/inference/controlnet_c_3b_identity.yaml
    ADDED
    
    | 
         @@ -0,0 +1,17 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            # GLOBAL STUFF
         
     | 
| 2 | 
         
            +
            model_version: 3.6B
         
     | 
| 3 | 
         
            +
            dtype: bfloat16
         
     | 
| 4 | 
         
            +
             
     | 
| 5 | 
         
            +
            # ControlNet specific
         
     | 
| 6 | 
         
            +
            controlnet_bottleneck_mode: 'simple'
         
     | 
| 7 | 
         
            +
            controlnet_blocks: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63]
         
     | 
| 8 | 
         
            +
            controlnet_filter: IdentityFilter
         
     | 
| 9 | 
         
            +
            controlnet_filter_params: 
         
     | 
| 10 | 
         
            +
              max_faces: 4
         
     | 
| 11 | 
         
            +
              p_drop: 0.00
         
     | 
| 12 | 
         
            +
              p_full: 0.0
         
     | 
| 13 | 
         
            +
             
     | 
| 14 | 
         
            +
            effnet_checkpoint_path: models/effnet_encoder.safetensors
         
     | 
| 15 | 
         
            +
            previewer_checkpoint_path: models/previewer.safetensors
         
     | 
| 16 | 
         
            +
            generator_checkpoint_path: models/stage_c_bf16.safetensors
         
     | 
| 17 | 
         
            +
            controlnet_checkpoint_path: 
         
     | 
    	
        configs/inference/controlnet_c_3b_inpainting.yaml
    ADDED
    
    | 
         @@ -0,0 +1,15 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            # GLOBAL STUFF
         
     | 
| 2 | 
         
            +
            model_version: 3.6B
         
     | 
| 3 | 
         
            +
            dtype: bfloat16
         
     | 
| 4 | 
         
            +
             
     | 
| 5 | 
         
            +
            # ControlNet specific
         
     | 
| 6 | 
         
            +
            controlnet_blocks: [0, 4, 8, 12, 51, 55, 59, 63]
         
     | 
| 7 | 
         
            +
            controlnet_filter: InpaintFilter
         
     | 
| 8 | 
         
            +
            controlnet_filter_params: 
         
     | 
| 9 | 
         
            +
              thresold: [0.04, 0.4]
         
     | 
| 10 | 
         
            +
              p_outpaint: 0.4
         
     | 
| 11 | 
         
            +
             
     | 
| 12 | 
         
            +
            effnet_checkpoint_path: models/effnet_encoder.safetensors
         
     | 
| 13 | 
         
            +
            previewer_checkpoint_path: models/previewer.safetensors
         
     | 
| 14 | 
         
            +
            generator_checkpoint_path: models/stage_c_bf16.safetensors
         
     | 
| 15 | 
         
            +
            controlnet_checkpoint_path: models/inpainting.safetensors
         
     | 
    	
        configs/inference/controlnet_c_3b_sr.yaml
    ADDED
    
    | 
         @@ -0,0 +1,15 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            # GLOBAL STUFF
         
     | 
| 2 | 
         
            +
            model_version: 3.6B
         
     | 
| 3 | 
         
            +
            dtype: bfloat16
         
     | 
| 4 | 
         
            +
             
     | 
| 5 | 
         
            +
            # ControlNet specific
         
     | 
| 6 | 
         
            +
            controlnet_bottleneck_mode: 'large'
         
     | 
| 7 | 
         
            +
            controlnet_blocks: [0, 4, 8, 12, 51, 55, 59, 63]
         
     | 
| 8 | 
         
            +
            controlnet_filter: SREffnetFilter
         
     | 
| 9 | 
         
            +
            controlnet_filter_params: 
         
     | 
| 10 | 
         
            +
              scale_factor: 0.5
         
     | 
| 11 | 
         
            +
             
     | 
| 12 | 
         
            +
            effnet_checkpoint_path: models/effnet_encoder.safetensors
         
     | 
| 13 | 
         
            +
            previewer_checkpoint_path: models/previewer.safetensors
         
     | 
| 14 | 
         
            +
            generator_checkpoint_path: models/stage_c_bf16.safetensors
         
     | 
| 15 | 
         
            +
            controlnet_checkpoint_path: models/super_resolution.safetensors
         
     | 
    	
        configs/inference/lora_c_3b.yaml
    ADDED
    
    | 
         @@ -0,0 +1,15 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            # GLOBAL STUFF
         
     | 
| 2 | 
         
            +
            model_version: 3.6B
         
     | 
| 3 | 
         
            +
            dtype: bfloat16
         
     | 
| 4 | 
         
            +
             
     | 
| 5 | 
         
            +
            # LoRA specific
         
     | 
| 6 | 
         
            +
            module_filters: ['.attn']
         
     | 
| 7 | 
         
            +
            rank: 4
         
     | 
| 8 | 
         
            +
            train_tokens:
         
     | 
| 9 | 
         
            +
              # - ['^snail', null] # token starts with "snail" -> "snail" & "snails", don't need to be reinitialized
         
     | 
| 10 | 
         
            +
              - ['[fernando]', '^dog</w>'] # custom token [snail], initialize as avg of snail & snails
         
     | 
| 11 | 
         
            +
             
     | 
| 12 | 
         
            +
            effnet_checkpoint_path: models/effnet_encoder.safetensors
         
     | 
| 13 | 
         
            +
            previewer_checkpoint_path: models/previewer.safetensors
         
     | 
| 14 | 
         
            +
            generator_checkpoint_path: models/stage_c_bf16.safetensors
         
     | 
| 15 | 
         
            +
            lora_checkpoint_path: models/lora_fernando_10k.safetensors
         
     | 
    	
        configs/inference/stage_b_1b.yaml
    ADDED
    
    | 
         @@ -0,0 +1,13 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            # GLOBAL STUFF
         
     | 
| 2 | 
         
            +
            model_version: 700M
         
     | 
| 3 | 
         
            +
            dtype: bfloat16
         
     | 
| 4 | 
         
            +
             
     | 
| 5 | 
         
            +
            # For demonstration purposes in reconstruct_images.ipynb
         
     | 
| 6 | 
         
            +
            webdataset_path: path to your dataset
         
     | 
| 7 | 
         
            +
            batch_size: 1
         
     | 
| 8 | 
         
            +
            image_size: 2048
         
     | 
| 9 | 
         
            +
            grad_accum_steps: 1
         
     | 
| 10 | 
         
            +
             
     | 
| 11 | 
         
            +
            effnet_checkpoint_path: models/effnet_encoder.safetensors
         
     | 
| 12 | 
         
            +
            stage_a_checkpoint_path: models/stage_a.safetensors
         
     | 
| 13 | 
         
            +
            generator_checkpoint_path: models/stage_b_lite_bf16.safetensors
         
     | 
    	
        configs/inference/stage_b_3b.yaml
    ADDED
    
    | 
         @@ -0,0 +1,13 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            # GLOBAL STUFF
         
     | 
| 2 | 
         
            +
            model_version: 3B
         
     | 
| 3 | 
         
            +
            dtype: bfloat16
         
     | 
| 4 | 
         
            +
             
     | 
| 5 | 
         
            +
            # For demonstration purposes in reconstruct_images.ipynb
         
     | 
| 6 | 
         
            +
            webdataset_path: path to your dataset
         
     | 
| 7 | 
         
            +
            batch_size: 4
         
     | 
| 8 | 
         
            +
            image_size: 1024
         
     | 
| 9 | 
         
            +
            grad_accum_steps: 1
         
     | 
| 10 | 
         
            +
             
     | 
| 11 | 
         
            +
            effnet_checkpoint_path: models/effnet_encoder.safetensors
         
     | 
| 12 | 
         
            +
            stage_a_checkpoint_path: models/stage_a.safetensors
         
     | 
| 13 | 
         
            +
            generator_checkpoint_path: models/stage_b_lite_bf16.safetensors
         
     | 
    	
        configs/inference/stage_c_1b.yaml
    ADDED
    
    | 
         @@ -0,0 +1,7 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            # GLOBAL STUFF
         
     | 
| 2 | 
         
            +
            model_version: 1B
         
     | 
| 3 | 
         
            +
            dtype: bfloat16
         
     | 
| 4 | 
         
            +
             
     | 
| 5 | 
         
            +
            effnet_checkpoint_path: models/effnet_encoder.safetensors
         
     | 
| 6 | 
         
            +
            previewer_checkpoint_path: models/previewer.safetensors
         
     | 
| 7 | 
         
            +
            generator_checkpoint_path: models/stage_c_lite_bf16.safetensors
         
     | 
    	
        configs/inference/stage_c_3b.yaml
    ADDED
    
    | 
         @@ -0,0 +1,7 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            # GLOBAL STUFF
         
     | 
| 2 | 
         
            +
            model_version: 3.6B
         
     | 
| 3 | 
         
            +
            dtype: bfloat16
         
     | 
| 4 | 
         
            +
             
     | 
| 5 | 
         
            +
            effnet_checkpoint_path: models/effnet_encoder.safetensors
         
     | 
| 6 | 
         
            +
            previewer_checkpoint_path: models/previewer.safetensors
         
     | 
| 7 | 
         
            +
            generator_checkpoint_path: models/stage_c_bf16.safetensors
         
     | 
    	
        configs/training/cfg_control_lr.yaml
    ADDED
    
    | 
         @@ -0,0 +1,47 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            # GLOBAL STUFF
         
     | 
| 2 | 
         
            +
            experiment_id: Ultrapixel_controlnet
         
     | 
| 3 | 
         
            +
             
     | 
| 4 | 
         
            +
            checkpoint_path: checkpoint output path
         
     | 
| 5 | 
         
            +
            output_path:  visual results output path
         
     | 
| 6 | 
         
            +
            model_version: 3.6B
         
     | 
| 7 | 
         
            +
            dtype: float32
         
     | 
| 8 | 
         
            +
            # # WandB
         
     | 
| 9 | 
         
            +
            # wandb_project: StableCascade
         
     | 
| 10 | 
         
            +
            # wandb_entity: wandb_username
         
     | 
| 11 | 
         
            +
            #module_filters: ['.depthwise', '.mapper', '.attn', '.channelwise' ]
         
     | 
| 12 | 
         
            +
            #rank: 32
         
     | 
| 13 | 
         
            +
            # TRAINING PARAMS
         
     | 
| 14 | 
         
            +
            lr: 1.0e-4
         
     | 
| 15 | 
         
            +
            batch_size: 12
         
     | 
| 16 | 
         
            +
            #image_size: [1536, 2048, 2560, 3072, 4096]
         
     | 
| 17 | 
         
            +
            image_size: [1024, 2048, 2560, 3072, 3584, 3840, 4096, 4608]
         
     | 
| 18 | 
         
            +
            #image_size: [  1024, 1536, 2048, 2560,  3072, 3584, 3840, 4096, 4608]
         
     | 
| 19 | 
         
            +
            #image_size: [  1024, 1280]
         
     | 
| 20 | 
         
            +
            multi_aspect_ratio: [1/1, 1/2, 1/3, 2/3, 3/4, 1/5, 2/5, 3/5, 4/5, 1/6, 5/6, 9/16]
         
     | 
| 21 | 
         
            +
            grad_accum_steps: 2
         
     | 
| 22 | 
         
            +
            updates: 40000
         
     | 
| 23 | 
         
            +
            backup_every: 5000
         
     | 
| 24 | 
         
            +
            save_every: 256
         
     | 
| 25 | 
         
            +
            warmup_updates: 1
         
     | 
| 26 | 
         
            +
            use_fsdp: True
         
     | 
| 27 | 
         
            +
             
     | 
| 28 | 
         
            +
            # ControlNet specific
         
     | 
| 29 | 
         
            +
            controlnet_blocks: [0, 4, 8, 12, 51, 55, 59, 63]
         
     | 
| 30 | 
         
            +
            controlnet_filter: CannyFilter
         
     | 
| 31 | 
         
            +
            controlnet_filter_params: 
         
     | 
| 32 | 
         
            +
              resize: 224
         
     | 
| 33 | 
         
            +
            # offset_noise: 0.1
         
     | 
| 34 | 
         
            +
             
     | 
| 35 | 
         
            +
            # GDF
         
     | 
| 36 | 
         
            +
            adaptive_loss_weight: True
         
     | 
| 37 | 
         
            +
             
     | 
| 38 | 
         
            +
            ema_start_iters: 10
         
     | 
| 39 | 
         
            +
            ema_iters: 50
         
     | 
| 40 | 
         
            +
            ema_beta: 0.9
         
     | 
| 41 | 
         
            +
             
     | 
| 42 | 
         
            +
            webdataset_path: path to your training dataset
         
     | 
| 43 | 
         
            +
            effnet_checkpoint_path: models/effnet_encoder.safetensors
         
     | 
| 44 | 
         
            +
            previewer_checkpoint_path: models/previewer.safetensors
         
     | 
| 45 | 
         
            +
            generator_checkpoint_path: models/stage_c_bf16.safetensors
         
     | 
| 46 | 
         
            +
            controlnet_checkpoint_path: pretrained controlnet path
         
     | 
| 47 | 
         
            +
             
     | 
    	
        configs/training/lora_personalization.yaml
    ADDED
    
    | 
         @@ -0,0 +1,37 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            # GLOBAL STUFF
         
     | 
| 2 | 
         
            +
            experiment_id: roubao_cat_personalized
         
     | 
| 3 | 
         
            +
             
     | 
| 4 | 
         
            +
            checkpoint_path: checkpoint output path
         
     | 
| 5 | 
         
            +
            output_path:  visual results output path
         
     | 
| 6 | 
         
            +
            model_version: 3.6B
         
     | 
| 7 | 
         
            +
            dtype: float32
         
     | 
| 8 | 
         
            +
             
     | 
| 9 | 
         
            +
            module_filters: [ '.attn']
         
     | 
| 10 | 
         
            +
            rank: 4
         
     | 
| 11 | 
         
            +
            train_tokens:
         
     | 
| 12 | 
         
            +
              # - ['^snail', null] # token starts with "snail" -> "snail" & "snails", don't need to be reinitialized
         
     | 
| 13 | 
         
            +
              - ['[roubaobao]', '^cat</w>'] # custom token [snail], initialize as avg of snail & snails
         
     | 
| 14 | 
         
            +
            # TRAINING PARAMS
         
     | 
| 15 | 
         
            +
            lr: 1.0e-4
         
     | 
| 16 | 
         
            +
            batch_size: 4
         
     | 
| 17 | 
         
            +
             
     | 
| 18 | 
         
            +
            image_size: [1024, 2048, 2560, 3072, 3584, 3840, 4096, 4608]
         
     | 
| 19 | 
         
            +
            multi_aspect_ratio: [1/1, 1/2, 1/3, 2/3, 3/4, 1/5, 2/5, 3/5, 4/5, 1/6, 5/6, 9/16]
         
     | 
| 20 | 
         
            +
            grad_accum_steps: 2
         
     | 
| 21 | 
         
            +
            updates: 40000
         
     | 
| 22 | 
         
            +
            backup_every: 5000
         
     | 
| 23 | 
         
            +
            save_every: 512
         
     | 
| 24 | 
         
            +
            warmup_updates: 1
         
     | 
| 25 | 
         
            +
            use_ddp: True
         
     | 
| 26 | 
         
            +
             
     | 
| 27 | 
         
            +
            # GDF
         
     | 
| 28 | 
         
            +
            adaptive_loss_weight: True
         
     | 
| 29 | 
         
            +
             
     | 
| 30 | 
         
            +
             
     | 
| 31 | 
         
            +
            tmp_prompt: a photo of a cat [roubaobao]
         
     | 
| 32 | 
         
            +
            webdataset_path: path to your personalized training dataset
         
     | 
| 33 | 
         
            +
            effnet_checkpoint_path:  models/effnet_encoder.safetensors
         
     | 
| 34 | 
         
            +
            previewer_checkpoint_path:  models/previewer.safetensors
         
     | 
| 35 | 
         
            +
            generator_checkpoint_path:  models/stage_c_bf16.safetensors
         
     | 
| 36 | 
         
            +
            ultrapixel_path: models/ultrapixel_t2i.safetensors
         
     | 
| 37 | 
         
            +
             
     | 
    	
        configs/training/t2i.yaml
    ADDED
    
    | 
         @@ -0,0 +1,29 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            # GLOBAL STUFF
         
     | 
| 2 | 
         
            +
            experiment_id: ultrapixel_t2i
         
     | 
| 3 | 
         
            +
            #strc_fixlrt_norm3_lite_1024_hrft_newdata
         
     | 
| 4 | 
         
            +
            checkpoint_path: checkpoint output path   #output model directory
         
     | 
| 5 | 
         
            +
            output_path: visual results output path       #experiment output directory
         
     | 
| 6 | 
         
            +
            model_version: 3.6B    # finetune large  stage c model of stablecascade
         
     | 
| 7 | 
         
            +
            dtype: float32
         
     | 
| 8 | 
         
            +
             
     | 
| 9 | 
         
            +
             
     | 
| 10 | 
         
            +
            # TRAINING PARAMS
         
     | 
| 11 | 
         
            +
            lr: 1.0e-4
         
     | 
| 12 | 
         
            +
            batch_size: 4   # gpu_number * num_per_gpu * grad_accum_steps
         
     | 
| 13 | 
         
            +
            image_size: [1024, 2048, 2560, 3072, 3584, 3840, 4096, 4608]  # possible image resolution
         
     | 
| 14 | 
         
            +
            multi_aspect_ratio: [1/1, 1/2, 1/3, 2/3, 3/4, 1/5, 2/5, 3/5, 4/5, 1/6, 5/6, 9/16]
         
     | 
| 15 | 
         
            +
            grad_accum_steps: 2
         
     | 
| 16 | 
         
            +
            updates: 40000
         
     | 
| 17 | 
         
            +
            backup_every: 5000
         
     | 
| 18 | 
         
            +
            save_every: 256
         
     | 
| 19 | 
         
            +
            warmup_updates: 1
         
     | 
| 20 | 
         
            +
            use_ddp: True
         
     | 
| 21 | 
         
            +
             
     | 
| 22 | 
         
            +
            # GDF
         
     | 
| 23 | 
         
            +
            adaptive_loss_weight: True
         
     | 
| 24 | 
         
            +
             
     | 
| 25 | 
         
            +
             
     | 
| 26 | 
         
            +
            webdataset_path: path to your personalized training dataset
         
     | 
| 27 | 
         
            +
            effnet_checkpoint_path: models/effnet_encoder.safetensors
         
     | 
| 28 | 
         
            +
            previewer_checkpoint_path: models/previewer.safetensors
         
     | 
| 29 | 
         
            +
            generator_checkpoint_path: models/stage_c_bf16.safetensors
         
     | 
    	
        core/__init__.py
    ADDED
    
    | 
         @@ -0,0 +1,372 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            import os
         
     | 
| 2 | 
         
            +
            import yaml
         
     | 
| 3 | 
         
            +
            import torch
         
     | 
| 4 | 
         
            +
            from torch import nn
         
     | 
| 5 | 
         
            +
            import wandb
         
     | 
| 6 | 
         
            +
            import json
         
     | 
| 7 | 
         
            +
            from abc import ABC, abstractmethod
         
     | 
| 8 | 
         
            +
            from dataclasses import dataclass
         
     | 
| 9 | 
         
            +
            from torch.utils.data import Dataset, DataLoader
         
     | 
| 10 | 
         
            +
             
     | 
| 11 | 
         
            +
            from torch.distributed import init_process_group, destroy_process_group, barrier
         
     | 
| 12 | 
         
            +
            from torch.distributed.fsdp import (
         
     | 
| 13 | 
         
            +
                FullyShardedDataParallel as FSDP,
         
     | 
| 14 | 
         
            +
                FullStateDictConfig,
         
     | 
| 15 | 
         
            +
                MixedPrecision,
         
     | 
| 16 | 
         
            +
                ShardingStrategy,
         
     | 
| 17 | 
         
            +
                StateDictType
         
     | 
| 18 | 
         
            +
            )
         
     | 
| 19 | 
         
            +
             
     | 
| 20 | 
         
            +
            from .utils import Base, EXPECTED, EXPECTED_TRAIN
         
     | 
| 21 | 
         
            +
            from .utils import create_folder_if_necessary, safe_save, load_or_fail
         
     | 
| 22 | 
         
            +
             
     | 
| 23 | 
         
            +
            # pylint: disable=unused-argument
         
     | 
| 24 | 
         
            +
            class WarpCore(ABC):
         
     | 
| 25 | 
         
            +
                @dataclass(frozen=True)
         
     | 
| 26 | 
         
            +
                class Config(Base):
         
     | 
| 27 | 
         
            +
                    experiment_id: str = EXPECTED_TRAIN
         
     | 
| 28 | 
         
            +
                    checkpoint_path: str = EXPECTED_TRAIN
         
     | 
| 29 | 
         
            +
                    output_path: str = EXPECTED_TRAIN
         
     | 
| 30 | 
         
            +
                    checkpoint_extension: str = "safetensors"
         
     | 
| 31 | 
         
            +
                    dist_file_subfolder: str = ""
         
     | 
| 32 | 
         
            +
                    allow_tf32: bool = True
         
     | 
| 33 | 
         
            +
             
     | 
| 34 | 
         
            +
                    wandb_project: str = None
         
     | 
| 35 | 
         
            +
                    wandb_entity: str = None
         
     | 
| 36 | 
         
            +
             
     | 
| 37 | 
         
            +
                @dataclass() # not frozen, means that fields are mutable
         
     | 
| 38 | 
         
            +
                class Info(): # not inheriting from Base, because we don't want to enforce the default fields
         
     | 
| 39 | 
         
            +
                    wandb_run_id: str = None
         
     | 
| 40 | 
         
            +
                    total_steps: int = 0
         
     | 
| 41 | 
         
            +
                    iter: int = 0
         
     | 
| 42 | 
         
            +
             
     | 
| 43 | 
         
            +
                @dataclass(frozen=True)
         
     | 
| 44 | 
         
            +
                class Data(Base):
         
     | 
| 45 | 
         
            +
                    dataset: Dataset = EXPECTED
         
     | 
| 46 | 
         
            +
                    dataloader: DataLoader  = EXPECTED
         
     | 
| 47 | 
         
            +
                    iterator: any = EXPECTED
         
     | 
| 48 | 
         
            +
             
     | 
| 49 | 
         
            +
                @dataclass(frozen=True)
         
     | 
| 50 | 
         
            +
                class Models(Base):
         
     | 
| 51 | 
         
            +
                    pass
         
     | 
| 52 | 
         
            +
             
     | 
| 53 | 
         
            +
                @dataclass(frozen=True)
         
     | 
| 54 | 
         
            +
                class Optimizers(Base):
         
     | 
| 55 | 
         
            +
                    pass
         
     | 
| 56 | 
         
            +
             
     | 
| 57 | 
         
            +
                @dataclass(frozen=True)
         
     | 
| 58 | 
         
            +
                class Schedulers(Base):
         
     | 
| 59 | 
         
            +
                    pass
         
     | 
| 60 | 
         
            +
             
     | 
| 61 | 
         
            +
                @dataclass(frozen=True)
         
     | 
| 62 | 
         
            +
                class Extras(Base):
         
     | 
| 63 | 
         
            +
                    pass
         
     | 
| 64 | 
         
            +
                # ---------------------------------------
         
     | 
| 65 | 
         
            +
                info: Info
         
     | 
| 66 | 
         
            +
                config: Config
         
     | 
| 67 | 
         
            +
             
     | 
| 68 | 
         
            +
                # FSDP stuff
         
     | 
| 69 | 
         
            +
                fsdp_defaults = {
         
     | 
| 70 | 
         
            +
                    "sharding_strategy": ShardingStrategy.SHARD_GRAD_OP,
         
     | 
| 71 | 
         
            +
                    "cpu_offload": None,
         
     | 
| 72 | 
         
            +
                    "mixed_precision": MixedPrecision(
         
     | 
| 73 | 
         
            +
                        param_dtype=torch.bfloat16,
         
     | 
| 74 | 
         
            +
                        reduce_dtype=torch.bfloat16,
         
     | 
| 75 | 
         
            +
                        buffer_dtype=torch.bfloat16,
         
     | 
| 76 | 
         
            +
                    ),
         
     | 
| 77 | 
         
            +
                    "limit_all_gathers": True,
         
     | 
| 78 | 
         
            +
                }
         
     | 
| 79 | 
         
            +
                fsdp_fullstate_save_policy = FullStateDictConfig(
         
     | 
| 80 | 
         
            +
                    offload_to_cpu=True, rank0_only=True
         
     | 
| 81 | 
         
            +
                )
         
     | 
| 82 | 
         
            +
                # ------------
         
     | 
| 83 | 
         
            +
             
     | 
| 84 | 
         
            +
                # OVERRIDEABLE METHODS
         
     | 
| 85 | 
         
            +
                
         
     | 
| 86 | 
         
            +
                # [optionally] setup extra stuff, will be called BEFORE the models & optimizers are setup
         
     | 
| 87 | 
         
            +
                def setup_extras_pre(self) -> Extras:
         
     | 
| 88 | 
         
            +
                    return self.Extras()
         
     | 
| 89 | 
         
            +
             
     | 
| 90 | 
         
            +
                # setup dataset & dataloader, return a dict contained dataser, dataloader and/or iterator
         
     | 
| 91 | 
         
            +
                @abstractmethod
         
     | 
| 92 | 
         
            +
                def setup_data(self, extras: Extras) -> Data:
         
     | 
| 93 | 
         
            +
                    raise NotImplementedError("This method needs to be overriden")
         
     | 
| 94 | 
         
            +
             
     | 
| 95 | 
         
            +
                # return a dict with all models that are going to be used in the training
         
     | 
| 96 | 
         
            +
                @abstractmethod
         
     | 
| 97 | 
         
            +
                def setup_models(self, extras: Extras) -> Models:
         
     | 
| 98 | 
         
            +
                    raise NotImplementedError("This method needs to be overriden")
         
     | 
| 99 | 
         
            +
             
     | 
| 100 | 
         
            +
                # return a dict with all optimizers that are going to be used in the training
         
     | 
| 101 | 
         
            +
                @abstractmethod
         
     | 
| 102 | 
         
            +
                def setup_optimizers(self, extras: Extras, models: Models) -> Optimizers:
         
     | 
| 103 | 
         
            +
                    raise NotImplementedError("This method needs to be overriden")
         
     | 
| 104 | 
         
            +
             
     | 
| 105 | 
         
            +
                # [optionally] return a dict with all schedulers that are going to be used in the training
         
     | 
| 106 | 
         
            +
                def setup_schedulers(self, extras: Extras, models: Models, optimizers: Optimizers) -> Schedulers:
         
     | 
| 107 | 
         
            +
                    return self.Schedulers()
         
     | 
| 108 | 
         
            +
             
     | 
| 109 | 
         
            +
                # [optionally] setup extra stuff, will be called AFTER the models & optimizers are setup
         
     | 
| 110 | 
         
            +
                def setup_extras_post(self, extras: Extras, models: Models, optimizers: Optimizers, schedulers: Schedulers) -> Extras:
         
     | 
| 111 | 
         
            +
                    return self.Extras.from_dict(extras.to_dict())
         
     | 
| 112 | 
         
            +
             
     | 
| 113 | 
         
            +
                # perform the training here
         
     | 
| 114 | 
         
            +
                @abstractmethod
         
     | 
| 115 | 
         
            +
                def train(self, data: Data, extras: Extras, models: Models, optimizers: Optimizers, schedulers: Schedulers):
         
     | 
| 116 | 
         
            +
                    raise NotImplementedError("This method needs to be overriden")
         
     | 
| 117 | 
         
            +
                # ------------
         
     | 
| 118 | 
         
            +
             
     | 
| 119 | 
         
            +
                def setup_info(self, full_path=None) -> Info:
         
     | 
| 120 | 
         
            +
                    if full_path is None:
         
     | 
| 121 | 
         
            +
                        full_path = (f"{self.config.checkpoint_path}/{self.config.experiment_id}/info.json")
         
     | 
| 122 | 
         
            +
                    info_dict = load_or_fail(full_path, wandb_run_id=None) or {}
         
     | 
| 123 | 
         
            +
                    info_dto = self.Info(**info_dict)
         
     | 
| 124 | 
         
            +
                    if info_dto.total_steps > 0 and self.is_main_node:
         
     | 
| 125 | 
         
            +
                        print(">>> RESUMING TRAINING FROM ITER ", info_dto.total_steps)
         
     | 
| 126 | 
         
            +
                    return info_dto
         
     | 
| 127 | 
         
            +
             
     | 
| 128 | 
         
            +
                def setup_config(self, config_file_path=None, config_dict=None, training=True) -> Config:
         
     | 
| 129 | 
         
            +
                    if config_file_path is not None:
         
     | 
| 130 | 
         
            +
                        if config_file_path.endswith(".yml") or config_file_path.endswith(".yaml"):
         
     | 
| 131 | 
         
            +
                            with open(config_file_path, "r", encoding="utf-8") as file:
         
     | 
| 132 | 
         
            +
                                loaded_config = yaml.safe_load(file)
         
     | 
| 133 | 
         
            +
                        elif config_file_path.endswith(".json"):
         
     | 
| 134 | 
         
            +
                            with open(config_file_path, "r", encoding="utf-8") as file:
         
     | 
| 135 | 
         
            +
                                loaded_config = json.load(file)
         
     | 
| 136 | 
         
            +
                        else:
         
     | 
| 137 | 
         
            +
                            raise ValueError("Config file must be either a .yml|.yaml or .json file")
         
     | 
| 138 | 
         
            +
                        return self.Config.from_dict({**loaded_config, 'training': training})
         
     | 
| 139 | 
         
            +
                    if config_dict is not None:
         
     | 
| 140 | 
         
            +
                        return self.Config.from_dict({**config_dict, 'training': training})
         
     | 
| 141 | 
         
            +
                    return self.Config(training=training)
         
     | 
| 142 | 
         
            +
             
     | 
| 143 | 
         
            +
                def setup_ddp(self, experiment_id, single_gpu=False):
         
     | 
| 144 | 
         
            +
                    if not single_gpu:
         
     | 
| 145 | 
         
            +
                        local_rank = int(os.environ.get("SLURM_LOCALID"))
         
     | 
| 146 | 
         
            +
                        process_id = int(os.environ.get("SLURM_PROCID"))
         
     | 
| 147 | 
         
            +
                        world_size = int(os.environ.get("SLURM_NNODES")) * torch.cuda.device_count()
         
     | 
| 148 | 
         
            +
             
     | 
| 149 | 
         
            +
                        self.process_id = process_id
         
     | 
| 150 | 
         
            +
                        self.is_main_node = process_id == 0
         
     | 
| 151 | 
         
            +
                        self.device = torch.device(local_rank)
         
     | 
| 152 | 
         
            +
                        self.world_size = world_size
         
     | 
| 153 | 
         
            +
             
     | 
| 154 | 
         
            +
                        dist_file_path = f"{os.getcwd()}/{self.config.dist_file_subfolder}dist_file_{experiment_id}"
         
     | 
| 155 | 
         
            +
                        # if os.path.exists(dist_file_path) and self.is_main_node:
         
     | 
| 156 | 
         
            +
                        #     os.remove(dist_file_path)
         
     | 
| 157 | 
         
            +
             
     | 
| 158 | 
         
            +
                        torch.cuda.set_device(local_rank)
         
     | 
| 159 | 
         
            +
                        init_process_group(
         
     | 
| 160 | 
         
            +
                            backend="nccl",
         
     | 
| 161 | 
         
            +
                            rank=process_id,
         
     | 
| 162 | 
         
            +
                            world_size=world_size,
         
     | 
| 163 | 
         
            +
                            init_method=f"file://{dist_file_path}",
         
     | 
| 164 | 
         
            +
                        )
         
     | 
| 165 | 
         
            +
                        print(f"[GPU {process_id}] READY")
         
     | 
| 166 | 
         
            +
                    else:
         
     | 
| 167 | 
         
            +
                        print("Running in single thread, DDP not enabled.")
         
     | 
| 168 | 
         
            +
             
     | 
| 169 | 
         
            +
                def setup_wandb(self):
         
     | 
| 170 | 
         
            +
                    if self.is_main_node and self.config.wandb_project is not None:
         
     | 
| 171 | 
         
            +
                        self.info.wandb_run_id = self.info.wandb_run_id or wandb.util.generate_id()
         
     | 
| 172 | 
         
            +
                        wandb.init(project=self.config.wandb_project, entity=self.config.wandb_entity, name=self.config.experiment_id, id=self.info.wandb_run_id, resume="allow", config=self.config.to_dict())
         
     | 
| 173 | 
         
            +
             
     | 
| 174 | 
         
            +
                        if self.info.total_steps > 0:
         
     | 
| 175 | 
         
            +
                            wandb.alert(title=f"Training {self.info.wandb_run_id} resumed", text=f"Training {self.info.wandb_run_id} resumed from step {self.info.total_steps}")
         
     | 
| 176 | 
         
            +
                        else:
         
     | 
| 177 | 
         
            +
                            wandb.alert(title=f"Training {self.info.wandb_run_id} started", text=f"Training {self.info.wandb_run_id} started")
         
     | 
| 178 | 
         
            +
             
     | 
| 179 | 
         
            +
                # LOAD UTILITIES ----------
         
     | 
| 180 | 
         
            +
                def load_model(self, model, model_id=None, full_path=None, strict=True):
         
     | 
| 181 | 
         
            +
                    print('in line 181 load model', type(model), model_id, full_path, strict)
         
     | 
| 182 | 
         
            +
                    if model_id is not None and full_path is None:
         
     | 
| 183 | 
         
            +
                        full_path = f"{self.config.checkpoint_path}/{self.config.experiment_id}/{model_id}.{self.config.checkpoint_extension}"
         
     | 
| 184 | 
         
            +
                    elif full_path is None and model_id is None:
         
     | 
| 185 | 
         
            +
                        raise ValueError(
         
     | 
| 186 | 
         
            +
                            "This method expects either 'model_id' or 'full_path' to be defined"
         
     | 
| 187 | 
         
            +
                        )
         
     | 
| 188 | 
         
            +
             
     | 
| 189 | 
         
            +
                    checkpoint = load_or_fail(full_path, wandb_run_id=self.info.wandb_run_id if self.is_main_node else None)
         
     | 
| 190 | 
         
            +
                    if checkpoint is not None:
         
     | 
| 191 | 
         
            +
                        model.load_state_dict(checkpoint, strict=strict)
         
     | 
| 192 | 
         
            +
                        del checkpoint
         
     | 
| 193 | 
         
            +
             
     | 
| 194 | 
         
            +
                    return model
         
     | 
| 195 | 
         
            +
             
     | 
| 196 | 
         
            +
                def load_optimizer(self, optim, optim_id=None, full_path=None, fsdp_model=None):
         
     | 
| 197 | 
         
            +
                    if optim_id is not None and full_path is None:
         
     | 
| 198 | 
         
            +
                        full_path = f"{self.config.checkpoint_path}/{self.config.experiment_id}/{optim_id}.pt"
         
     | 
| 199 | 
         
            +
                    elif full_path is None and optim_id is None:
         
     | 
| 200 | 
         
            +
                        raise ValueError(
         
     | 
| 201 | 
         
            +
                            "This method expects either 'optim_id' or 'full_path' to be defined"
         
     | 
| 202 | 
         
            +
                        )
         
     | 
| 203 | 
         
            +
             
     | 
| 204 | 
         
            +
                    checkpoint = load_or_fail(full_path, wandb_run_id=self.info.wandb_run_id if self.is_main_node else None)
         
     | 
| 205 | 
         
            +
                    if checkpoint is not None:
         
     | 
| 206 | 
         
            +
                        try:
         
     | 
| 207 | 
         
            +
                            if fsdp_model is not None:
         
     | 
| 208 | 
         
            +
                                sharded_optimizer_state_dict = (
         
     | 
| 209 | 
         
            +
                                    FSDP.scatter_full_optim_state_dict(  # <---- FSDP
         
     | 
| 210 | 
         
            +
                                        checkpoint
         
     | 
| 211 | 
         
            +
                                        if (
         
     | 
| 212 | 
         
            +
                                            self.is_main_node
         
     | 
| 213 | 
         
            +
                                            or self.fsdp_defaults["sharding_strategy"]
         
     | 
| 214 | 
         
            +
                                            == ShardingStrategy.NO_SHARD
         
     | 
| 215 | 
         
            +
                                        )
         
     | 
| 216 | 
         
            +
                                        else None,
         
     | 
| 217 | 
         
            +
                                        fsdp_model,
         
     | 
| 218 | 
         
            +
                                    )
         
     | 
| 219 | 
         
            +
                                )
         
     | 
| 220 | 
         
            +
                                optim.load_state_dict(sharded_optimizer_state_dict)
         
     | 
| 221 | 
         
            +
                                del checkpoint, sharded_optimizer_state_dict
         
     | 
| 222 | 
         
            +
                            else:
         
     | 
| 223 | 
         
            +
                                optim.load_state_dict(checkpoint)
         
     | 
| 224 | 
         
            +
                        # pylint: disable=broad-except
         
     | 
| 225 | 
         
            +
                        except Exception as e:
         
     | 
| 226 | 
         
            +
                            print("!!! Failed loading optimizer, skipping... Exception:", e)
         
     | 
| 227 | 
         
            +
             
     | 
| 228 | 
         
            +
                    return optim
         
     | 
| 229 | 
         
            +
             
     | 
| 230 | 
         
            +
                # SAVE UTILITIES ----------
         
     | 
| 231 | 
         
            +
                def save_info(self, info, suffix=""):
         
     | 
| 232 | 
         
            +
                    full_path = f"{self.config.checkpoint_path}/{self.config.experiment_id}/info{suffix}.json"
         
     | 
| 233 | 
         
            +
                    create_folder_if_necessary(full_path)
         
     | 
| 234 | 
         
            +
                    if self.is_main_node:
         
     | 
| 235 | 
         
            +
                        safe_save(vars(self.info), full_path)
         
     | 
| 236 | 
         
            +
             
     | 
| 237 | 
         
            +
                def save_model(self, model, model_id=None, full_path=None, is_fsdp=False):
         
     | 
| 238 | 
         
            +
                    if model_id is not None and full_path is None:
         
     | 
| 239 | 
         
            +
                        full_path = f"{self.config.checkpoint_path}/{self.config.experiment_id}/{model_id}.{self.config.checkpoint_extension}"
         
     | 
| 240 | 
         
            +
                    elif full_path is None and model_id is None:
         
     | 
| 241 | 
         
            +
                        raise ValueError(
         
     | 
| 242 | 
         
            +
                            "This method expects either 'model_id' or 'full_path' to be defined"
         
     | 
| 243 | 
         
            +
                        )
         
     | 
| 244 | 
         
            +
                    create_folder_if_necessary(full_path)
         
     | 
| 245 | 
         
            +
                    if is_fsdp:
         
     | 
| 246 | 
         
            +
                        with FSDP.summon_full_params(model):
         
     | 
| 247 | 
         
            +
                            pass
         
     | 
| 248 | 
         
            +
                        with FSDP.state_dict_type(
         
     | 
| 249 | 
         
            +
                            model, StateDictType.FULL_STATE_DICT, self.fsdp_fullstate_save_policy
         
     | 
| 250 | 
         
            +
                        ):
         
     | 
| 251 | 
         
            +
                            checkpoint = model.state_dict()
         
     | 
| 252 | 
         
            +
                        if self.is_main_node:
         
     | 
| 253 | 
         
            +
                            safe_save(checkpoint, full_path)
         
     | 
| 254 | 
         
            +
                        del checkpoint
         
     | 
| 255 | 
         
            +
                    else:
         
     | 
| 256 | 
         
            +
                        if self.is_main_node:
         
     | 
| 257 | 
         
            +
                            checkpoint = model.state_dict()
         
     | 
| 258 | 
         
            +
                            safe_save(checkpoint, full_path)
         
     | 
| 259 | 
         
            +
                            del checkpoint
         
     | 
| 260 | 
         
            +
             
     | 
| 261 | 
         
            +
                def save_optimizer(self, optim, optim_id=None, full_path=None, fsdp_model=None):
         
     | 
| 262 | 
         
            +
                    if optim_id is not None and full_path is None:
         
     | 
| 263 | 
         
            +
                        full_path = f"{self.config.checkpoint_path}/{self.config.experiment_id}/{optim_id}.pt"
         
     | 
| 264 | 
         
            +
                    elif full_path is None and optim_id is None:
         
     | 
| 265 | 
         
            +
                        raise ValueError(
         
     | 
| 266 | 
         
            +
                            "This method expects either 'optim_id' or 'full_path' to be defined"
         
     | 
| 267 | 
         
            +
                        )
         
     | 
| 268 | 
         
            +
                    create_folder_if_necessary(full_path)
         
     | 
| 269 | 
         
            +
                    if fsdp_model is not None:
         
     | 
| 270 | 
         
            +
                        optim_statedict = FSDP.full_optim_state_dict(fsdp_model, optim)
         
     | 
| 271 | 
         
            +
                        if self.is_main_node:
         
     | 
| 272 | 
         
            +
                            safe_save(optim_statedict, full_path)
         
     | 
| 273 | 
         
            +
                        del optim_statedict
         
     | 
| 274 | 
         
            +
                    else:
         
     | 
| 275 | 
         
            +
                        if self.is_main_node:
         
     | 
| 276 | 
         
            +
                            checkpoint = optim.state_dict()
         
     | 
| 277 | 
         
            +
                            safe_save(checkpoint, full_path)
         
     | 
| 278 | 
         
            +
                            del checkpoint
         
     | 
| 279 | 
         
            +
                # -----
         
     | 
| 280 | 
         
            +
             
     | 
| 281 | 
         
            +
                def __init__(self, config_file_path=None, config_dict=None, device="cpu", training=True):
         
     | 
| 282 | 
         
            +
                    # Temporary setup, will be overriden by setup_ddp if required
         
     | 
| 283 | 
         
            +
                    self.device = device
         
     | 
| 284 | 
         
            +
                    self.process_id = 0
         
     | 
| 285 | 
         
            +
                    self.is_main_node = True
         
     | 
| 286 | 
         
            +
                    self.world_size = 1
         
     | 
| 287 | 
         
            +
                    # ----
         
     | 
| 288 | 
         
            +
             
     | 
| 289 | 
         
            +
                    self.config: self.Config = self.setup_config(config_file_path, config_dict, training)
         
     | 
| 290 | 
         
            +
                    self.info: self.Info = self.setup_info()
         
     | 
| 291 | 
         
            +
             
     | 
| 292 | 
         
            +
                def __call__(self, single_gpu=False):
         
     | 
| 293 | 
         
            +
                    self.setup_ddp(self.config.experiment_id, single_gpu=single_gpu)  # this will change the device to the CUDA rank
         
     | 
| 294 | 
         
            +
                    self.setup_wandb()
         
     | 
| 295 | 
         
            +
                    if self.config.allow_tf32:
         
     | 
| 296 | 
         
            +
                        torch.backends.cuda.matmul.allow_tf32 = True
         
     | 
| 297 | 
         
            +
                        torch.backends.cudnn.allow_tf32 = True
         
     | 
| 298 | 
         
            +
             
     | 
| 299 | 
         
            +
                    if self.is_main_node:
         
     | 
| 300 | 
         
            +
                        print()
         
     | 
| 301 | 
         
            +
                        print("**STARTIG JOB WITH CONFIG:**")
         
     | 
| 302 | 
         
            +
                        print(yaml.dump(self.config.to_dict(), default_flow_style=False))
         
     | 
| 303 | 
         
            +
                        print("------------------------------------")
         
     | 
| 304 | 
         
            +
                        print()
         
     | 
| 305 | 
         
            +
                        print("**INFO:**")
         
     | 
| 306 | 
         
            +
                        print(yaml.dump(vars(self.info), default_flow_style=False))
         
     | 
| 307 | 
         
            +
                        print("------------------------------------")
         
     | 
| 308 | 
         
            +
                        print()
         
     | 
| 309 | 
         
            +
             
     | 
| 310 | 
         
            +
                    # SETUP STUFF
         
     | 
| 311 | 
         
            +
                    extras = self.setup_extras_pre()
         
     | 
| 312 | 
         
            +
                    assert extras is not None, "setup_extras_pre() must return a DTO"
         
     | 
| 313 | 
         
            +
             
     | 
| 314 | 
         
            +
                    data = self.setup_data(extras)
         
     | 
| 315 | 
         
            +
                    assert data is not None, "setup_data() must return a DTO"
         
     | 
| 316 | 
         
            +
                    if self.is_main_node:
         
     | 
| 317 | 
         
            +
                        print("**DATA:**")
         
     | 
| 318 | 
         
            +
                        print(yaml.dump({k:type(v).__name__ for k, v in data.to_dict().items()}, default_flow_style=False))
         
     | 
| 319 | 
         
            +
                        print("------------------------------------")
         
     | 
| 320 | 
         
            +
                        print()
         
     | 
| 321 | 
         
            +
             
     | 
| 322 | 
         
            +
                    models = self.setup_models(extras)
         
     | 
| 323 | 
         
            +
                    assert models is not None, "setup_models() must return a DTO"
         
     | 
| 324 | 
         
            +
                    if self.is_main_node:
         
     | 
| 325 | 
         
            +
                        print("**MODELS:**")
         
     | 
| 326 | 
         
            +
                        print(yaml.dump({
         
     | 
| 327 | 
         
            +
                            k:f"{type(v).__name__} - {f'trainable params {sum(p.numel() for p in v.parameters() if p.requires_grad)}' if isinstance(v, nn.Module) else 'Not a nn.Module'}" for k, v in models.to_dict().items()
         
     | 
| 328 | 
         
            +
                        }, default_flow_style=False))
         
     | 
| 329 | 
         
            +
                        print("------------------------------------")
         
     | 
| 330 | 
         
            +
                        print()
         
     | 
| 331 | 
         
            +
             
     | 
| 332 | 
         
            +
                    optimizers = self.setup_optimizers(extras, models)
         
     | 
| 333 | 
         
            +
                    assert optimizers is not None, "setup_optimizers() must return a DTO"
         
     | 
| 334 | 
         
            +
                    if self.is_main_node:
         
     | 
| 335 | 
         
            +
                        print("**OPTIMIZERS:**")
         
     | 
| 336 | 
         
            +
                        print(yaml.dump({k:type(v).__name__ for k, v in optimizers.to_dict().items()}, default_flow_style=False))
         
     | 
| 337 | 
         
            +
                        print("------------------------------------")
         
     | 
| 338 | 
         
            +
                        print()
         
     | 
| 339 | 
         
            +
             
     | 
| 340 | 
         
            +
                    schedulers = self.setup_schedulers(extras, models, optimizers)
         
     | 
| 341 | 
         
            +
                    assert schedulers is not None, "setup_schedulers() must return a DTO"
         
     | 
| 342 | 
         
            +
                    if self.is_main_node:
         
     | 
| 343 | 
         
            +
                        print("**SCHEDULERS:**")
         
     | 
| 344 | 
         
            +
                        print(yaml.dump({k:type(v).__name__ for k, v in schedulers.to_dict().items()}, default_flow_style=False))
         
     | 
| 345 | 
         
            +
                        print("------------------------------------")
         
     | 
| 346 | 
         
            +
                        print()
         
     | 
| 347 | 
         
            +
             
     | 
| 348 | 
         
            +
                    post_extras =self.setup_extras_post(extras, models, optimizers, schedulers)
         
     | 
| 349 | 
         
            +
                    assert post_extras is not None, "setup_extras_post() must return a DTO"
         
     | 
| 350 | 
         
            +
                    extras = self.Extras.from_dict({ **extras.to_dict(),**post_extras.to_dict() })
         
     | 
| 351 | 
         
            +
                    if self.is_main_node:
         
     | 
| 352 | 
         
            +
                        print("**EXTRAS:**")
         
     | 
| 353 | 
         
            +
                        print(yaml.dump({k:f"{v}" for k, v in extras.to_dict().items()}, default_flow_style=False))
         
     | 
| 354 | 
         
            +
                        print("------------------------------------")
         
     | 
| 355 | 
         
            +
                        print()
         
     | 
| 356 | 
         
            +
                    # -------
         
     | 
| 357 | 
         
            +
             
     | 
| 358 | 
         
            +
                    # TRAIN
         
     | 
| 359 | 
         
            +
                    if self.is_main_node:
         
     | 
| 360 | 
         
            +
                        print("**TRAINING STARTING...**")
         
     | 
| 361 | 
         
            +
                    self.train(data, extras, models, optimizers, schedulers)
         
     | 
| 362 | 
         
            +
             
     | 
| 363 | 
         
            +
                    if single_gpu is False:
         
     | 
| 364 | 
         
            +
                        barrier()
         
     | 
| 365 | 
         
            +
                        destroy_process_group()
         
     | 
| 366 | 
         
            +
                    if self.is_main_node:
         
     | 
| 367 | 
         
            +
                        print()
         
     | 
| 368 | 
         
            +
                        print("------------------------------------")
         
     | 
| 369 | 
         
            +
                        print()
         
     | 
| 370 | 
         
            +
                        print("**TRAINING COMPLETE**")
         
     | 
| 371 | 
         
            +
                        if self.config.wandb_project is not None:
         
     | 
| 372 | 
         
            +
                            wandb.alert(title=f"Training {self.info.wandb_run_id} finished", text=f"Training {self.info.wandb_run_id} finished")
         
     | 
    	
        core/data/__init__.py
    ADDED
    
    | 
         @@ -0,0 +1,69 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            import json
         
     | 
| 2 | 
         
            +
            import subprocess
         
     | 
| 3 | 
         
            +
            import yaml
         
     | 
| 4 | 
         
            +
            import os
         
     | 
| 5 | 
         
            +
            from .bucketeer import Bucketeer
         
     | 
| 6 | 
         
            +
             
     | 
| 7 | 
         
            +
            class MultiFilter():
         
     | 
| 8 | 
         
            +
                def __init__(self, rules, default=False):
         
     | 
| 9 | 
         
            +
                    self.rules = rules
         
     | 
| 10 | 
         
            +
                    self.default = default
         
     | 
| 11 | 
         
            +
             
     | 
| 12 | 
         
            +
                def __call__(self, x):
         
     | 
| 13 | 
         
            +
                    try:
         
     | 
| 14 | 
         
            +
                        x_json = x['json']
         
     | 
| 15 | 
         
            +
                        if isinstance(x_json, bytes):
         
     | 
| 16 | 
         
            +
                            x_json = json.loads(x_json) 
         
     | 
| 17 | 
         
            +
                        validations = []
         
     | 
| 18 | 
         
            +
                        for k, r in self.rules.items():
         
     | 
| 19 | 
         
            +
                            if isinstance(k, tuple):
         
     | 
| 20 | 
         
            +
                                v = r(*[x_json[kv] for kv in k])
         
     | 
| 21 | 
         
            +
                            else:
         
     | 
| 22 | 
         
            +
                                v = r(x_json[k])
         
     | 
| 23 | 
         
            +
                            validations.append(v)
         
     | 
| 24 | 
         
            +
                        return all(validations)
         
     | 
| 25 | 
         
            +
                    except Exception:
         
     | 
| 26 | 
         
            +
                        return False
         
     | 
| 27 | 
         
            +
             
     | 
| 28 | 
         
            +
            class MultiGetter():
         
     | 
| 29 | 
         
            +
                def __init__(self, rules):
         
     | 
| 30 | 
         
            +
                    self.rules = rules
         
     | 
| 31 | 
         
            +
             
     | 
| 32 | 
         
            +
                def __call__(self, x_json):
         
     | 
| 33 | 
         
            +
                    if isinstance(x_json, bytes):
         
     | 
| 34 | 
         
            +
                        x_json = json.loads(x_json) 
         
     | 
| 35 | 
         
            +
                    outputs = []
         
     | 
| 36 | 
         
            +
                    for k, r in self.rules.items():
         
     | 
| 37 | 
         
            +
                        if isinstance(k, tuple):
         
     | 
| 38 | 
         
            +
                            v = r(*[x_json[kv] for kv in k])
         
     | 
| 39 | 
         
            +
                        else:
         
     | 
| 40 | 
         
            +
                            v = r(x_json[k])
         
     | 
| 41 | 
         
            +
                        outputs.append(v)
         
     | 
| 42 | 
         
            +
                    if len(outputs) == 1:
         
     | 
| 43 | 
         
            +
                        outputs = outputs[0]
         
     | 
| 44 | 
         
            +
                    return outputs
         
     | 
| 45 | 
         
            +
             
     | 
| 46 | 
         
            +
            def setup_webdataset_path(paths, cache_path=None):
         
     | 
| 47 | 
         
            +
                if cache_path is None or not os.path.exists(cache_path):
         
     | 
| 48 | 
         
            +
                    tar_paths = []
         
     | 
| 49 | 
         
            +
                    if isinstance(paths, str):
         
     | 
| 50 | 
         
            +
                        paths = [paths]
         
     | 
| 51 | 
         
            +
                    for path in paths:
         
     | 
| 52 | 
         
            +
                        if path.strip().endswith(".tar"):
         
     | 
| 53 | 
         
            +
                            # Avoid looking up s3 if we already have a tar file
         
     | 
| 54 | 
         
            +
                            tar_paths.append(path)
         
     | 
| 55 | 
         
            +
                            continue
         
     | 
| 56 | 
         
            +
                        bucket = "/".join(path.split("/")[:3])
         
     | 
| 57 | 
         
            +
                        result = subprocess.run([f"aws s3 ls {path} --recursive | awk '{{print $4}}'"], stdout=subprocess.PIPE, shell=True, check=True)
         
     | 
| 58 | 
         
            +
                        files = result.stdout.decode('utf-8').split()
         
     | 
| 59 | 
         
            +
                        files = [f"{bucket}/{f}" for f in files if f.endswith(".tar")]
         
     | 
| 60 | 
         
            +
                        tar_paths += files
         
     | 
| 61 | 
         
            +
             
     | 
| 62 | 
         
            +
                    with open(cache_path, 'w', encoding='utf-8') as outfile:
         
     | 
| 63 | 
         
            +
                        yaml.dump(tar_paths, outfile, default_flow_style=False)
         
     | 
| 64 | 
         
            +
                else:
         
     | 
| 65 | 
         
            +
                    with open(cache_path, 'r', encoding='utf-8') as file:
         
     | 
| 66 | 
         
            +
                        tar_paths = yaml.safe_load(file)
         
     | 
| 67 | 
         
            +
             
     | 
| 68 | 
         
            +
                tar_paths_str = ",".join([f"{p}" for p in tar_paths])
         
     | 
| 69 | 
         
            +
                return f"pipe:aws s3 cp {{ {tar_paths_str} }} -"
         
     | 
    	
        core/data/bucketeer.py
    ADDED
    
    | 
         @@ -0,0 +1,88 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            import torch
         
     | 
| 2 | 
         
            +
            import torchvision
         
     | 
| 3 | 
         
            +
            import numpy as np
         
     | 
| 4 | 
         
            +
            from torchtools.transforms import SmartCrop
         
     | 
| 5 | 
         
            +
            import math
         
     | 
| 6 | 
         
            +
             
     | 
| 7 | 
         
            +
            class Bucketeer():
         
     | 
| 8 | 
         
            +
                def __init__(self, dataloader, density=256*256, factor=8, ratios=[1/1, 1/2, 3/4, 3/5, 4/5, 6/9, 9/16], reverse_list=True, randomize_p=0.3, randomize_q=0.2, crop_mode='random', p_random_ratio=0.0, interpolate_nearest=False):
         
     | 
| 9 | 
         
            +
                    assert crop_mode in ['center', 'random', 'smart']
         
     | 
| 10 | 
         
            +
                    self.crop_mode = crop_mode
         
     | 
| 11 | 
         
            +
                    self.ratios = ratios
         
     | 
| 12 | 
         
            +
                    if reverse_list:
         
     | 
| 13 | 
         
            +
                        for r in list(ratios):
         
     | 
| 14 | 
         
            +
                            if 1/r not in self.ratios:
         
     | 
| 15 | 
         
            +
                                self.ratios.append(1/r)
         
     | 
| 16 | 
         
            +
                    self.sizes = {}
         
     | 
| 17 | 
         
            +
                    for dd in density:
         
     | 
| 18 | 
         
            +
                      self.sizes[dd]= [(int(((dd/r)**0.5//factor)*factor), int(((dd*r)**0.5//factor)*factor)) for r in ratios]   
         
     | 
| 19 | 
         
            +
                
         
     | 
| 20 | 
         
            +
                    self.batch_size = dataloader.batch_size
         
     | 
| 21 | 
         
            +
                    self.iterator = iter(dataloader)
         
     | 
| 22 | 
         
            +
                    all_sizes =  []
         
     | 
| 23 | 
         
            +
                    for k, vs in self.sizes.items():
         
     | 
| 24 | 
         
            +
                      all_sizes += vs
         
     | 
| 25 | 
         
            +
                    self.buckets = {s: [] for s in all_sizes}
         
     | 
| 26 | 
         
            +
                    self.smartcrop = SmartCrop(int(density**0.5), randomize_p, randomize_q) if self.crop_mode=='smart' else None
         
     | 
| 27 | 
         
            +
                    self.p_random_ratio = p_random_ratio
         
     | 
| 28 | 
         
            +
                    self.interpolate_nearest = interpolate_nearest
         
     | 
| 29 | 
         
            +
             
     | 
| 30 | 
         
            +
                def get_available_batch(self):
         
     | 
| 31 | 
         
            +
                    for b in self.buckets:
         
     | 
| 32 | 
         
            +
                        if len(self.buckets[b]) >= self.batch_size:
         
     | 
| 33 | 
         
            +
                            batch = self.buckets[b][:self.batch_size]
         
     | 
| 34 | 
         
            +
                            self.buckets[b] = self.buckets[b][self.batch_size:]
         
     | 
| 35 | 
         
            +
                            return batch
         
     | 
| 36 | 
         
            +
                    return None
         
     | 
| 37 | 
         
            +
             
     | 
| 38 | 
         
            +
                def get_closest_size(self, x):
         
     | 
| 39 | 
         
            +
                    w, h = x.size(-1), x.size(-2)
         
     | 
| 40 | 
         
            +
             
     | 
| 41 | 
         
            +
                        
         
     | 
| 42 | 
         
            +
                    best_size_idx = np.argmin([abs(w/h-r) for r in self.ratios])
         
     | 
| 43 | 
         
            +
                    find_dict = {dd : abs(w*h - self.sizes[dd][best_size_idx][0]*self.sizes[dd][best_size_idx][1]) for dd, vv in self.sizes.items()}
         
     | 
| 44 | 
         
            +
                    min_ = find_dict[list(find_dict.keys())[0]]
         
     | 
| 45 | 
         
            +
                    find_size = self.sizes[list(find_dict.keys())[0]][best_size_idx]
         
     | 
| 46 | 
         
            +
                    for dd, val in find_dict.items():
         
     | 
| 47 | 
         
            +
                      if val < min_:
         
     | 
| 48 | 
         
            +
                        min_ = val
         
     | 
| 49 | 
         
            +
                        find_size = self.sizes[dd][best_size_idx]  
         
     | 
| 50 | 
         
            +
                        
         
     | 
| 51 | 
         
            +
                    return find_size
         
     | 
| 52 | 
         
            +
             
     | 
| 53 | 
         
            +
                def get_resize_size(self, orig_size, tgt_size):
         
     | 
| 54 | 
         
            +
                    if (tgt_size[1]/tgt_size[0] - 1) * (orig_size[1]/orig_size[0] - 1) >= 0:
         
     | 
| 55 | 
         
            +
                        alt_min = int(math.ceil(max(tgt_size)*min(orig_size)/max(orig_size)))
         
     | 
| 56 | 
         
            +
                        resize_size = max(alt_min, min(tgt_size))
         
     | 
| 57 | 
         
            +
                    else:
         
     | 
| 58 | 
         
            +
                        alt_max = int(math.ceil(min(tgt_size)*max(orig_size)/min(orig_size)))
         
     | 
| 59 | 
         
            +
                        resize_size = max(alt_max, max(tgt_size))
         
     | 
| 60 | 
         
            +
             
     | 
| 61 | 
         
            +
                    return resize_size
         
     | 
| 62 | 
         
            +
             
     | 
| 63 | 
         
            +
                def __next__(self):
         
     | 
| 64 | 
         
            +
                    batch = self.get_available_batch()
         
     | 
| 65 | 
         
            +
                    while batch is None:
         
     | 
| 66 | 
         
            +
                        elements = next(self.iterator)
         
     | 
| 67 | 
         
            +
                        for dct in elements:
         
     | 
| 68 | 
         
            +
                            img = dct['images']
         
     | 
| 69 | 
         
            +
                            size = self.get_closest_size(img)
         
     | 
| 70 | 
         
            +
                            resize_size = self.get_resize_size(img.shape[-2:], size)
         
     | 
| 71 | 
         
            +
                          
         
     | 
| 72 | 
         
            +
                            if self.interpolate_nearest:
         
     | 
| 73 | 
         
            +
                                img = torchvision.transforms.functional.resize(img, resize_size, interpolation=torchvision.transforms.InterpolationMode.NEAREST)
         
     | 
| 74 | 
         
            +
                            else:
         
     | 
| 75 | 
         
            +
                                img = torchvision.transforms.functional.resize(img, resize_size, interpolation=torchvision.transforms.InterpolationMode.BILINEAR, antialias=True)
         
     | 
| 76 | 
         
            +
                            if self.crop_mode == 'center':
         
     | 
| 77 | 
         
            +
                                img = torchvision.transforms.functional.center_crop(img, size)
         
     | 
| 78 | 
         
            +
                            elif self.crop_mode == 'random':
         
     | 
| 79 | 
         
            +
                                img = torchvision.transforms.RandomCrop(size)(img)
         
     | 
| 80 | 
         
            +
                            elif self.crop_mode == 'smart':
         
     | 
| 81 | 
         
            +
                                self.smartcrop.output_size = size
         
     | 
| 82 | 
         
            +
                                img = self.smartcrop(img)
         
     | 
| 83 | 
         
            +
                            
         
     | 
| 84 | 
         
            +
                            self.buckets[size].append({**{'images': img}, **{k:dct[k] for k in dct if k != 'images'}})
         
     | 
| 85 | 
         
            +
                        batch = self.get_available_batch()
         
     | 
| 86 | 
         
            +
             
     | 
| 87 | 
         
            +
                    out = {k:[batch[i][k] for i in range(len(batch))] for k in batch[0]}
         
     | 
| 88 | 
         
            +
                    return {k: torch.stack(o, dim=0) if isinstance(o[0], torch.Tensor) else o for k, o in out.items()}
         
     | 
    	
        core/data/bucketeer_deg.py
    ADDED
    
    | 
         @@ -0,0 +1,91 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            import torch
         
     | 
| 2 | 
         
            +
            import torchvision
         
     | 
| 3 | 
         
            +
            import numpy as np
         
     | 
| 4 | 
         
            +
            from torchtools.transforms import SmartCrop
         
     | 
| 5 | 
         
            +
            import math
         
     | 
| 6 | 
         
            +
             
     | 
| 7 | 
         
            +
            class Bucketeer():
         
     | 
| 8 | 
         
            +
                def __init__(self, dataloader, density=256*256, factor=8, ratios=[1/1, 1/2, 3/4, 3/5, 4/5, 6/9, 9/16], reverse_list=True, randomize_p=0.3, randomize_q=0.2, crop_mode='random', p_random_ratio=0.0, interpolate_nearest=False):
         
     | 
| 9 | 
         
            +
                    assert crop_mode in ['center', 'random', 'smart']
         
     | 
| 10 | 
         
            +
                    self.crop_mode = crop_mode
         
     | 
| 11 | 
         
            +
                    self.ratios = ratios
         
     | 
| 12 | 
         
            +
                    if reverse_list:
         
     | 
| 13 | 
         
            +
                        for r in list(ratios):
         
     | 
| 14 | 
         
            +
                            if 1/r not in self.ratios:
         
     | 
| 15 | 
         
            +
                                self.ratios.append(1/r)
         
     | 
| 16 | 
         
            +
                    self.sizes = {}
         
     | 
| 17 | 
         
            +
                    for dd in density:
         
     | 
| 18 | 
         
            +
                      self.sizes[dd]= [(int(((dd/r)**0.5//factor)*factor), int(((dd*r)**0.5//factor)*factor)) for r in ratios]   
         
     | 
| 19 | 
         
            +
                    print('in line 17 buckteer', self.sizes)     
         
     | 
| 20 | 
         
            +
                    self.batch_size = dataloader.batch_size
         
     | 
| 21 | 
         
            +
                    self.iterator = iter(dataloader)
         
     | 
| 22 | 
         
            +
                    all_sizes =  []
         
     | 
| 23 | 
         
            +
                    for k, vs in self.sizes.items():
         
     | 
| 24 | 
         
            +
                      all_sizes += vs
         
     | 
| 25 | 
         
            +
                    self.buckets = {s: [] for s in all_sizes}
         
     | 
| 26 | 
         
            +
                    self.smartcrop = SmartCrop(int(density**0.5), randomize_p, randomize_q) if self.crop_mode=='smart' else None
         
     | 
| 27 | 
         
            +
                    self.p_random_ratio = p_random_ratio
         
     | 
| 28 | 
         
            +
                    self.interpolate_nearest = interpolate_nearest
         
     | 
| 29 | 
         
            +
             
     | 
| 30 | 
         
            +
                def get_available_batch(self):
         
     | 
| 31 | 
         
            +
                    for b in self.buckets:
         
     | 
| 32 | 
         
            +
                        if len(self.buckets[b]) >= self.batch_size:
         
     | 
| 33 | 
         
            +
                            batch = self.buckets[b][:self.batch_size]
         
     | 
| 34 | 
         
            +
                            self.buckets[b] = self.buckets[b][self.batch_size:]
         
     | 
| 35 | 
         
            +
                            return batch
         
     | 
| 36 | 
         
            +
                    return None
         
     | 
| 37 | 
         
            +
             
     | 
| 38 | 
         
            +
                def get_closest_size(self, x):
         
     | 
| 39 | 
         
            +
                    w, h = x.size(-1), x.size(-2)
         
     | 
| 40 | 
         
            +
                    #if self.p_random_ratio > 0 and np.random.rand() < self.p_random_ratio:
         
     | 
| 41 | 
         
            +
                    #    best_size_idx = np.random.randint(len(self.ratios))
         
     | 
| 42 | 
         
            +
                        #print('in line 41 get closes size', best_size_idx, x.shape, self.p_random_ratio)
         
     | 
| 43 | 
         
            +
                    #else:
         
     | 
| 44 | 
         
            +
                        
         
     | 
| 45 | 
         
            +
                    best_size_idx = np.argmin([abs(w/h-r) for r in self.ratios])
         
     | 
| 46 | 
         
            +
                    find_dict = {dd : abs(w*h - self.sizes[dd][best_size_idx][0]*self.sizes[dd][best_size_idx][1]) for dd, vv in self.sizes.items()}
         
     | 
| 47 | 
         
            +
                    min_ = find_dict[list(find_dict.keys())[0]]
         
     | 
| 48 | 
         
            +
                    find_size = self.sizes[list(find_dict.keys())[0]][best_size_idx]
         
     | 
| 49 | 
         
            +
                    for dd, val in find_dict.items():
         
     | 
| 50 | 
         
            +
                      if val < min_:
         
     | 
| 51 | 
         
            +
                        min_ = val
         
     | 
| 52 | 
         
            +
                        find_size = self.sizes[dd][best_size_idx]  
         
     | 
| 53 | 
         
            +
                        
         
     | 
| 54 | 
         
            +
                    return find_size
         
     | 
| 55 | 
         
            +
             
     | 
| 56 | 
         
            +
                def get_resize_size(self, orig_size, tgt_size):
         
     | 
| 57 | 
         
            +
                    if (tgt_size[1]/tgt_size[0] - 1) * (orig_size[1]/orig_size[0] - 1) >= 0:
         
     | 
| 58 | 
         
            +
                        alt_min = int(math.ceil(max(tgt_size)*min(orig_size)/max(orig_size)))
         
     | 
| 59 | 
         
            +
                        resize_size = max(alt_min, min(tgt_size))
         
     | 
| 60 | 
         
            +
                    else:
         
     | 
| 61 | 
         
            +
                        alt_max = int(math.ceil(min(tgt_size)*max(orig_size)/min(orig_size)))
         
     | 
| 62 | 
         
            +
                        resize_size = max(alt_max, max(tgt_size))
         
     | 
| 63 | 
         
            +
                    #print('in line 50', orig_size, tgt_size, resize_size)
         
     | 
| 64 | 
         
            +
                    return resize_size
         
     | 
| 65 | 
         
            +
             
     | 
| 66 | 
         
            +
                def __next__(self):
         
     | 
| 67 | 
         
            +
                    batch = self.get_available_batch()
         
     | 
| 68 | 
         
            +
                    while batch is None:
         
     | 
| 69 | 
         
            +
                        elements = next(self.iterator)
         
     | 
| 70 | 
         
            +
                        for dct in elements:
         
     | 
| 71 | 
         
            +
                            img = dct['images']
         
     | 
| 72 | 
         
            +
                            size = self.get_closest_size(img)
         
     | 
| 73 | 
         
            +
                            resize_size = self.get_resize_size(img.shape[-2:], size)
         
     | 
| 74 | 
         
            +
                            #print('in line 74', img.size(), resize_size)
         
     | 
| 75 | 
         
            +
                            if self.interpolate_nearest:
         
     | 
| 76 | 
         
            +
                                img = torchvision.transforms.functional.resize(img, resize_size, interpolation=torchvision.transforms.InterpolationMode.NEAREST)
         
     | 
| 77 | 
         
            +
                            else:
         
     | 
| 78 | 
         
            +
                                img = torchvision.transforms.functional.resize(img, resize_size, interpolation=torchvision.transforms.InterpolationMode.BILINEAR, antialias=True)
         
     | 
| 79 | 
         
            +
                            if self.crop_mode == 'center':
         
     | 
| 80 | 
         
            +
                                img = torchvision.transforms.functional.center_crop(img, size)
         
     | 
| 81 | 
         
            +
                            elif self.crop_mode == 'random':
         
     | 
| 82 | 
         
            +
                                img = torchvision.transforms.RandomCrop(size)(img)
         
     | 
| 83 | 
         
            +
                            elif self.crop_mode == 'smart':
         
     | 
| 84 | 
         
            +
                                self.smartcrop.output_size = size
         
     | 
| 85 | 
         
            +
                                img = self.smartcrop(img)
         
     | 
| 86 | 
         
            +
                            print('in line 86 bucketeer', type(img), img.shape, torch.max(img), torch.min(img))
         
     | 
| 87 | 
         
            +
                            self.buckets[size].append({**{'images': img}, **{k:dct[k] for k in dct if k != 'images'}})
         
     | 
| 88 | 
         
            +
                        batch = self.get_available_batch()
         
     | 
| 89 | 
         
            +
             
     | 
| 90 | 
         
            +
                    out = {k:[batch[i][k] for i in range(len(batch))] for k in batch[0]}
         
     | 
| 91 | 
         
            +
                    return {k: torch.stack(o, dim=0) if isinstance(o[0], torch.Tensor) else o for k, o in out.items()}
         
     | 
    	
        core/data/deg_kair_utils/utils_alignfaces.py
    ADDED
    
    | 
         @@ -0,0 +1,263 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            # -*- coding: utf-8 -*-
         
     | 
| 2 | 
         
            +
            """
         
     | 
| 3 | 
         
            +
            Created on Mon Apr 24 15:43:29 2017
         
     | 
| 4 | 
         
            +
            @author: zhaoy
         
     | 
| 5 | 
         
            +
            """
         
     | 
| 6 | 
         
            +
            import cv2
         
     | 
| 7 | 
         
            +
            import numpy as np
         
     | 
| 8 | 
         
            +
            from skimage import transform as trans
         
     | 
| 9 | 
         
            +
             
     | 
| 10 | 
         
            +
            # reference facial points, a list of coordinates (x,y)
         
     | 
| 11 | 
         
            +
            REFERENCE_FACIAL_POINTS = [
         
     | 
| 12 | 
         
            +
                [30.29459953, 51.69630051],
         
     | 
| 13 | 
         
            +
                [65.53179932, 51.50139999],
         
     | 
| 14 | 
         
            +
                [48.02519989, 71.73660278],
         
     | 
| 15 | 
         
            +
                [33.54930115, 92.3655014],
         
     | 
| 16 | 
         
            +
                [62.72990036, 92.20410156]
         
     | 
| 17 | 
         
            +
            ]
         
     | 
| 18 | 
         
            +
             
     | 
| 19 | 
         
            +
            DEFAULT_CROP_SIZE = (96, 112)
         
     | 
| 20 | 
         
            +
             
     | 
| 21 | 
         
            +
             
     | 
| 22 | 
         
            +
            def _umeyama(src, dst, estimate_scale=True, scale=1.0):
         
     | 
| 23 | 
         
            +
                """Estimate N-D similarity transformation with or without scaling.
         
     | 
| 24 | 
         
            +
                Parameters
         
     | 
| 25 | 
         
            +
                ----------
         
     | 
| 26 | 
         
            +
                src : (M, N) array
         
     | 
| 27 | 
         
            +
                    Source coordinates.
         
     | 
| 28 | 
         
            +
                dst : (M, N) array
         
     | 
| 29 | 
         
            +
                    Destination coordinates.
         
     | 
| 30 | 
         
            +
                estimate_scale : bool
         
     | 
| 31 | 
         
            +
                    Whether to estimate scaling factor.
         
     | 
| 32 | 
         
            +
                Returns
         
     | 
| 33 | 
         
            +
                -------
         
     | 
| 34 | 
         
            +
                T : (N + 1, N + 1)
         
     | 
| 35 | 
         
            +
                    The homogeneous similarity transformation matrix. The matrix contains
         
     | 
| 36 | 
         
            +
                    NaN values only if the problem is not well-conditioned.
         
     | 
| 37 | 
         
            +
                References
         
     | 
| 38 | 
         
            +
                ----------
         
     | 
| 39 | 
         
            +
                .. [1] "Least-squares estimation of transformation parameters between two
         
     | 
| 40 | 
         
            +
                        point patterns", Shinji Umeyama, PAMI 1991, :DOI:`10.1109/34.88573`
         
     | 
| 41 | 
         
            +
                """
         
     | 
| 42 | 
         
            +
             
     | 
| 43 | 
         
            +
                num = src.shape[0]
         
     | 
| 44 | 
         
            +
                dim = src.shape[1]
         
     | 
| 45 | 
         
            +
             
     | 
| 46 | 
         
            +
                # Compute mean of src and dst.
         
     | 
| 47 | 
         
            +
                src_mean = src.mean(axis=0)
         
     | 
| 48 | 
         
            +
                dst_mean = dst.mean(axis=0)
         
     | 
| 49 | 
         
            +
             
     | 
| 50 | 
         
            +
                # Subtract mean from src and dst.
         
     | 
| 51 | 
         
            +
                src_demean = src - src_mean
         
     | 
| 52 | 
         
            +
                dst_demean = dst - dst_mean
         
     | 
| 53 | 
         
            +
             
     | 
| 54 | 
         
            +
                # Eq. (38).
         
     | 
| 55 | 
         
            +
                A = dst_demean.T @ src_demean / num
         
     | 
| 56 | 
         
            +
             
     | 
| 57 | 
         
            +
                # Eq. (39).
         
     | 
| 58 | 
         
            +
                d = np.ones((dim,), dtype=np.double)
         
     | 
| 59 | 
         
            +
                if np.linalg.det(A) < 0:
         
     | 
| 60 | 
         
            +
                    d[dim - 1] = -1
         
     | 
| 61 | 
         
            +
             
     | 
| 62 | 
         
            +
                T = np.eye(dim + 1, dtype=np.double)
         
     | 
| 63 | 
         
            +
             
     | 
| 64 | 
         
            +
                U, S, V = np.linalg.svd(A)
         
     | 
| 65 | 
         
            +
             
     | 
| 66 | 
         
            +
                # Eq. (40) and (43).
         
     | 
| 67 | 
         
            +
                rank = np.linalg.matrix_rank(A)
         
     | 
| 68 | 
         
            +
                if rank == 0:
         
     | 
| 69 | 
         
            +
                    return np.nan * T
         
     | 
| 70 | 
         
            +
                elif rank == dim - 1:
         
     | 
| 71 | 
         
            +
                    if np.linalg.det(U) * np.linalg.det(V) > 0:
         
     | 
| 72 | 
         
            +
                        T[:dim, :dim] = U @ V
         
     | 
| 73 | 
         
            +
                    else:
         
     | 
| 74 | 
         
            +
                        s = d[dim - 1]
         
     | 
| 75 | 
         
            +
                        d[dim - 1] = -1
         
     | 
| 76 | 
         
            +
                        T[:dim, :dim] = U @ np.diag(d) @ V
         
     | 
| 77 | 
         
            +
                        d[dim - 1] = s
         
     | 
| 78 | 
         
            +
                else:
         
     | 
| 79 | 
         
            +
                    T[:dim, :dim] = U @ np.diag(d) @ V
         
     | 
| 80 | 
         
            +
             
     | 
| 81 | 
         
            +
                if estimate_scale:
         
     | 
| 82 | 
         
            +
                    # Eq. (41) and (42).
         
     | 
| 83 | 
         
            +
                    scale = 1.0 / src_demean.var(axis=0).sum() * (S @ d)
         
     | 
| 84 | 
         
            +
                else:
         
     | 
| 85 | 
         
            +
                    scale = scale
         
     | 
| 86 | 
         
            +
             
     | 
| 87 | 
         
            +
                T[:dim, dim] = dst_mean - scale * (T[:dim, :dim] @ src_mean.T)
         
     | 
| 88 | 
         
            +
                T[:dim, :dim] *= scale
         
     | 
| 89 | 
         
            +
             
     | 
| 90 | 
         
            +
                return T, scale
         
     | 
| 91 | 
         
            +
             
     | 
| 92 | 
         
            +
             
     | 
| 93 | 
         
            +
            class FaceWarpException(Exception):
         
     | 
| 94 | 
         
            +
                def __str__(self):
         
     | 
| 95 | 
         
            +
                    return 'In File {}:{}'.format(
         
     | 
| 96 | 
         
            +
                        __file__, super.__str__(self))
         
     | 
| 97 | 
         
            +
             
     | 
| 98 | 
         
            +
             
     | 
| 99 | 
         
            +
            def get_reference_facial_points(output_size=None,
         
     | 
| 100 | 
         
            +
                                            inner_padding_factor=0.0,
         
     | 
| 101 | 
         
            +
                                            outer_padding=(0, 0),
         
     | 
| 102 | 
         
            +
                                            default_square=False):
         
     | 
| 103 | 
         
            +
                tmp_5pts = np.array(REFERENCE_FACIAL_POINTS)
         
     | 
| 104 | 
         
            +
                tmp_crop_size = np.array(DEFAULT_CROP_SIZE)
         
     | 
| 105 | 
         
            +
             
     | 
| 106 | 
         
            +
                # 0) make the inner region a square
         
     | 
| 107 | 
         
            +
                if default_square:
         
     | 
| 108 | 
         
            +
                    size_diff = max(tmp_crop_size) - tmp_crop_size
         
     | 
| 109 | 
         
            +
                    tmp_5pts += size_diff / 2
         
     | 
| 110 | 
         
            +
                    tmp_crop_size += size_diff
         
     | 
| 111 | 
         
            +
             
     | 
| 112 | 
         
            +
                if (output_size and
         
     | 
| 113 | 
         
            +
                        output_size[0] == tmp_crop_size[0] and
         
     | 
| 114 | 
         
            +
                        output_size[1] == tmp_crop_size[1]):
         
     | 
| 115 | 
         
            +
                    print('output_size == DEFAULT_CROP_SIZE {}: return default reference points'.format(tmp_crop_size))
         
     | 
| 116 | 
         
            +
                    return tmp_5pts
         
     | 
| 117 | 
         
            +
             
     | 
| 118 | 
         
            +
                if (inner_padding_factor == 0 and
         
     | 
| 119 | 
         
            +
                        outer_padding == (0, 0)):
         
     | 
| 120 | 
         
            +
                    if output_size is None:
         
     | 
| 121 | 
         
            +
                        print('No paddings to do: return default reference points')
         
     | 
| 122 | 
         
            +
                        return tmp_5pts
         
     | 
| 123 | 
         
            +
                    else:
         
     | 
| 124 | 
         
            +
                        raise FaceWarpException(
         
     | 
| 125 | 
         
            +
                            'No paddings to do, output_size must be None or {}'.format(tmp_crop_size))
         
     | 
| 126 | 
         
            +
             
     | 
| 127 | 
         
            +
                # check output size
         
     | 
| 128 | 
         
            +
                if not (0 <= inner_padding_factor <= 1.0):
         
     | 
| 129 | 
         
            +
                    raise FaceWarpException('Not (0 <= inner_padding_factor <= 1.0)')
         
     | 
| 130 | 
         
            +
             
     | 
| 131 | 
         
            +
                if ((inner_padding_factor > 0 or outer_padding[0] > 0 or outer_padding[1] > 0)
         
     | 
| 132 | 
         
            +
                        and output_size is None):
         
     | 
| 133 | 
         
            +
                    output_size = tmp_crop_size * \
         
     | 
| 134 | 
         
            +
                                  (1 + inner_padding_factor * 2).astype(np.int32)
         
     | 
| 135 | 
         
            +
                    output_size += np.array(outer_padding)
         
     | 
| 136 | 
         
            +
                    print('              deduced from paddings, output_size = ', output_size)
         
     | 
| 137 | 
         
            +
             
     | 
| 138 | 
         
            +
                if not (outer_padding[0] < output_size[0]
         
     | 
| 139 | 
         
            +
                        and outer_padding[1] < output_size[1]):
         
     | 
| 140 | 
         
            +
                    raise FaceWarpException('Not (outer_padding[0] < output_size[0]'
         
     | 
| 141 | 
         
            +
                                            'and outer_padding[1] < output_size[1])')
         
     | 
| 142 | 
         
            +
             
     | 
| 143 | 
         
            +
                # 1) pad the inner region according inner_padding_factor
         
     | 
| 144 | 
         
            +
                # print('---> STEP1: pad the inner region according inner_padding_factor')
         
     | 
| 145 | 
         
            +
                if inner_padding_factor > 0:
         
     | 
| 146 | 
         
            +
                    size_diff = tmp_crop_size * inner_padding_factor * 2
         
     | 
| 147 | 
         
            +
                    tmp_5pts += size_diff / 2
         
     | 
| 148 | 
         
            +
                    tmp_crop_size += np.round(size_diff).astype(np.int32)
         
     | 
| 149 | 
         
            +
             
     | 
| 150 | 
         
            +
                # print('              crop_size = ', tmp_crop_size)
         
     | 
| 151 | 
         
            +
                # print('              reference_5pts = ', tmp_5pts)
         
     | 
| 152 | 
         
            +
             
     | 
| 153 | 
         
            +
                # 2) resize the padded inner region
         
     | 
| 154 | 
         
            +
                # print('---> STEP2: resize the padded inner region')
         
     | 
| 155 | 
         
            +
                size_bf_outer_pad = np.array(output_size) - np.array(outer_padding) * 2
         
     | 
| 156 | 
         
            +
                # print('              crop_size = ', tmp_crop_size)
         
     | 
| 157 | 
         
            +
                # print('              size_bf_outer_pad = ', size_bf_outer_pad)
         
     | 
| 158 | 
         
            +
             
     | 
| 159 | 
         
            +
                if size_bf_outer_pad[0] * tmp_crop_size[1] != size_bf_outer_pad[1] * tmp_crop_size[0]:
         
     | 
| 160 | 
         
            +
                    raise FaceWarpException('Must have (output_size - outer_padding)'
         
     | 
| 161 | 
         
            +
                                            '= some_scale * (crop_size * (1.0 + inner_padding_factor)')
         
     | 
| 162 | 
         
            +
             
     | 
| 163 | 
         
            +
                scale_factor = size_bf_outer_pad[0].astype(np.float32) / tmp_crop_size[0]
         
     | 
| 164 | 
         
            +
                # print('              resize scale_factor = ', scale_factor)
         
     | 
| 165 | 
         
            +
                tmp_5pts = tmp_5pts * scale_factor
         
     | 
| 166 | 
         
            +
                #    size_diff = tmp_crop_size * (scale_factor - min(scale_factor))
         
     | 
| 167 | 
         
            +
                #    tmp_5pts = tmp_5pts + size_diff / 2
         
     | 
| 168 | 
         
            +
                tmp_crop_size = size_bf_outer_pad
         
     | 
| 169 | 
         
            +
                # print('              crop_size = ', tmp_crop_size)
         
     | 
| 170 | 
         
            +
                # print('              reference_5pts = ', tmp_5pts)
         
     | 
| 171 | 
         
            +
             
     | 
| 172 | 
         
            +
                # 3) add outer_padding to make output_size
         
     | 
| 173 | 
         
            +
                reference_5point = tmp_5pts + np.array(outer_padding)
         
     | 
| 174 | 
         
            +
                tmp_crop_size = output_size
         
     | 
| 175 | 
         
            +
                # print('---> STEP3: add outer_padding to make output_size')
         
     | 
| 176 | 
         
            +
                # print('              crop_size = ', tmp_crop_size)
         
     | 
| 177 | 
         
            +
                # print('              reference_5pts = ', tmp_5pts)
         
     | 
| 178 | 
         
            +
                #
         
     | 
| 179 | 
         
            +
                # print('===> end get_reference_facial_points\n')
         
     | 
| 180 | 
         
            +
             
     | 
| 181 | 
         
            +
                return reference_5point
         
     | 
| 182 | 
         
            +
             
     | 
| 183 | 
         
            +
             
     | 
| 184 | 
         
            +
            def get_affine_transform_matrix(src_pts, dst_pts):
         
     | 
| 185 | 
         
            +
                tfm = np.float32([[1, 0, 0], [0, 1, 0]])
         
     | 
| 186 | 
         
            +
                n_pts = src_pts.shape[0]
         
     | 
| 187 | 
         
            +
                ones = np.ones((n_pts, 1), src_pts.dtype)
         
     | 
| 188 | 
         
            +
                src_pts_ = np.hstack([src_pts, ones])
         
     | 
| 189 | 
         
            +
                dst_pts_ = np.hstack([dst_pts, ones])
         
     | 
| 190 | 
         
            +
             
     | 
| 191 | 
         
            +
                A, res, rank, s = np.linalg.lstsq(src_pts_, dst_pts_)
         
     | 
| 192 | 
         
            +
             
     | 
| 193 | 
         
            +
                if rank == 3:
         
     | 
| 194 | 
         
            +
                    tfm = np.float32([
         
     | 
| 195 | 
         
            +
                        [A[0, 0], A[1, 0], A[2, 0]],
         
     | 
| 196 | 
         
            +
                        [A[0, 1], A[1, 1], A[2, 1]]
         
     | 
| 197 | 
         
            +
                    ])
         
     | 
| 198 | 
         
            +
                elif rank == 2:
         
     | 
| 199 | 
         
            +
                    tfm = np.float32([
         
     | 
| 200 | 
         
            +
                        [A[0, 0], A[1, 0], 0],
         
     | 
| 201 | 
         
            +
                        [A[0, 1], A[1, 1], 0]
         
     | 
| 202 | 
         
            +
                    ])
         
     | 
| 203 | 
         
            +
             
     | 
| 204 | 
         
            +
                return tfm
         
     | 
| 205 | 
         
            +
             
     | 
| 206 | 
         
            +
             
     | 
| 207 | 
         
            +
            def warp_and_crop_face(src_img,
         
     | 
| 208 | 
         
            +
                                   facial_pts,
         
     | 
| 209 | 
         
            +
                                   reference_pts=None,
         
     | 
| 210 | 
         
            +
                                   crop_size=(96, 112),
         
     | 
| 211 | 
         
            +
                                   align_type='smilarity'): #smilarity cv2_affine affine
         
     | 
| 212 | 
         
            +
                if reference_pts is None:
         
     | 
| 213 | 
         
            +
                    if crop_size[0] == 96 and crop_size[1] == 112:
         
     | 
| 214 | 
         
            +
                        reference_pts = REFERENCE_FACIAL_POINTS
         
     | 
| 215 | 
         
            +
                    else:
         
     | 
| 216 | 
         
            +
                        default_square = False
         
     | 
| 217 | 
         
            +
                        inner_padding_factor = 0
         
     | 
| 218 | 
         
            +
                        outer_padding = (0, 0)
         
     | 
| 219 | 
         
            +
                        output_size = crop_size
         
     | 
| 220 | 
         
            +
             
     | 
| 221 | 
         
            +
                        reference_pts = get_reference_facial_points(output_size,
         
     | 
| 222 | 
         
            +
                                                                    inner_padding_factor,
         
     | 
| 223 | 
         
            +
                                                                    outer_padding,
         
     | 
| 224 | 
         
            +
                                                                    default_square)
         
     | 
| 225 | 
         
            +
             
     | 
| 226 | 
         
            +
                ref_pts = np.float32(reference_pts)
         
     | 
| 227 | 
         
            +
                ref_pts_shp = ref_pts.shape
         
     | 
| 228 | 
         
            +
                if max(ref_pts_shp) < 3 or min(ref_pts_shp) != 2:
         
     | 
| 229 | 
         
            +
                    raise FaceWarpException(
         
     | 
| 230 | 
         
            +
                        'reference_pts.shape must be (K,2) or (2,K) and K>2')
         
     | 
| 231 | 
         
            +
             
     | 
| 232 | 
         
            +
                if ref_pts_shp[0] == 2:
         
     | 
| 233 | 
         
            +
                    ref_pts = ref_pts.T
         
     | 
| 234 | 
         
            +
             
     | 
| 235 | 
         
            +
                src_pts = np.float32(facial_pts)
         
     | 
| 236 | 
         
            +
                src_pts_shp = src_pts.shape
         
     | 
| 237 | 
         
            +
                if max(src_pts_shp) < 3 or min(src_pts_shp) != 2:
         
     | 
| 238 | 
         
            +
                    raise FaceWarpException(
         
     | 
| 239 | 
         
            +
                        'facial_pts.shape must be (K,2) or (2,K) and K>2')
         
     | 
| 240 | 
         
            +
             
     | 
| 241 | 
         
            +
                if src_pts_shp[0] == 2:
         
     | 
| 242 | 
         
            +
                    src_pts = src_pts.T
         
     | 
| 243 | 
         
            +
             
     | 
| 244 | 
         
            +
                if src_pts.shape != ref_pts.shape:
         
     | 
| 245 | 
         
            +
                    raise FaceWarpException(
         
     | 
| 246 | 
         
            +
                        'facial_pts and reference_pts must have the same shape')
         
     | 
| 247 | 
         
            +
             
     | 
| 248 | 
         
            +
                if align_type is 'cv2_affine':
         
     | 
| 249 | 
         
            +
                    tfm = cv2.getAffineTransform(src_pts[0:3], ref_pts[0:3])
         
     | 
| 250 | 
         
            +
                    tfm_inv = cv2.getAffineTransform(ref_pts[0:3], src_pts[0:3])
         
     | 
| 251 | 
         
            +
                elif align_type is 'affine':
         
     | 
| 252 | 
         
            +
                    tfm = get_affine_transform_matrix(src_pts, ref_pts)
         
     | 
| 253 | 
         
            +
                    tfm_inv = get_affine_transform_matrix(ref_pts, src_pts)
         
     | 
| 254 | 
         
            +
                else:
         
     | 
| 255 | 
         
            +
                    params, scale = _umeyama(src_pts, ref_pts)
         
     | 
| 256 | 
         
            +
                    tfm = params[:2, :]
         
     | 
| 257 | 
         
            +
             
     | 
| 258 | 
         
            +
                    params, _ = _umeyama(ref_pts, src_pts, False, scale=1.0/scale)
         
     | 
| 259 | 
         
            +
                    tfm_inv = params[:2, :]
         
     | 
| 260 | 
         
            +
             
     | 
| 261 | 
         
            +
                face_img = cv2.warpAffine(src_img, tfm, (crop_size[0], crop_size[1]), flags=3)
         
     | 
| 262 | 
         
            +
             
     | 
| 263 | 
         
            +
                return face_img, tfm_inv
         
     | 
    	
        core/data/deg_kair_utils/utils_blindsr.py
    ADDED
    
    | 
         @@ -0,0 +1,631 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            # -*- coding: utf-8 -*-
         
     | 
| 2 | 
         
            +
            import numpy as np
         
     | 
| 3 | 
         
            +
            import cv2
         
     | 
| 4 | 
         
            +
            import torch
         
     | 
| 5 | 
         
            +
             
     | 
| 6 | 
         
            +
            from core.data.deg_kair_utils import utils_image as util
         
     | 
| 7 | 
         
            +
             
     | 
| 8 | 
         
            +
            import random
         
     | 
| 9 | 
         
            +
            from scipy import ndimage
         
     | 
| 10 | 
         
            +
            import scipy
         
     | 
| 11 | 
         
            +
            import scipy.stats as ss
         
     | 
| 12 | 
         
            +
            from scipy.interpolate import interp2d
         
     | 
| 13 | 
         
            +
            from scipy.linalg import orth
         
     | 
| 14 | 
         
            +
             
     | 
| 15 | 
         
            +
             
     | 
| 16 | 
         
            +
             
     | 
| 17 | 
         
            +
             
     | 
| 18 | 
         
            +
            """
         
     | 
| 19 | 
         
            +
            # --------------------------------------------
         
     | 
| 20 | 
         
            +
            # Super-Resolution
         
     | 
| 21 | 
         
            +
            # --------------------------------------------
         
     | 
| 22 | 
         
            +
            #
         
     | 
| 23 | 
         
            +
            # Kai Zhang ([email protected])
         
     | 
| 24 | 
         
            +
            # https://github.com/cszn
         
     | 
| 25 | 
         
            +
            # From 2019/03--2021/08
         
     | 
| 26 | 
         
            +
            # --------------------------------------------
         
     | 
| 27 | 
         
            +
            """
         
     | 
| 28 | 
         
            +
             
     | 
| 29 | 
         
            +
            def modcrop_np(img, sf):
         
     | 
| 30 | 
         
            +
                '''
         
     | 
| 31 | 
         
            +
                Args:
         
     | 
| 32 | 
         
            +
                    img: numpy image, WxH or WxHxC
         
     | 
| 33 | 
         
            +
                    sf: scale factor
         
     | 
| 34 | 
         
            +
             
     | 
| 35 | 
         
            +
                Return:
         
     | 
| 36 | 
         
            +
                    cropped image
         
     | 
| 37 | 
         
            +
                '''
         
     | 
| 38 | 
         
            +
                w, h = img.shape[:2]
         
     | 
| 39 | 
         
            +
                im = np.copy(img)
         
     | 
| 40 | 
         
            +
                return im[:w - w % sf, :h - h % sf, ...]
         
     | 
| 41 | 
         
            +
             
     | 
| 42 | 
         
            +
             
     | 
| 43 | 
         
            +
            """
         
     | 
| 44 | 
         
            +
            # --------------------------------------------
         
     | 
| 45 | 
         
            +
            # anisotropic Gaussian kernels
         
     | 
| 46 | 
         
            +
            # --------------------------------------------
         
     | 
| 47 | 
         
            +
            """
         
     | 
| 48 | 
         
            +
            def analytic_kernel(k):
         
     | 
| 49 | 
         
            +
                """Calculate the X4 kernel from the X2 kernel (for proof see appendix in paper)"""
         
     | 
| 50 | 
         
            +
                k_size = k.shape[0]
         
     | 
| 51 | 
         
            +
                # Calculate the big kernels size
         
     | 
| 52 | 
         
            +
                big_k = np.zeros((3 * k_size - 2, 3 * k_size - 2))
         
     | 
| 53 | 
         
            +
                # Loop over the small kernel to fill the big one
         
     | 
| 54 | 
         
            +
                for r in range(k_size):
         
     | 
| 55 | 
         
            +
                    for c in range(k_size):
         
     | 
| 56 | 
         
            +
                        big_k[2 * r:2 * r + k_size, 2 * c:2 * c + k_size] += k[r, c] * k
         
     | 
| 57 | 
         
            +
                # Crop the edges of the big kernel to ignore very small values and increase run time of SR
         
     | 
| 58 | 
         
            +
                crop = k_size // 2
         
     | 
| 59 | 
         
            +
                cropped_big_k = big_k[crop:-crop, crop:-crop]
         
     | 
| 60 | 
         
            +
                # Normalize to 1
         
     | 
| 61 | 
         
            +
                return cropped_big_k / cropped_big_k.sum()
         
     | 
| 62 | 
         
            +
             
     | 
| 63 | 
         
            +
             
     | 
| 64 | 
         
            +
            def anisotropic_Gaussian(ksize=15, theta=np.pi, l1=6, l2=6):
         
     | 
| 65 | 
         
            +
                """ generate an anisotropic Gaussian kernel
         
     | 
| 66 | 
         
            +
                Args:
         
     | 
| 67 | 
         
            +
                    ksize : e.g., 15, kernel size
         
     | 
| 68 | 
         
            +
                    theta : [0,  pi], rotation angle range
         
     | 
| 69 | 
         
            +
                    l1    : [0.1,50], scaling of eigenvalues
         
     | 
| 70 | 
         
            +
                    l2    : [0.1,l1], scaling of eigenvalues
         
     | 
| 71 | 
         
            +
                    If l1 = l2, will get an isotropic Gaussian kernel.
         
     | 
| 72 | 
         
            +
             
     | 
| 73 | 
         
            +
                Returns:
         
     | 
| 74 | 
         
            +
                    k     : kernel
         
     | 
| 75 | 
         
            +
                """
         
     | 
| 76 | 
         
            +
             
     | 
| 77 | 
         
            +
                v = np.dot(np.array([[np.cos(theta), -np.sin(theta)], [np.sin(theta), np.cos(theta)]]), np.array([1., 0.]))
         
     | 
| 78 | 
         
            +
                V = np.array([[v[0], v[1]], [v[1], -v[0]]])
         
     | 
| 79 | 
         
            +
                D = np.array([[l1, 0], [0, l2]])
         
     | 
| 80 | 
         
            +
                Sigma = np.dot(np.dot(V, D), np.linalg.inv(V))
         
     | 
| 81 | 
         
            +
                k = gm_blur_kernel(mean=[0, 0], cov=Sigma, size=ksize)
         
     | 
| 82 | 
         
            +
             
     | 
| 83 | 
         
            +
                return k
         
     | 
| 84 | 
         
            +
             
     | 
| 85 | 
         
            +
             
     | 
| 86 | 
         
            +
            def gm_blur_kernel(mean, cov, size=15):
         
     | 
| 87 | 
         
            +
                center = size / 2.0 + 0.5
         
     | 
| 88 | 
         
            +
                k = np.zeros([size, size])
         
     | 
| 89 | 
         
            +
                for y in range(size):
         
     | 
| 90 | 
         
            +
                    for x in range(size):
         
     | 
| 91 | 
         
            +
                        cy = y - center + 1
         
     | 
| 92 | 
         
            +
                        cx = x - center + 1
         
     | 
| 93 | 
         
            +
                        k[y, x] = ss.multivariate_normal.pdf([cx, cy], mean=mean, cov=cov)
         
     | 
| 94 | 
         
            +
             
     | 
| 95 | 
         
            +
                k = k / np.sum(k)
         
     | 
| 96 | 
         
            +
                return k
         
     | 
| 97 | 
         
            +
             
     | 
| 98 | 
         
            +
             
     | 
| 99 | 
         
            +
            def shift_pixel(x, sf, upper_left=True):
         
     | 
| 100 | 
         
            +
                """shift pixel for super-resolution with different scale factors
         
     | 
| 101 | 
         
            +
                Args:
         
     | 
| 102 | 
         
            +
                    x: WxHxC or WxH
         
     | 
| 103 | 
         
            +
                    sf: scale factor
         
     | 
| 104 | 
         
            +
                    upper_left: shift direction
         
     | 
| 105 | 
         
            +
                """
         
     | 
| 106 | 
         
            +
                h, w = x.shape[:2]
         
     | 
| 107 | 
         
            +
                shift = (sf-1)*0.5
         
     | 
| 108 | 
         
            +
                xv, yv = np.arange(0, w, 1.0), np.arange(0, h, 1.0)
         
     | 
| 109 | 
         
            +
                if upper_left:
         
     | 
| 110 | 
         
            +
                    x1 = xv + shift
         
     | 
| 111 | 
         
            +
                    y1 = yv + shift
         
     | 
| 112 | 
         
            +
                else:
         
     | 
| 113 | 
         
            +
                    x1 = xv - shift
         
     | 
| 114 | 
         
            +
                    y1 = yv - shift
         
     | 
| 115 | 
         
            +
             
     | 
| 116 | 
         
            +
                x1 = np.clip(x1, 0, w-1)
         
     | 
| 117 | 
         
            +
                y1 = np.clip(y1, 0, h-1)
         
     | 
| 118 | 
         
            +
             
     | 
| 119 | 
         
            +
                if x.ndim == 2:
         
     | 
| 120 | 
         
            +
                    x = interp2d(xv, yv, x)(x1, y1)
         
     | 
| 121 | 
         
            +
                if x.ndim == 3:
         
     | 
| 122 | 
         
            +
                    for i in range(x.shape[-1]):
         
     | 
| 123 | 
         
            +
                        x[:, :, i] = interp2d(xv, yv, x[:, :, i])(x1, y1)
         
     | 
| 124 | 
         
            +
             
     | 
| 125 | 
         
            +
                return x
         
     | 
| 126 | 
         
            +
             
     | 
| 127 | 
         
            +
             
     | 
| 128 | 
         
            +
            def blur(x, k):
         
     | 
| 129 | 
         
            +
                '''
         
     | 
| 130 | 
         
            +
                x: image, NxcxHxW
         
     | 
| 131 | 
         
            +
                k: kernel, Nx1xhxw
         
     | 
| 132 | 
         
            +
                '''
         
     | 
| 133 | 
         
            +
                n, c = x.shape[:2]
         
     | 
| 134 | 
         
            +
                p1, p2 = (k.shape[-2]-1)//2, (k.shape[-1]-1)//2
         
     | 
| 135 | 
         
            +
                x = torch.nn.functional.pad(x, pad=(p1, p2, p1, p2), mode='replicate')
         
     | 
| 136 | 
         
            +
                k = k.repeat(1,c,1,1)
         
     | 
| 137 | 
         
            +
                k = k.view(-1, 1, k.shape[2], k.shape[3])
         
     | 
| 138 | 
         
            +
                x = x.view(1, -1, x.shape[2], x.shape[3])
         
     | 
| 139 | 
         
            +
                x = torch.nn.functional.conv2d(x, k, bias=None, stride=1, padding=0, groups=n*c)
         
     | 
| 140 | 
         
            +
                x = x.view(n, c, x.shape[2], x.shape[3])
         
     | 
| 141 | 
         
            +
             
     | 
| 142 | 
         
            +
                return x
         
     | 
| 143 | 
         
            +
             
     | 
| 144 | 
         
            +
             
     | 
| 145 | 
         
            +
             
     | 
| 146 | 
         
            +
            def gen_kernel(k_size=np.array([15, 15]), scale_factor=np.array([4, 4]), min_var=0.6, max_var=10., noise_level=0):
         
     | 
| 147 | 
         
            +
                """"
         
     | 
| 148 | 
         
            +
                # modified version of https://github.com/assafshocher/BlindSR_dataset_generator
         
     | 
| 149 | 
         
            +
                # Kai Zhang
         
     | 
| 150 | 
         
            +
                # min_var = 0.175 * sf  # variance of the gaussian kernel will be sampled between min_var and max_var
         
     | 
| 151 | 
         
            +
                # max_var = 2.5 * sf
         
     | 
| 152 | 
         
            +
                """
         
     | 
| 153 | 
         
            +
                # Set random eigen-vals (lambdas) and angle (theta) for COV matrix
         
     | 
| 154 | 
         
            +
                lambda_1 = min_var + np.random.rand() * (max_var - min_var)
         
     | 
| 155 | 
         
            +
                lambda_2 = min_var + np.random.rand() * (max_var - min_var)
         
     | 
| 156 | 
         
            +
                theta = np.random.rand() * np.pi  # random theta
         
     | 
| 157 | 
         
            +
                noise = -noise_level + np.random.rand(*k_size) * noise_level * 2
         
     | 
| 158 | 
         
            +
             
     | 
| 159 | 
         
            +
                # Set COV matrix using Lambdas and Theta
         
     | 
| 160 | 
         
            +
                LAMBDA = np.diag([lambda_1, lambda_2])
         
     | 
| 161 | 
         
            +
                Q = np.array([[np.cos(theta), -np.sin(theta)],
         
     | 
| 162 | 
         
            +
                              [np.sin(theta), np.cos(theta)]])
         
     | 
| 163 | 
         
            +
                SIGMA = Q @ LAMBDA @ Q.T
         
     | 
| 164 | 
         
            +
                INV_SIGMA = np.linalg.inv(SIGMA)[None, None, :, :]
         
     | 
| 165 | 
         
            +
             
     | 
| 166 | 
         
            +
                # Set expectation position (shifting kernel for aligned image)
         
     | 
| 167 | 
         
            +
                MU = k_size // 2 - 0.5*(scale_factor - 1) # - 0.5 * (scale_factor - k_size % 2)
         
     | 
| 168 | 
         
            +
                MU = MU[None, None, :, None]
         
     | 
| 169 | 
         
            +
             
     | 
| 170 | 
         
            +
                # Create meshgrid for Gaussian
         
     | 
| 171 | 
         
            +
                [X,Y] = np.meshgrid(range(k_size[0]), range(k_size[1]))
         
     | 
| 172 | 
         
            +
                Z = np.stack([X, Y], 2)[:, :, :, None]
         
     | 
| 173 | 
         
            +
             
     | 
| 174 | 
         
            +
                # Calcualte Gaussian for every pixel of the kernel
         
     | 
| 175 | 
         
            +
                ZZ = Z-MU
         
     | 
| 176 | 
         
            +
                ZZ_t = ZZ.transpose(0,1,3,2)
         
     | 
| 177 | 
         
            +
                raw_kernel = np.exp(-0.5 * np.squeeze(ZZ_t @ INV_SIGMA @ ZZ)) * (1 + noise)
         
     | 
| 178 | 
         
            +
             
     | 
| 179 | 
         
            +
                # shift the kernel so it will be centered
         
     | 
| 180 | 
         
            +
                #raw_kernel_centered = kernel_shift(raw_kernel, scale_factor)
         
     | 
| 181 | 
         
            +
             
     | 
| 182 | 
         
            +
                # Normalize the kernel and return
         
     | 
| 183 | 
         
            +
                #kernel = raw_kernel_centered / np.sum(raw_kernel_centered)
         
     | 
| 184 | 
         
            +
                kernel = raw_kernel / np.sum(raw_kernel)
         
     | 
| 185 | 
         
            +
                return kernel
         
     | 
| 186 | 
         
            +
             
     | 
| 187 | 
         
            +
             
     | 
| 188 | 
         
            +
            def fspecial_gaussian(hsize, sigma):
         
     | 
| 189 | 
         
            +
                hsize = [hsize, hsize]
         
     | 
| 190 | 
         
            +
                siz = [(hsize[0]-1.0)/2.0, (hsize[1]-1.0)/2.0]
         
     | 
| 191 | 
         
            +
                std = sigma
         
     | 
| 192 | 
         
            +
                [x, y] = np.meshgrid(np.arange(-siz[1], siz[1]+1), np.arange(-siz[0], siz[0]+1))
         
     | 
| 193 | 
         
            +
                arg = -(x*x + y*y)/(2*std*std)
         
     | 
| 194 | 
         
            +
                h = np.exp(arg)
         
     | 
| 195 | 
         
            +
                h[h < scipy.finfo(float).eps * h.max()] = 0
         
     | 
| 196 | 
         
            +
                sumh = h.sum()
         
     | 
| 197 | 
         
            +
                if sumh != 0:
         
     | 
| 198 | 
         
            +
                    h = h/sumh
         
     | 
| 199 | 
         
            +
                return h
         
     | 
| 200 | 
         
            +
             
     | 
| 201 | 
         
            +
             
     | 
| 202 | 
         
            +
            def fspecial_laplacian(alpha):
         
     | 
| 203 | 
         
            +
                alpha = max([0, min([alpha,1])])
         
     | 
| 204 | 
         
            +
                h1 = alpha/(alpha+1)
         
     | 
| 205 | 
         
            +
                h2 = (1-alpha)/(alpha+1)
         
     | 
| 206 | 
         
            +
                h = [[h1, h2, h1], [h2, -4/(alpha+1), h2], [h1, h2, h1]]
         
     | 
| 207 | 
         
            +
                h = np.array(h)
         
     | 
| 208 | 
         
            +
                return h
         
     | 
| 209 | 
         
            +
             
     | 
| 210 | 
         
            +
             
     | 
| 211 | 
         
            +
            def fspecial(filter_type, *args, **kwargs):
         
     | 
| 212 | 
         
            +
                '''
         
     | 
| 213 | 
         
            +
                python code from:
         
     | 
| 214 | 
         
            +
                https://github.com/ronaldosena/imagens-medicas-2/blob/40171a6c259edec7827a6693a93955de2bd39e76/Aulas/aula_2_-_uniform_filter/matlab_fspecial.py
         
     | 
| 215 | 
         
            +
                '''
         
     | 
| 216 | 
         
            +
                if filter_type == 'gaussian':
         
     | 
| 217 | 
         
            +
                    return fspecial_gaussian(*args, **kwargs)
         
     | 
| 218 | 
         
            +
                if filter_type == 'laplacian':
         
     | 
| 219 | 
         
            +
                    return fspecial_laplacian(*args, **kwargs)
         
     | 
| 220 | 
         
            +
             
     | 
| 221 | 
         
            +
            """
         
     | 
| 222 | 
         
            +
            # --------------------------------------------
         
     | 
| 223 | 
         
            +
            # degradation models
         
     | 
| 224 | 
         
            +
            # --------------------------------------------
         
     | 
| 225 | 
         
            +
            """
         
     | 
| 226 | 
         
            +
             
     | 
| 227 | 
         
            +
             
     | 
| 228 | 
         
            +
            def bicubic_degradation(x, sf=3):
         
     | 
| 229 | 
         
            +
                '''
         
     | 
| 230 | 
         
            +
                Args:
         
     | 
| 231 | 
         
            +
                    x: HxWxC image, [0, 1]
         
     | 
| 232 | 
         
            +
                    sf: down-scale factor
         
     | 
| 233 | 
         
            +
             
     | 
| 234 | 
         
            +
                Return:
         
     | 
| 235 | 
         
            +
                    bicubicly downsampled LR image
         
     | 
| 236 | 
         
            +
                '''
         
     | 
| 237 | 
         
            +
                x = util.imresize_np(x, scale=1/sf)
         
     | 
| 238 | 
         
            +
                return x
         
     | 
| 239 | 
         
            +
             
     | 
| 240 | 
         
            +
             
     | 
| 241 | 
         
            +
            def srmd_degradation(x, k, sf=3):
         
     | 
| 242 | 
         
            +
                ''' blur + bicubic downsampling
         
     | 
| 243 | 
         
            +
             
     | 
| 244 | 
         
            +
                Args:
         
     | 
| 245 | 
         
            +
                    x: HxWxC image, [0, 1]
         
     | 
| 246 | 
         
            +
                    k: hxw, double
         
     | 
| 247 | 
         
            +
                    sf: down-scale factor
         
     | 
| 248 | 
         
            +
             
     | 
| 249 | 
         
            +
                Return:
         
     | 
| 250 | 
         
            +
                    downsampled LR image
         
     | 
| 251 | 
         
            +
             
     | 
| 252 | 
         
            +
                Reference:
         
     | 
| 253 | 
         
            +
                    @inproceedings{zhang2018learning,
         
     | 
| 254 | 
         
            +
                      title={Learning a single convolutional super-resolution network for multiple degradations},
         
     | 
| 255 | 
         
            +
                      author={Zhang, Kai and Zuo, Wangmeng and Zhang, Lei},
         
     | 
| 256 | 
         
            +
                      booktitle={IEEE Conference on Computer Vision and Pattern Recognition},
         
     | 
| 257 | 
         
            +
                      pages={3262--3271},
         
     | 
| 258 | 
         
            +
                      year={2018}
         
     | 
| 259 | 
         
            +
                    }
         
     | 
| 260 | 
         
            +
                '''
         
     | 
| 261 | 
         
            +
                x = ndimage.filters.convolve(x, np.expand_dims(k, axis=2), mode='wrap')  # 'nearest' | 'mirror'
         
     | 
| 262 | 
         
            +
                x = bicubic_degradation(x, sf=sf)
         
     | 
| 263 | 
         
            +
                return x
         
     | 
| 264 | 
         
            +
             
     | 
| 265 | 
         
            +
             
     | 
| 266 | 
         
            +
            def dpsr_degradation(x, k, sf=3):
         
     | 
| 267 | 
         
            +
             
     | 
| 268 | 
         
            +
                ''' bicubic downsampling + blur
         
     | 
| 269 | 
         
            +
             
     | 
| 270 | 
         
            +
                Args:
         
     | 
| 271 | 
         
            +
                    x: HxWxC image, [0, 1]
         
     | 
| 272 | 
         
            +
                    k: hxw, double
         
     | 
| 273 | 
         
            +
                    sf: down-scale factor
         
     | 
| 274 | 
         
            +
             
     | 
| 275 | 
         
            +
                Return:
         
     | 
| 276 | 
         
            +
                    downsampled LR image
         
     | 
| 277 | 
         
            +
             
     | 
| 278 | 
         
            +
                Reference:
         
     | 
| 279 | 
         
            +
                    @inproceedings{zhang2019deep,
         
     | 
| 280 | 
         
            +
                      title={Deep Plug-and-Play Super-Resolution for Arbitrary Blur Kernels},
         
     | 
| 281 | 
         
            +
                      author={Zhang, Kai and Zuo, Wangmeng and Zhang, Lei},
         
     | 
| 282 | 
         
            +
                      booktitle={IEEE Conference on Computer Vision and Pattern Recognition},
         
     | 
| 283 | 
         
            +
                      pages={1671--1681},
         
     | 
| 284 | 
         
            +
                      year={2019}
         
     | 
| 285 | 
         
            +
                    }
         
     | 
| 286 | 
         
            +
                '''
         
     | 
| 287 | 
         
            +
                x = bicubic_degradation(x, sf=sf)
         
     | 
| 288 | 
         
            +
                x = ndimage.filters.convolve(x, np.expand_dims(k, axis=2), mode='wrap')
         
     | 
| 289 | 
         
            +
                return x
         
     | 
| 290 | 
         
            +
             
     | 
| 291 | 
         
            +
             
     | 
| 292 | 
         
            +
            def classical_degradation(x, k, sf=3):
         
     | 
| 293 | 
         
            +
                ''' blur + downsampling
         
     | 
| 294 | 
         
            +
             
     | 
| 295 | 
         
            +
                Args:
         
     | 
| 296 | 
         
            +
                    x: HxWxC image, [0, 1]/[0, 255]
         
     | 
| 297 | 
         
            +
                    k: hxw, double
         
     | 
| 298 | 
         
            +
                    sf: down-scale factor
         
     | 
| 299 | 
         
            +
             
     | 
| 300 | 
         
            +
                Return:
         
     | 
| 301 | 
         
            +
                    downsampled LR image
         
     | 
| 302 | 
         
            +
                '''
         
     | 
| 303 | 
         
            +
                x = ndimage.filters.convolve(x, np.expand_dims(k, axis=2), mode='wrap')
         
     | 
| 304 | 
         
            +
                #x = filters.correlate(x, np.expand_dims(np.flip(k), axis=2))
         
     | 
| 305 | 
         
            +
                st = 0
         
     | 
| 306 | 
         
            +
                return x[st::sf, st::sf, ...]
         
     | 
| 307 | 
         
            +
             
     | 
| 308 | 
         
            +
             
     | 
| 309 | 
         
            +
            def add_sharpening(img, weight=0.5, radius=50, threshold=10):
         
     | 
| 310 | 
         
            +
                """USM sharpening. borrowed from real-ESRGAN
         
     | 
| 311 | 
         
            +
                Input image: I; Blurry image: B.
         
     | 
| 312 | 
         
            +
                1. K = I + weight * (I - B)
         
     | 
| 313 | 
         
            +
                2. Mask = 1 if abs(I - B) > threshold, else: 0
         
     | 
| 314 | 
         
            +
                3. Blur mask:
         
     | 
| 315 | 
         
            +
                4. Out = Mask * K + (1 - Mask) * I
         
     | 
| 316 | 
         
            +
                Args:
         
     | 
| 317 | 
         
            +
                    img (Numpy array): Input image, HWC, BGR; float32, [0, 1].
         
     | 
| 318 | 
         
            +
                    weight (float): Sharp weight. Default: 1.
         
     | 
| 319 | 
         
            +
                    radius (float): Kernel size of Gaussian blur. Default: 50.
         
     | 
| 320 | 
         
            +
                    threshold (int):
         
     | 
| 321 | 
         
            +
                """
         
     | 
| 322 | 
         
            +
                if radius % 2 == 0:
         
     | 
| 323 | 
         
            +
                    radius += 1
         
     | 
| 324 | 
         
            +
                blur = cv2.GaussianBlur(img, (radius, radius), 0)
         
     | 
| 325 | 
         
            +
                residual = img - blur
         
     | 
| 326 | 
         
            +
                mask = np.abs(residual) * 255 > threshold
         
     | 
| 327 | 
         
            +
                mask = mask.astype('float32')
         
     | 
| 328 | 
         
            +
                soft_mask = cv2.GaussianBlur(mask, (radius, radius), 0)
         
     | 
| 329 | 
         
            +
             
     | 
| 330 | 
         
            +
                K = img + weight * residual
         
     | 
| 331 | 
         
            +
                K = np.clip(K, 0, 1)
         
     | 
| 332 | 
         
            +
                return soft_mask * K + (1 - soft_mask) * img
         
     | 
| 333 | 
         
            +
             
     | 
| 334 | 
         
            +
             
     | 
| 335 | 
         
            +
            def add_blur(img, sf=4):
         
     | 
| 336 | 
         
            +
                wd2 = 4.0 + sf
         
     | 
| 337 | 
         
            +
                wd = 2.0 + 0.2*sf
         
     | 
| 338 | 
         
            +
                if random.random() < 0.5:
         
     | 
| 339 | 
         
            +
                    l1 = wd2*random.random()
         
     | 
| 340 | 
         
            +
                    l2 = wd2*random.random()
         
     | 
| 341 | 
         
            +
                    k = anisotropic_Gaussian(ksize=2*random.randint(2,11)+3, theta=random.random()*np.pi, l1=l1, l2=l2)
         
     | 
| 342 | 
         
            +
                else:
         
     | 
| 343 | 
         
            +
                    k = fspecial('gaussian', 2*random.randint(2,11)+3, wd*random.random())
         
     | 
| 344 | 
         
            +
                img = ndimage.filters.convolve(img, np.expand_dims(k, axis=2), mode='mirror')
         
     | 
| 345 | 
         
            +
             
     | 
| 346 | 
         
            +
                return img
         
     | 
| 347 | 
         
            +
             
     | 
| 348 | 
         
            +
             
     | 
| 349 | 
         
            +
            def add_resize(img, sf=4):
         
     | 
| 350 | 
         
            +
                rnum = np.random.rand()
         
     | 
| 351 | 
         
            +
                if rnum > 0.8:  # up
         
     | 
| 352 | 
         
            +
                    sf1 = random.uniform(1, 2)
         
     | 
| 353 | 
         
            +
                elif rnum < 0.7:  # down
         
     | 
| 354 | 
         
            +
                    sf1 = random.uniform(0.5/sf, 1)
         
     | 
| 355 | 
         
            +
                else:
         
     | 
| 356 | 
         
            +
                    sf1 = 1.0
         
     | 
| 357 | 
         
            +
                img = cv2.resize(img, (int(sf1*img.shape[1]), int(sf1*img.shape[0])), interpolation=random.choice([1, 2, 3]))
         
     | 
| 358 | 
         
            +
                img = np.clip(img, 0.0, 1.0)
         
     | 
| 359 | 
         
            +
             
     | 
| 360 | 
         
            +
                return img
         
     | 
| 361 | 
         
            +
             
     | 
| 362 | 
         
            +
             
     | 
| 363 | 
         
            +
            def add_Gaussian_noise(img, noise_level1=2, noise_level2=25):
         
     | 
| 364 | 
         
            +
                noise_level = random.randint(noise_level1, noise_level2)
         
     | 
| 365 | 
         
            +
                rnum = np.random.rand()
         
     | 
| 366 | 
         
            +
                if rnum > 0.6:   # add color Gaussian noise
         
     | 
| 367 | 
         
            +
                    img += np.random.normal(0, noise_level/255.0, img.shape).astype(np.float32)
         
     | 
| 368 | 
         
            +
                elif rnum < 0.4: # add grayscale Gaussian noise
         
     | 
| 369 | 
         
            +
                    img += np.random.normal(0, noise_level/255.0, (*img.shape[:2], 1)).astype(np.float32)
         
     | 
| 370 | 
         
            +
                else:            # add  noise
         
     | 
| 371 | 
         
            +
                    L = noise_level2/255.
         
     | 
| 372 | 
         
            +
                    D = np.diag(np.random.rand(3))
         
     | 
| 373 | 
         
            +
                    U = orth(np.random.rand(3,3))
         
     | 
| 374 | 
         
            +
                    conv = np.dot(np.dot(np.transpose(U), D), U)
         
     | 
| 375 | 
         
            +
                    img += np.random.multivariate_normal([0,0,0], np.abs(L**2*conv), img.shape[:2]).astype(np.float32)
         
     | 
| 376 | 
         
            +
                img = np.clip(img, 0.0, 1.0)
         
     | 
| 377 | 
         
            +
                return img
         
     | 
| 378 | 
         
            +
             
     | 
| 379 | 
         
            +
             
     | 
| 380 | 
         
            +
            def add_speckle_noise(img, noise_level1=2, noise_level2=25):
         
     | 
| 381 | 
         
            +
                noise_level = random.randint(noise_level1, noise_level2)
         
     | 
| 382 | 
         
            +
                img = np.clip(img, 0.0, 1.0)
         
     | 
| 383 | 
         
            +
                rnum = random.random()
         
     | 
| 384 | 
         
            +
                if rnum > 0.6:
         
     | 
| 385 | 
         
            +
                    img += img*np.random.normal(0, noise_level/255.0, img.shape).astype(np.float32)
         
     | 
| 386 | 
         
            +
                elif rnum < 0.4:
         
     | 
| 387 | 
         
            +
                    img += img*np.random.normal(0, noise_level/255.0, (*img.shape[:2], 1)).astype(np.float32)
         
     | 
| 388 | 
         
            +
                else:
         
     | 
| 389 | 
         
            +
                    L = noise_level2/255.
         
     | 
| 390 | 
         
            +
                    D = np.diag(np.random.rand(3))
         
     | 
| 391 | 
         
            +
                    U = orth(np.random.rand(3,3))
         
     | 
| 392 | 
         
            +
                    conv = np.dot(np.dot(np.transpose(U), D), U)
         
     | 
| 393 | 
         
            +
                    img += img*np.random.multivariate_normal([0,0,0], np.abs(L**2*conv), img.shape[:2]).astype(np.float32)
         
     | 
| 394 | 
         
            +
                img = np.clip(img, 0.0, 1.0)
         
     | 
| 395 | 
         
            +
                return img
         
     | 
| 396 | 
         
            +
             
     | 
| 397 | 
         
            +
             
     | 
| 398 | 
         
            +
            def add_Poisson_noise(img):
         
     | 
| 399 | 
         
            +
                img = np.clip((img * 255.0).round(), 0, 255) / 255.
         
     | 
| 400 | 
         
            +
                vals = 10**(2*random.random()+2.0)  # [2, 4]
         
     | 
| 401 | 
         
            +
                if random.random() < 0.5:
         
     | 
| 402 | 
         
            +
                    img = np.random.poisson(img * vals).astype(np.float32) / vals
         
     | 
| 403 | 
         
            +
                else:
         
     | 
| 404 | 
         
            +
                    img_gray = np.dot(img[...,:3], [0.299, 0.587, 0.114])
         
     | 
| 405 | 
         
            +
                    img_gray = np.clip((img_gray * 255.0).round(), 0, 255) / 255.
         
     | 
| 406 | 
         
            +
                    noise_gray = np.random.poisson(img_gray * vals).astype(np.float32) / vals - img_gray
         
     | 
| 407 | 
         
            +
                    img += noise_gray[:, :, np.newaxis]
         
     | 
| 408 | 
         
            +
                img = np.clip(img, 0.0, 1.0)
         
     | 
| 409 | 
         
            +
                return img
         
     | 
| 410 | 
         
            +
             
     | 
| 411 | 
         
            +
             
     | 
| 412 | 
         
            +
            def add_JPEG_noise(img):
         
     | 
| 413 | 
         
            +
                quality_factor = random.randint(30, 95)
         
     | 
| 414 | 
         
            +
                img = cv2.cvtColor(util.single2uint(img), cv2.COLOR_RGB2BGR)
         
     | 
| 415 | 
         
            +
                result, encimg = cv2.imencode('.jpg', img, [int(cv2.IMWRITE_JPEG_QUALITY), quality_factor])
         
     | 
| 416 | 
         
            +
                img = cv2.imdecode(encimg, 1)
         
     | 
| 417 | 
         
            +
                img = cv2.cvtColor(util.uint2single(img), cv2.COLOR_BGR2RGB)
         
     | 
| 418 | 
         
            +
                return img
         
     | 
| 419 | 
         
            +
             
     | 
| 420 | 
         
            +
             
     | 
| 421 | 
         
            +
            def random_crop(lq, hq, sf=4, lq_patchsize=64):
         
     | 
| 422 | 
         
            +
                h, w = lq.shape[:2]
         
     | 
| 423 | 
         
            +
                rnd_h = random.randint(0, h-lq_patchsize)
         
     | 
| 424 | 
         
            +
                rnd_w = random.randint(0, w-lq_patchsize)
         
     | 
| 425 | 
         
            +
                lq = lq[rnd_h:rnd_h + lq_patchsize, rnd_w:rnd_w + lq_patchsize, :]
         
     | 
| 426 | 
         
            +
             
     | 
| 427 | 
         
            +
                rnd_h_H, rnd_w_H = int(rnd_h * sf), int(rnd_w * sf)
         
     | 
| 428 | 
         
            +
                hq = hq[rnd_h_H:rnd_h_H + lq_patchsize*sf, rnd_w_H:rnd_w_H + lq_patchsize*sf, :]
         
     | 
| 429 | 
         
            +
                return lq, hq
         
     | 
| 430 | 
         
            +
             
     | 
| 431 | 
         
            +
             
     | 
| 432 | 
         
            +
            def degradation_bsrgan(img, sf=4, lq_patchsize=72, isp_model=None):
         
     | 
| 433 | 
         
            +
                """
         
     | 
| 434 | 
         
            +
                This is the degradation model of BSRGAN from the paper
         
     | 
| 435 | 
         
            +
                "Designing a Practical Degradation Model for Deep Blind Image Super-Resolution"
         
     | 
| 436 | 
         
            +
                ----------
         
     | 
| 437 | 
         
            +
                img: HXWXC, [0, 1], its size should be large than (lq_patchsizexsf)x(lq_patchsizexsf)
         
     | 
| 438 | 
         
            +
                sf: scale factor
         
     | 
| 439 | 
         
            +
                isp_model: camera ISP model
         
     | 
| 440 | 
         
            +
             
     | 
| 441 | 
         
            +
                Returns
         
     | 
| 442 | 
         
            +
                -------
         
     | 
| 443 | 
         
            +
                img: low-quality patch, size: lq_patchsizeXlq_patchsizeXC, range: [0, 1]
         
     | 
| 444 | 
         
            +
                hq: corresponding high-quality patch, size: (lq_patchsizexsf)X(lq_patchsizexsf)XC, range: [0, 1]
         
     | 
| 445 | 
         
            +
                """
         
     | 
| 446 | 
         
            +
                isp_prob, jpeg_prob, scale2_prob = 0.25, 0.9, 0.25
         
     | 
| 447 | 
         
            +
                sf_ori = sf
         
     | 
| 448 | 
         
            +
             
     | 
| 449 | 
         
            +
                h1, w1 = img.shape[:2]
         
     | 
| 450 | 
         
            +
                img = img.copy()[:h1 - h1 % sf, :w1 - w1 % sf, ...]  # mod crop
         
     | 
| 451 | 
         
            +
                h, w = img.shape[:2]
         
     | 
| 452 | 
         
            +
             
     | 
| 453 | 
         
            +
                if h < lq_patchsize*sf or w < lq_patchsize*sf:
         
     | 
| 454 | 
         
            +
                    raise ValueError(f'img size ({h1}X{w1}) is too small!')
         
     | 
| 455 | 
         
            +
             
     | 
| 456 | 
         
            +
                hq = img.copy()
         
     | 
| 457 | 
         
            +
             
     | 
| 458 | 
         
            +
                if sf == 4 and random.random() < scale2_prob:   # downsample1
         
     | 
| 459 | 
         
            +
                    if np.random.rand() < 0.5:
         
     | 
| 460 | 
         
            +
                        img = cv2.resize(img, (int(1/2*img.shape[1]), int(1/2*img.shape[0])), interpolation=random.choice([1,2,3]))
         
     | 
| 461 | 
         
            +
                    else:
         
     | 
| 462 | 
         
            +
                        img = util.imresize_np(img, 1/2, True)
         
     | 
| 463 | 
         
            +
                    img = np.clip(img, 0.0, 1.0)
         
     | 
| 464 | 
         
            +
                    sf = 2
         
     | 
| 465 | 
         
            +
             
     | 
| 466 | 
         
            +
                shuffle_order = random.sample(range(7), 7)
         
     | 
| 467 | 
         
            +
                idx1, idx2 = shuffle_order.index(2), shuffle_order.index(3)
         
     | 
| 468 | 
         
            +
                if idx1 > idx2:  # keep downsample3 last
         
     | 
| 469 | 
         
            +
                    shuffle_order[idx1], shuffle_order[idx2] = shuffle_order[idx2], shuffle_order[idx1]
         
     | 
| 470 | 
         
            +
             
     | 
| 471 | 
         
            +
                for i in shuffle_order:
         
     | 
| 472 | 
         
            +
             
     | 
| 473 | 
         
            +
                    if i == 0:
         
     | 
| 474 | 
         
            +
                        img = add_blur(img, sf=sf)
         
     | 
| 475 | 
         
            +
             
     | 
| 476 | 
         
            +
                    elif i == 1:
         
     | 
| 477 | 
         
            +
                        img = add_blur(img, sf=sf)
         
     | 
| 478 | 
         
            +
             
     | 
| 479 | 
         
            +
                    elif i == 2:
         
     | 
| 480 | 
         
            +
                        a, b = img.shape[1], img.shape[0]
         
     | 
| 481 | 
         
            +
                        # downsample2
         
     | 
| 482 | 
         
            +
                        if random.random() < 0.75:
         
     | 
| 483 | 
         
            +
                            sf1 = random.uniform(1,2*sf)
         
     | 
| 484 | 
         
            +
                            img = cv2.resize(img, (int(1/sf1*img.shape[1]), int(1/sf1*img.shape[0])), interpolation=random.choice([1,2,3]))
         
     | 
| 485 | 
         
            +
                        else:
         
     | 
| 486 | 
         
            +
                            k = fspecial('gaussian', 25, random.uniform(0.1, 0.6*sf))
         
     | 
| 487 | 
         
            +
                            k_shifted = shift_pixel(k, sf)
         
     | 
| 488 | 
         
            +
                            k_shifted = k_shifted/k_shifted.sum()  # blur with shifted kernel
         
     | 
| 489 | 
         
            +
                            img = ndimage.filters.convolve(img, np.expand_dims(k_shifted, axis=2), mode='mirror')
         
     | 
| 490 | 
         
            +
                            img = img[0::sf, 0::sf, ...]  # nearest downsampling
         
     | 
| 491 | 
         
            +
                        img = np.clip(img, 0.0, 1.0)
         
     | 
| 492 | 
         
            +
             
     | 
| 493 | 
         
            +
                    elif i == 3:
         
     | 
| 494 | 
         
            +
                        # downsample3
         
     | 
| 495 | 
         
            +
                        img = cv2.resize(img, (int(1/sf*a), int(1/sf*b)), interpolation=random.choice([1,2,3]))
         
     | 
| 496 | 
         
            +
                        img = np.clip(img, 0.0, 1.0)
         
     | 
| 497 | 
         
            +
             
     | 
| 498 | 
         
            +
                    elif i == 4:
         
     | 
| 499 | 
         
            +
                        # add Gaussian noise
         
     | 
| 500 | 
         
            +
                        img = add_Gaussian_noise(img, noise_level1=2, noise_level2=25)
         
     | 
| 501 | 
         
            +
             
     | 
| 502 | 
         
            +
                    elif i == 5:
         
     | 
| 503 | 
         
            +
                        # add JPEG noise
         
     | 
| 504 | 
         
            +
                        if random.random() < jpeg_prob:
         
     | 
| 505 | 
         
            +
                            img = add_JPEG_noise(img)
         
     | 
| 506 | 
         
            +
             
     | 
| 507 | 
         
            +
                    elif i == 6:
         
     | 
| 508 | 
         
            +
                        # add processed camera sensor noise
         
     | 
| 509 | 
         
            +
                        if random.random() < isp_prob and isp_model is not None:
         
     | 
| 510 | 
         
            +
                            with torch.no_grad():
         
     | 
| 511 | 
         
            +
                                img, hq = isp_model.forward(img.copy(), hq)
         
     | 
| 512 | 
         
            +
             
     | 
| 513 | 
         
            +
                # add final JPEG compression noise
         
     | 
| 514 | 
         
            +
                img = add_JPEG_noise(img)
         
     | 
| 515 | 
         
            +
             
     | 
| 516 | 
         
            +
                # random crop
         
     | 
| 517 | 
         
            +
                img, hq = random_crop(img, hq, sf_ori, lq_patchsize)
         
     | 
| 518 | 
         
            +
             
     | 
| 519 | 
         
            +
                return img, hq
         
     | 
| 520 | 
         
            +
             
     | 
| 521 | 
         
            +
             
     | 
| 522 | 
         
            +
             
     | 
| 523 | 
         
            +
             
     | 
| 524 | 
         
            +
            def degradation_bsrgan_plus(img, sf=4, shuffle_prob=0.5, use_sharp=False, lq_patchsize=64, isp_model=None):
         
     | 
| 525 | 
         
            +
                """
         
     | 
| 526 | 
         
            +
                This is an extended degradation model by combining
         
     | 
| 527 | 
         
            +
                the degradation models of BSRGAN and Real-ESRGAN
         
     | 
| 528 | 
         
            +
                ----------
         
     | 
| 529 | 
         
            +
                img: HXWXC, [0, 1], its size should be large than (lq_patchsizexsf)x(lq_patchsizexsf)
         
     | 
| 530 | 
         
            +
                sf: scale factor
         
     | 
| 531 | 
         
            +
                use_shuffle: the degradation shuffle
         
     | 
| 532 | 
         
            +
                use_sharp: sharpening the img
         
     | 
| 533 | 
         
            +
             
     | 
| 534 | 
         
            +
                Returns
         
     | 
| 535 | 
         
            +
                -------
         
     | 
| 536 | 
         
            +
                img: low-quality patch, size: lq_patchsizeXlq_patchsizeXC, range: [0, 1]
         
     | 
| 537 | 
         
            +
                hq: corresponding high-quality patch, size: (lq_patchsizexsf)X(lq_patchsizexsf)XC, range: [0, 1]
         
     | 
| 538 | 
         
            +
                """
         
     | 
| 539 | 
         
            +
             
     | 
| 540 | 
         
            +
                h1, w1 = img.shape[:2]
         
     | 
| 541 | 
         
            +
                img = img.copy()[:h1 - h1 % sf, :w1 - w1 % sf, ...]  # mod crop
         
     | 
| 542 | 
         
            +
                h, w = img.shape[:2]
         
     | 
| 543 | 
         
            +
             
     | 
| 544 | 
         
            +
                if h < lq_patchsize*sf or w < lq_patchsize*sf:
         
     | 
| 545 | 
         
            +
                    raise ValueError(f'img size ({h1}X{w1}) is too small!')
         
     | 
| 546 | 
         
            +
             
     | 
| 547 | 
         
            +
                if use_sharp:
         
     | 
| 548 | 
         
            +
                    img = add_sharpening(img)
         
     | 
| 549 | 
         
            +
                hq = img.copy()
         
     | 
| 550 | 
         
            +
             
     | 
| 551 | 
         
            +
                if random.random() < shuffle_prob:
         
     | 
| 552 | 
         
            +
                    shuffle_order = random.sample(range(13), 13)
         
     | 
| 553 | 
         
            +
                else:
         
     | 
| 554 | 
         
            +
                    shuffle_order = list(range(13))
         
     | 
| 555 | 
         
            +
                    # local shuffle for noise, JPEG is always the last one
         
     | 
| 556 | 
         
            +
                    shuffle_order[2:6] = random.sample(shuffle_order[2:6], len(range(2, 6)))
         
     | 
| 557 | 
         
            +
                    shuffle_order[9:13] = random.sample(shuffle_order[9:13], len(range(9, 13)))
         
     | 
| 558 | 
         
            +
             
     | 
| 559 | 
         
            +
                poisson_prob, speckle_prob, isp_prob = 0.1, 0.1, 0.1
         
     | 
| 560 | 
         
            +
             
     | 
| 561 | 
         
            +
                for i in shuffle_order:
         
     | 
| 562 | 
         
            +
                    if i == 0:
         
     | 
| 563 | 
         
            +
                        img = add_blur(img, sf=sf)
         
     | 
| 564 | 
         
            +
                    elif i == 1:
         
     | 
| 565 | 
         
            +
                        img = add_resize(img, sf=sf)
         
     | 
| 566 | 
         
            +
                    elif i == 2:
         
     | 
| 567 | 
         
            +
                        img = add_Gaussian_noise(img, noise_level1=2, noise_level2=25)
         
     | 
| 568 | 
         
            +
                    elif i == 3:
         
     | 
| 569 | 
         
            +
                        if random.random() < poisson_prob:
         
     | 
| 570 | 
         
            +
                            img = add_Poisson_noise(img)
         
     | 
| 571 | 
         
            +
                    elif i == 4:
         
     | 
| 572 | 
         
            +
                        if random.random() < speckle_prob:
         
     | 
| 573 | 
         
            +
                            img = add_speckle_noise(img)
         
     | 
| 574 | 
         
            +
                    elif i == 5:
         
     | 
| 575 | 
         
            +
                        if random.random() < isp_prob and isp_model is not None:
         
     | 
| 576 | 
         
            +
                            with torch.no_grad():
         
     | 
| 577 | 
         
            +
                                img, hq = isp_model.forward(img.copy(), hq)
         
     | 
| 578 | 
         
            +
                    elif i == 6:
         
     | 
| 579 | 
         
            +
                        img = add_JPEG_noise(img)
         
     | 
| 580 | 
         
            +
                    elif i == 7:
         
     | 
| 581 | 
         
            +
                        img = add_blur(img, sf=sf)
         
     | 
| 582 | 
         
            +
                    elif i == 8:
         
     | 
| 583 | 
         
            +
                        img = add_resize(img, sf=sf)
         
     | 
| 584 | 
         
            +
                    elif i == 9:
         
     | 
| 585 | 
         
            +
                        img = add_Gaussian_noise(img, noise_level1=2, noise_level2=25)
         
     | 
| 586 | 
         
            +
                    elif i == 10:
         
     | 
| 587 | 
         
            +
                        if random.random() < poisson_prob:
         
     | 
| 588 | 
         
            +
                            img = add_Poisson_noise(img)
         
     | 
| 589 | 
         
            +
                    elif i == 11:
         
     | 
| 590 | 
         
            +
                        if random.random() < speckle_prob:
         
     | 
| 591 | 
         
            +
                            img = add_speckle_noise(img)
         
     | 
| 592 | 
         
            +
                    elif i == 12:
         
     | 
| 593 | 
         
            +
                        if random.random() < isp_prob and isp_model is not None:
         
     | 
| 594 | 
         
            +
                            with torch.no_grad():
         
     | 
| 595 | 
         
            +
                                img, hq = isp_model.forward(img.copy(), hq)
         
     | 
| 596 | 
         
            +
                    else:
         
     | 
| 597 | 
         
            +
                        print('check the shuffle!')
         
     | 
| 598 | 
         
            +
             
     | 
| 599 | 
         
            +
                # resize to desired size
         
     | 
| 600 | 
         
            +
                img = cv2.resize(img, (int(1/sf*hq.shape[1]), int(1/sf*hq.shape[0])), interpolation=random.choice([1, 2, 3]))
         
     | 
| 601 | 
         
            +
             
     | 
| 602 | 
         
            +
                # add final JPEG compression noise
         
     | 
| 603 | 
         
            +
                img = add_JPEG_noise(img)
         
     | 
| 604 | 
         
            +
             
     | 
| 605 | 
         
            +
                # random crop
         
     | 
| 606 | 
         
            +
                img, hq = random_crop(img, hq, sf, lq_patchsize)
         
     | 
| 607 | 
         
            +
             
     | 
| 608 | 
         
            +
                return img, hq
         
     | 
| 609 | 
         
            +
             
     | 
| 610 | 
         
            +
             
     | 
| 611 | 
         
            +
             
     | 
| 612 | 
         
            +
            if __name__ == '__main__':
         
     | 
| 613 | 
         
            +
                img = util.imread_uint('utils/test.png', 3)
         
     | 
| 614 | 
         
            +
                img = util.uint2single(img)
         
     | 
| 615 | 
         
            +
                sf = 4
         
     | 
| 616 | 
         
            +
                
         
     | 
| 617 | 
         
            +
                for i in range(20):
         
     | 
| 618 | 
         
            +
                    img_lq, img_hq = degradation_bsrgan(img, sf=sf, lq_patchsize=72)
         
     | 
| 619 | 
         
            +
                    print(i)
         
     | 
| 620 | 
         
            +
                    lq_nearest =  cv2.resize(util.single2uint(img_lq), (int(sf*img_lq.shape[1]), int(sf*img_lq.shape[0])), interpolation=0)
         
     | 
| 621 | 
         
            +
                    img_concat = np.concatenate([lq_nearest, util.single2uint(img_hq)], axis=1)
         
     | 
| 622 | 
         
            +
                    util.imsave(img_concat, str(i)+'.png')
         
     | 
| 623 | 
         
            +
             
     | 
| 624 | 
         
            +
            #    for i in range(10):
         
     | 
| 625 | 
         
            +
            #        img_lq, img_hq = degradation_bsrgan_plus(img, sf=sf, shuffle_prob=0.1, use_sharp=True, lq_patchsize=64)
         
     | 
| 626 | 
         
            +
            #        print(i)
         
     | 
| 627 | 
         
            +
            #        lq_nearest =  cv2.resize(util.single2uint(img_lq), (int(sf*img_lq.shape[1]), int(sf*img_lq.shape[0])), interpolation=0)
         
     | 
| 628 | 
         
            +
            #        img_concat = np.concatenate([lq_nearest, util.single2uint(img_hq)], axis=1)
         
     | 
| 629 | 
         
            +
            #        util.imsave(img_concat, str(i)+'.png')
         
     | 
| 630 | 
         
            +
             
     | 
| 631 | 
         
            +
            #    run utils/utils_blindsr.py
         
     | 
    	
        core/data/deg_kair_utils/utils_bnorm.py
    ADDED
    
    | 
         @@ -0,0 +1,91 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            import torch
         
     | 
| 2 | 
         
            +
            import torch.nn as nn
         
     | 
| 3 | 
         
            +
             
     | 
| 4 | 
         
            +
             
     | 
| 5 | 
         
            +
            """
         
     | 
| 6 | 
         
            +
            # --------------------------------------------
         
     | 
| 7 | 
         
            +
            # Batch Normalization
         
     | 
| 8 | 
         
            +
            # --------------------------------------------
         
     | 
| 9 | 
         
            +
             
     | 
| 10 | 
         
            +
            # Kai Zhang ([email protected])
         
     | 
| 11 | 
         
            +
            # https://github.com/cszn
         
     | 
| 12 | 
         
            +
            # 01/Jan/2019
         
     | 
| 13 | 
         
            +
            # --------------------------------------------
         
     | 
| 14 | 
         
            +
            """
         
     | 
| 15 | 
         
            +
             
     | 
| 16 | 
         
            +
             
     | 
| 17 | 
         
            +
            # --------------------------------------------
         
     | 
| 18 | 
         
            +
            # remove/delete specified layer
         
     | 
| 19 | 
         
            +
            # --------------------------------------------
         
     | 
| 20 | 
         
            +
            def deleteLayer(model, layer_type=nn.BatchNorm2d):
         
     | 
| 21 | 
         
            +
                ''' Kai Zhang, 11/Jan/2019.
         
     | 
| 22 | 
         
            +
                '''
         
     | 
| 23 | 
         
            +
                for k, m in list(model.named_children()):
         
     | 
| 24 | 
         
            +
                    if isinstance(m, layer_type):
         
     | 
| 25 | 
         
            +
                        del model._modules[k]
         
     | 
| 26 | 
         
            +
                    deleteLayer(m, layer_type)
         
     | 
| 27 | 
         
            +
             
     | 
| 28 | 
         
            +
             
     | 
| 29 | 
         
            +
            # --------------------------------------------
         
     | 
| 30 | 
         
            +
            # merge bn, "conv+bn" --> "conv"
         
     | 
| 31 | 
         
            +
            # --------------------------------------------
         
     | 
| 32 | 
         
            +
            def merge_bn(model):
         
     | 
| 33 | 
         
            +
                ''' Kai Zhang, 11/Jan/2019.
         
     | 
| 34 | 
         
            +
                merge all 'Conv+BN' (or 'TConv+BN') into 'Conv' (or 'TConv')
         
     | 
| 35 | 
         
            +
                based on https://github.com/pytorch/pytorch/pull/901
         
     | 
| 36 | 
         
            +
                '''
         
     | 
| 37 | 
         
            +
                prev_m = None
         
     | 
| 38 | 
         
            +
                for k, m in list(model.named_children()):
         
     | 
| 39 | 
         
            +
                    if (isinstance(m, nn.BatchNorm2d) or isinstance(m, nn.BatchNorm1d)) and (isinstance(prev_m, nn.Conv2d) or isinstance(prev_m, nn.Linear) or isinstance(prev_m, nn.ConvTranspose2d)):
         
     | 
| 40 | 
         
            +
             
     | 
| 41 | 
         
            +
                        w = prev_m.weight.data
         
     | 
| 42 | 
         
            +
             
     | 
| 43 | 
         
            +
                        if prev_m.bias is None:
         
     | 
| 44 | 
         
            +
                            zeros = torch.Tensor(prev_m.out_channels).zero_().type(w.type())
         
     | 
| 45 | 
         
            +
                            prev_m.bias = nn.Parameter(zeros)
         
     | 
| 46 | 
         
            +
                        b = prev_m.bias.data
         
     | 
| 47 | 
         
            +
             
     | 
| 48 | 
         
            +
                        invstd = m.running_var.clone().add_(m.eps).pow_(-0.5)
         
     | 
| 49 | 
         
            +
                        if isinstance(prev_m, nn.ConvTranspose2d):
         
     | 
| 50 | 
         
            +
                            w.mul_(invstd.view(1, w.size(1), 1, 1).expand_as(w))
         
     | 
| 51 | 
         
            +
                        else:
         
     | 
| 52 | 
         
            +
                            w.mul_(invstd.view(w.size(0), 1, 1, 1).expand_as(w))
         
     | 
| 53 | 
         
            +
                        b.add_(-m.running_mean).mul_(invstd)
         
     | 
| 54 | 
         
            +
                        if m.affine:
         
     | 
| 55 | 
         
            +
                            if isinstance(prev_m, nn.ConvTranspose2d):
         
     | 
| 56 | 
         
            +
                                w.mul_(m.weight.data.view(1, w.size(1), 1, 1).expand_as(w))
         
     | 
| 57 | 
         
            +
                            else:
         
     | 
| 58 | 
         
            +
                                w.mul_(m.weight.data.view(w.size(0), 1, 1, 1).expand_as(w))
         
     | 
| 59 | 
         
            +
                            b.mul_(m.weight.data).add_(m.bias.data)
         
     | 
| 60 | 
         
            +
             
     | 
| 61 | 
         
            +
                        del model._modules[k]
         
     | 
| 62 | 
         
            +
                    prev_m = m
         
     | 
| 63 | 
         
            +
                    merge_bn(m)
         
     | 
| 64 | 
         
            +
             
     | 
| 65 | 
         
            +
             
     | 
| 66 | 
         
            +
            # --------------------------------------------
         
     | 
| 67 | 
         
            +
            # add bn, "conv" --> "conv+bn"
         
     | 
| 68 | 
         
            +
            # --------------------------------------------
         
     | 
| 69 | 
         
            +
            def add_bn(model):
         
     | 
| 70 | 
         
            +
                ''' Kai Zhang, 11/Jan/2019.
         
     | 
| 71 | 
         
            +
                '''
         
     | 
| 72 | 
         
            +
                for k, m in list(model.named_children()):
         
     | 
| 73 | 
         
            +
                    if (isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear) or isinstance(m, nn.ConvTranspose2d)):
         
     | 
| 74 | 
         
            +
                        b = nn.BatchNorm2d(m.out_channels, momentum=0.1, affine=True)
         
     | 
| 75 | 
         
            +
                        b.weight.data.fill_(1)
         
     | 
| 76 | 
         
            +
                        new_m = nn.Sequential(model._modules[k], b)
         
     | 
| 77 | 
         
            +
                        model._modules[k] = new_m
         
     | 
| 78 | 
         
            +
                    add_bn(m)
         
     | 
| 79 | 
         
            +
             
     | 
| 80 | 
         
            +
             
     | 
| 81 | 
         
            +
            # --------------------------------------------
         
     | 
| 82 | 
         
            +
            # tidy model after removing bn
         
     | 
| 83 | 
         
            +
            # --------------------------------------------
         
     | 
| 84 | 
         
            +
            def tidy_sequential(model):
         
     | 
| 85 | 
         
            +
                ''' Kai Zhang, 11/Jan/2019.
         
     | 
| 86 | 
         
            +
                '''
         
     | 
| 87 | 
         
            +
                for k, m in list(model.named_children()):
         
     | 
| 88 | 
         
            +
                    if isinstance(m, nn.Sequential):
         
     | 
| 89 | 
         
            +
                        if m.__len__() == 1:
         
     | 
| 90 | 
         
            +
                            model._modules[k] = m.__getitem__(0)
         
     | 
| 91 | 
         
            +
                    tidy_sequential(m)
         
     | 
    	
        core/data/deg_kair_utils/utils_deblur.py
    ADDED
    
    | 
         @@ -0,0 +1,655 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            # -*- coding: utf-8 -*-
         
     | 
| 2 | 
         
            +
            import numpy as np
         
     | 
| 3 | 
         
            +
            import scipy
         
     | 
| 4 | 
         
            +
            from scipy import fftpack
         
     | 
| 5 | 
         
            +
            import torch
         
     | 
| 6 | 
         
            +
             
     | 
| 7 | 
         
            +
            from math import cos, sin
         
     | 
| 8 | 
         
            +
            from numpy import zeros, ones, prod, array, pi, log, min, mod, arange, sum, mgrid, exp, pad, round
         
     | 
| 9 | 
         
            +
            from numpy.random import randn, rand
         
     | 
| 10 | 
         
            +
            from scipy.signal import convolve2d
         
     | 
| 11 | 
         
            +
            import cv2
         
     | 
| 12 | 
         
            +
            import random
         
     | 
| 13 | 
         
            +
            # import utils_image as util
         
     | 
| 14 | 
         
            +
             
     | 
| 15 | 
         
            +
            '''
         
     | 
| 16 | 
         
            +
            modified by Kai Zhang (github: https://github.com/cszn)
         
     | 
| 17 | 
         
            +
            03/03/2019
         
     | 
| 18 | 
         
            +
            '''
         
     | 
| 19 | 
         
            +
             
     | 
| 20 | 
         
            +
             
     | 
| 21 | 
         
            +
            def get_uperleft_denominator(img, kernel):
         
     | 
| 22 | 
         
            +
                '''
         
     | 
| 23 | 
         
            +
                img: HxWxC
         
     | 
| 24 | 
         
            +
                kernel: hxw
         
     | 
| 25 | 
         
            +
                denominator: HxWx1
         
     | 
| 26 | 
         
            +
                upperleft: HxWxC
         
     | 
| 27 | 
         
            +
                '''
         
     | 
| 28 | 
         
            +
                V = psf2otf(kernel, img.shape[:2])
         
     | 
| 29 | 
         
            +
                denominator = np.expand_dims(np.abs(V)**2, axis=2)
         
     | 
| 30 | 
         
            +
                upperleft = np.expand_dims(np.conj(V), axis=2) * np.fft.fft2(img, axes=[0, 1])
         
     | 
| 31 | 
         
            +
                return upperleft, denominator
         
     | 
| 32 | 
         
            +
             
     | 
| 33 | 
         
            +
             
     | 
| 34 | 
         
            +
            def get_uperleft_denominator_pytorch(img, kernel):
         
     | 
| 35 | 
         
            +
                '''
         
     | 
| 36 | 
         
            +
                img: NxCxHxW
         
     | 
| 37 | 
         
            +
                kernel: Nx1xhxw
         
     | 
| 38 | 
         
            +
                denominator: Nx1xHxW
         
     | 
| 39 | 
         
            +
                upperleft: NxCxHxWx2
         
     | 
| 40 | 
         
            +
                '''
         
     | 
| 41 | 
         
            +
                V = p2o(kernel, img.shape[-2:])  # Nx1xHxWx2
         
     | 
| 42 | 
         
            +
                denominator = V[..., 0]**2+V[..., 1]**2  # Nx1xHxW
         
     | 
| 43 | 
         
            +
                upperleft = cmul(cconj(V), rfft(img))  # Nx1xHxWx2 * NxCxHxWx2
         
     | 
| 44 | 
         
            +
                return upperleft, denominator
         
     | 
| 45 | 
         
            +
             
     | 
| 46 | 
         
            +
             
     | 
| 47 | 
         
            +
            def c2c(x):
         
     | 
| 48 | 
         
            +
                return torch.from_numpy(np.stack([np.float32(x.real), np.float32(x.imag)], axis=-1))
         
     | 
| 49 | 
         
            +
             
     | 
| 50 | 
         
            +
             
     | 
| 51 | 
         
            +
            def r2c(x):
         
     | 
| 52 | 
         
            +
                return torch.stack([x, torch.zeros_like(x)], -1)
         
     | 
| 53 | 
         
            +
             
     | 
| 54 | 
         
            +
             
     | 
| 55 | 
         
            +
            def cdiv(x, y):
         
     | 
| 56 | 
         
            +
                a, b = x[..., 0], x[..., 1]
         
     | 
| 57 | 
         
            +
                c, d = y[..., 0], y[..., 1]
         
     | 
| 58 | 
         
            +
                cd2 = c**2 + d**2
         
     | 
| 59 | 
         
            +
                return torch.stack([(a*c+b*d)/cd2, (b*c-a*d)/cd2], -1)
         
     | 
| 60 | 
         
            +
             
     | 
| 61 | 
         
            +
             
     | 
| 62 | 
         
            +
            def cabs(x):
         
     | 
| 63 | 
         
            +
                return torch.pow(x[..., 0]**2+x[..., 1]**2, 0.5)
         
     | 
| 64 | 
         
            +
             
     | 
| 65 | 
         
            +
             
     | 
| 66 | 
         
            +
            def cmul(t1, t2):
         
     | 
| 67 | 
         
            +
                '''
         
     | 
| 68 | 
         
            +
                complex multiplication
         
     | 
| 69 | 
         
            +
                t1: NxCxHxWx2
         
     | 
| 70 | 
         
            +
                output: NxCxHxWx2
         
     | 
| 71 | 
         
            +
                '''
         
     | 
| 72 | 
         
            +
                real1, imag1 = t1[..., 0], t1[..., 1]
         
     | 
| 73 | 
         
            +
                real2, imag2 = t2[..., 0], t2[..., 1]
         
     | 
| 74 | 
         
            +
                return torch.stack([real1 * real2 - imag1 * imag2, real1 * imag2 + imag1 * real2], dim=-1)
         
     | 
| 75 | 
         
            +
             
     | 
| 76 | 
         
            +
             
     | 
| 77 | 
         
            +
            def cconj(t, inplace=False):
         
     | 
| 78 | 
         
            +
                '''
         
     | 
| 79 | 
         
            +
                # complex's conjugation
         
     | 
| 80 | 
         
            +
                t: NxCxHxWx2
         
     | 
| 81 | 
         
            +
                output: NxCxHxWx2
         
     | 
| 82 | 
         
            +
                '''
         
     | 
| 83 | 
         
            +
                c = t.clone() if not inplace else t
         
     | 
| 84 | 
         
            +
                c[..., 1] *= -1
         
     | 
| 85 | 
         
            +
                return c
         
     | 
| 86 | 
         
            +
             
     | 
| 87 | 
         
            +
             
     | 
| 88 | 
         
            +
            def rfft(t):
         
     | 
| 89 | 
         
            +
                return torch.rfft(t, 2, onesided=False)
         
     | 
| 90 | 
         
            +
             
     | 
| 91 | 
         
            +
             
     | 
| 92 | 
         
            +
            def irfft(t):
         
     | 
| 93 | 
         
            +
                return torch.irfft(t, 2, onesided=False)
         
     | 
| 94 | 
         
            +
             
     | 
| 95 | 
         
            +
             
     | 
| 96 | 
         
            +
            def fft(t):
         
     | 
| 97 | 
         
            +
                return torch.fft(t, 2)
         
     | 
| 98 | 
         
            +
             
     | 
| 99 | 
         
            +
             
     | 
| 100 | 
         
            +
            def ifft(t):
         
     | 
| 101 | 
         
            +
                return torch.ifft(t, 2)
         
     | 
| 102 | 
         
            +
             
     | 
| 103 | 
         
            +
             
     | 
| 104 | 
         
            +
            def p2o(psf, shape):
         
     | 
| 105 | 
         
            +
                '''
         
     | 
| 106 | 
         
            +
                # psf: NxCxhxw
         
     | 
| 107 | 
         
            +
                # shape: [H,W]
         
     | 
| 108 | 
         
            +
                # otf: NxCxHxWx2
         
     | 
| 109 | 
         
            +
                '''
         
     | 
| 110 | 
         
            +
                otf = torch.zeros(psf.shape[:-2] + shape).type_as(psf)
         
     | 
| 111 | 
         
            +
                otf[...,:psf.shape[2],:psf.shape[3]].copy_(psf)
         
     | 
| 112 | 
         
            +
                for axis, axis_size in enumerate(psf.shape[2:]):
         
     | 
| 113 | 
         
            +
                    otf = torch.roll(otf, -int(axis_size / 2), dims=axis+2)
         
     | 
| 114 | 
         
            +
                otf = torch.rfft(otf, 2, onesided=False)
         
     | 
| 115 | 
         
            +
                n_ops = torch.sum(torch.tensor(psf.shape).type_as(psf) * torch.log2(torch.tensor(psf.shape).type_as(psf)))
         
     | 
| 116 | 
         
            +
                otf[...,1][torch.abs(otf[...,1])<n_ops*2.22e-16] = torch.tensor(0).type_as(psf)
         
     | 
| 117 | 
         
            +
                return otf
         
     | 
| 118 | 
         
            +
             
     | 
| 119 | 
         
            +
             
     | 
| 120 | 
         
            +
             
     | 
| 121 | 
         
            +
            # otf2psf: not sure where I got this one from. Maybe translated from Octave source code or whatever. It's just math.
         
     | 
| 122 | 
         
            +
            def otf2psf(otf, outsize=None):
         
     | 
| 123 | 
         
            +
                insize = np.array(otf.shape)
         
     | 
| 124 | 
         
            +
                psf = np.fft.ifftn(otf, axes=(0, 1))
         
     | 
| 125 | 
         
            +
                for axis, axis_size in enumerate(insize):
         
     | 
| 126 | 
         
            +
                    psf = np.roll(psf, np.floor(axis_size / 2).astype(int), axis=axis)
         
     | 
| 127 | 
         
            +
                if type(outsize) != type(None):
         
     | 
| 128 | 
         
            +
                    insize = np.array(otf.shape)
         
     | 
| 129 | 
         
            +
                    outsize = np.array(outsize)
         
     | 
| 130 | 
         
            +
                    n = max(np.size(outsize), np.size(insize))
         
     | 
| 131 | 
         
            +
                    # outsize = postpad(outsize(:), n, 1);
         
     | 
| 132 | 
         
            +
                    # insize = postpad(insize(:) , n, 1);
         
     | 
| 133 | 
         
            +
                    colvec_out = outsize.flatten().reshape((np.size(outsize), 1))
         
     | 
| 134 | 
         
            +
                    colvec_in = insize.flatten().reshape((np.size(insize), 1))
         
     | 
| 135 | 
         
            +
                    outsize = np.pad(colvec_out, ((0, max(0, n - np.size(colvec_out))), (0, 0)), mode="constant")
         
     | 
| 136 | 
         
            +
                    insize = np.pad(colvec_in, ((0, max(0, n - np.size(colvec_in))), (0, 0)), mode="constant")
         
     | 
| 137 | 
         
            +
             
     | 
| 138 | 
         
            +
                    pad = (insize - outsize) / 2
         
     | 
| 139 | 
         
            +
                    if np.any(pad < 0):
         
     | 
| 140 | 
         
            +
                        print("otf2psf error: OUTSIZE must be smaller than or equal than OTF size")
         
     | 
| 141 | 
         
            +
                    prepad = np.floor(pad)
         
     | 
| 142 | 
         
            +
                    postpad = np.ceil(pad)
         
     | 
| 143 | 
         
            +
                    dims_start = prepad.astype(int)
         
     | 
| 144 | 
         
            +
                    dims_end = (insize - postpad).astype(int)
         
     | 
| 145 | 
         
            +
                    for i in range(len(dims_start.shape)):
         
     | 
| 146 | 
         
            +
                        psf = np.take(psf, range(dims_start[i][0], dims_end[i][0]), axis=i)
         
     | 
| 147 | 
         
            +
                n_ops = np.sum(otf.size * np.log2(otf.shape))
         
     | 
| 148 | 
         
            +
                psf = np.real_if_close(psf, tol=n_ops)
         
     | 
| 149 | 
         
            +
                return psf
         
     | 
| 150 | 
         
            +
             
     | 
| 151 | 
         
            +
             
     | 
| 152 | 
         
            +
            # psf2otf copied/modified from https://github.com/aboucaud/pypher/blob/master/pypher/pypher.py
         
     | 
| 153 | 
         
            +
            def psf2otf(psf, shape=None):
         
     | 
| 154 | 
         
            +
                """
         
     | 
| 155 | 
         
            +
                Convert point-spread function to optical transfer function.
         
     | 
| 156 | 
         
            +
                Compute the Fast Fourier Transform (FFT) of the point-spread
         
     | 
| 157 | 
         
            +
                function (PSF) array and creates the optical transfer function (OTF)
         
     | 
| 158 | 
         
            +
                array that is not influenced by the PSF off-centering.
         
     | 
| 159 | 
         
            +
                By default, the OTF array is the same size as the PSF array.
         
     | 
| 160 | 
         
            +
                To ensure that the OTF is not altered due to PSF off-centering, PSF2OTF
         
     | 
| 161 | 
         
            +
                post-pads the PSF array (down or to the right) with zeros to match
         
     | 
| 162 | 
         
            +
                dimensions specified in OUTSIZE, then circularly shifts the values of
         
     | 
| 163 | 
         
            +
                the PSF array up (or to the left) until the central pixel reaches (1,1)
         
     | 
| 164 | 
         
            +
                position.
         
     | 
| 165 | 
         
            +
                Parameters
         
     | 
| 166 | 
         
            +
                ----------
         
     | 
| 167 | 
         
            +
                psf : `numpy.ndarray`
         
     | 
| 168 | 
         
            +
                    PSF array
         
     | 
| 169 | 
         
            +
                shape : int
         
     | 
| 170 | 
         
            +
                    Output shape of the OTF array
         
     | 
| 171 | 
         
            +
                Returns
         
     | 
| 172 | 
         
            +
                -------
         
     | 
| 173 | 
         
            +
                otf : `numpy.ndarray`
         
     | 
| 174 | 
         
            +
                    OTF array
         
     | 
| 175 | 
         
            +
                Notes
         
     | 
| 176 | 
         
            +
                -----
         
     | 
| 177 | 
         
            +
                Adapted from MATLAB psf2otf function
         
     | 
| 178 | 
         
            +
                """
         
     | 
| 179 | 
         
            +
                if type(shape) == type(None):
         
     | 
| 180 | 
         
            +
                    shape = psf.shape
         
     | 
| 181 | 
         
            +
                shape = np.array(shape)
         
     | 
| 182 | 
         
            +
                if np.all(psf == 0):
         
     | 
| 183 | 
         
            +
                    # return np.zeros_like(psf)
         
     | 
| 184 | 
         
            +
                    return np.zeros(shape)
         
     | 
| 185 | 
         
            +
                if len(psf.shape) == 1:
         
     | 
| 186 | 
         
            +
                    psf = psf.reshape((1, psf.shape[0]))
         
     | 
| 187 | 
         
            +
                inshape = psf.shape
         
     | 
| 188 | 
         
            +
                psf = zero_pad(psf, shape, position='corner')
         
     | 
| 189 | 
         
            +
                for axis, axis_size in enumerate(inshape):
         
     | 
| 190 | 
         
            +
                    psf = np.roll(psf, -int(axis_size / 2), axis=axis)
         
     | 
| 191 | 
         
            +
                # Compute the OTF
         
     | 
| 192 | 
         
            +
                otf = np.fft.fft2(psf, axes=(0, 1))
         
     | 
| 193 | 
         
            +
                # Estimate the rough number of operations involved in the FFT
         
     | 
| 194 | 
         
            +
                # and discard the PSF imaginary part if within roundoff error
         
     | 
| 195 | 
         
            +
                # roundoff error  = machine epsilon = sys.float_info.epsilon
         
     | 
| 196 | 
         
            +
                # or np.finfo().eps
         
     | 
| 197 | 
         
            +
                n_ops = np.sum(psf.size * np.log2(psf.shape))
         
     | 
| 198 | 
         
            +
                otf = np.real_if_close(otf, tol=n_ops)
         
     | 
| 199 | 
         
            +
                return otf
         
     | 
| 200 | 
         
            +
             
     | 
| 201 | 
         
            +
             
     | 
| 202 | 
         
            +
            def zero_pad(image, shape, position='corner'):
         
     | 
| 203 | 
         
            +
                """
         
     | 
| 204 | 
         
            +
                Extends image to a certain size with zeros
         
     | 
| 205 | 
         
            +
                Parameters
         
     | 
| 206 | 
         
            +
                ----------
         
     | 
| 207 | 
         
            +
                image: real 2d `numpy.ndarray`
         
     | 
| 208 | 
         
            +
                    Input image
         
     | 
| 209 | 
         
            +
                shape: tuple of int
         
     | 
| 210 | 
         
            +
                    Desired output shape of the image
         
     | 
| 211 | 
         
            +
                position : str, optional
         
     | 
| 212 | 
         
            +
                    The position of the input image in the output one:
         
     | 
| 213 | 
         
            +
                        * 'corner'
         
     | 
| 214 | 
         
            +
                            top-left corner (default)
         
     | 
| 215 | 
         
            +
                        * 'center'
         
     | 
| 216 | 
         
            +
                            centered
         
     | 
| 217 | 
         
            +
                Returns
         
     | 
| 218 | 
         
            +
                -------
         
     | 
| 219 | 
         
            +
                padded_img: real `numpy.ndarray`
         
     | 
| 220 | 
         
            +
                    The zero-padded image
         
     | 
| 221 | 
         
            +
                """
         
     | 
| 222 | 
         
            +
                shape = np.asarray(shape, dtype=int)
         
     | 
| 223 | 
         
            +
                imshape = np.asarray(image.shape, dtype=int)
         
     | 
| 224 | 
         
            +
                if np.alltrue(imshape == shape):
         
     | 
| 225 | 
         
            +
                    return image
         
     | 
| 226 | 
         
            +
                if np.any(shape <= 0):
         
     | 
| 227 | 
         
            +
                    raise ValueError("ZERO_PAD: null or negative shape given")
         
     | 
| 228 | 
         
            +
                dshape = shape - imshape
         
     | 
| 229 | 
         
            +
                if np.any(dshape < 0):
         
     | 
| 230 | 
         
            +
                    raise ValueError("ZERO_PAD: target size smaller than source one")
         
     | 
| 231 | 
         
            +
                pad_img = np.zeros(shape, dtype=image.dtype)
         
     | 
| 232 | 
         
            +
                idx, idy = np.indices(imshape)
         
     | 
| 233 | 
         
            +
                if position == 'center':
         
     | 
| 234 | 
         
            +
                    if np.any(dshape % 2 != 0):
         
     | 
| 235 | 
         
            +
                        raise ValueError("ZERO_PAD: source and target shapes "
         
     | 
| 236 | 
         
            +
                                         "have different parity.")
         
     | 
| 237 | 
         
            +
                    offx, offy = dshape // 2
         
     | 
| 238 | 
         
            +
                else:
         
     | 
| 239 | 
         
            +
                    offx, offy = (0, 0)
         
     | 
| 240 | 
         
            +
                pad_img[idx + offx, idy + offy] = image
         
     | 
| 241 | 
         
            +
                return pad_img
         
     | 
| 242 | 
         
            +
             
     | 
| 243 | 
         
            +
             
     | 
| 244 | 
         
            +
            '''
         
     | 
| 245 | 
         
            +
            Reducing boundary artifacts
         
     | 
| 246 | 
         
            +
            '''
         
     | 
| 247 | 
         
            +
             
     | 
| 248 | 
         
            +
             
     | 
| 249 | 
         
            +
            def opt_fft_size(n):
         
     | 
| 250 | 
         
            +
                '''
         
     | 
| 251 | 
         
            +
                Kai Zhang (github: https://github.com/cszn)
         
     | 
| 252 | 
         
            +
                03/03/2019
         
     | 
| 253 | 
         
            +
                #  opt_fft_size.m
         
     | 
| 254 | 
         
            +
                # compute an optimal data length for Fourier transforms
         
     | 
| 255 | 
         
            +
                # written by Sunghyun Cho ([email protected])
         
     | 
| 256 | 
         
            +
                # persistent opt_fft_size_LUT;
         
     | 
| 257 | 
         
            +
                '''
         
     | 
| 258 | 
         
            +
             
     | 
| 259 | 
         
            +
                LUT_size = 2048
         
     | 
| 260 | 
         
            +
                # print("generate opt_fft_size_LUT")
         
     | 
| 261 | 
         
            +
                opt_fft_size_LUT = np.zeros(LUT_size)
         
     | 
| 262 | 
         
            +
             
     | 
| 263 | 
         
            +
                e2 = 1
         
     | 
| 264 | 
         
            +
                while e2 <= LUT_size:
         
     | 
| 265 | 
         
            +
                    e3 = e2
         
     | 
| 266 | 
         
            +
                    while e3 <= LUT_size:
         
     | 
| 267 | 
         
            +
                        e5 = e3
         
     | 
| 268 | 
         
            +
                        while e5 <= LUT_size:
         
     | 
| 269 | 
         
            +
                            e7 = e5
         
     | 
| 270 | 
         
            +
                            while e7 <= LUT_size:
         
     | 
| 271 | 
         
            +
                                if e7 <= LUT_size:
         
     | 
| 272 | 
         
            +
                                    opt_fft_size_LUT[e7-1] = e7
         
     | 
| 273 | 
         
            +
                                if e7*11 <= LUT_size:
         
     | 
| 274 | 
         
            +
                                    opt_fft_size_LUT[e7*11-1] = e7*11
         
     | 
| 275 | 
         
            +
                                if e7*13 <= LUT_size:
         
     | 
| 276 | 
         
            +
                                    opt_fft_size_LUT[e7*13-1] = e7*13
         
     | 
| 277 | 
         
            +
                                e7 = e7 * 7
         
     | 
| 278 | 
         
            +
                            e5 = e5 * 5
         
     | 
| 279 | 
         
            +
                        e3 = e3 * 3
         
     | 
| 280 | 
         
            +
                    e2 = e2 * 2
         
     | 
| 281 | 
         
            +
             
     | 
| 282 | 
         
            +
                nn = 0
         
     | 
| 283 | 
         
            +
                for i in range(LUT_size, 0, -1):
         
     | 
| 284 | 
         
            +
                    if opt_fft_size_LUT[i-1] != 0:
         
     | 
| 285 | 
         
            +
                        nn = i-1
         
     | 
| 286 | 
         
            +
                    else:
         
     | 
| 287 | 
         
            +
                        opt_fft_size_LUT[i-1] = nn+1
         
     | 
| 288 | 
         
            +
             
     | 
| 289 | 
         
            +
                m = np.zeros(len(n))
         
     | 
| 290 | 
         
            +
                for c in range(len(n)):
         
     | 
| 291 | 
         
            +
                    nn = n[c]
         
     | 
| 292 | 
         
            +
                    if nn <= LUT_size:
         
     | 
| 293 | 
         
            +
                        m[c] = opt_fft_size_LUT[nn-1]
         
     | 
| 294 | 
         
            +
                    else:
         
     | 
| 295 | 
         
            +
                        m[c] = -1
         
     | 
| 296 | 
         
            +
                return m
         
     | 
| 297 | 
         
            +
             
     | 
| 298 | 
         
            +
             
     | 
| 299 | 
         
            +
            def wrap_boundary_liu(img, img_size):
         
     | 
| 300 | 
         
            +
             
     | 
| 301 | 
         
            +
                """
         
     | 
| 302 | 
         
            +
                Reducing boundary artifacts in image deconvolution
         
     | 
| 303 | 
         
            +
                Renting Liu, Jiaya Jia
         
     | 
| 304 | 
         
            +
                ICIP 2008
         
     | 
| 305 | 
         
            +
                """
         
     | 
| 306 | 
         
            +
                if img.ndim == 2:
         
     | 
| 307 | 
         
            +
                    ret = wrap_boundary(img, img_size)
         
     | 
| 308 | 
         
            +
                elif img.ndim == 3:
         
     | 
| 309 | 
         
            +
                    ret = [wrap_boundary(img[:, :, i], img_size) for i in range(3)]
         
     | 
| 310 | 
         
            +
                    ret = np.stack(ret, 2)
         
     | 
| 311 | 
         
            +
                return ret
         
     | 
| 312 | 
         
            +
             
     | 
| 313 | 
         
            +
             
     | 
| 314 | 
         
            +
            def wrap_boundary(img, img_size):
         
     | 
| 315 | 
         
            +
             
     | 
| 316 | 
         
            +
                """
         
     | 
| 317 | 
         
            +
                python code from:
         
     | 
| 318 | 
         
            +
                https://github.com/ys-koshelev/nla_deblur/blob/90fe0ab98c26c791dcbdf231fe6f938fca80e2a0/boundaries.py
         
     | 
| 319 | 
         
            +
                Reducing boundary artifacts in image deconvolution
         
     | 
| 320 | 
         
            +
                Renting Liu, Jiaya Jia
         
     | 
| 321 | 
         
            +
                ICIP 2008
         
     | 
| 322 | 
         
            +
                """
         
     | 
| 323 | 
         
            +
                (H, W) = np.shape(img)
         
     | 
| 324 | 
         
            +
                H_w = int(img_size[0]) - H
         
     | 
| 325 | 
         
            +
                W_w = int(img_size[1]) - W
         
     | 
| 326 | 
         
            +
             
     | 
| 327 | 
         
            +
                # ret = np.zeros((img_size[0], img_size[1]));
         
     | 
| 328 | 
         
            +
                alpha = 1
         
     | 
| 329 | 
         
            +
                HG = img[:, :]
         
     | 
| 330 | 
         
            +
             
     | 
| 331 | 
         
            +
                r_A = np.zeros((alpha*2+H_w, W))
         
     | 
| 332 | 
         
            +
                r_A[:alpha, :] = HG[-alpha:, :]
         
     | 
| 333 | 
         
            +
                r_A[-alpha:, :] = HG[:alpha, :]
         
     | 
| 334 | 
         
            +
                a = np.arange(H_w)/(H_w-1)
         
     | 
| 335 | 
         
            +
                # r_A(alpha+1:end-alpha, 1) = (1-a)*r_A(alpha,1) + a*r_A(end-alpha+1,1)
         
     | 
| 336 | 
         
            +
                r_A[alpha:-alpha, 0] = (1-a)*r_A[alpha-1, 0] + a*r_A[-alpha, 0]
         
     | 
| 337 | 
         
            +
                # r_A(alpha+1:end-alpha, end) = (1-a)*r_A(alpha,end) + a*r_A(end-alpha+1,end)
         
     | 
| 338 | 
         
            +
                r_A[alpha:-alpha, -1] = (1-a)*r_A[alpha-1, -1] + a*r_A[-alpha, -1]
         
     | 
| 339 | 
         
            +
             
     | 
| 340 | 
         
            +
                r_B = np.zeros((H, alpha*2+W_w))
         
     | 
| 341 | 
         
            +
                r_B[:, :alpha] = HG[:, -alpha:]
         
     | 
| 342 | 
         
            +
                r_B[:, -alpha:] = HG[:, :alpha]
         
     | 
| 343 | 
         
            +
                a = np.arange(W_w)/(W_w-1)
         
     | 
| 344 | 
         
            +
                r_B[0, alpha:-alpha] = (1-a)*r_B[0, alpha-1] + a*r_B[0, -alpha]
         
     | 
| 345 | 
         
            +
                r_B[-1, alpha:-alpha] = (1-a)*r_B[-1, alpha-1] + a*r_B[-1, -alpha]
         
     | 
| 346 | 
         
            +
             
     | 
| 347 | 
         
            +
                if alpha == 1:
         
     | 
| 348 | 
         
            +
                    A2 = solve_min_laplacian(r_A[alpha-1:, :])
         
     | 
| 349 | 
         
            +
                    B2 = solve_min_laplacian(r_B[:, alpha-1:])
         
     | 
| 350 | 
         
            +
                    r_A[alpha-1:, :] = A2
         
     | 
| 351 | 
         
            +
                    r_B[:, alpha-1:] = B2
         
     | 
| 352 | 
         
            +
                else:
         
     | 
| 353 | 
         
            +
                    A2 = solve_min_laplacian(r_A[alpha-1:-alpha+1, :])
         
     | 
| 354 | 
         
            +
                    r_A[alpha-1:-alpha+1, :] = A2
         
     | 
| 355 | 
         
            +
                    B2 = solve_min_laplacian(r_B[:, alpha-1:-alpha+1])
         
     | 
| 356 | 
         
            +
                    r_B[:, alpha-1:-alpha+1] = B2
         
     | 
| 357 | 
         
            +
                A = r_A
         
     | 
| 358 | 
         
            +
                B = r_B
         
     | 
| 359 | 
         
            +
             
     | 
| 360 | 
         
            +
                r_C = np.zeros((alpha*2+H_w, alpha*2+W_w))
         
     | 
| 361 | 
         
            +
                r_C[:alpha, :] = B[-alpha:, :]
         
     | 
| 362 | 
         
            +
                r_C[-alpha:, :] = B[:alpha, :]
         
     | 
| 363 | 
         
            +
                r_C[:, :alpha] = A[:, -alpha:]
         
     | 
| 364 | 
         
            +
                r_C[:, -alpha:] = A[:, :alpha]
         
     | 
| 365 | 
         
            +
             
     | 
| 366 | 
         
            +
                if alpha == 1:
         
     | 
| 367 | 
         
            +
                    C2 = C2 = solve_min_laplacian(r_C[alpha-1:, alpha-1:])
         
     | 
| 368 | 
         
            +
                    r_C[alpha-1:, alpha-1:] = C2
         
     | 
| 369 | 
         
            +
                else:
         
     | 
| 370 | 
         
            +
                    C2 = solve_min_laplacian(r_C[alpha-1:-alpha+1, alpha-1:-alpha+1])
         
     | 
| 371 | 
         
            +
                    r_C[alpha-1:-alpha+1, alpha-1:-alpha+1] = C2
         
     | 
| 372 | 
         
            +
                C = r_C
         
     | 
| 373 | 
         
            +
                # return C
         
     | 
| 374 | 
         
            +
                A = A[alpha-1:-alpha-1, :]
         
     | 
| 375 | 
         
            +
                B = B[:, alpha:-alpha]
         
     | 
| 376 | 
         
            +
                C = C[alpha:-alpha, alpha:-alpha]
         
     | 
| 377 | 
         
            +
                ret = np.vstack((np.hstack((img, B)), np.hstack((A, C))))
         
     | 
| 378 | 
         
            +
                return ret
         
     | 
| 379 | 
         
            +
             
     | 
| 380 | 
         
            +
             
     | 
| 381 | 
         
            +
            def solve_min_laplacian(boundary_image):
         
     | 
| 382 | 
         
            +
                (H, W) = np.shape(boundary_image)
         
     | 
| 383 | 
         
            +
             
     | 
| 384 | 
         
            +
                # Laplacian
         
     | 
| 385 | 
         
            +
                f = np.zeros((H, W))
         
     | 
| 386 | 
         
            +
                # boundary image contains image intensities at boundaries
         
     | 
| 387 | 
         
            +
                boundary_image[1:-1, 1:-1] = 0
         
     | 
| 388 | 
         
            +
                j = np.arange(2, H)-1
         
     | 
| 389 | 
         
            +
                k = np.arange(2, W)-1
         
     | 
| 390 | 
         
            +
                f_bp = np.zeros((H, W))
         
     | 
| 391 | 
         
            +
                f_bp[np.ix_(j, k)] = -4*boundary_image[np.ix_(j, k)] + boundary_image[np.ix_(j, k+1)] + boundary_image[np.ix_(j, k-1)] + boundary_image[np.ix_(j-1, k)] + boundary_image[np.ix_(j+1, k)]
         
     | 
| 392 | 
         
            +
                
         
     | 
| 393 | 
         
            +
                del(j, k)
         
     | 
| 394 | 
         
            +
                f1 = f - f_bp  # subtract boundary points contribution
         
     | 
| 395 | 
         
            +
                del(f_bp, f)
         
     | 
| 396 | 
         
            +
             
     | 
| 397 | 
         
            +
                # DST Sine Transform algo starts here
         
     | 
| 398 | 
         
            +
                f2 = f1[1:-1,1:-1]
         
     | 
| 399 | 
         
            +
                del(f1)
         
     | 
| 400 | 
         
            +
             
     | 
| 401 | 
         
            +
                # compute sine tranform
         
     | 
| 402 | 
         
            +
                if f2.shape[1] == 1:
         
     | 
| 403 | 
         
            +
                    tt = fftpack.dst(f2, type=1, axis=0)/2
         
     | 
| 404 | 
         
            +
                else:
         
     | 
| 405 | 
         
            +
                    tt = fftpack.dst(f2, type=1)/2
         
     | 
| 406 | 
         
            +
             
     | 
| 407 | 
         
            +
                if tt.shape[0] == 1:
         
     | 
| 408 | 
         
            +
                    f2sin = np.transpose(fftpack.dst(np.transpose(tt), type=1, axis=0)/2)
         
     | 
| 409 | 
         
            +
                else:
         
     | 
| 410 | 
         
            +
                    f2sin = np.transpose(fftpack.dst(np.transpose(tt), type=1)/2) 
         
     | 
| 411 | 
         
            +
                del(f2)
         
     | 
| 412 | 
         
            +
             
     | 
| 413 | 
         
            +
                # compute Eigen Values
         
     | 
| 414 | 
         
            +
                [x, y] = np.meshgrid(np.arange(1, W-1), np.arange(1, H-1))
         
     | 
| 415 | 
         
            +
                denom = (2*np.cos(np.pi*x/(W-1))-2) + (2*np.cos(np.pi*y/(H-1)) - 2)
         
     | 
| 416 | 
         
            +
             
     | 
| 417 | 
         
            +
                # divide
         
     | 
| 418 | 
         
            +
                f3 = f2sin/denom
         
     | 
| 419 | 
         
            +
                del(f2sin, x, y)
         
     | 
| 420 | 
         
            +
             
     | 
| 421 | 
         
            +
                # compute Inverse Sine Transform
         
     | 
| 422 | 
         
            +
                if f3.shape[0] == 1:
         
     | 
| 423 | 
         
            +
                    tt = fftpack.idst(f3*2, type=1, axis=1)/(2*(f3.shape[1]+1))
         
     | 
| 424 | 
         
            +
                else:
         
     | 
| 425 | 
         
            +
                    tt = fftpack.idst(f3*2, type=1, axis=0)/(2*(f3.shape[0]+1))
         
     | 
| 426 | 
         
            +
                del(f3)
         
     | 
| 427 | 
         
            +
                if tt.shape[1] == 1:
         
     | 
| 428 | 
         
            +
                    img_tt = np.transpose(fftpack.idst(np.transpose(tt)*2, type=1)/(2*(tt.shape[0]+1)))
         
     | 
| 429 | 
         
            +
                else:
         
     | 
| 430 | 
         
            +
                    img_tt = np.transpose(fftpack.idst(np.transpose(tt)*2, type=1, axis=0)/(2*(tt.shape[1]+1)))
         
     | 
| 431 | 
         
            +
                del(tt)
         
     | 
| 432 | 
         
            +
             
     | 
| 433 | 
         
            +
                # put solution in inner points; outer points obtained from boundary image
         
     | 
| 434 | 
         
            +
                img_direct = boundary_image
         
     | 
| 435 | 
         
            +
                img_direct[1:-1, 1:-1] = 0
         
     | 
| 436 | 
         
            +
                img_direct[1:-1, 1:-1] = img_tt
         
     | 
| 437 | 
         
            +
                return img_direct
         
     | 
| 438 | 
         
            +
             
     | 
| 439 | 
         
            +
             
     | 
| 440 | 
         
            +
            """
         
     | 
| 441 | 
         
            +
            Created on Thu Jan 18 15:36:32 2018
         
     | 
| 442 | 
         
            +
            @author: italo
         
     | 
| 443 | 
         
            +
            https://github.com/ronaldosena/imagens-medicas-2/blob/40171a6c259edec7827a6693a93955de2bd39e76/Aulas/aula_2_-_uniform_filter/matlab_fspecial.py
         
     | 
| 444 | 
         
            +
            """
         
     | 
| 445 | 
         
            +
             
     | 
| 446 | 
         
            +
            """
         
     | 
| 447 | 
         
            +
            Syntax
         
     | 
| 448 | 
         
            +
            h = fspecial(type)
         
     | 
| 449 | 
         
            +
            h = fspecial('average',hsize)
         
     | 
| 450 | 
         
            +
            h = fspecial('disk',radius)
         
     | 
| 451 | 
         
            +
            h = fspecial('gaussian',hsize,sigma)
         
     | 
| 452 | 
         
            +
            h = fspecial('laplacian',alpha)
         
     | 
| 453 | 
         
            +
            h = fspecial('log',hsize,sigma)
         
     | 
| 454 | 
         
            +
            h = fspecial('motion',len,theta)
         
     | 
| 455 | 
         
            +
            h = fspecial('prewitt')
         
     | 
| 456 | 
         
            +
            h = fspecial('sobel')
         
     | 
| 457 | 
         
            +
            """
         
     | 
| 458 | 
         
            +
             
     | 
| 459 | 
         
            +
             
     | 
| 460 | 
         
            +
            def fspecial_average(hsize=3):
         
     | 
| 461 | 
         
            +
                """Smoothing filter"""
         
     | 
| 462 | 
         
            +
                return np.ones((hsize, hsize))/hsize**2
         
     | 
| 463 | 
         
            +
             
     | 
| 464 | 
         
            +
             
     | 
| 465 | 
         
            +
            def fspecial_disk(radius):
         
     | 
| 466 | 
         
            +
                """Disk filter"""
         
     | 
| 467 | 
         
            +
                raise(NotImplemented)
         
     | 
| 468 | 
         
            +
                rad = 0.6
         
     | 
| 469 | 
         
            +
                crad = np.ceil(rad-0.5)
         
     | 
| 470 | 
         
            +
                [x, y] = np.meshgrid(np.arange(-crad, crad+1), np.arange(-crad, crad+1))
         
     | 
| 471 | 
         
            +
                maxxy = np.zeros(x.shape)
         
     | 
| 472 | 
         
            +
                maxxy[abs(x) >= abs(y)] = abs(x)[abs(x) >= abs(y)]
         
     | 
| 473 | 
         
            +
                maxxy[abs(y) >= abs(x)] = abs(y)[abs(y) >= abs(x)]
         
     | 
| 474 | 
         
            +
                minxy = np.zeros(x.shape)
         
     | 
| 475 | 
         
            +
                minxy[abs(x) <= abs(y)] = abs(x)[abs(x) <= abs(y)]
         
     | 
| 476 | 
         
            +
                minxy[abs(y) <= abs(x)] = abs(y)[abs(y) <= abs(x)]
         
     | 
| 477 | 
         
            +
                m1 = (rad**2 <  (maxxy+0.5)**2 + (minxy-0.5)**2)*(minxy-0.5) +\
         
     | 
| 478 | 
         
            +
                     (rad**2 >= (maxxy+0.5)**2 + (minxy-0.5)**2)*\
         
     | 
| 479 | 
         
            +
                     np.sqrt((rad**2 + 0j) - (maxxy + 0.5)**2)
         
     | 
| 480 | 
         
            +
                m2 = (rad**2 >  (maxxy-0.5)**2 + (minxy+0.5)**2)*(minxy+0.5) +\
         
     | 
| 481 | 
         
            +
                     (rad**2 <= (maxxy-0.5)**2 + (minxy+0.5)**2)*\
         
     | 
| 482 | 
         
            +
                     np.sqrt((rad**2 + 0j) - (maxxy - 0.5)**2)
         
     | 
| 483 | 
         
            +
                h = None
         
     | 
| 484 | 
         
            +
                return h
         
     | 
| 485 | 
         
            +
             
     | 
| 486 | 
         
            +
             
     | 
| 487 | 
         
            +
            def fspecial_gaussian(hsize, sigma):
         
     | 
| 488 | 
         
            +
                hsize = [hsize, hsize]
         
     | 
| 489 | 
         
            +
                siz = [(hsize[0]-1.0)/2.0, (hsize[1]-1.0)/2.0]
         
     | 
| 490 | 
         
            +
                std = sigma
         
     | 
| 491 | 
         
            +
                [x, y] = np.meshgrid(np.arange(-siz[1], siz[1]+1), np.arange(-siz[0], siz[0]+1))
         
     | 
| 492 | 
         
            +
                arg = -(x*x + y*y)/(2*std*std)
         
     | 
| 493 | 
         
            +
                h = np.exp(arg)
         
     | 
| 494 | 
         
            +
                h[h < scipy.finfo(float).eps * h.max()] = 0
         
     | 
| 495 | 
         
            +
                sumh = h.sum()
         
     | 
| 496 | 
         
            +
                if sumh != 0:
         
     | 
| 497 | 
         
            +
                    h = h/sumh
         
     | 
| 498 | 
         
            +
                return h
         
     | 
| 499 | 
         
            +
             
     | 
| 500 | 
         
            +
             
     | 
| 501 | 
         
            +
            def fspecial_laplacian(alpha):
         
     | 
| 502 | 
         
            +
                alpha = max([0, min([alpha,1])])
         
     | 
| 503 | 
         
            +
                h1 = alpha/(alpha+1)
         
     | 
| 504 | 
         
            +
                h2 = (1-alpha)/(alpha+1)
         
     | 
| 505 | 
         
            +
                h = [[h1, h2, h1], [h2, -4/(alpha+1), h2], [h1, h2, h1]]
         
     | 
| 506 | 
         
            +
                h = np.array(h)
         
     | 
| 507 | 
         
            +
                return h
         
     | 
| 508 | 
         
            +
             
     | 
| 509 | 
         
            +
             
     | 
| 510 | 
         
            +
            def fspecial_log(hsize, sigma):
         
     | 
| 511 | 
         
            +
                raise(NotImplemented)
         
     | 
| 512 | 
         
            +
             
     | 
| 513 | 
         
            +
             
     | 
| 514 | 
         
            +
            def fspecial_motion(motion_len, theta):
         
     | 
| 515 | 
         
            +
                raise(NotImplemented)
         
     | 
| 516 | 
         
            +
             
     | 
| 517 | 
         
            +
             
     | 
| 518 | 
         
            +
            def fspecial_prewitt():
         
     | 
| 519 | 
         
            +
                return np.array([[1, 1, 1], [0, 0, 0], [-1, -1, -1]])
         
     | 
| 520 | 
         
            +
             
     | 
| 521 | 
         
            +
             
     | 
| 522 | 
         
            +
            def fspecial_sobel():
         
     | 
| 523 | 
         
            +
                return np.array([[1, 2, 1], [0, 0, 0], [-1, -2, -1]])
         
     | 
| 524 | 
         
            +
             
     | 
| 525 | 
         
            +
             
     | 
| 526 | 
         
            +
            def fspecial(filter_type, *args, **kwargs):
         
     | 
| 527 | 
         
            +
                '''
         
     | 
| 528 | 
         
            +
                python code from:
         
     | 
| 529 | 
         
            +
                https://github.com/ronaldosena/imagens-medicas-2/blob/40171a6c259edec7827a6693a93955de2bd39e76/Aulas/aula_2_-_uniform_filter/matlab_fspecial.py
         
     | 
| 530 | 
         
            +
                '''
         
     | 
| 531 | 
         
            +
                if filter_type == 'average':
         
     | 
| 532 | 
         
            +
                    return fspecial_average(*args, **kwargs)
         
     | 
| 533 | 
         
            +
                if filter_type == 'disk':
         
     | 
| 534 | 
         
            +
                    return fspecial_disk(*args, **kwargs)
         
     | 
| 535 | 
         
            +
                if filter_type == 'gaussian':
         
     | 
| 536 | 
         
            +
                    return fspecial_gaussian(*args, **kwargs)
         
     | 
| 537 | 
         
            +
                if filter_type == 'laplacian':
         
     | 
| 538 | 
         
            +
                    return fspecial_laplacian(*args, **kwargs)
         
     | 
| 539 | 
         
            +
                if filter_type == 'log':
         
     | 
| 540 | 
         
            +
                    return fspecial_log(*args, **kwargs)
         
     | 
| 541 | 
         
            +
                if filter_type == 'motion':
         
     | 
| 542 | 
         
            +
                    return fspecial_motion(*args, **kwargs)
         
     | 
| 543 | 
         
            +
                if filter_type == 'prewitt':
         
     | 
| 544 | 
         
            +
                    return fspecial_prewitt(*args, **kwargs)
         
     | 
| 545 | 
         
            +
                if filter_type == 'sobel':
         
     | 
| 546 | 
         
            +
                    return fspecial_sobel(*args, **kwargs)
         
     | 
| 547 | 
         
            +
             
     | 
| 548 | 
         
            +
             
     | 
| 549 | 
         
            +
            def fspecial_gauss(size, sigma):
         
     | 
| 550 | 
         
            +
                x, y = mgrid[-size // 2 + 1 : size // 2 + 1, -size // 2 + 1 : size // 2 + 1]
         
     | 
| 551 | 
         
            +
                g = exp(-((x ** 2 + y ** 2) / (2.0 * sigma ** 2)))
         
     | 
| 552 | 
         
            +
                return g / g.sum()
         
     | 
| 553 | 
         
            +
             
     | 
| 554 | 
         
            +
             
     | 
| 555 | 
         
            +
            def blurkernel_synthesis(h=37, w=None):
         
     | 
| 556 | 
         
            +
                # https://github.com/tkkcc/prior/blob/879a0b6c117c810776d8cc6b63720bf29f7d0cc4/util/gen_kernel.py
         
     | 
| 557 | 
         
            +
                w = h if w is None else w
         
     | 
| 558 | 
         
            +
                kdims = [h, w]
         
     | 
| 559 | 
         
            +
                x = randomTrajectory(250)
         
     | 
| 560 | 
         
            +
                k = None
         
     | 
| 561 | 
         
            +
                while k is None:
         
     | 
| 562 | 
         
            +
                    k = kernelFromTrajectory(x)
         
     | 
| 563 | 
         
            +
             
     | 
| 564 | 
         
            +
                # center pad to kdims
         
     | 
| 565 | 
         
            +
                pad_width = ((kdims[0] - k.shape[0]) // 2, (kdims[1] - k.shape[1]) // 2)
         
     | 
| 566 | 
         
            +
                pad_width = [(pad_width[0],), (pad_width[1],)]
         
     | 
| 567 | 
         
            +
                
         
     | 
| 568 | 
         
            +
                if pad_width[0][0]<0 or pad_width[1][0]<0:
         
     | 
| 569 | 
         
            +
                    k = k[0:h, 0:h]
         
     | 
| 570 | 
         
            +
                else:
         
     | 
| 571 | 
         
            +
                    k = pad(k, pad_width, "constant")
         
     | 
| 572 | 
         
            +
                x1,x2 = k.shape
         
     | 
| 573 | 
         
            +
                if np.random.randint(0, 4) == 1:
         
     | 
| 574 | 
         
            +
                    k = cv2.resize(k, (random.randint(x1, 5*x1), random.randint(x2, 5*x2)), interpolation=cv2.INTER_LINEAR)
         
     | 
| 575 | 
         
            +
                    y1, y2 = k.shape
         
     | 
| 576 | 
         
            +
                    k = k[(y1-x1)//2: (y1-x1)//2+x1, (y2-x2)//2: (y2-x2)//2+x2]
         
     | 
| 577 | 
         
            +
                    
         
     | 
| 578 | 
         
            +
                if sum(k)<0.1:
         
     | 
| 579 | 
         
            +
                    k = fspecial_gaussian(h, 0.1+6*np.random.rand(1))
         
     | 
| 580 | 
         
            +
                k = k / sum(k)
         
     | 
| 581 | 
         
            +
                # import matplotlib.pyplot as plt
         
     | 
| 582 | 
         
            +
                # plt.imshow(k, interpolation="nearest", cmap="gray")
         
     | 
| 583 | 
         
            +
                # plt.show()
         
     | 
| 584 | 
         
            +
                return k
         
     | 
| 585 | 
         
            +
             
     | 
| 586 | 
         
            +
             
     | 
| 587 | 
         
            +
            def kernelFromTrajectory(x):
         
     | 
| 588 | 
         
            +
                h = 5 - log(rand()) / 0.15
         
     | 
| 589 | 
         
            +
                h = round(min([h, 27])).astype(int)
         
     | 
| 590 | 
         
            +
                h = h + 1 - h % 2
         
     | 
| 591 | 
         
            +
                w = h
         
     | 
| 592 | 
         
            +
                k = zeros((h, w))
         
     | 
| 593 | 
         
            +
             
     | 
| 594 | 
         
            +
                xmin = min(x[0])
         
     | 
| 595 | 
         
            +
                xmax = max(x[0])
         
     | 
| 596 | 
         
            +
                ymin = min(x[1])
         
     | 
| 597 | 
         
            +
                ymax = max(x[1])
         
     | 
| 598 | 
         
            +
                xthr = arange(xmin, xmax, (xmax - xmin) / w)
         
     | 
| 599 | 
         
            +
                ythr = arange(ymin, ymax, (ymax - ymin) / h)
         
     | 
| 600 | 
         
            +
             
     | 
| 601 | 
         
            +
                for i in range(1, xthr.size):
         
     | 
| 602 | 
         
            +
                    for j in range(1, ythr.size):
         
     | 
| 603 | 
         
            +
                        idx = (
         
     | 
| 604 | 
         
            +
                            (x[0, :] >= xthr[i - 1])
         
     | 
| 605 | 
         
            +
                            & (x[0, :] < xthr[i])
         
     | 
| 606 | 
         
            +
                            & (x[1, :] >= ythr[j - 1])
         
     | 
| 607 | 
         
            +
                            & (x[1, :] < ythr[j])
         
     | 
| 608 | 
         
            +
                        )
         
     | 
| 609 | 
         
            +
                        k[i - 1, j - 1] = sum(idx)
         
     | 
| 610 | 
         
            +
                if sum(k) == 0:
         
     | 
| 611 | 
         
            +
                    return
         
     | 
| 612 | 
         
            +
                k = k / sum(k)
         
     | 
| 613 | 
         
            +
                k = convolve2d(k, fspecial_gauss(3, 1), "same")
         
     | 
| 614 | 
         
            +
                k = k / sum(k)
         
     | 
| 615 | 
         
            +
                return k
         
     | 
| 616 | 
         
            +
             
     | 
| 617 | 
         
            +
             
     | 
| 618 | 
         
            +
            def randomTrajectory(T):
         
     | 
| 619 | 
         
            +
                x = zeros((3, T))
         
     | 
| 620 | 
         
            +
                v = randn(3, T)
         
     | 
| 621 | 
         
            +
                r = zeros((3, T))
         
     | 
| 622 | 
         
            +
                trv = 1 / 1
         
     | 
| 623 | 
         
            +
                trr = 2 * pi / T
         
     | 
| 624 | 
         
            +
                for t in range(1, T):
         
     | 
| 625 | 
         
            +
                    F_rot = randn(3) / (t + 1) + r[:, t - 1]
         
     | 
| 626 | 
         
            +
                    F_trans = randn(3) / (t + 1)
         
     | 
| 627 | 
         
            +
                    r[:, t] = r[:, t - 1] + trr * F_rot
         
     | 
| 628 | 
         
            +
                    v[:, t] = v[:, t - 1] + trv * F_trans
         
     | 
| 629 | 
         
            +
                    st = v[:, t]
         
     | 
| 630 | 
         
            +
                    st = rot3D(st, r[:, t])
         
     | 
| 631 | 
         
            +
                    x[:, t] = x[:, t - 1] + st
         
     | 
| 632 | 
         
            +
                return x
         
     | 
| 633 | 
         
            +
             
     | 
| 634 | 
         
            +
             
     | 
| 635 | 
         
            +
            def rot3D(x, r):
         
     | 
| 636 | 
         
            +
                Rx = array([[1, 0, 0], [0, cos(r[0]), -sin(r[0])], [0, sin(r[0]), cos(r[0])]])
         
     | 
| 637 | 
         
            +
                Ry = array([[cos(r[1]), 0, sin(r[1])], [0, 1, 0], [-sin(r[1]), 0, cos(r[1])]])
         
     | 
| 638 | 
         
            +
                Rz = array([[cos(r[2]), -sin(r[2]), 0], [sin(r[2]), cos(r[2]), 0], [0, 0, 1]])
         
     | 
| 639 | 
         
            +
                R = Rz @ Ry @ Rx
         
     | 
| 640 | 
         
            +
                x = R @ x
         
     | 
| 641 | 
         
            +
                return x
         
     | 
| 642 | 
         
            +
             
     | 
| 643 | 
         
            +
             
     | 
| 644 | 
         
            +
            if __name__ == '__main__':
         
     | 
| 645 | 
         
            +
                a = opt_fft_size([111])
         
     | 
| 646 | 
         
            +
                print(a)
         
     | 
| 647 | 
         
            +
             
     | 
| 648 | 
         
            +
                print(fspecial('gaussian', 5, 1))
         
     | 
| 649 | 
         
            +
                
         
     | 
| 650 | 
         
            +
                print(p2o(torch.zeros(1,1,4,4).float(),(14,14)).shape)
         
     | 
| 651 | 
         
            +
             
     | 
| 652 | 
         
            +
                k = blurkernel_synthesis(11)
         
     | 
| 653 | 
         
            +
                import matplotlib.pyplot as plt
         
     | 
| 654 | 
         
            +
                plt.imshow(k, interpolation="nearest", cmap="gray")
         
     | 
| 655 | 
         
            +
                plt.show()
         
     | 
    	
        core/data/deg_kair_utils/utils_dist.py
    ADDED
    
    | 
         @@ -0,0 +1,201 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            # Modified from https://github.com/open-mmlab/mmcv/blob/master/mmcv/runner/dist_utils.py  # noqa: E501
         
     | 
| 2 | 
         
            +
            import functools
         
     | 
| 3 | 
         
            +
            import os
         
     | 
| 4 | 
         
            +
            import subprocess
         
     | 
| 5 | 
         
            +
            import torch
         
     | 
| 6 | 
         
            +
            import torch.distributed as dist
         
     | 
| 7 | 
         
            +
            import torch.multiprocessing as mp
         
     | 
| 8 | 
         
            +
             
     | 
| 9 | 
         
            +
             
     | 
| 10 | 
         
            +
            # ----------------------------------
         
     | 
| 11 | 
         
            +
            # init
         
     | 
| 12 | 
         
            +
            # ----------------------------------
         
     | 
| 13 | 
         
            +
            def init_dist(launcher, backend='nccl', **kwargs):
         
     | 
| 14 | 
         
            +
                if mp.get_start_method(allow_none=True) is None:
         
     | 
| 15 | 
         
            +
                    mp.set_start_method('spawn')
         
     | 
| 16 | 
         
            +
                if launcher == 'pytorch':
         
     | 
| 17 | 
         
            +
                    _init_dist_pytorch(backend, **kwargs)
         
     | 
| 18 | 
         
            +
                elif launcher == 'slurm':
         
     | 
| 19 | 
         
            +
                    _init_dist_slurm(backend, **kwargs)
         
     | 
| 20 | 
         
            +
                else:
         
     | 
| 21 | 
         
            +
                    raise ValueError(f'Invalid launcher type: {launcher}')
         
     | 
| 22 | 
         
            +
             
     | 
| 23 | 
         
            +
             
     | 
| 24 | 
         
            +
            def _init_dist_pytorch(backend, **kwargs):
         
     | 
| 25 | 
         
            +
                rank = int(os.environ['RANK'])
         
     | 
| 26 | 
         
            +
                num_gpus = torch.cuda.device_count()
         
     | 
| 27 | 
         
            +
                torch.cuda.set_device(rank % num_gpus)
         
     | 
| 28 | 
         
            +
                dist.init_process_group(backend=backend, **kwargs)
         
     | 
| 29 | 
         
            +
             
     | 
| 30 | 
         
            +
             
     | 
| 31 | 
         
            +
            def _init_dist_slurm(backend, port=None):
         
     | 
| 32 | 
         
            +
                """Initialize slurm distributed training environment.
         
     | 
| 33 | 
         
            +
                If argument ``port`` is not specified, then the master port will be system
         
     | 
| 34 | 
         
            +
                environment variable ``MASTER_PORT``. If ``MASTER_PORT`` is not in system
         
     | 
| 35 | 
         
            +
                environment variable, then a default port ``29500`` will be used.
         
     | 
| 36 | 
         
            +
                Args:
         
     | 
| 37 | 
         
            +
                    backend (str): Backend of torch.distributed.
         
     | 
| 38 | 
         
            +
                    port (int, optional): Master port. Defaults to None.
         
     | 
| 39 | 
         
            +
                """
         
     | 
| 40 | 
         
            +
                proc_id = int(os.environ['SLURM_PROCID'])
         
     | 
| 41 | 
         
            +
                ntasks = int(os.environ['SLURM_NTASKS'])
         
     | 
| 42 | 
         
            +
                node_list = os.environ['SLURM_NODELIST']
         
     | 
| 43 | 
         
            +
                num_gpus = torch.cuda.device_count()
         
     | 
| 44 | 
         
            +
                torch.cuda.set_device(proc_id % num_gpus)
         
     | 
| 45 | 
         
            +
                addr = subprocess.getoutput(
         
     | 
| 46 | 
         
            +
                    f'scontrol show hostname {node_list} | head -n1')
         
     | 
| 47 | 
         
            +
                # specify master port
         
     | 
| 48 | 
         
            +
                if port is not None:
         
     | 
| 49 | 
         
            +
                    os.environ['MASTER_PORT'] = str(port)
         
     | 
| 50 | 
         
            +
                elif 'MASTER_PORT' in os.environ:
         
     | 
| 51 | 
         
            +
                    pass  # use MASTER_PORT in the environment variable
         
     | 
| 52 | 
         
            +
                else:
         
     | 
| 53 | 
         
            +
                    # 29500 is torch.distributed default port
         
     | 
| 54 | 
         
            +
                    os.environ['MASTER_PORT'] = '29500'
         
     | 
| 55 | 
         
            +
                os.environ['MASTER_ADDR'] = addr
         
     | 
| 56 | 
         
            +
                os.environ['WORLD_SIZE'] = str(ntasks)
         
     | 
| 57 | 
         
            +
                os.environ['LOCAL_RANK'] = str(proc_id % num_gpus)
         
     | 
| 58 | 
         
            +
                os.environ['RANK'] = str(proc_id)
         
     | 
| 59 | 
         
            +
                dist.init_process_group(backend=backend)
         
     | 
| 60 | 
         
            +
             
     | 
| 61 | 
         
            +
             
     | 
| 62 | 
         
            +
             
     | 
| 63 | 
         
            +
            # ----------------------------------
         
     | 
| 64 | 
         
            +
            # get rank and world_size
         
     | 
| 65 | 
         
            +
            # ----------------------------------
         
     | 
| 66 | 
         
            +
            def get_dist_info():
         
     | 
| 67 | 
         
            +
                if dist.is_available():
         
     | 
| 68 | 
         
            +
                    initialized = dist.is_initialized()
         
     | 
| 69 | 
         
            +
                else:
         
     | 
| 70 | 
         
            +
                    initialized = False
         
     | 
| 71 | 
         
            +
                if initialized:
         
     | 
| 72 | 
         
            +
                    rank = dist.get_rank()
         
     | 
| 73 | 
         
            +
                    world_size = dist.get_world_size()
         
     | 
| 74 | 
         
            +
                else:
         
     | 
| 75 | 
         
            +
                    rank = 0
         
     | 
| 76 | 
         
            +
                    world_size = 1
         
     | 
| 77 | 
         
            +
                return rank, world_size
         
     | 
| 78 | 
         
            +
             
     | 
| 79 | 
         
            +
             
     | 
| 80 | 
         
            +
            def get_rank():
         
     | 
| 81 | 
         
            +
                if not dist.is_available():
         
     | 
| 82 | 
         
            +
                    return 0
         
     | 
| 83 | 
         
            +
             
     | 
| 84 | 
         
            +
                if not dist.is_initialized():
         
     | 
| 85 | 
         
            +
                    return 0
         
     | 
| 86 | 
         
            +
             
     | 
| 87 | 
         
            +
                return dist.get_rank()
         
     | 
| 88 | 
         
            +
             
     | 
| 89 | 
         
            +
             
     | 
| 90 | 
         
            +
            def get_world_size():
         
     | 
| 91 | 
         
            +
                if not dist.is_available():
         
     | 
| 92 | 
         
            +
                    return 1
         
     | 
| 93 | 
         
            +
             
     | 
| 94 | 
         
            +
                if not dist.is_initialized():
         
     | 
| 95 | 
         
            +
                    return 1
         
     | 
| 96 | 
         
            +
             
     | 
| 97 | 
         
            +
                return dist.get_world_size()
         
     | 
| 98 | 
         
            +
             
     | 
| 99 | 
         
            +
             
     | 
| 100 | 
         
            +
            def master_only(func):
         
     | 
| 101 | 
         
            +
             
     | 
| 102 | 
         
            +
                @functools.wraps(func)
         
     | 
| 103 | 
         
            +
                def wrapper(*args, **kwargs):
         
     | 
| 104 | 
         
            +
                    rank, _ = get_dist_info()
         
     | 
| 105 | 
         
            +
                    if rank == 0:
         
     | 
| 106 | 
         
            +
                        return func(*args, **kwargs)
         
     | 
| 107 | 
         
            +
             
     | 
| 108 | 
         
            +
                return wrapper
         
     | 
| 109 | 
         
            +
             
     | 
| 110 | 
         
            +
             
     | 
| 111 | 
         
            +
             
     | 
| 112 | 
         
            +
             
     | 
| 113 | 
         
            +
             
     | 
| 114 | 
         
            +
             
     | 
| 115 | 
         
            +
            # ----------------------------------
         
     | 
| 116 | 
         
            +
            # operation across ranks
         
     | 
| 117 | 
         
            +
            # ----------------------------------
         
     | 
| 118 | 
         
            +
            def reduce_sum(tensor):
         
     | 
| 119 | 
         
            +
                if not dist.is_available():
         
     | 
| 120 | 
         
            +
                    return tensor
         
     | 
| 121 | 
         
            +
             
     | 
| 122 | 
         
            +
                if not dist.is_initialized():
         
     | 
| 123 | 
         
            +
                    return tensor
         
     | 
| 124 | 
         
            +
             
     | 
| 125 | 
         
            +
                tensor = tensor.clone()
         
     | 
| 126 | 
         
            +
                dist.all_reduce(tensor, op=dist.ReduceOp.SUM)
         
     | 
| 127 | 
         
            +
             
     | 
| 128 | 
         
            +
                return tensor
         
     | 
| 129 | 
         
            +
             
     | 
| 130 | 
         
            +
             
     | 
| 131 | 
         
            +
            def gather_grad(params):
         
     | 
| 132 | 
         
            +
                world_size = get_world_size()
         
     | 
| 133 | 
         
            +
                
         
     | 
| 134 | 
         
            +
                if world_size == 1:
         
     | 
| 135 | 
         
            +
                    return
         
     | 
| 136 | 
         
            +
             
     | 
| 137 | 
         
            +
                for param in params:
         
     | 
| 138 | 
         
            +
                    if param.grad is not None:
         
     | 
| 139 | 
         
            +
                        dist.all_reduce(param.grad.data, op=dist.ReduceOp.SUM)
         
     | 
| 140 | 
         
            +
                        param.grad.data.div_(world_size)
         
     | 
| 141 | 
         
            +
             
     | 
| 142 | 
         
            +
             
     | 
| 143 | 
         
            +
            def all_gather(data):
         
     | 
| 144 | 
         
            +
                world_size = get_world_size()
         
     | 
| 145 | 
         
            +
             
     | 
| 146 | 
         
            +
                if world_size == 1:
         
     | 
| 147 | 
         
            +
                    return [data]
         
     | 
| 148 | 
         
            +
             
     | 
| 149 | 
         
            +
                buffer = pickle.dumps(data)
         
     | 
| 150 | 
         
            +
                storage = torch.ByteStorage.from_buffer(buffer)
         
     | 
| 151 | 
         
            +
                tensor = torch.ByteTensor(storage).to('cuda')
         
     | 
| 152 | 
         
            +
             
     | 
| 153 | 
         
            +
                local_size = torch.IntTensor([tensor.numel()]).to('cuda')
         
     | 
| 154 | 
         
            +
                size_list = [torch.IntTensor([0]).to('cuda') for _ in range(world_size)]
         
     | 
| 155 | 
         
            +
                dist.all_gather(size_list, local_size)
         
     | 
| 156 | 
         
            +
                size_list = [int(size.item()) for size in size_list]
         
     | 
| 157 | 
         
            +
                max_size = max(size_list)
         
     | 
| 158 | 
         
            +
             
     | 
| 159 | 
         
            +
                tensor_list = []
         
     | 
| 160 | 
         
            +
                for _ in size_list:
         
     | 
| 161 | 
         
            +
                    tensor_list.append(torch.ByteTensor(size=(max_size,)).to('cuda'))
         
     | 
| 162 | 
         
            +
             
     | 
| 163 | 
         
            +
                if local_size != max_size:
         
     | 
| 164 | 
         
            +
                    padding = torch.ByteTensor(size=(max_size - local_size,)).to('cuda')
         
     | 
| 165 | 
         
            +
                    tensor = torch.cat((tensor, padding), 0)
         
     | 
| 166 | 
         
            +
             
     | 
| 167 | 
         
            +
                dist.all_gather(tensor_list, tensor)
         
     | 
| 168 | 
         
            +
             
     | 
| 169 | 
         
            +
                data_list = []
         
     | 
| 170 | 
         
            +
             
     | 
| 171 | 
         
            +
                for size, tensor in zip(size_list, tensor_list):
         
     | 
| 172 | 
         
            +
                    buffer = tensor.cpu().numpy().tobytes()[:size]
         
     | 
| 173 | 
         
            +
                    data_list.append(pickle.loads(buffer))
         
     | 
| 174 | 
         
            +
             
     | 
| 175 | 
         
            +
                return data_list
         
     | 
| 176 | 
         
            +
             
     | 
| 177 | 
         
            +
             
     | 
| 178 | 
         
            +
            def reduce_loss_dict(loss_dict):
         
     | 
| 179 | 
         
            +
                world_size = get_world_size()
         
     | 
| 180 | 
         
            +
             
     | 
| 181 | 
         
            +
                if world_size < 2:
         
     | 
| 182 | 
         
            +
                    return loss_dict
         
     | 
| 183 | 
         
            +
             
     | 
| 184 | 
         
            +
                with torch.no_grad():
         
     | 
| 185 | 
         
            +
                    keys = []
         
     | 
| 186 | 
         
            +
                    losses = []
         
     | 
| 187 | 
         
            +
             
     | 
| 188 | 
         
            +
                    for k in sorted(loss_dict.keys()):
         
     | 
| 189 | 
         
            +
                        keys.append(k)
         
     | 
| 190 | 
         
            +
                        losses.append(loss_dict[k])
         
     | 
| 191 | 
         
            +
             
     | 
| 192 | 
         
            +
                    losses = torch.stack(losses, 0)
         
     | 
| 193 | 
         
            +
                    dist.reduce(losses, dst=0)
         
     | 
| 194 | 
         
            +
             
     | 
| 195 | 
         
            +
                    if dist.get_rank() == 0:
         
     | 
| 196 | 
         
            +
                        losses /= world_size
         
     | 
| 197 | 
         
            +
             
     | 
| 198 | 
         
            +
                    reduced_losses = {k: v for k, v in zip(keys, losses)}
         
     | 
| 199 | 
         
            +
             
     | 
| 200 | 
         
            +
                return reduced_losses
         
     | 
| 201 | 
         
            +
             
     | 
    	
        core/data/deg_kair_utils/utils_googledownload.py
    ADDED
    
    | 
         @@ -0,0 +1,93 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            import math
         
     | 
| 2 | 
         
            +
            import requests
         
     | 
| 3 | 
         
            +
            from tqdm import tqdm
         
     | 
| 4 | 
         
            +
             
     | 
| 5 | 
         
            +
             
     | 
| 6 | 
         
            +
            '''
         
     | 
| 7 | 
         
            +
            borrowed from 
         
     | 
| 8 | 
         
            +
            https://github.com/xinntao/BasicSR/blob/28883e15eedc3381d23235ff3cf7c454c4be87e6/basicsr/utils/download_util.py
         
     | 
| 9 | 
         
            +
            '''
         
     | 
| 10 | 
         
            +
             
     | 
| 11 | 
         
            +
             
     | 
| 12 | 
         
            +
            def sizeof_fmt(size, suffix='B'):
         
     | 
| 13 | 
         
            +
                """Get human readable file size.
         
     | 
| 14 | 
         
            +
                Args:
         
     | 
| 15 | 
         
            +
                    size (int): File size.
         
     | 
| 16 | 
         
            +
                    suffix (str): Suffix. Default: 'B'.
         
     | 
| 17 | 
         
            +
                Return:
         
     | 
| 18 | 
         
            +
                    str: Formated file siz.
         
     | 
| 19 | 
         
            +
                """
         
     | 
| 20 | 
         
            +
                for unit in ['', 'K', 'M', 'G', 'T', 'P', 'E', 'Z']:
         
     | 
| 21 | 
         
            +
                    if abs(size) < 1024.0:
         
     | 
| 22 | 
         
            +
                        return f'{size:3.1f} {unit}{suffix}'
         
     | 
| 23 | 
         
            +
                    size /= 1024.0
         
     | 
| 24 | 
         
            +
                return f'{size:3.1f} Y{suffix}'
         
     | 
| 25 | 
         
            +
             
     | 
| 26 | 
         
            +
             
     | 
| 27 | 
         
            +
            def download_file_from_google_drive(file_id, save_path):
         
     | 
| 28 | 
         
            +
                """Download files from google drive.
         
     | 
| 29 | 
         
            +
                Ref:
         
     | 
| 30 | 
         
            +
                https://stackoverflow.com/questions/25010369/wget-curl-large-file-from-google-drive  # noqa E501
         
     | 
| 31 | 
         
            +
                Args:
         
     | 
| 32 | 
         
            +
                    file_id (str): File id.
         
     | 
| 33 | 
         
            +
                    save_path (str): Save path.
         
     | 
| 34 | 
         
            +
                """
         
     | 
| 35 | 
         
            +
             
     | 
| 36 | 
         
            +
                session = requests.Session()
         
     | 
| 37 | 
         
            +
                URL = 'https://docs.google.com/uc?export=download'
         
     | 
| 38 | 
         
            +
                params = {'id': file_id}
         
     | 
| 39 | 
         
            +
             
     | 
| 40 | 
         
            +
                response = session.get(URL, params=params, stream=True)
         
     | 
| 41 | 
         
            +
                token = get_confirm_token(response)
         
     | 
| 42 | 
         
            +
                if token:
         
     | 
| 43 | 
         
            +
                    params['confirm'] = token
         
     | 
| 44 | 
         
            +
                    response = session.get(URL, params=params, stream=True)
         
     | 
| 45 | 
         
            +
             
     | 
| 46 | 
         
            +
                # get file size
         
     | 
| 47 | 
         
            +
                response_file_size = session.get(
         
     | 
| 48 | 
         
            +
                    URL, params=params, stream=True, headers={'Range': 'bytes=0-2'})
         
     | 
| 49 | 
         
            +
                if 'Content-Range' in response_file_size.headers:
         
     | 
| 50 | 
         
            +
                    file_size = int(
         
     | 
| 51 | 
         
            +
                        response_file_size.headers['Content-Range'].split('/')[1])
         
     | 
| 52 | 
         
            +
                else:
         
     | 
| 53 | 
         
            +
                    file_size = None
         
     | 
| 54 | 
         
            +
             
     | 
| 55 | 
         
            +
                save_response_content(response, save_path, file_size)
         
     | 
| 56 | 
         
            +
             
     | 
| 57 | 
         
            +
             
     | 
| 58 | 
         
            +
            def get_confirm_token(response):
         
     | 
| 59 | 
         
            +
                for key, value in response.cookies.items():
         
     | 
| 60 | 
         
            +
                    if key.startswith('download_warning'):
         
     | 
| 61 | 
         
            +
                        return value
         
     | 
| 62 | 
         
            +
                return None
         
     | 
| 63 | 
         
            +
             
     | 
| 64 | 
         
            +
             
     | 
| 65 | 
         
            +
            def save_response_content(response,
         
     | 
| 66 | 
         
            +
                                      destination,
         
     | 
| 67 | 
         
            +
                                      file_size=None,
         
     | 
| 68 | 
         
            +
                                      chunk_size=32768):
         
     | 
| 69 | 
         
            +
                if file_size is not None:
         
     | 
| 70 | 
         
            +
                    pbar = tqdm(total=math.ceil(file_size / chunk_size), unit='chunk')
         
     | 
| 71 | 
         
            +
             
     | 
| 72 | 
         
            +
                    readable_file_size = sizeof_fmt(file_size)
         
     | 
| 73 | 
         
            +
                else:
         
     | 
| 74 | 
         
            +
                    pbar = None
         
     | 
| 75 | 
         
            +
             
     | 
| 76 | 
         
            +
                with open(destination, 'wb') as f:
         
     | 
| 77 | 
         
            +
                    downloaded_size = 0
         
     | 
| 78 | 
         
            +
                    for chunk in response.iter_content(chunk_size):
         
     | 
| 79 | 
         
            +
                        downloaded_size += chunk_size
         
     | 
| 80 | 
         
            +
                        if pbar is not None:
         
     | 
| 81 | 
         
            +
                            pbar.update(1)
         
     | 
| 82 | 
         
            +
                            pbar.set_description(f'Download {sizeof_fmt(downloaded_size)} '
         
     | 
| 83 | 
         
            +
                                                 f'/ {readable_file_size}')
         
     | 
| 84 | 
         
            +
                        if chunk:  # filter out keep-alive new chunks
         
     | 
| 85 | 
         
            +
                            f.write(chunk)
         
     | 
| 86 | 
         
            +
                    if pbar is not None:
         
     | 
| 87 | 
         
            +
                        pbar.close()
         
     | 
| 88 | 
         
            +
             
     | 
| 89 | 
         
            +
             
     | 
| 90 | 
         
            +
            if __name__ == "__main__":
         
     | 
| 91 | 
         
            +
                file_id = '1WNULM1e8gRNvsngVscsQ8tpaOqJ4mYtv'
         
     | 
| 92 | 
         
            +
                save_path = 'BSRGAN.pth'
         
     | 
| 93 | 
         
            +
                download_file_from_google_drive(file_id, save_path)
         
     | 
    	
        core/data/deg_kair_utils/utils_image.py
    ADDED
    
    | 
         @@ -0,0 +1,1016 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            import os
         
     | 
| 2 | 
         
            +
            import math
         
     | 
| 3 | 
         
            +
            import random
         
     | 
| 4 | 
         
            +
            import numpy as np
         
     | 
| 5 | 
         
            +
            import torch
         
     | 
| 6 | 
         
            +
            import cv2
         
     | 
| 7 | 
         
            +
            from torchvision.utils import make_grid
         
     | 
| 8 | 
         
            +
            from datetime import datetime
         
     | 
| 9 | 
         
            +
            # import torchvision.transforms as transforms
         
     | 
| 10 | 
         
            +
            import matplotlib.pyplot as plt
         
     | 
| 11 | 
         
            +
            from mpl_toolkits.mplot3d import Axes3D
         
     | 
| 12 | 
         
            +
            os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE"
         
     | 
| 13 | 
         
            +
             
     | 
| 14 | 
         
            +
             
     | 
| 15 | 
         
            +
            '''
         
     | 
| 16 | 
         
            +
            # --------------------------------------------
         
     | 
| 17 | 
         
            +
            # Kai Zhang (github: https://github.com/cszn)
         
     | 
| 18 | 
         
            +
            # 03/Mar/2019
         
     | 
| 19 | 
         
            +
            # --------------------------------------------
         
     | 
| 20 | 
         
            +
            # https://github.com/twhui/SRGAN-pyTorch
         
     | 
| 21 | 
         
            +
            # https://github.com/xinntao/BasicSR
         
     | 
| 22 | 
         
            +
            # --------------------------------------------
         
     | 
| 23 | 
         
            +
            '''
         
     | 
| 24 | 
         
            +
             
     | 
| 25 | 
         
            +
             
     | 
| 26 | 
         
            +
            IMG_EXTENSIONS = ['.jpg', '.JPG', '.jpeg', '.JPEG', '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP', '.tif']
         
     | 
| 27 | 
         
            +
             
     | 
| 28 | 
         
            +
             
     | 
| 29 | 
         
            +
            def is_image_file(filename):
         
     | 
| 30 | 
         
            +
                return any(filename.endswith(extension) for extension in IMG_EXTENSIONS)
         
     | 
| 31 | 
         
            +
             
     | 
| 32 | 
         
            +
             
     | 
| 33 | 
         
            +
            def get_timestamp():
         
     | 
| 34 | 
         
            +
                return datetime.now().strftime('%y%m%d-%H%M%S')
         
     | 
| 35 | 
         
            +
             
     | 
| 36 | 
         
            +
             
     | 
| 37 | 
         
            +
            def imshow(x, title=None, cbar=False, figsize=None):
         
     | 
| 38 | 
         
            +
                plt.figure(figsize=figsize)
         
     | 
| 39 | 
         
            +
                plt.imshow(np.squeeze(x), interpolation='nearest', cmap='gray')
         
     | 
| 40 | 
         
            +
                if title:
         
     | 
| 41 | 
         
            +
                    plt.title(title)
         
     | 
| 42 | 
         
            +
                if cbar:
         
     | 
| 43 | 
         
            +
                    plt.colorbar()
         
     | 
| 44 | 
         
            +
                plt.show()
         
     | 
| 45 | 
         
            +
             
     | 
| 46 | 
         
            +
             
     | 
| 47 | 
         
            +
            def surf(Z, cmap='rainbow', figsize=None):
         
     | 
| 48 | 
         
            +
                plt.figure(figsize=figsize)
         
     | 
| 49 | 
         
            +
                ax3 = plt.axes(projection='3d')
         
     | 
| 50 | 
         
            +
             
     | 
| 51 | 
         
            +
                w, h = Z.shape[:2]
         
     | 
| 52 | 
         
            +
                xx = np.arange(0,w,1)
         
     | 
| 53 | 
         
            +
                yy = np.arange(0,h,1)
         
     | 
| 54 | 
         
            +
                X, Y = np.meshgrid(xx, yy)
         
     | 
| 55 | 
         
            +
                ax3.plot_surface(X,Y,Z,cmap=cmap)
         
     | 
| 56 | 
         
            +
                #ax3.contour(X,Y,Z, zdim='z',offset=-2,cmap=cmap)
         
     | 
| 57 | 
         
            +
                plt.show()
         
     | 
| 58 | 
         
            +
             
     | 
| 59 | 
         
            +
             
     | 
| 60 | 
         
            +
            '''
         
     | 
| 61 | 
         
            +
            # --------------------------------------------
         
     | 
| 62 | 
         
            +
            # get image pathes
         
     | 
| 63 | 
         
            +
            # --------------------------------------------
         
     | 
| 64 | 
         
            +
            '''
         
     | 
| 65 | 
         
            +
             
     | 
| 66 | 
         
            +
             
     | 
| 67 | 
         
            +
            def get_image_paths(dataroot):
         
     | 
| 68 | 
         
            +
                paths = None  # return None if dataroot is None
         
     | 
| 69 | 
         
            +
                if isinstance(dataroot, str):
         
     | 
| 70 | 
         
            +
                    paths = sorted(_get_paths_from_images(dataroot))
         
     | 
| 71 | 
         
            +
                elif isinstance(dataroot, list):
         
     | 
| 72 | 
         
            +
                    paths = []
         
     | 
| 73 | 
         
            +
                    for i in dataroot:
         
     | 
| 74 | 
         
            +
                        paths += sorted(_get_paths_from_images(i))
         
     | 
| 75 | 
         
            +
                return paths
         
     | 
| 76 | 
         
            +
             
     | 
| 77 | 
         
            +
             
     | 
| 78 | 
         
            +
            def _get_paths_from_images(path):
         
     | 
| 79 | 
         
            +
                assert os.path.isdir(path), '{:s} is not a valid directory'.format(path)
         
     | 
| 80 | 
         
            +
                images = []
         
     | 
| 81 | 
         
            +
                for dirpath, _, fnames in sorted(os.walk(path)):
         
     | 
| 82 | 
         
            +
                    for fname in sorted(fnames):
         
     | 
| 83 | 
         
            +
                        if is_image_file(fname):
         
     | 
| 84 | 
         
            +
                            img_path = os.path.join(dirpath, fname)
         
     | 
| 85 | 
         
            +
                            images.append(img_path)
         
     | 
| 86 | 
         
            +
                assert images, '{:s} has no valid image file'.format(path)
         
     | 
| 87 | 
         
            +
                return images
         
     | 
| 88 | 
         
            +
             
     | 
| 89 | 
         
            +
             
     | 
| 90 | 
         
            +
            '''
         
     | 
| 91 | 
         
            +
            # --------------------------------------------
         
     | 
| 92 | 
         
            +
            # split large images into small images 
         
     | 
| 93 | 
         
            +
            # --------------------------------------------
         
     | 
| 94 | 
         
            +
            '''
         
     | 
| 95 | 
         
            +
             
     | 
| 96 | 
         
            +
             
     | 
| 97 | 
         
            +
            def patches_from_image(img, p_size=512, p_overlap=64, p_max=800):
         
     | 
| 98 | 
         
            +
                w, h = img.shape[:2]
         
     | 
| 99 | 
         
            +
                patches = []
         
     | 
| 100 | 
         
            +
                if w > p_max and h > p_max:
         
     | 
| 101 | 
         
            +
                    w1 = list(np.arange(0, w-p_size, p_size-p_overlap, dtype=np.int))
         
     | 
| 102 | 
         
            +
                    h1 = list(np.arange(0, h-p_size, p_size-p_overlap, dtype=np.int))
         
     | 
| 103 | 
         
            +
                    w1.append(w-p_size)
         
     | 
| 104 | 
         
            +
                    h1.append(h-p_size)
         
     | 
| 105 | 
         
            +
                    # print(w1)
         
     | 
| 106 | 
         
            +
                    # print(h1)
         
     | 
| 107 | 
         
            +
                    for i in w1:
         
     | 
| 108 | 
         
            +
                        for j in h1:
         
     | 
| 109 | 
         
            +
                            patches.append(img[i:i+p_size, j:j+p_size,:])
         
     | 
| 110 | 
         
            +
                else:
         
     | 
| 111 | 
         
            +
                    patches.append(img)
         
     | 
| 112 | 
         
            +
             
     | 
| 113 | 
         
            +
                return patches
         
     | 
| 114 | 
         
            +
             
     | 
| 115 | 
         
            +
             
     | 
| 116 | 
         
            +
            def imssave(imgs, img_path):
         
     | 
| 117 | 
         
            +
                """
         
     | 
| 118 | 
         
            +
                imgs: list, N images of size WxHxC
         
     | 
| 119 | 
         
            +
                """
         
     | 
| 120 | 
         
            +
                img_name, ext = os.path.splitext(os.path.basename(img_path))
         
     | 
| 121 | 
         
            +
                for i, img in enumerate(imgs):
         
     | 
| 122 | 
         
            +
                    if img.ndim == 3:
         
     | 
| 123 | 
         
            +
                        img = img[:, :, [2, 1, 0]]
         
     | 
| 124 | 
         
            +
                    new_path = os.path.join(os.path.dirname(img_path), img_name+str('_{:04d}'.format(i))+'.png')
         
     | 
| 125 | 
         
            +
                    cv2.imwrite(new_path, img)
         
     | 
| 126 | 
         
            +
             
     | 
| 127 | 
         
            +
             
     | 
| 128 | 
         
            +
            def split_imageset(original_dataroot, taget_dataroot, n_channels=3, p_size=512, p_overlap=96, p_max=800):
         
     | 
| 129 | 
         
            +
                """
         
     | 
| 130 | 
         
            +
                split the large images from original_dataroot into small overlapped images with size (p_size)x(p_size), 
         
     | 
| 131 | 
         
            +
                and save them into taget_dataroot; only the images with larger size than (p_max)x(p_max)
         
     | 
| 132 | 
         
            +
                will be splitted.
         
     | 
| 133 | 
         
            +
             
     | 
| 134 | 
         
            +
                Args:
         
     | 
| 135 | 
         
            +
                    original_dataroot:
         
     | 
| 136 | 
         
            +
                    taget_dataroot:
         
     | 
| 137 | 
         
            +
                    p_size: size of small images
         
     | 
| 138 | 
         
            +
                    p_overlap: patch size in training is a good choice
         
     | 
| 139 | 
         
            +
                    p_max: images with smaller size than (p_max)x(p_max) keep unchanged.
         
     | 
| 140 | 
         
            +
                """
         
     | 
| 141 | 
         
            +
                paths = get_image_paths(original_dataroot)
         
     | 
| 142 | 
         
            +
                for img_path in paths:
         
     | 
| 143 | 
         
            +
                    # img_name, ext = os.path.splitext(os.path.basename(img_path))
         
     | 
| 144 | 
         
            +
                    img = imread_uint(img_path, n_channels=n_channels)
         
     | 
| 145 | 
         
            +
                    patches = patches_from_image(img, p_size, p_overlap, p_max)
         
     | 
| 146 | 
         
            +
                    imssave(patches, os.path.join(taget_dataroot, os.path.basename(img_path)))
         
     | 
| 147 | 
         
            +
                    #if original_dataroot == taget_dataroot:
         
     | 
| 148 | 
         
            +
                    #del img_path
         
     | 
| 149 | 
         
            +
             
     | 
| 150 | 
         
            +
            '''
         
     | 
| 151 | 
         
            +
            # --------------------------------------------
         
     | 
| 152 | 
         
            +
            # makedir
         
     | 
| 153 | 
         
            +
            # --------------------------------------------
         
     | 
| 154 | 
         
            +
            '''
         
     | 
| 155 | 
         
            +
             
     | 
| 156 | 
         
            +
             
     | 
| 157 | 
         
            +
            def mkdir(path):
         
     | 
| 158 | 
         
            +
                if not os.path.exists(path):
         
     | 
| 159 | 
         
            +
                    os.makedirs(path)
         
     | 
| 160 | 
         
            +
             
     | 
| 161 | 
         
            +
             
     | 
| 162 | 
         
            +
            def mkdirs(paths):
         
     | 
| 163 | 
         
            +
                if isinstance(paths, str):
         
     | 
| 164 | 
         
            +
                    mkdir(paths)
         
     | 
| 165 | 
         
            +
                else:
         
     | 
| 166 | 
         
            +
                    for path in paths:
         
     | 
| 167 | 
         
            +
                        mkdir(path)
         
     | 
| 168 | 
         
            +
             
     | 
| 169 | 
         
            +
             
     | 
| 170 | 
         
            +
            def mkdir_and_rename(path):
         
     | 
| 171 | 
         
            +
                if os.path.exists(path):
         
     | 
| 172 | 
         
            +
                    new_name = path + '_archived_' + get_timestamp()
         
     | 
| 173 | 
         
            +
                    print('Path already exists. Rename it to [{:s}]'.format(new_name))
         
     | 
| 174 | 
         
            +
                    os.rename(path, new_name)
         
     | 
| 175 | 
         
            +
                os.makedirs(path)
         
     | 
| 176 | 
         
            +
             
     | 
| 177 | 
         
            +
             
     | 
| 178 | 
         
            +
            '''
         
     | 
| 179 | 
         
            +
            # --------------------------------------------
         
     | 
| 180 | 
         
            +
            # read image from path
         
     | 
| 181 | 
         
            +
            # opencv is fast, but read BGR numpy image
         
     | 
| 182 | 
         
            +
            # --------------------------------------------
         
     | 
| 183 | 
         
            +
            '''
         
     | 
| 184 | 
         
            +
             
     | 
| 185 | 
         
            +
             
     | 
| 186 | 
         
            +
            # --------------------------------------------
         
     | 
| 187 | 
         
            +
            # get uint8 image of size HxWxn_channles (RGB)
         
     | 
| 188 | 
         
            +
            # --------------------------------------------
         
     | 
| 189 | 
         
            +
            def imread_uint(path, n_channels=3):
         
     | 
| 190 | 
         
            +
                #  input: path
         
     | 
| 191 | 
         
            +
                # output: HxWx3(RGB or GGG), or HxWx1 (G)
         
     | 
| 192 | 
         
            +
                if n_channels == 1:
         
     | 
| 193 | 
         
            +
                    img = cv2.imread(path, 0)  # cv2.IMREAD_GRAYSCALE
         
     | 
| 194 | 
         
            +
                    img = np.expand_dims(img, axis=2)  # HxWx1
         
     | 
| 195 | 
         
            +
                elif n_channels == 3:
         
     | 
| 196 | 
         
            +
                    img = cv2.imread(path, cv2.IMREAD_UNCHANGED)  # BGR or G
         
     | 
| 197 | 
         
            +
                    if img.ndim == 2:
         
     | 
| 198 | 
         
            +
                        img = cv2.cvtColor(img, cv2.COLOR_GRAY2RGB)  # GGG
         
     | 
| 199 | 
         
            +
                    else:
         
     | 
| 200 | 
         
            +
                        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)  # RGB
         
     | 
| 201 | 
         
            +
                return img
         
     | 
| 202 | 
         
            +
             
     | 
| 203 | 
         
            +
             
     | 
| 204 | 
         
            +
            # --------------------------------------------
         
     | 
| 205 | 
         
            +
            # matlab's imwrite
         
     | 
| 206 | 
         
            +
            # --------------------------------------------
         
     | 
| 207 | 
         
            +
            def imsave(img, img_path):
         
     | 
| 208 | 
         
            +
                img = np.squeeze(img)
         
     | 
| 209 | 
         
            +
                if img.ndim == 3:
         
     | 
| 210 | 
         
            +
                    img = img[:, :, [2, 1, 0]]
         
     | 
| 211 | 
         
            +
                cv2.imwrite(img_path, img)
         
     | 
| 212 | 
         
            +
             
     | 
| 213 | 
         
            +
            def imwrite(img, img_path):
         
     | 
| 214 | 
         
            +
                img = np.squeeze(img)
         
     | 
| 215 | 
         
            +
                if img.ndim == 3:
         
     | 
| 216 | 
         
            +
                    img = img[:, :, [2, 1, 0]]
         
     | 
| 217 | 
         
            +
                cv2.imwrite(img_path, img)
         
     | 
| 218 | 
         
            +
             
     | 
| 219 | 
         
            +
             
     | 
| 220 | 
         
            +
             
     | 
| 221 | 
         
            +
            # --------------------------------------------
         
     | 
| 222 | 
         
            +
            # get single image of size HxWxn_channles (BGR)
         
     | 
| 223 | 
         
            +
            # --------------------------------------------
         
     | 
| 224 | 
         
            +
            def read_img(path):
         
     | 
| 225 | 
         
            +
                # read image by cv2
         
     | 
| 226 | 
         
            +
                # return: Numpy float32, HWC, BGR, [0,1]
         
     | 
| 227 | 
         
            +
                img = cv2.imread(path, cv2.IMREAD_UNCHANGED)  # cv2.IMREAD_GRAYSCALE
         
     | 
| 228 | 
         
            +
                img = img.astype(np.float32) / 255.
         
     | 
| 229 | 
         
            +
                if img.ndim == 2:
         
     | 
| 230 | 
         
            +
                    img = np.expand_dims(img, axis=2)
         
     | 
| 231 | 
         
            +
                # some images have 4 channels
         
     | 
| 232 | 
         
            +
                if img.shape[2] > 3:
         
     | 
| 233 | 
         
            +
                    img = img[:, :, :3]
         
     | 
| 234 | 
         
            +
                return img
         
     | 
| 235 | 
         
            +
             
     | 
| 236 | 
         
            +
             
     | 
| 237 | 
         
            +
            '''
         
     | 
| 238 | 
         
            +
            # --------------------------------------------
         
     | 
| 239 | 
         
            +
            # image format conversion
         
     | 
| 240 | 
         
            +
            # --------------------------------------------
         
     | 
| 241 | 
         
            +
            # numpy(single) <--->  numpy(uint)
         
     | 
| 242 | 
         
            +
            # numpy(single) <--->  tensor
         
     | 
| 243 | 
         
            +
            # numpy(uint)   <--->  tensor
         
     | 
| 244 | 
         
            +
            # --------------------------------------------
         
     | 
| 245 | 
         
            +
            '''
         
     | 
| 246 | 
         
            +
             
     | 
| 247 | 
         
            +
             
     | 
| 248 | 
         
            +
            # --------------------------------------------
         
     | 
| 249 | 
         
            +
            # numpy(single) [0, 1] <--->  numpy(uint)
         
     | 
| 250 | 
         
            +
            # --------------------------------------------
         
     | 
| 251 | 
         
            +
             
     | 
| 252 | 
         
            +
             
     | 
| 253 | 
         
            +
            def uint2single(img):
         
     | 
| 254 | 
         
            +
             
     | 
| 255 | 
         
            +
                return np.float32(img/255.)
         
     | 
| 256 | 
         
            +
             
     | 
| 257 | 
         
            +
             
     | 
| 258 | 
         
            +
            def single2uint(img):
         
     | 
| 259 | 
         
            +
             
     | 
| 260 | 
         
            +
                return np.uint8((img.clip(0, 1)*255.).round())
         
     | 
| 261 | 
         
            +
             
     | 
| 262 | 
         
            +
             
     | 
| 263 | 
         
            +
            def uint162single(img):
         
     | 
| 264 | 
         
            +
             
     | 
| 265 | 
         
            +
                return np.float32(img/65535.)
         
     | 
| 266 | 
         
            +
             
     | 
| 267 | 
         
            +
             
     | 
| 268 | 
         
            +
            def single2uint16(img):
         
     | 
| 269 | 
         
            +
             
     | 
| 270 | 
         
            +
                return np.uint16((img.clip(0, 1)*65535.).round())
         
     | 
| 271 | 
         
            +
             
     | 
| 272 | 
         
            +
             
     | 
| 273 | 
         
            +
            # --------------------------------------------
         
     | 
| 274 | 
         
            +
            # numpy(uint) (HxWxC or HxW) <--->  tensor
         
     | 
| 275 | 
         
            +
            # --------------------------------------------
         
     | 
| 276 | 
         
            +
             
     | 
| 277 | 
         
            +
             
     | 
| 278 | 
         
            +
            # convert uint to 4-dimensional torch tensor
         
     | 
| 279 | 
         
            +
            def uint2tensor4(img):
         
     | 
| 280 | 
         
            +
                if img.ndim == 2:
         
     | 
| 281 | 
         
            +
                    img = np.expand_dims(img, axis=2)
         
     | 
| 282 | 
         
            +
                return torch.from_numpy(np.ascontiguousarray(img)).permute(2, 0, 1).float().div(255.).unsqueeze(0)
         
     | 
| 283 | 
         
            +
             
     | 
| 284 | 
         
            +
             
     | 
| 285 | 
         
            +
            # convert uint to 3-dimensional torch tensor
         
     | 
| 286 | 
         
            +
            def uint2tensor3(img):
         
     | 
| 287 | 
         
            +
                if img.ndim == 2:
         
     | 
| 288 | 
         
            +
                    img = np.expand_dims(img, axis=2)
         
     | 
| 289 | 
         
            +
                return torch.from_numpy(np.ascontiguousarray(img)).permute(2, 0, 1).float().div(255.)
         
     | 
| 290 | 
         
            +
             
     | 
| 291 | 
         
            +
             
     | 
| 292 | 
         
            +
            # convert 2/3/4-dimensional torch tensor to uint
         
     | 
| 293 | 
         
            +
            def tensor2uint(img):
         
     | 
| 294 | 
         
            +
                img = img.data.squeeze().float().clamp_(0, 1).cpu().numpy()
         
     | 
| 295 | 
         
            +
                if img.ndim == 3:
         
     | 
| 296 | 
         
            +
                    img = np.transpose(img, (1, 2, 0))
         
     | 
| 297 | 
         
            +
                return np.uint8((img*255.0).round())
         
     | 
| 298 | 
         
            +
             
     | 
| 299 | 
         
            +
             
     | 
| 300 | 
         
            +
            # --------------------------------------------
         
     | 
| 301 | 
         
            +
            # numpy(single) (HxWxC) <--->  tensor
         
     | 
| 302 | 
         
            +
            # --------------------------------------------
         
     | 
| 303 | 
         
            +
             
     | 
| 304 | 
         
            +
             
     | 
| 305 | 
         
            +
            # convert single (HxWxC) to 3-dimensional torch tensor
         
     | 
| 306 | 
         
            +
            def single2tensor3(img):
         
     | 
| 307 | 
         
            +
                return torch.from_numpy(np.ascontiguousarray(img)).permute(2, 0, 1).float()
         
     | 
| 308 | 
         
            +
             
     | 
| 309 | 
         
            +
             
     | 
| 310 | 
         
            +
            # convert single (HxWxC) to 4-dimensional torch tensor
         
     | 
| 311 | 
         
            +
            def single2tensor4(img):
         
     | 
| 312 | 
         
            +
                return torch.from_numpy(np.ascontiguousarray(img)).permute(2, 0, 1).float().unsqueeze(0)
         
     | 
| 313 | 
         
            +
             
     | 
| 314 | 
         
            +
             
     | 
| 315 | 
         
            +
            # convert torch tensor to single
         
     | 
| 316 | 
         
            +
            def tensor2single(img):
         
     | 
| 317 | 
         
            +
                img = img.data.squeeze().float().cpu().numpy()
         
     | 
| 318 | 
         
            +
                if img.ndim == 3:
         
     | 
| 319 | 
         
            +
                    img = np.transpose(img, (1, 2, 0))
         
     | 
| 320 | 
         
            +
             
     | 
| 321 | 
         
            +
                return img
         
     | 
| 322 | 
         
            +
             
     | 
| 323 | 
         
            +
            # convert torch tensor to single
         
     | 
| 324 | 
         
            +
            def tensor2single3(img):
         
     | 
| 325 | 
         
            +
                img = img.data.squeeze().float().cpu().numpy()
         
     | 
| 326 | 
         
            +
                if img.ndim == 3:
         
     | 
| 327 | 
         
            +
                    img = np.transpose(img, (1, 2, 0))
         
     | 
| 328 | 
         
            +
                elif img.ndim == 2:
         
     | 
| 329 | 
         
            +
                    img = np.expand_dims(img, axis=2)
         
     | 
| 330 | 
         
            +
                return img
         
     | 
| 331 | 
         
            +
             
     | 
| 332 | 
         
            +
             
     | 
| 333 | 
         
            +
            def single2tensor5(img):
         
     | 
| 334 | 
         
            +
                return torch.from_numpy(np.ascontiguousarray(img)).permute(2, 0, 1, 3).float().unsqueeze(0)
         
     | 
| 335 | 
         
            +
             
     | 
| 336 | 
         
            +
             
     | 
| 337 | 
         
            +
            def single32tensor5(img):
         
     | 
| 338 | 
         
            +
                return torch.from_numpy(np.ascontiguousarray(img)).float().unsqueeze(0).unsqueeze(0)
         
     | 
| 339 | 
         
            +
             
     | 
| 340 | 
         
            +
             
     | 
| 341 | 
         
            +
            def single42tensor4(img):
         
     | 
| 342 | 
         
            +
                return torch.from_numpy(np.ascontiguousarray(img)).permute(2, 0, 1, 3).float()
         
     | 
| 343 | 
         
            +
             
     | 
| 344 | 
         
            +
             
     | 
| 345 | 
         
            +
            # from skimage.io import imread, imsave
         
     | 
| 346 | 
         
            +
            def tensor2img(tensor, out_type=np.uint8, min_max=(0, 1)):
         
     | 
| 347 | 
         
            +
                '''
         
     | 
| 348 | 
         
            +
                Converts a torch Tensor into an image Numpy array of BGR channel order
         
     | 
| 349 | 
         
            +
                Input: 4D(B,(3/1),H,W), 3D(C,H,W), or 2D(H,W), any range, RGB channel order
         
     | 
| 350 | 
         
            +
                Output: 3D(H,W,C) or 2D(H,W), [0,255], np.uint8 (default)
         
     | 
| 351 | 
         
            +
                '''
         
     | 
| 352 | 
         
            +
                tensor = tensor.squeeze().float().cpu().clamp_(*min_max)  # squeeze first, then clamp
         
     | 
| 353 | 
         
            +
                tensor = (tensor - min_max[0]) / (min_max[1] - min_max[0])  # to range [0,1]
         
     | 
| 354 | 
         
            +
                n_dim = tensor.dim()
         
     | 
| 355 | 
         
            +
                if n_dim == 4:
         
     | 
| 356 | 
         
            +
                    n_img = len(tensor)
         
     | 
| 357 | 
         
            +
                    img_np = make_grid(tensor, nrow=int(math.sqrt(n_img)), normalize=False).numpy()
         
     | 
| 358 | 
         
            +
                    img_np = np.transpose(img_np[[2, 1, 0], :, :], (1, 2, 0))  # HWC, BGR
         
     | 
| 359 | 
         
            +
                elif n_dim == 3:
         
     | 
| 360 | 
         
            +
                    img_np = tensor.numpy()
         
     | 
| 361 | 
         
            +
                    img_np = np.transpose(img_np[[2, 1, 0], :, :], (1, 2, 0))  # HWC, BGR
         
     | 
| 362 | 
         
            +
                elif n_dim == 2:
         
     | 
| 363 | 
         
            +
                    img_np = tensor.numpy()
         
     | 
| 364 | 
         
            +
                else:
         
     | 
| 365 | 
         
            +
                    raise TypeError(
         
     | 
| 366 | 
         
            +
                        'Only support 4D, 3D and 2D tensor. But received with dimension: {:d}'.format(n_dim))
         
     | 
| 367 | 
         
            +
                if out_type == np.uint8:
         
     | 
| 368 | 
         
            +
                    img_np = (img_np * 255.0).round()
         
     | 
| 369 | 
         
            +
                    # Important. Unlike matlab, numpy.uint8() WILL NOT round by default.
         
     | 
| 370 | 
         
            +
                return img_np.astype(out_type)
         
     | 
| 371 | 
         
            +
             
     | 
| 372 | 
         
            +
             
     | 
| 373 | 
         
            +
            '''
         
     | 
| 374 | 
         
            +
            # --------------------------------------------
         
     | 
| 375 | 
         
            +
            # Augmentation, flipe and/or rotate
         
     | 
| 376 | 
         
            +
            # --------------------------------------------
         
     | 
| 377 | 
         
            +
            # The following two are enough.
         
     | 
| 378 | 
         
            +
            # (1) augmet_img: numpy image of WxHxC or WxH
         
     | 
| 379 | 
         
            +
            # (2) augment_img_tensor4: tensor image 1xCxWxH
         
     | 
| 380 | 
         
            +
            # --------------------------------------------
         
     | 
| 381 | 
         
            +
            '''
         
     | 
| 382 | 
         
            +
             
     | 
| 383 | 
         
            +
             
     | 
| 384 | 
         
            +
            def augment_img(img, mode=0):
         
     | 
| 385 | 
         
            +
                '''Kai Zhang (github: https://github.com/cszn)
         
     | 
| 386 | 
         
            +
                '''
         
     | 
| 387 | 
         
            +
                if mode == 0:
         
     | 
| 388 | 
         
            +
                    return img
         
     | 
| 389 | 
         
            +
                elif mode == 1:
         
     | 
| 390 | 
         
            +
                    return np.flipud(np.rot90(img))
         
     | 
| 391 | 
         
            +
                elif mode == 2:
         
     | 
| 392 | 
         
            +
                    return np.flipud(img)
         
     | 
| 393 | 
         
            +
                elif mode == 3:
         
     | 
| 394 | 
         
            +
                    return np.rot90(img, k=3)
         
     | 
| 395 | 
         
            +
                elif mode == 4:
         
     | 
| 396 | 
         
            +
                    return np.flipud(np.rot90(img, k=2))
         
     | 
| 397 | 
         
            +
                elif mode == 5:
         
     | 
| 398 | 
         
            +
                    return np.rot90(img)
         
     | 
| 399 | 
         
            +
                elif mode == 6:
         
     | 
| 400 | 
         
            +
                    return np.rot90(img, k=2)
         
     | 
| 401 | 
         
            +
                elif mode == 7:
         
     | 
| 402 | 
         
            +
                    return np.flipud(np.rot90(img, k=3))
         
     | 
| 403 | 
         
            +
             
     | 
| 404 | 
         
            +
             
     | 
| 405 | 
         
            +
            def augment_img_tensor4(img, mode=0):
         
     | 
| 406 | 
         
            +
                '''Kai Zhang (github: https://github.com/cszn)
         
     | 
| 407 | 
         
            +
                '''
         
     | 
| 408 | 
         
            +
                if mode == 0:
         
     | 
| 409 | 
         
            +
                    return img
         
     | 
| 410 | 
         
            +
                elif mode == 1:
         
     | 
| 411 | 
         
            +
                    return img.rot90(1, [2, 3]).flip([2])
         
     | 
| 412 | 
         
            +
                elif mode == 2:
         
     | 
| 413 | 
         
            +
                    return img.flip([2])
         
     | 
| 414 | 
         
            +
                elif mode == 3:
         
     | 
| 415 | 
         
            +
                    return img.rot90(3, [2, 3])
         
     | 
| 416 | 
         
            +
                elif mode == 4:
         
     | 
| 417 | 
         
            +
                    return img.rot90(2, [2, 3]).flip([2])
         
     | 
| 418 | 
         
            +
                elif mode == 5:
         
     | 
| 419 | 
         
            +
                    return img.rot90(1, [2, 3])
         
     | 
| 420 | 
         
            +
                elif mode == 6:
         
     | 
| 421 | 
         
            +
                    return img.rot90(2, [2, 3])
         
     | 
| 422 | 
         
            +
                elif mode == 7:
         
     | 
| 423 | 
         
            +
                    return img.rot90(3, [2, 3]).flip([2])
         
     | 
| 424 | 
         
            +
             
     | 
| 425 | 
         
            +
             
     | 
| 426 | 
         
            +
            def augment_img_tensor(img, mode=0):
         
     | 
| 427 | 
         
            +
                '''Kai Zhang (github: https://github.com/cszn)
         
     | 
| 428 | 
         
            +
                '''
         
     | 
| 429 | 
         
            +
                img_size = img.size()
         
     | 
| 430 | 
         
            +
                img_np = img.data.cpu().numpy()
         
     | 
| 431 | 
         
            +
                if len(img_size) == 3:
         
     | 
| 432 | 
         
            +
                    img_np = np.transpose(img_np, (1, 2, 0))
         
     | 
| 433 | 
         
            +
                elif len(img_size) == 4:
         
     | 
| 434 | 
         
            +
                    img_np = np.transpose(img_np, (2, 3, 1, 0))
         
     | 
| 435 | 
         
            +
                img_np = augment_img(img_np, mode=mode)
         
     | 
| 436 | 
         
            +
                img_tensor = torch.from_numpy(np.ascontiguousarray(img_np))
         
     | 
| 437 | 
         
            +
                if len(img_size) == 3:
         
     | 
| 438 | 
         
            +
                    img_tensor = img_tensor.permute(2, 0, 1)
         
     | 
| 439 | 
         
            +
                elif len(img_size) == 4:
         
     | 
| 440 | 
         
            +
                    img_tensor = img_tensor.permute(3, 2, 0, 1)
         
     | 
| 441 | 
         
            +
             
     | 
| 442 | 
         
            +
                return img_tensor.type_as(img)
         
     | 
| 443 | 
         
            +
             
     | 
| 444 | 
         
            +
             
     | 
| 445 | 
         
            +
            def augment_img_np3(img, mode=0):
         
     | 
| 446 | 
         
            +
                if mode == 0:
         
     | 
| 447 | 
         
            +
                    return img
         
     | 
| 448 | 
         
            +
                elif mode == 1:
         
     | 
| 449 | 
         
            +
                    return img.transpose(1, 0, 2)
         
     | 
| 450 | 
         
            +
                elif mode == 2:
         
     | 
| 451 | 
         
            +
                    return img[::-1, :, :]
         
     | 
| 452 | 
         
            +
                elif mode == 3:
         
     | 
| 453 | 
         
            +
                    img = img[::-1, :, :]
         
     | 
| 454 | 
         
            +
                    img = img.transpose(1, 0, 2)
         
     | 
| 455 | 
         
            +
                    return img
         
     | 
| 456 | 
         
            +
                elif mode == 4:
         
     | 
| 457 | 
         
            +
                    return img[:, ::-1, :]
         
     | 
| 458 | 
         
            +
                elif mode == 5:
         
     | 
| 459 | 
         
            +
                    img = img[:, ::-1, :]
         
     | 
| 460 | 
         
            +
                    img = img.transpose(1, 0, 2)
         
     | 
| 461 | 
         
            +
                    return img
         
     | 
| 462 | 
         
            +
                elif mode == 6:
         
     | 
| 463 | 
         
            +
                    img = img[:, ::-1, :]
         
     | 
| 464 | 
         
            +
                    img = img[::-1, :, :]
         
     | 
| 465 | 
         
            +
                    return img
         
     | 
| 466 | 
         
            +
                elif mode == 7:
         
     | 
| 467 | 
         
            +
                    img = img[:, ::-1, :]
         
     | 
| 468 | 
         
            +
                    img = img[::-1, :, :]
         
     | 
| 469 | 
         
            +
                    img = img.transpose(1, 0, 2)
         
     | 
| 470 | 
         
            +
                    return img
         
     | 
| 471 | 
         
            +
             
     | 
| 472 | 
         
            +
             
     | 
| 473 | 
         
            +
            def augment_imgs(img_list, hflip=True, rot=True):
         
     | 
| 474 | 
         
            +
                # horizontal flip OR rotate
         
     | 
| 475 | 
         
            +
                hflip = hflip and random.random() < 0.5
         
     | 
| 476 | 
         
            +
                vflip = rot and random.random() < 0.5
         
     | 
| 477 | 
         
            +
                rot90 = rot and random.random() < 0.5
         
     | 
| 478 | 
         
            +
             
     | 
| 479 | 
         
            +
                def _augment(img):
         
     | 
| 480 | 
         
            +
                    if hflip:
         
     | 
| 481 | 
         
            +
                        img = img[:, ::-1, :]
         
     | 
| 482 | 
         
            +
                    if vflip:
         
     | 
| 483 | 
         
            +
                        img = img[::-1, :, :]
         
     | 
| 484 | 
         
            +
                    if rot90:
         
     | 
| 485 | 
         
            +
                        img = img.transpose(1, 0, 2)
         
     | 
| 486 | 
         
            +
                    return img
         
     | 
| 487 | 
         
            +
             
     | 
| 488 | 
         
            +
                return [_augment(img) for img in img_list]
         
     | 
| 489 | 
         
            +
             
     | 
| 490 | 
         
            +
             
     | 
| 491 | 
         
            +
            '''
         
     | 
| 492 | 
         
            +
            # --------------------------------------------
         
     | 
| 493 | 
         
            +
            # modcrop and shave
         
     | 
| 494 | 
         
            +
            # --------------------------------------------
         
     | 
| 495 | 
         
            +
            '''
         
     | 
| 496 | 
         
            +
             
     | 
| 497 | 
         
            +
             
     | 
| 498 | 
         
            +
            def modcrop(img_in, scale):
         
     | 
| 499 | 
         
            +
                # img_in: Numpy, HWC or HW
         
     | 
| 500 | 
         
            +
                img = np.copy(img_in)
         
     | 
| 501 | 
         
            +
                if img.ndim == 2:
         
     | 
| 502 | 
         
            +
                    H, W = img.shape
         
     | 
| 503 | 
         
            +
                    H_r, W_r = H % scale, W % scale
         
     | 
| 504 | 
         
            +
                    img = img[:H - H_r, :W - W_r]
         
     | 
| 505 | 
         
            +
                elif img.ndim == 3:
         
     | 
| 506 | 
         
            +
                    H, W, C = img.shape
         
     | 
| 507 | 
         
            +
                    H_r, W_r = H % scale, W % scale
         
     | 
| 508 | 
         
            +
                    img = img[:H - H_r, :W - W_r, :]
         
     | 
| 509 | 
         
            +
                else:
         
     | 
| 510 | 
         
            +
                    raise ValueError('Wrong img ndim: [{:d}].'.format(img.ndim))
         
     | 
| 511 | 
         
            +
                return img
         
     | 
| 512 | 
         
            +
             
     | 
| 513 | 
         
            +
             
     | 
| 514 | 
         
            +
            def shave(img_in, border=0):
         
     | 
| 515 | 
         
            +
                # img_in: Numpy, HWC or HW
         
     | 
| 516 | 
         
            +
                img = np.copy(img_in)
         
     | 
| 517 | 
         
            +
                h, w = img.shape[:2]
         
     | 
| 518 | 
         
            +
                img = img[border:h-border, border:w-border]
         
     | 
| 519 | 
         
            +
                return img
         
     | 
| 520 | 
         
            +
             
     | 
| 521 | 
         
            +
             
     | 
| 522 | 
         
            +
            '''
         
     | 
| 523 | 
         
            +
            # --------------------------------------------
         
     | 
| 524 | 
         
            +
            # image processing process on numpy image
         
     | 
| 525 | 
         
            +
            # channel_convert(in_c, tar_type, img_list):
         
     | 
| 526 | 
         
            +
            # rgb2ycbcr(img, only_y=True):
         
     | 
| 527 | 
         
            +
            # bgr2ycbcr(img, only_y=True):
         
     | 
| 528 | 
         
            +
            # ycbcr2rgb(img):
         
     | 
| 529 | 
         
            +
            # --------------------------------------------
         
     | 
| 530 | 
         
            +
            '''
         
     | 
| 531 | 
         
            +
             
     | 
| 532 | 
         
            +
             
     | 
| 533 | 
         
            +
            def rgb2ycbcr(img, only_y=True):
         
     | 
| 534 | 
         
            +
                '''same as matlab rgb2ycbcr
         
     | 
| 535 | 
         
            +
                only_y: only return Y channel
         
     | 
| 536 | 
         
            +
                Input:
         
     | 
| 537 | 
         
            +
                    uint8, [0, 255]
         
     | 
| 538 | 
         
            +
                    float, [0, 1]
         
     | 
| 539 | 
         
            +
                '''
         
     | 
| 540 | 
         
            +
                in_img_type = img.dtype
         
     | 
| 541 | 
         
            +
                img.astype(np.float32)
         
     | 
| 542 | 
         
            +
                if in_img_type != np.uint8:
         
     | 
| 543 | 
         
            +
                    img *= 255.
         
     | 
| 544 | 
         
            +
                # convert
         
     | 
| 545 | 
         
            +
                if only_y:
         
     | 
| 546 | 
         
            +
                    rlt = np.dot(img, [65.481, 128.553, 24.966]) / 255.0 + 16.0
         
     | 
| 547 | 
         
            +
                else:
         
     | 
| 548 | 
         
            +
                    rlt = np.matmul(img, [[65.481, -37.797, 112.0], [128.553, -74.203, -93.786],
         
     | 
| 549 | 
         
            +
                                          [24.966, 112.0, -18.214]]) / 255.0 + [16, 128, 128]
         
     | 
| 550 | 
         
            +
                if in_img_type == np.uint8:
         
     | 
| 551 | 
         
            +
                    rlt = rlt.round()
         
     | 
| 552 | 
         
            +
                else:
         
     | 
| 553 | 
         
            +
                    rlt /= 255.
         
     | 
| 554 | 
         
            +
                return rlt.astype(in_img_type)
         
     | 
| 555 | 
         
            +
             
     | 
| 556 | 
         
            +
             
     | 
| 557 | 
         
            +
            def ycbcr2rgb(img):
         
     | 
| 558 | 
         
            +
                '''same as matlab ycbcr2rgb
         
     | 
| 559 | 
         
            +
                Input:
         
     | 
| 560 | 
         
            +
                    uint8, [0, 255]
         
     | 
| 561 | 
         
            +
                    float, [0, 1]
         
     | 
| 562 | 
         
            +
                '''
         
     | 
| 563 | 
         
            +
                in_img_type = img.dtype
         
     | 
| 564 | 
         
            +
                img.astype(np.float32)
         
     | 
| 565 | 
         
            +
                if in_img_type != np.uint8:
         
     | 
| 566 | 
         
            +
                    img *= 255.
         
     | 
| 567 | 
         
            +
                # convert
         
     | 
| 568 | 
         
            +
                rlt = np.matmul(img, [[0.00456621, 0.00456621, 0.00456621], [0, -0.00153632, 0.00791071],
         
     | 
| 569 | 
         
            +
                                      [0.00625893, -0.00318811, 0]]) * 255.0 + [-222.921, 135.576, -276.836]
         
     | 
| 570 | 
         
            +
                rlt = np.clip(rlt, 0, 255)
         
     | 
| 571 | 
         
            +
                if in_img_type == np.uint8:
         
     | 
| 572 | 
         
            +
                    rlt = rlt.round()
         
     | 
| 573 | 
         
            +
                else:
         
     | 
| 574 | 
         
            +
                    rlt /= 255.
         
     | 
| 575 | 
         
            +
                return rlt.astype(in_img_type)
         
     | 
| 576 | 
         
            +
             
     | 
| 577 | 
         
            +
             
     | 
| 578 | 
         
            +
            def bgr2ycbcr(img, only_y=True):
         
     | 
| 579 | 
         
            +
                '''bgr version of rgb2ycbcr
         
     | 
| 580 | 
         
            +
                only_y: only return Y channel
         
     | 
| 581 | 
         
            +
                Input:
         
     | 
| 582 | 
         
            +
                    uint8, [0, 255]
         
     | 
| 583 | 
         
            +
                    float, [0, 1]
         
     | 
| 584 | 
         
            +
                '''
         
     | 
| 585 | 
         
            +
                in_img_type = img.dtype
         
     | 
| 586 | 
         
            +
                img.astype(np.float32)
         
     | 
| 587 | 
         
            +
                if in_img_type != np.uint8:
         
     | 
| 588 | 
         
            +
                    img *= 255.
         
     | 
| 589 | 
         
            +
                # convert
         
     | 
| 590 | 
         
            +
                if only_y:
         
     | 
| 591 | 
         
            +
                    rlt = np.dot(img, [24.966, 128.553, 65.481]) / 255.0 + 16.0
         
     | 
| 592 | 
         
            +
                else:
         
     | 
| 593 | 
         
            +
                    rlt = np.matmul(img, [[24.966, 112.0, -18.214], [128.553, -74.203, -93.786],
         
     | 
| 594 | 
         
            +
                                          [65.481, -37.797, 112.0]]) / 255.0 + [16, 128, 128]
         
     | 
| 595 | 
         
            +
                if in_img_type == np.uint8:
         
     | 
| 596 | 
         
            +
                    rlt = rlt.round()
         
     | 
| 597 | 
         
            +
                else:
         
     | 
| 598 | 
         
            +
                    rlt /= 255.
         
     | 
| 599 | 
         
            +
                return rlt.astype(in_img_type)
         
     | 
| 600 | 
         
            +
             
     | 
| 601 | 
         
            +
             
     | 
| 602 | 
         
            +
            def channel_convert(in_c, tar_type, img_list):
         
     | 
| 603 | 
         
            +
                # conversion among BGR, gray and y
         
     | 
| 604 | 
         
            +
                if in_c == 3 and tar_type == 'gray':  # BGR to gray
         
     | 
| 605 | 
         
            +
                    gray_list = [cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) for img in img_list]
         
     | 
| 606 | 
         
            +
                    return [np.expand_dims(img, axis=2) for img in gray_list]
         
     | 
| 607 | 
         
            +
                elif in_c == 3 and tar_type == 'y':  # BGR to y
         
     | 
| 608 | 
         
            +
                    y_list = [bgr2ycbcr(img, only_y=True) for img in img_list]
         
     | 
| 609 | 
         
            +
                    return [np.expand_dims(img, axis=2) for img in y_list]
         
     | 
| 610 | 
         
            +
                elif in_c == 1 and tar_type == 'RGB':  # gray/y to BGR
         
     | 
| 611 | 
         
            +
                    return [cv2.cvtColor(img, cv2.COLOR_GRAY2BGR) for img in img_list]
         
     | 
| 612 | 
         
            +
                else:
         
     | 
| 613 | 
         
            +
                    return img_list
         
     | 
| 614 | 
         
            +
             
     | 
| 615 | 
         
            +
             
     | 
| 616 | 
         
            +
            '''
         
     | 
| 617 | 
         
            +
            # --------------------------------------------
         
     | 
| 618 | 
         
            +
            # metric, PSNR, SSIM and PSNRB
         
     | 
| 619 | 
         
            +
            # --------------------------------------------
         
     | 
| 620 | 
         
            +
            '''
         
     | 
| 621 | 
         
            +
             
     | 
| 622 | 
         
            +
             
     | 
| 623 | 
         
            +
            # --------------------------------------------
         
     | 
| 624 | 
         
            +
            # PSNR
         
     | 
| 625 | 
         
            +
            # --------------------------------------------
         
     | 
| 626 | 
         
            +
            def calculate_psnr(img1, img2, border=0):
         
     | 
| 627 | 
         
            +
                # img1 and img2 have range [0, 255]
         
     | 
| 628 | 
         
            +
                #img1 = img1.squeeze()
         
     | 
| 629 | 
         
            +
                #img2 = img2.squeeze()
         
     | 
| 630 | 
         
            +
                if not img1.shape == img2.shape:
         
     | 
| 631 | 
         
            +
                    raise ValueError('Input images must have the same dimensions.')
         
     | 
| 632 | 
         
            +
                h, w = img1.shape[:2]
         
     | 
| 633 | 
         
            +
                img1 = img1[border:h-border, border:w-border]
         
     | 
| 634 | 
         
            +
                img2 = img2[border:h-border, border:w-border]
         
     | 
| 635 | 
         
            +
             
     | 
| 636 | 
         
            +
                img1 = img1.astype(np.float64)
         
     | 
| 637 | 
         
            +
                img2 = img2.astype(np.float64)
         
     | 
| 638 | 
         
            +
                mse = np.mean((img1 - img2)**2)
         
     | 
| 639 | 
         
            +
                if mse == 0:
         
     | 
| 640 | 
         
            +
                    return float('inf')
         
     | 
| 641 | 
         
            +
                return 20 * math.log10(255.0 / math.sqrt(mse))
         
     | 
| 642 | 
         
            +
             
     | 
| 643 | 
         
            +
             
     | 
| 644 | 
         
            +
            # --------------------------------------------
         
     | 
| 645 | 
         
            +
            # SSIM
         
     | 
| 646 | 
         
            +
            # --------------------------------------------
         
     | 
| 647 | 
         
            +
            def calculate_ssim(img1, img2, border=0):
         
     | 
| 648 | 
         
            +
                '''calculate SSIM
         
     | 
| 649 | 
         
            +
                the same outputs as MATLAB's
         
     | 
| 650 | 
         
            +
                img1, img2: [0, 255]
         
     | 
| 651 | 
         
            +
                '''
         
     | 
| 652 | 
         
            +
                #img1 = img1.squeeze()
         
     | 
| 653 | 
         
            +
                #img2 = img2.squeeze()
         
     | 
| 654 | 
         
            +
                if not img1.shape == img2.shape:
         
     | 
| 655 | 
         
            +
                    raise ValueError('Input images must have the same dimensions.')
         
     | 
| 656 | 
         
            +
                h, w = img1.shape[:2]
         
     | 
| 657 | 
         
            +
                img1 = img1[border:h-border, border:w-border]
         
     | 
| 658 | 
         
            +
                img2 = img2[border:h-border, border:w-border]
         
     | 
| 659 | 
         
            +
             
     | 
| 660 | 
         
            +
                if img1.ndim == 2:
         
     | 
| 661 | 
         
            +
                    return ssim(img1, img2)
         
     | 
| 662 | 
         
            +
                elif img1.ndim == 3:
         
     | 
| 663 | 
         
            +
                    if img1.shape[2] == 3:
         
     | 
| 664 | 
         
            +
                        ssims = []
         
     | 
| 665 | 
         
            +
                        for i in range(3):
         
     | 
| 666 | 
         
            +
                            ssims.append(ssim(img1[:,:,i], img2[:,:,i]))
         
     | 
| 667 | 
         
            +
                        return np.array(ssims).mean()
         
     | 
| 668 | 
         
            +
                    elif img1.shape[2] == 1:
         
     | 
| 669 | 
         
            +
                        return ssim(np.squeeze(img1), np.squeeze(img2))
         
     | 
| 670 | 
         
            +
                else:
         
     | 
| 671 | 
         
            +
                    raise ValueError('Wrong input image dimensions.')
         
     | 
| 672 | 
         
            +
             
     | 
| 673 | 
         
            +
             
     | 
| 674 | 
         
            +
            def ssim(img1, img2):
         
     | 
| 675 | 
         
            +
                C1 = (0.01 * 255)**2
         
     | 
| 676 | 
         
            +
                C2 = (0.03 * 255)**2
         
     | 
| 677 | 
         
            +
             
     | 
| 678 | 
         
            +
                img1 = img1.astype(np.float64)
         
     | 
| 679 | 
         
            +
                img2 = img2.astype(np.float64)
         
     | 
| 680 | 
         
            +
                kernel = cv2.getGaussianKernel(11, 1.5)
         
     | 
| 681 | 
         
            +
                window = np.outer(kernel, kernel.transpose())
         
     | 
| 682 | 
         
            +
             
     | 
| 683 | 
         
            +
                mu1 = cv2.filter2D(img1, -1, window)[5:-5, 5:-5]  # valid
         
     | 
| 684 | 
         
            +
                mu2 = cv2.filter2D(img2, -1, window)[5:-5, 5:-5]
         
     | 
| 685 | 
         
            +
                mu1_sq = mu1**2
         
     | 
| 686 | 
         
            +
                mu2_sq = mu2**2
         
     | 
| 687 | 
         
            +
                mu1_mu2 = mu1 * mu2
         
     | 
| 688 | 
         
            +
                sigma1_sq = cv2.filter2D(img1**2, -1, window)[5:-5, 5:-5] - mu1_sq
         
     | 
| 689 | 
         
            +
                sigma2_sq = cv2.filter2D(img2**2, -1, window)[5:-5, 5:-5] - mu2_sq
         
     | 
| 690 | 
         
            +
                sigma12 = cv2.filter2D(img1 * img2, -1, window)[5:-5, 5:-5] - mu1_mu2
         
     | 
| 691 | 
         
            +
             
     | 
| 692 | 
         
            +
                ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) *
         
     | 
| 693 | 
         
            +
                                                                        (sigma1_sq + sigma2_sq + C2))
         
     | 
| 694 | 
         
            +
                return ssim_map.mean()
         
     | 
| 695 | 
         
            +
             
     | 
| 696 | 
         
            +
             
     | 
| 697 | 
         
            +
            def _blocking_effect_factor(im):
         
     | 
| 698 | 
         
            +
                block_size = 8
         
     | 
| 699 | 
         
            +
             
     | 
| 700 | 
         
            +
                block_horizontal_positions = torch.arange(7, im.shape[3] - 1, 8)
         
     | 
| 701 | 
         
            +
                block_vertical_positions = torch.arange(7, im.shape[2] - 1, 8)
         
     | 
| 702 | 
         
            +
             
     | 
| 703 | 
         
            +
                horizontal_block_difference = (
         
     | 
| 704 | 
         
            +
                            (im[:, :, :, block_horizontal_positions] - im[:, :, :, block_horizontal_positions + 1]) ** 2).sum(
         
     | 
| 705 | 
         
            +
                    3).sum(2).sum(1)
         
     | 
| 706 | 
         
            +
                vertical_block_difference = (
         
     | 
| 707 | 
         
            +
                            (im[:, :, block_vertical_positions, :] - im[:, :, block_vertical_positions + 1, :]) ** 2).sum(3).sum(
         
     | 
| 708 | 
         
            +
                    2).sum(1)
         
     | 
| 709 | 
         
            +
             
     | 
| 710 | 
         
            +
                nonblock_horizontal_positions = np.setdiff1d(torch.arange(0, im.shape[3] - 1), block_horizontal_positions)
         
     | 
| 711 | 
         
            +
                nonblock_vertical_positions = np.setdiff1d(torch.arange(0, im.shape[2] - 1), block_vertical_positions)
         
     | 
| 712 | 
         
            +
             
     | 
| 713 | 
         
            +
                horizontal_nonblock_difference = (
         
     | 
| 714 | 
         
            +
                            (im[:, :, :, nonblock_horizontal_positions] - im[:, :, :, nonblock_horizontal_positions + 1]) ** 2).sum(
         
     | 
| 715 | 
         
            +
                    3).sum(2).sum(1)
         
     | 
| 716 | 
         
            +
                vertical_nonblock_difference = (
         
     | 
| 717 | 
         
            +
                            (im[:, :, nonblock_vertical_positions, :] - im[:, :, nonblock_vertical_positions + 1, :]) ** 2).sum(
         
     | 
| 718 | 
         
            +
                    3).sum(2).sum(1)
         
     | 
| 719 | 
         
            +
             
     | 
| 720 | 
         
            +
                n_boundary_horiz = im.shape[2] * (im.shape[3] // block_size - 1)
         
     | 
| 721 | 
         
            +
                n_boundary_vert = im.shape[3] * (im.shape[2] // block_size - 1)
         
     | 
| 722 | 
         
            +
                boundary_difference = (horizontal_block_difference + vertical_block_difference) / (
         
     | 
| 723 | 
         
            +
                            n_boundary_horiz + n_boundary_vert)
         
     | 
| 724 | 
         
            +
             
     | 
| 725 | 
         
            +
                n_nonboundary_horiz = im.shape[2] * (im.shape[3] - 1) - n_boundary_horiz
         
     | 
| 726 | 
         
            +
                n_nonboundary_vert = im.shape[3] * (im.shape[2] - 1) - n_boundary_vert
         
     | 
| 727 | 
         
            +
                nonboundary_difference = (horizontal_nonblock_difference + vertical_nonblock_difference) / (
         
     | 
| 728 | 
         
            +
                            n_nonboundary_horiz + n_nonboundary_vert)
         
     | 
| 729 | 
         
            +
             
     | 
| 730 | 
         
            +
                scaler = np.log2(block_size) / np.log2(min([im.shape[2], im.shape[3]]))
         
     | 
| 731 | 
         
            +
                bef = scaler * (boundary_difference - nonboundary_difference)
         
     | 
| 732 | 
         
            +
             
     | 
| 733 | 
         
            +
                bef[boundary_difference <= nonboundary_difference] = 0
         
     | 
| 734 | 
         
            +
                return bef
         
     | 
| 735 | 
         
            +
             
     | 
| 736 | 
         
            +
             
     | 
| 737 | 
         
            +
            def calculate_psnrb(img1, img2, border=0):
         
     | 
| 738 | 
         
            +
                """Calculate PSNR-B (Peak Signal-to-Noise Ratio).
         
     | 
| 739 | 
         
            +
                Ref: Quality assessment of deblocked images, for JPEG image deblocking evaluation
         
     | 
| 740 | 
         
            +
                # https://gitlab.com/Queuecumber/quantization-guided-ac/-/blob/master/metrics/psnrb.py
         
     | 
| 741 | 
         
            +
                Args:
         
     | 
| 742 | 
         
            +
                    img1 (ndarray): Images with range [0, 255].
         
     | 
| 743 | 
         
            +
                    img2 (ndarray): Images with range [0, 255].
         
     | 
| 744 | 
         
            +
                    border (int): Cropped pixels in each edge of an image. These
         
     | 
| 745 | 
         
            +
                        pixels are not involved in the PSNR calculation.
         
     | 
| 746 | 
         
            +
                    test_y_channel (bool): Test on Y channel of YCbCr. Default: False.
         
     | 
| 747 | 
         
            +
                Returns:
         
     | 
| 748 | 
         
            +
                    float: psnr result.
         
     | 
| 749 | 
         
            +
                """
         
     | 
| 750 | 
         
            +
             
     | 
| 751 | 
         
            +
                if not img1.shape == img2.shape:
         
     | 
| 752 | 
         
            +
                    raise ValueError('Input images must have the same dimensions.')
         
     | 
| 753 | 
         
            +
             
     | 
| 754 | 
         
            +
                if img1.ndim == 2:
         
     | 
| 755 | 
         
            +
                    img1, img2 = np.expand_dims(img1, 2), np.expand_dims(img2, 2)
         
     | 
| 756 | 
         
            +
             
     | 
| 757 | 
         
            +
                h, w = img1.shape[:2]
         
     | 
| 758 | 
         
            +
                img1 = img1[border:h-border, border:w-border]
         
     | 
| 759 | 
         
            +
                img2 = img2[border:h-border, border:w-border]
         
     | 
| 760 | 
         
            +
             
     | 
| 761 | 
         
            +
                img1 = img1.astype(np.float64)
         
     | 
| 762 | 
         
            +
                img2 = img2.astype(np.float64)
         
     | 
| 763 | 
         
            +
             
     | 
| 764 | 
         
            +
                # follow https://gitlab.com/Queuecumber/quantization-guided-ac/-/blob/master/metrics/psnrb.py
         
     | 
| 765 | 
         
            +
                img1 = torch.from_numpy(img1).permute(2, 0, 1).unsqueeze(0) / 255.
         
     | 
| 766 | 
         
            +
                img2 = torch.from_numpy(img2).permute(2, 0, 1).unsqueeze(0) / 255.
         
     | 
| 767 | 
         
            +
             
     | 
| 768 | 
         
            +
                total = 0
         
     | 
| 769 | 
         
            +
                for c in range(img1.shape[1]):
         
     | 
| 770 | 
         
            +
                    mse = torch.nn.functional.mse_loss(img1[:, c:c + 1, :, :], img2[:, c:c + 1, :, :], reduction='none')
         
     | 
| 771 | 
         
            +
                    bef = _blocking_effect_factor(img1[:, c:c + 1, :, :])
         
     | 
| 772 | 
         
            +
             
     | 
| 773 | 
         
            +
                    mse = mse.view(mse.shape[0], -1).mean(1)
         
     | 
| 774 | 
         
            +
                    total += 10 * torch.log10(1 / (mse + bef))
         
     | 
| 775 | 
         
            +
             
     | 
| 776 | 
         
            +
                return float(total) / img1.shape[1]
         
     | 
| 777 | 
         
            +
             
     | 
| 778 | 
         
            +
            '''
         
     | 
| 779 | 
         
            +
            # --------------------------------------------
         
     | 
| 780 | 
         
            +
            # matlab's bicubic imresize (numpy and torch) [0, 1]
         
     | 
| 781 | 
         
            +
            # --------------------------------------------
         
     | 
| 782 | 
         
            +
            '''
         
     | 
| 783 | 
         
            +
             
     | 
| 784 | 
         
            +
             
     | 
| 785 | 
         
            +
            # matlab 'imresize' function, now only support 'bicubic'
         
     | 
| 786 | 
         
            +
            def cubic(x):
         
     | 
| 787 | 
         
            +
                absx = torch.abs(x)
         
     | 
| 788 | 
         
            +
                absx2 = absx**2
         
     | 
| 789 | 
         
            +
                absx3 = absx**3
         
     | 
| 790 | 
         
            +
                return (1.5*absx3 - 2.5*absx2 + 1) * ((absx <= 1).type_as(absx)) + \
         
     | 
| 791 | 
         
            +
                    (-0.5*absx3 + 2.5*absx2 - 4*absx + 2) * (((absx > 1)*(absx <= 2)).type_as(absx))
         
     | 
| 792 | 
         
            +
             
     | 
| 793 | 
         
            +
             
     | 
| 794 | 
         
            +
            def calculate_weights_indices(in_length, out_length, scale, kernel, kernel_width, antialiasing):
         
     | 
| 795 | 
         
            +
                if (scale < 1) and (antialiasing):
         
     | 
| 796 | 
         
            +
                    # Use a modified kernel to simultaneously interpolate and antialias- larger kernel width
         
     | 
| 797 | 
         
            +
                    kernel_width = kernel_width / scale
         
     | 
| 798 | 
         
            +
             
     | 
| 799 | 
         
            +
                # Output-space coordinates
         
     | 
| 800 | 
         
            +
                x = torch.linspace(1, out_length, out_length)
         
     | 
| 801 | 
         
            +
             
     | 
| 802 | 
         
            +
                # Input-space coordinates. Calculate the inverse mapping such that 0.5
         
     | 
| 803 | 
         
            +
                # in output space maps to 0.5 in input space, and 0.5+scale in output
         
     | 
| 804 | 
         
            +
                # space maps to 1.5 in input space.
         
     | 
| 805 | 
         
            +
                u = x / scale + 0.5 * (1 - 1 / scale)
         
     | 
| 806 | 
         
            +
             
     | 
| 807 | 
         
            +
                # What is the left-most pixel that can be involved in the computation?
         
     | 
| 808 | 
         
            +
                left = torch.floor(u - kernel_width / 2)
         
     | 
| 809 | 
         
            +
             
     | 
| 810 | 
         
            +
                # What is the maximum number of pixels that can be involved in the
         
     | 
| 811 | 
         
            +
                # computation?  Note: it's OK to use an extra pixel here; if the
         
     | 
| 812 | 
         
            +
                # corresponding weights are all zero, it will be eliminated at the end
         
     | 
| 813 | 
         
            +
                # of this function.
         
     | 
| 814 | 
         
            +
                P = math.ceil(kernel_width) + 2
         
     | 
| 815 | 
         
            +
             
     | 
| 816 | 
         
            +
                # The indices of the input pixels involved in computing the k-th output
         
     | 
| 817 | 
         
            +
                # pixel are in row k of the indices matrix.
         
     | 
| 818 | 
         
            +
                indices = left.view(out_length, 1).expand(out_length, P) + torch.linspace(0, P - 1, P).view(
         
     | 
| 819 | 
         
            +
                    1, P).expand(out_length, P)
         
     | 
| 820 | 
         
            +
             
     | 
| 821 | 
         
            +
                # The weights used to compute the k-th output pixel are in row k of the
         
     | 
| 822 | 
         
            +
                # weights matrix.
         
     | 
| 823 | 
         
            +
                distance_to_center = u.view(out_length, 1).expand(out_length, P) - indices
         
     | 
| 824 | 
         
            +
                # apply cubic kernel
         
     | 
| 825 | 
         
            +
                if (scale < 1) and (antialiasing):
         
     | 
| 826 | 
         
            +
                    weights = scale * cubic(distance_to_center * scale)
         
     | 
| 827 | 
         
            +
                else:
         
     | 
| 828 | 
         
            +
                    weights = cubic(distance_to_center)
         
     | 
| 829 | 
         
            +
                # Normalize the weights matrix so that each row sums to 1.
         
     | 
| 830 | 
         
            +
                weights_sum = torch.sum(weights, 1).view(out_length, 1)
         
     | 
| 831 | 
         
            +
                weights = weights / weights_sum.expand(out_length, P)
         
     | 
| 832 | 
         
            +
             
     | 
| 833 | 
         
            +
                # If a column in weights is all zero, get rid of it. only consider the first and last column.
         
     | 
| 834 | 
         
            +
                weights_zero_tmp = torch.sum((weights == 0), 0)
         
     | 
| 835 | 
         
            +
                if not math.isclose(weights_zero_tmp[0], 0, rel_tol=1e-6):
         
     | 
| 836 | 
         
            +
                    indices = indices.narrow(1, 1, P - 2)
         
     | 
| 837 | 
         
            +
                    weights = weights.narrow(1, 1, P - 2)
         
     | 
| 838 | 
         
            +
                if not math.isclose(weights_zero_tmp[-1], 0, rel_tol=1e-6):
         
     | 
| 839 | 
         
            +
                    indices = indices.narrow(1, 0, P - 2)
         
     | 
| 840 | 
         
            +
                    weights = weights.narrow(1, 0, P - 2)
         
     | 
| 841 | 
         
            +
                weights = weights.contiguous()
         
     | 
| 842 | 
         
            +
                indices = indices.contiguous()
         
     | 
| 843 | 
         
            +
                sym_len_s = -indices.min() + 1
         
     | 
| 844 | 
         
            +
                sym_len_e = indices.max() - in_length
         
     | 
| 845 | 
         
            +
                indices = indices + sym_len_s - 1
         
     | 
| 846 | 
         
            +
                return weights, indices, int(sym_len_s), int(sym_len_e)
         
     | 
| 847 | 
         
            +
             
     | 
| 848 | 
         
            +
             
     | 
| 849 | 
         
            +
            # --------------------------------------------
         
     | 
| 850 | 
         
            +
            # imresize for tensor image [0, 1]
         
     | 
| 851 | 
         
            +
            # --------------------------------------------
         
     | 
| 852 | 
         
            +
            def imresize(img, scale, antialiasing=True):
         
     | 
| 853 | 
         
            +
                # Now the scale should be the same for H and W
         
     | 
| 854 | 
         
            +
                # input: img: pytorch tensor, CHW or HW [0,1]
         
     | 
| 855 | 
         
            +
                # output: CHW or HW [0,1] w/o round
         
     | 
| 856 | 
         
            +
                need_squeeze = True if img.dim() == 2 else False
         
     | 
| 857 | 
         
            +
                if need_squeeze:
         
     | 
| 858 | 
         
            +
                    img.unsqueeze_(0)
         
     | 
| 859 | 
         
            +
                in_C, in_H, in_W = img.size()
         
     | 
| 860 | 
         
            +
                out_C, out_H, out_W = in_C, math.ceil(in_H * scale), math.ceil(in_W * scale)
         
     | 
| 861 | 
         
            +
                kernel_width = 4
         
     | 
| 862 | 
         
            +
                kernel = 'cubic'
         
     | 
| 863 | 
         
            +
             
     | 
| 864 | 
         
            +
                # Return the desired dimension order for performing the resize.  The
         
     | 
| 865 | 
         
            +
                # strategy is to perform the resize first along the dimension with the
         
     | 
| 866 | 
         
            +
                # smallest scale factor.
         
     | 
| 867 | 
         
            +
                # Now we do not support this.
         
     | 
| 868 | 
         
            +
             
     | 
| 869 | 
         
            +
                # get weights and indices
         
     | 
| 870 | 
         
            +
                weights_H, indices_H, sym_len_Hs, sym_len_He = calculate_weights_indices(
         
     | 
| 871 | 
         
            +
                    in_H, out_H, scale, kernel, kernel_width, antialiasing)
         
     | 
| 872 | 
         
            +
                weights_W, indices_W, sym_len_Ws, sym_len_We = calculate_weights_indices(
         
     | 
| 873 | 
         
            +
                    in_W, out_W, scale, kernel, kernel_width, antialiasing)
         
     | 
| 874 | 
         
            +
                # process H dimension
         
     | 
| 875 | 
         
            +
                # symmetric copying
         
     | 
| 876 | 
         
            +
                img_aug = torch.FloatTensor(in_C, in_H + sym_len_Hs + sym_len_He, in_W)
         
     | 
| 877 | 
         
            +
                img_aug.narrow(1, sym_len_Hs, in_H).copy_(img)
         
     | 
| 878 | 
         
            +
             
     | 
| 879 | 
         
            +
                sym_patch = img[:, :sym_len_Hs, :]
         
     | 
| 880 | 
         
            +
                inv_idx = torch.arange(sym_patch.size(1) - 1, -1, -1).long()
         
     | 
| 881 | 
         
            +
                sym_patch_inv = sym_patch.index_select(1, inv_idx)
         
     | 
| 882 | 
         
            +
                img_aug.narrow(1, 0, sym_len_Hs).copy_(sym_patch_inv)
         
     | 
| 883 | 
         
            +
             
     | 
| 884 | 
         
            +
                sym_patch = img[:, -sym_len_He:, :]
         
     | 
| 885 | 
         
            +
                inv_idx = torch.arange(sym_patch.size(1) - 1, -1, -1).long()
         
     | 
| 886 | 
         
            +
                sym_patch_inv = sym_patch.index_select(1, inv_idx)
         
     | 
| 887 | 
         
            +
                img_aug.narrow(1, sym_len_Hs + in_H, sym_len_He).copy_(sym_patch_inv)
         
     | 
| 888 | 
         
            +
             
     | 
| 889 | 
         
            +
                out_1 = torch.FloatTensor(in_C, out_H, in_W)
         
     | 
| 890 | 
         
            +
                kernel_width = weights_H.size(1)
         
     | 
| 891 | 
         
            +
                for i in range(out_H):
         
     | 
| 892 | 
         
            +
                    idx = int(indices_H[i][0])
         
     | 
| 893 | 
         
            +
                    for j in range(out_C):
         
     | 
| 894 | 
         
            +
                        out_1[j, i, :] = img_aug[j, idx:idx + kernel_width, :].transpose(0, 1).mv(weights_H[i])
         
     | 
| 895 | 
         
            +
             
     | 
| 896 | 
         
            +
                # process W dimension
         
     | 
| 897 | 
         
            +
                # symmetric copying
         
     | 
| 898 | 
         
            +
                out_1_aug = torch.FloatTensor(in_C, out_H, in_W + sym_len_Ws + sym_len_We)
         
     | 
| 899 | 
         
            +
                out_1_aug.narrow(2, sym_len_Ws, in_W).copy_(out_1)
         
     | 
| 900 | 
         
            +
             
     | 
| 901 | 
         
            +
                sym_patch = out_1[:, :, :sym_len_Ws]
         
     | 
| 902 | 
         
            +
                inv_idx = torch.arange(sym_patch.size(2) - 1, -1, -1).long()
         
     | 
| 903 | 
         
            +
                sym_patch_inv = sym_patch.index_select(2, inv_idx)
         
     | 
| 904 | 
         
            +
                out_1_aug.narrow(2, 0, sym_len_Ws).copy_(sym_patch_inv)
         
     | 
| 905 | 
         
            +
             
     | 
| 906 | 
         
            +
                sym_patch = out_1[:, :, -sym_len_We:]
         
     | 
| 907 | 
         
            +
                inv_idx = torch.arange(sym_patch.size(2) - 1, -1, -1).long()
         
     | 
| 908 | 
         
            +
                sym_patch_inv = sym_patch.index_select(2, inv_idx)
         
     | 
| 909 | 
         
            +
                out_1_aug.narrow(2, sym_len_Ws + in_W, sym_len_We).copy_(sym_patch_inv)
         
     | 
| 910 | 
         
            +
             
     | 
| 911 | 
         
            +
                out_2 = torch.FloatTensor(in_C, out_H, out_W)
         
     | 
| 912 | 
         
            +
                kernel_width = weights_W.size(1)
         
     | 
| 913 | 
         
            +
                for i in range(out_W):
         
     | 
| 914 | 
         
            +
                    idx = int(indices_W[i][0])
         
     | 
| 915 | 
         
            +
                    for j in range(out_C):
         
     | 
| 916 | 
         
            +
                        out_2[j, :, i] = out_1_aug[j, :, idx:idx + kernel_width].mv(weights_W[i])
         
     | 
| 917 | 
         
            +
                if need_squeeze:
         
     | 
| 918 | 
         
            +
                    out_2.squeeze_()
         
     | 
| 919 | 
         
            +
                return out_2
         
     | 
| 920 | 
         
            +
             
     | 
| 921 | 
         
            +
             
     | 
| 922 | 
         
            +
            # --------------------------------------------
         
     | 
| 923 | 
         
            +
            # imresize for numpy image [0, 1]
         
     | 
| 924 | 
         
            +
            # --------------------------------------------
         
     | 
| 925 | 
         
            +
            def imresize_np(img, scale, antialiasing=True):
         
     | 
| 926 | 
         
            +
                # Now the scale should be the same for H and W
         
     | 
| 927 | 
         
            +
                # input: img: Numpy, HWC or HW [0,1]
         
     | 
| 928 | 
         
            +
                # output: HWC or HW [0,1] w/o round
         
     | 
| 929 | 
         
            +
                img = torch.from_numpy(img)
         
     | 
| 930 | 
         
            +
                need_squeeze = True if img.dim() == 2 else False
         
     | 
| 931 | 
         
            +
                if need_squeeze:
         
     | 
| 932 | 
         
            +
                    img.unsqueeze_(2)
         
     | 
| 933 | 
         
            +
             
     | 
| 934 | 
         
            +
                in_H, in_W, in_C = img.size()
         
     | 
| 935 | 
         
            +
                out_C, out_H, out_W = in_C, math.ceil(in_H * scale), math.ceil(in_W * scale)
         
     | 
| 936 | 
         
            +
                kernel_width = 4
         
     | 
| 937 | 
         
            +
                kernel = 'cubic'
         
     | 
| 938 | 
         
            +
             
     | 
| 939 | 
         
            +
                # Return the desired dimension order for performing the resize.  The
         
     | 
| 940 | 
         
            +
                # strategy is to perform the resize first along the dimension with the
         
     | 
| 941 | 
         
            +
                # smallest scale factor.
         
     | 
| 942 | 
         
            +
                # Now we do not support this.
         
     | 
| 943 | 
         
            +
             
     | 
| 944 | 
         
            +
                # get weights and indices
         
     | 
| 945 | 
         
            +
                weights_H, indices_H, sym_len_Hs, sym_len_He = calculate_weights_indices(
         
     | 
| 946 | 
         
            +
                    in_H, out_H, scale, kernel, kernel_width, antialiasing)
         
     | 
| 947 | 
         
            +
                weights_W, indices_W, sym_len_Ws, sym_len_We = calculate_weights_indices(
         
     | 
| 948 | 
         
            +
                    in_W, out_W, scale, kernel, kernel_width, antialiasing)
         
     | 
| 949 | 
         
            +
                # process H dimension
         
     | 
| 950 | 
         
            +
                # symmetric copying
         
     | 
| 951 | 
         
            +
                img_aug = torch.FloatTensor(in_H + sym_len_Hs + sym_len_He, in_W, in_C)
         
     | 
| 952 | 
         
            +
                img_aug.narrow(0, sym_len_Hs, in_H).copy_(img)
         
     | 
| 953 | 
         
            +
             
     | 
| 954 | 
         
            +
                sym_patch = img[:sym_len_Hs, :, :]
         
     | 
| 955 | 
         
            +
                inv_idx = torch.arange(sym_patch.size(0) - 1, -1, -1).long()
         
     | 
| 956 | 
         
            +
                sym_patch_inv = sym_patch.index_select(0, inv_idx)
         
     | 
| 957 | 
         
            +
                img_aug.narrow(0, 0, sym_len_Hs).copy_(sym_patch_inv)
         
     | 
| 958 | 
         
            +
             
     | 
| 959 | 
         
            +
                sym_patch = img[-sym_len_He:, :, :]
         
     | 
| 960 | 
         
            +
                inv_idx = torch.arange(sym_patch.size(0) - 1, -1, -1).long()
         
     | 
| 961 | 
         
            +
                sym_patch_inv = sym_patch.index_select(0, inv_idx)
         
     | 
| 962 | 
         
            +
                img_aug.narrow(0, sym_len_Hs + in_H, sym_len_He).copy_(sym_patch_inv)
         
     | 
| 963 | 
         
            +
             
     | 
| 964 | 
         
            +
                out_1 = torch.FloatTensor(out_H, in_W, in_C)
         
     | 
| 965 | 
         
            +
                kernel_width = weights_H.size(1)
         
     | 
| 966 | 
         
            +
                for i in range(out_H):
         
     | 
| 967 | 
         
            +
                    idx = int(indices_H[i][0])
         
     | 
| 968 | 
         
            +
                    for j in range(out_C):
         
     | 
| 969 | 
         
            +
                        out_1[i, :, j] = img_aug[idx:idx + kernel_width, :, j].transpose(0, 1).mv(weights_H[i])
         
     | 
| 970 | 
         
            +
             
     | 
| 971 | 
         
            +
                # process W dimension
         
     | 
| 972 | 
         
            +
                # symmetric copying
         
     | 
| 973 | 
         
            +
                out_1_aug = torch.FloatTensor(out_H, in_W + sym_len_Ws + sym_len_We, in_C)
         
     | 
| 974 | 
         
            +
                out_1_aug.narrow(1, sym_len_Ws, in_W).copy_(out_1)
         
     | 
| 975 | 
         
            +
             
     | 
| 976 | 
         
            +
                sym_patch = out_1[:, :sym_len_Ws, :]
         
     | 
| 977 | 
         
            +
                inv_idx = torch.arange(sym_patch.size(1) - 1, -1, -1).long()
         
     | 
| 978 | 
         
            +
                sym_patch_inv = sym_patch.index_select(1, inv_idx)
         
     | 
| 979 | 
         
            +
                out_1_aug.narrow(1, 0, sym_len_Ws).copy_(sym_patch_inv)
         
     | 
| 980 | 
         
            +
             
     | 
| 981 | 
         
            +
                sym_patch = out_1[:, -sym_len_We:, :]
         
     | 
| 982 | 
         
            +
                inv_idx = torch.arange(sym_patch.size(1) - 1, -1, -1).long()
         
     | 
| 983 | 
         
            +
                sym_patch_inv = sym_patch.index_select(1, inv_idx)
         
     | 
| 984 | 
         
            +
                out_1_aug.narrow(1, sym_len_Ws + in_W, sym_len_We).copy_(sym_patch_inv)
         
     | 
| 985 | 
         
            +
             
     | 
| 986 | 
         
            +
                out_2 = torch.FloatTensor(out_H, out_W, in_C)
         
     | 
| 987 | 
         
            +
                kernel_width = weights_W.size(1)
         
     | 
| 988 | 
         
            +
                for i in range(out_W):
         
     | 
| 989 | 
         
            +
                    idx = int(indices_W[i][0])
         
     | 
| 990 | 
         
            +
                    for j in range(out_C):
         
     | 
| 991 | 
         
            +
                        out_2[:, i, j] = out_1_aug[:, idx:idx + kernel_width, j].mv(weights_W[i])
         
     | 
| 992 | 
         
            +
                if need_squeeze:
         
     | 
| 993 | 
         
            +
                    out_2.squeeze_()
         
     | 
| 994 | 
         
            +
             
     | 
| 995 | 
         
            +
                return out_2.numpy()
         
     | 
| 996 | 
         
            +
             
     | 
| 997 | 
         
            +
             
     | 
| 998 | 
         
            +
            if __name__ == '__main__':
         
     | 
| 999 | 
         
            +
                img = imread_uint('test.bmp', 3)
         
     | 
| 1000 | 
         
            +
            #    img = uint2single(img)
         
     | 
| 1001 | 
         
            +
            #    img_bicubic = imresize_np(img, 1/4)
         
     | 
| 1002 | 
         
            +
            #    imshow(single2uint(img_bicubic))
         
     | 
| 1003 | 
         
            +
            #
         
     | 
| 1004 | 
         
            +
            #    img_tensor = single2tensor4(img)
         
     | 
| 1005 | 
         
            +
            #    for i in range(8):
         
     | 
| 1006 | 
         
            +
            #        imshow(np.concatenate((augment_img(img, i), tensor2single(augment_img_tensor4(img_tensor, i))), 1))
         
     | 
| 1007 | 
         
            +
                
         
     | 
| 1008 | 
         
            +
            #    patches = patches_from_image(img, p_size=128, p_overlap=0, p_max=200)
         
     | 
| 1009 | 
         
            +
            #    imssave(patches,'a.png')
         
     | 
| 1010 | 
         
            +
             
     | 
| 1011 | 
         
            +
             
     | 
| 1012 | 
         
            +
                
         
     | 
| 1013 | 
         
            +
                
         
     | 
| 1014 | 
         
            +
                
         
     | 
| 1015 | 
         
            +
                
         
     | 
| 1016 | 
         
            +
                
         
     | 
    	
        core/data/deg_kair_utils/utils_lmdb.py
    ADDED
    
    | 
         @@ -0,0 +1,205 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            import cv2
         
     | 
| 2 | 
         
            +
            import lmdb
         
     | 
| 3 | 
         
            +
            import sys
         
     | 
| 4 | 
         
            +
            from multiprocessing import Pool
         
     | 
| 5 | 
         
            +
            from os import path as osp
         
     | 
| 6 | 
         
            +
            from tqdm import tqdm
         
     | 
| 7 | 
         
            +
             
     | 
| 8 | 
         
            +
             
     | 
| 9 | 
         
            +
            def make_lmdb_from_imgs(data_path,
         
     | 
| 10 | 
         
            +
                                    lmdb_path,
         
     | 
| 11 | 
         
            +
                                    img_path_list,
         
     | 
| 12 | 
         
            +
                                    keys,
         
     | 
| 13 | 
         
            +
                                    batch=5000,
         
     | 
| 14 | 
         
            +
                                    compress_level=1,
         
     | 
| 15 | 
         
            +
                                    multiprocessing_read=False,
         
     | 
| 16 | 
         
            +
                                    n_thread=40,
         
     | 
| 17 | 
         
            +
                                    map_size=None):
         
     | 
| 18 | 
         
            +
                """Make lmdb from images.
         
     | 
| 19 | 
         
            +
             
     | 
| 20 | 
         
            +
                Contents of lmdb. The file structure is:
         
     | 
| 21 | 
         
            +
                example.lmdb
         
     | 
| 22 | 
         
            +
                ├── data.mdb
         
     | 
| 23 | 
         
            +
                ├── lock.mdb
         
     | 
| 24 | 
         
            +
                ├── meta_info.txt
         
     | 
| 25 | 
         
            +
             
     | 
| 26 | 
         
            +
                The data.mdb and lock.mdb are standard lmdb files and you can refer to
         
     | 
| 27 | 
         
            +
                https://lmdb.readthedocs.io/en/release/ for more details.
         
     | 
| 28 | 
         
            +
             
     | 
| 29 | 
         
            +
                The meta_info.txt is a specified txt file to record the meta information
         
     | 
| 30 | 
         
            +
                of our datasets. It will be automatically created when preparing
         
     | 
| 31 | 
         
            +
                datasets by our provided dataset tools.
         
     | 
| 32 | 
         
            +
                Each line in the txt file records 1)image name (with extension),
         
     | 
| 33 | 
         
            +
                2)image shape, and 3)compression level, separated by a white space.
         
     | 
| 34 | 
         
            +
             
     | 
| 35 | 
         
            +
                For example, the meta information could be:
         
     | 
| 36 | 
         
            +
                `000_00000000.png (720,1280,3) 1`, which means:
         
     | 
| 37 | 
         
            +
                1) image name (with extension): 000_00000000.png;
         
     | 
| 38 | 
         
            +
                2) image shape: (720,1280,3);
         
     | 
| 39 | 
         
            +
                3) compression level: 1
         
     | 
| 40 | 
         
            +
             
     | 
| 41 | 
         
            +
                We use the image name without extension as the lmdb key.
         
     | 
| 42 | 
         
            +
             
     | 
| 43 | 
         
            +
                If `multiprocessing_read` is True, it will read all the images to memory
         
     | 
| 44 | 
         
            +
                using multiprocessing. Thus, your server needs to have enough memory.
         
     | 
| 45 | 
         
            +
             
     | 
| 46 | 
         
            +
                Args:
         
     | 
| 47 | 
         
            +
                    data_path (str): Data path for reading images.
         
     | 
| 48 | 
         
            +
                    lmdb_path (str): Lmdb save path.
         
     | 
| 49 | 
         
            +
                    img_path_list (str): Image path list.
         
     | 
| 50 | 
         
            +
                    keys (str): Used for lmdb keys.
         
     | 
| 51 | 
         
            +
                    batch (int): After processing batch images, lmdb commits.
         
     | 
| 52 | 
         
            +
                        Default: 5000.
         
     | 
| 53 | 
         
            +
                    compress_level (int): Compress level when encoding images. Default: 1.
         
     | 
| 54 | 
         
            +
                    multiprocessing_read (bool): Whether use multiprocessing to read all
         
     | 
| 55 | 
         
            +
                        the images to memory. Default: False.
         
     | 
| 56 | 
         
            +
                    n_thread (int): For multiprocessing.
         
     | 
| 57 | 
         
            +
                    map_size (int | None): Map size for lmdb env. If None, use the
         
     | 
| 58 | 
         
            +
                        estimated size from images. Default: None
         
     | 
| 59 | 
         
            +
                """
         
     | 
| 60 | 
         
            +
             
     | 
| 61 | 
         
            +
                assert len(img_path_list) == len(keys), ('img_path_list and keys should have the same length, '
         
     | 
| 62 | 
         
            +
                                                         f'but got {len(img_path_list)} and {len(keys)}')
         
     | 
| 63 | 
         
            +
                print(f'Create lmdb for {data_path}, save to {lmdb_path}...')
         
     | 
| 64 | 
         
            +
                print(f'Totoal images: {len(img_path_list)}')
         
     | 
| 65 | 
         
            +
                if not lmdb_path.endswith('.lmdb'):
         
     | 
| 66 | 
         
            +
                    raise ValueError("lmdb_path must end with '.lmdb'.")
         
     | 
| 67 | 
         
            +
                if osp.exists(lmdb_path):
         
     | 
| 68 | 
         
            +
                    print(f'Folder {lmdb_path} already exists. Exit.')
         
     | 
| 69 | 
         
            +
                    sys.exit(1)
         
     | 
| 70 | 
         
            +
             
     | 
| 71 | 
         
            +
                if multiprocessing_read:
         
     | 
| 72 | 
         
            +
                    # read all the images to memory (multiprocessing)
         
     | 
| 73 | 
         
            +
                    dataset = {}  # use dict to keep the order for multiprocessing
         
     | 
| 74 | 
         
            +
                    shapes = {}
         
     | 
| 75 | 
         
            +
                    print(f'Read images with multiprocessing, #thread: {n_thread} ...')
         
     | 
| 76 | 
         
            +
                    pbar = tqdm(total=len(img_path_list), unit='image')
         
     | 
| 77 | 
         
            +
             
     | 
| 78 | 
         
            +
                    def callback(arg):
         
     | 
| 79 | 
         
            +
                        """get the image data and update pbar."""
         
     | 
| 80 | 
         
            +
                        key, dataset[key], shapes[key] = arg
         
     | 
| 81 | 
         
            +
                        pbar.update(1)
         
     | 
| 82 | 
         
            +
                        pbar.set_description(f'Read {key}')
         
     | 
| 83 | 
         
            +
             
     | 
| 84 | 
         
            +
                    pool = Pool(n_thread)
         
     | 
| 85 | 
         
            +
                    for path, key in zip(img_path_list, keys):
         
     | 
| 86 | 
         
            +
                        pool.apply_async(read_img_worker, args=(osp.join(data_path, path), key, compress_level), callback=callback)
         
     | 
| 87 | 
         
            +
                    pool.close()
         
     | 
| 88 | 
         
            +
                    pool.join()
         
     | 
| 89 | 
         
            +
                    pbar.close()
         
     | 
| 90 | 
         
            +
                    print(f'Finish reading {len(img_path_list)} images.')
         
     | 
| 91 | 
         
            +
             
     | 
| 92 | 
         
            +
                # create lmdb environment
         
     | 
| 93 | 
         
            +
                if map_size is None:
         
     | 
| 94 | 
         
            +
                    # obtain data size for one image
         
     | 
| 95 | 
         
            +
                    img = cv2.imread(osp.join(data_path, img_path_list[0]), cv2.IMREAD_UNCHANGED)
         
     | 
| 96 | 
         
            +
                    _, img_byte = cv2.imencode('.png', img, [cv2.IMWRITE_PNG_COMPRESSION, compress_level])
         
     | 
| 97 | 
         
            +
                    data_size_per_img = img_byte.nbytes
         
     | 
| 98 | 
         
            +
                    print('Data size per image is: ', data_size_per_img)
         
     | 
| 99 | 
         
            +
                    data_size = data_size_per_img * len(img_path_list)
         
     | 
| 100 | 
         
            +
                    map_size = data_size * 10
         
     | 
| 101 | 
         
            +
             
     | 
| 102 | 
         
            +
                env = lmdb.open(lmdb_path, map_size=map_size)
         
     | 
| 103 | 
         
            +
             
     | 
| 104 | 
         
            +
                # write data to lmdb
         
     | 
| 105 | 
         
            +
                pbar = tqdm(total=len(img_path_list), unit='chunk')
         
     | 
| 106 | 
         
            +
                txn = env.begin(write=True)
         
     | 
| 107 | 
         
            +
                txt_file = open(osp.join(lmdb_path, 'meta_info.txt'), 'w')
         
     | 
| 108 | 
         
            +
                for idx, (path, key) in enumerate(zip(img_path_list, keys)):
         
     | 
| 109 | 
         
            +
                    pbar.update(1)
         
     | 
| 110 | 
         
            +
                    pbar.set_description(f'Write {key}')
         
     | 
| 111 | 
         
            +
                    key_byte = key.encode('ascii')
         
     | 
| 112 | 
         
            +
                    if multiprocessing_read:
         
     | 
| 113 | 
         
            +
                        img_byte = dataset[key]
         
     | 
| 114 | 
         
            +
                        h, w, c = shapes[key]
         
     | 
| 115 | 
         
            +
                    else:
         
     | 
| 116 | 
         
            +
                        _, img_byte, img_shape = read_img_worker(osp.join(data_path, path), key, compress_level)
         
     | 
| 117 | 
         
            +
                        h, w, c = img_shape
         
     | 
| 118 | 
         
            +
             
     | 
| 119 | 
         
            +
                    txn.put(key_byte, img_byte)
         
     | 
| 120 | 
         
            +
                    # write meta information
         
     | 
| 121 | 
         
            +
                    txt_file.write(f'{key}.png ({h},{w},{c}) {compress_level}\n')
         
     | 
| 122 | 
         
            +
                    if idx % batch == 0:
         
     | 
| 123 | 
         
            +
                        txn.commit()
         
     | 
| 124 | 
         
            +
                        txn = env.begin(write=True)
         
     | 
| 125 | 
         
            +
                pbar.close()
         
     | 
| 126 | 
         
            +
                txn.commit()
         
     | 
| 127 | 
         
            +
                env.close()
         
     | 
| 128 | 
         
            +
                txt_file.close()
         
     | 
| 129 | 
         
            +
                print('\nFinish writing lmdb.')
         
     | 
| 130 | 
         
            +
             
     | 
| 131 | 
         
            +
             
     | 
| 132 | 
         
            +
            def read_img_worker(path, key, compress_level):
         
     | 
| 133 | 
         
            +
                """Read image worker.
         
     | 
| 134 | 
         
            +
             
     | 
| 135 | 
         
            +
                Args:
         
     | 
| 136 | 
         
            +
                    path (str): Image path.
         
     | 
| 137 | 
         
            +
                    key (str): Image key.
         
     | 
| 138 | 
         
            +
                    compress_level (int): Compress level when encoding images.
         
     | 
| 139 | 
         
            +
             
     | 
| 140 | 
         
            +
                Returns:
         
     | 
| 141 | 
         
            +
                    str: Image key.
         
     | 
| 142 | 
         
            +
                    byte: Image byte.
         
     | 
| 143 | 
         
            +
                    tuple[int]: Image shape.
         
     | 
| 144 | 
         
            +
                """
         
     | 
| 145 | 
         
            +
             
     | 
| 146 | 
         
            +
                img = cv2.imread(path, cv2.IMREAD_UNCHANGED)
         
     | 
| 147 | 
         
            +
                # deal with `libpng error: Read Error`
         
     | 
| 148 | 
         
            +
                if img is None:
         
     | 
| 149 | 
         
            +
                    print(f'To deal with `libpng error: Read Error`, use PIL to load {path}')
         
     | 
| 150 | 
         
            +
                    from PIL import Image
         
     | 
| 151 | 
         
            +
                    import numpy as np
         
     | 
| 152 | 
         
            +
                    img = Image.open(path)
         
     | 
| 153 | 
         
            +
                    img = np.asanyarray(img)
         
     | 
| 154 | 
         
            +
                    img = img[:, :, [2, 1, 0]]
         
     | 
| 155 | 
         
            +
             
     | 
| 156 | 
         
            +
                if img.ndim == 2:
         
     | 
| 157 | 
         
            +
                    h, w = img.shape
         
     | 
| 158 | 
         
            +
                    c = 1
         
     | 
| 159 | 
         
            +
                else:
         
     | 
| 160 | 
         
            +
                    h, w, c = img.shape
         
     | 
| 161 | 
         
            +
                _, img_byte = cv2.imencode('.png', img, [cv2.IMWRITE_PNG_COMPRESSION, compress_level])
         
     | 
| 162 | 
         
            +
                return (key, img_byte, (h, w, c))
         
     | 
| 163 | 
         
            +
             
     | 
| 164 | 
         
            +
             
     | 
| 165 | 
         
            +
            class LmdbMaker():
         
     | 
| 166 | 
         
            +
                """LMDB Maker.
         
     | 
| 167 | 
         
            +
             
     | 
| 168 | 
         
            +
                Args:
         
     | 
| 169 | 
         
            +
                    lmdb_path (str): Lmdb save path.
         
     | 
| 170 | 
         
            +
                    map_size (int): Map size for lmdb env. Default: 1024 ** 4, 1TB.
         
     | 
| 171 | 
         
            +
                    batch (int): After processing batch images, lmdb commits.
         
     | 
| 172 | 
         
            +
                        Default: 5000.
         
     | 
| 173 | 
         
            +
                    compress_level (int): Compress level when encoding images. Default: 1.
         
     | 
| 174 | 
         
            +
                """
         
     | 
| 175 | 
         
            +
             
     | 
| 176 | 
         
            +
                def __init__(self, lmdb_path, map_size=1024**4, batch=5000, compress_level=1):
         
     | 
| 177 | 
         
            +
                    if not lmdb_path.endswith('.lmdb'):
         
     | 
| 178 | 
         
            +
                        raise ValueError("lmdb_path must end with '.lmdb'.")
         
     | 
| 179 | 
         
            +
                    if osp.exists(lmdb_path):
         
     | 
| 180 | 
         
            +
                        print(f'Folder {lmdb_path} already exists. Exit.')
         
     | 
| 181 | 
         
            +
                        sys.exit(1)
         
     | 
| 182 | 
         
            +
             
     | 
| 183 | 
         
            +
                    self.lmdb_path = lmdb_path
         
     | 
| 184 | 
         
            +
                    self.batch = batch
         
     | 
| 185 | 
         
            +
                    self.compress_level = compress_level
         
     | 
| 186 | 
         
            +
                    self.env = lmdb.open(lmdb_path, map_size=map_size)
         
     | 
| 187 | 
         
            +
                    self.txn = self.env.begin(write=True)
         
     | 
| 188 | 
         
            +
                    self.txt_file = open(osp.join(lmdb_path, 'meta_info.txt'), 'w')
         
     | 
| 189 | 
         
            +
                    self.counter = 0
         
     | 
| 190 | 
         
            +
             
     | 
| 191 | 
         
            +
                def put(self, img_byte, key, img_shape):
         
     | 
| 192 | 
         
            +
                    self.counter += 1
         
     | 
| 193 | 
         
            +
                    key_byte = key.encode('ascii')
         
     | 
| 194 | 
         
            +
                    self.txn.put(key_byte, img_byte)
         
     | 
| 195 | 
         
            +
                    # write meta information
         
     | 
| 196 | 
         
            +
                    h, w, c = img_shape
         
     | 
| 197 | 
         
            +
                    self.txt_file.write(f'{key}.png ({h},{w},{c}) {self.compress_level}\n')
         
     | 
| 198 | 
         
            +
                    if self.counter % self.batch == 0:
         
     | 
| 199 | 
         
            +
                        self.txn.commit()
         
     | 
| 200 | 
         
            +
                        self.txn = self.env.begin(write=True)
         
     | 
| 201 | 
         
            +
             
     | 
| 202 | 
         
            +
                def close(self):
         
     | 
| 203 | 
         
            +
                    self.txn.commit()
         
     | 
| 204 | 
         
            +
                    self.env.close()
         
     | 
| 205 | 
         
            +
                    self.txt_file.close()
         
     | 
    	
        core/data/deg_kair_utils/utils_logger.py
    ADDED
    
    | 
         @@ -0,0 +1,66 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            import sys
         
     | 
| 2 | 
         
            +
            import datetime
         
     | 
| 3 | 
         
            +
            import logging
         
     | 
| 4 | 
         
            +
             
     | 
| 5 | 
         
            +
             
     | 
| 6 | 
         
            +
            '''
         
     | 
| 7 | 
         
            +
            # --------------------------------------------
         
     | 
| 8 | 
         
            +
            # Kai Zhang (github: https://github.com/cszn)
         
     | 
| 9 | 
         
            +
            # 03/Mar/2019
         
     | 
| 10 | 
         
            +
            # --------------------------------------------
         
     | 
| 11 | 
         
            +
            # https://github.com/xinntao/BasicSR
         
     | 
| 12 | 
         
            +
            # --------------------------------------------
         
     | 
| 13 | 
         
            +
            '''
         
     | 
| 14 | 
         
            +
             
     | 
| 15 | 
         
            +
             
     | 
| 16 | 
         
            +
            def log(*args, **kwargs):
         
     | 
| 17 | 
         
            +
                print(datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S:"), *args, **kwargs)
         
     | 
| 18 | 
         
            +
             
     | 
| 19 | 
         
            +
             
     | 
| 20 | 
         
            +
            '''
         
     | 
| 21 | 
         
            +
            # --------------------------------------------
         
     | 
| 22 | 
         
            +
            # logger
         
     | 
| 23 | 
         
            +
            # --------------------------------------------
         
     | 
| 24 | 
         
            +
            '''
         
     | 
| 25 | 
         
            +
             
     | 
| 26 | 
         
            +
             
     | 
| 27 | 
         
            +
            def logger_info(logger_name, log_path='default_logger.log'):
         
     | 
| 28 | 
         
            +
                ''' set up logger
         
     | 
| 29 | 
         
            +
                modified by Kai Zhang (github: https://github.com/cszn)
         
     | 
| 30 | 
         
            +
                '''
         
     | 
| 31 | 
         
            +
                log = logging.getLogger(logger_name)
         
     | 
| 32 | 
         
            +
                if log.hasHandlers():
         
     | 
| 33 | 
         
            +
                    print('LogHandlers exist!')
         
     | 
| 34 | 
         
            +
                else:
         
     | 
| 35 | 
         
            +
                    print('LogHandlers setup!')
         
     | 
| 36 | 
         
            +
                    level = logging.INFO
         
     | 
| 37 | 
         
            +
                    formatter = logging.Formatter('%(asctime)s.%(msecs)03d : %(message)s', datefmt='%y-%m-%d %H:%M:%S')
         
     | 
| 38 | 
         
            +
                    fh = logging.FileHandler(log_path, mode='a')
         
     | 
| 39 | 
         
            +
                    fh.setFormatter(formatter)
         
     | 
| 40 | 
         
            +
                    log.setLevel(level)
         
     | 
| 41 | 
         
            +
                    log.addHandler(fh)
         
     | 
| 42 | 
         
            +
                    # print(len(log.handlers))
         
     | 
| 43 | 
         
            +
             
     | 
| 44 | 
         
            +
                    sh = logging.StreamHandler()
         
     | 
| 45 | 
         
            +
                    sh.setFormatter(formatter)
         
     | 
| 46 | 
         
            +
                    log.addHandler(sh)
         
     | 
| 47 | 
         
            +
             
     | 
| 48 | 
         
            +
             
     | 
| 49 | 
         
            +
            '''
         
     | 
| 50 | 
         
            +
            # --------------------------------------------
         
     | 
| 51 | 
         
            +
            # print to file and std_out simultaneously
         
     | 
| 52 | 
         
            +
            # --------------------------------------------
         
     | 
| 53 | 
         
            +
            '''
         
     | 
| 54 | 
         
            +
             
     | 
| 55 | 
         
            +
             
     | 
| 56 | 
         
            +
            class logger_print(object):
         
     | 
| 57 | 
         
            +
                def __init__(self, log_path="default.log"):
         
     | 
| 58 | 
         
            +
                    self.terminal = sys.stdout
         
     | 
| 59 | 
         
            +
                    self.log = open(log_path, 'a')
         
     | 
| 60 | 
         
            +
             
     | 
| 61 | 
         
            +
                def write(self, message):
         
     | 
| 62 | 
         
            +
                    self.terminal.write(message)
         
     | 
| 63 | 
         
            +
                    self.log.write(message)  # write the message
         
     | 
| 64 | 
         
            +
             
     | 
| 65 | 
         
            +
                def flush(self):
         
     | 
| 66 | 
         
            +
                    pass
         
     | 
    	
        core/data/deg_kair_utils/utils_mat.py
    ADDED
    
    | 
         @@ -0,0 +1,88 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            import os
         
     | 
| 2 | 
         
            +
            import json
         
     | 
| 3 | 
         
            +
            import scipy.io as spio
         
     | 
| 4 | 
         
            +
            import pandas as pd
         
     | 
| 5 | 
         
            +
             
     | 
| 6 | 
         
            +
             
     | 
| 7 | 
         
            +
            def loadmat(filename):
         
     | 
| 8 | 
         
            +
                '''
         
     | 
| 9 | 
         
            +
                this function should be called instead of direct spio.loadmat
         
     | 
| 10 | 
         
            +
                as it cures the problem of not properly recovering python dictionaries
         
     | 
| 11 | 
         
            +
                from mat files. It calls the function check keys to cure all entries
         
     | 
| 12 | 
         
            +
                which are still mat-objects
         
     | 
| 13 | 
         
            +
                '''
         
     | 
| 14 | 
         
            +
                data = spio.loadmat(filename, struct_as_record=False, squeeze_me=True)
         
     | 
| 15 | 
         
            +
                return dict_to_nonedict(_check_keys(data))
         
     | 
| 16 | 
         
            +
             
     | 
| 17 | 
         
            +
            def _check_keys(dict):
         
     | 
| 18 | 
         
            +
                '''
         
     | 
| 19 | 
         
            +
                checks if entries in dictionary are mat-objects. If yes
         
     | 
| 20 | 
         
            +
                todict is called to change them to nested dictionaries
         
     | 
| 21 | 
         
            +
                '''
         
     | 
| 22 | 
         
            +
                for key in dict:
         
     | 
| 23 | 
         
            +
                    if isinstance(dict[key], spio.matlab.mio5_params.mat_struct):
         
     | 
| 24 | 
         
            +
                        dict[key] = _todict(dict[key])
         
     | 
| 25 | 
         
            +
                return dict
         
     | 
| 26 | 
         
            +
             
     | 
| 27 | 
         
            +
            def _todict(matobj):
         
     | 
| 28 | 
         
            +
                '''
         
     | 
| 29 | 
         
            +
                A recursive function which constructs from matobjects nested dictionaries
         
     | 
| 30 | 
         
            +
                '''
         
     | 
| 31 | 
         
            +
                dict = {}
         
     | 
| 32 | 
         
            +
                for strg in matobj._fieldnames:
         
     | 
| 33 | 
         
            +
                    elem = matobj.__dict__[strg]
         
     | 
| 34 | 
         
            +
                    if isinstance(elem, spio.matlab.mio5_params.mat_struct):
         
     | 
| 35 | 
         
            +
                        dict[strg] = _todict(elem)
         
     | 
| 36 | 
         
            +
                    else:
         
     | 
| 37 | 
         
            +
                        dict[strg] = elem
         
     | 
| 38 | 
         
            +
                return dict
         
     | 
| 39 | 
         
            +
             
     | 
| 40 | 
         
            +
             
     | 
| 41 | 
         
            +
            def dict_to_nonedict(opt):
         
     | 
| 42 | 
         
            +
                if isinstance(opt, dict):
         
     | 
| 43 | 
         
            +
                    new_opt = dict()
         
     | 
| 44 | 
         
            +
                    for key, sub_opt in opt.items():
         
     | 
| 45 | 
         
            +
                        new_opt[key] = dict_to_nonedict(sub_opt)
         
     | 
| 46 | 
         
            +
                    return NoneDict(**new_opt)
         
     | 
| 47 | 
         
            +
                elif isinstance(opt, list):
         
     | 
| 48 | 
         
            +
                    return [dict_to_nonedict(sub_opt) for sub_opt in opt]
         
     | 
| 49 | 
         
            +
                else:
         
     | 
| 50 | 
         
            +
                    return opt
         
     | 
| 51 | 
         
            +
             
     | 
| 52 | 
         
            +
             
     | 
| 53 | 
         
            +
            class NoneDict(dict):
         
     | 
| 54 | 
         
            +
                def __missing__(self, key):
         
     | 
| 55 | 
         
            +
                    return None
         
     | 
| 56 | 
         
            +
             
     | 
| 57 | 
         
            +
             
     | 
| 58 | 
         
            +
            def mat2json(mat_path=None, filepath = None):
         
     | 
| 59 | 
         
            +
                """
         
     | 
| 60 | 
         
            +
                Converts .mat file to .json and writes new file
         
     | 
| 61 | 
         
            +
                Parameters
         
     | 
| 62 | 
         
            +
                ----------
         
     | 
| 63 | 
         
            +
                mat_path: Str
         
     | 
| 64 | 
         
            +
                    path/filename .mat存放路径
         
     | 
| 65 | 
         
            +
                filepath: Str
         
     | 
| 66 | 
         
            +
                    如果需要保存成json, 添加这一路径. 否则不保存
         
     | 
| 67 | 
         
            +
                Returns
         
     | 
| 68 | 
         
            +
                    返回转化的字典
         
     | 
| 69 | 
         
            +
                -------
         
     | 
| 70 | 
         
            +
                None
         
     | 
| 71 | 
         
            +
                Examples
         
     | 
| 72 | 
         
            +
                --------
         
     | 
| 73 | 
         
            +
                >>> mat2json(blah blah)
         
     | 
| 74 | 
         
            +
                """
         
     | 
| 75 | 
         
            +
             
     | 
| 76 | 
         
            +
                matlabFile = loadmat(mat_path)
         
     | 
| 77 | 
         
            +
                #pop all those dumb fields that don't let you jsonize file
         
     | 
| 78 | 
         
            +
                matlabFile.pop('__header__')
         
     | 
| 79 | 
         
            +
                matlabFile.pop('__version__')
         
     | 
| 80 | 
         
            +
                matlabFile.pop('__globals__')
         
     | 
| 81 | 
         
            +
                #jsonize the file - orientation is 'index'
         
     | 
| 82 | 
         
            +
                matlabFile = pd.Series(matlabFile).to_json()
         
     | 
| 83 | 
         
            +
             
     | 
| 84 | 
         
            +
                if filepath:
         
     | 
| 85 | 
         
            +
                    json_path = os.path.splitext(os.path.split(mat_path)[1])[0] + '.json'
         
     | 
| 86 | 
         
            +
                    with open(json_path, 'w') as f:
         
     | 
| 87 | 
         
            +
                            f.write(matlabFile)
         
     | 
| 88 | 
         
            +
                return matlabFile
         
     | 
    	
        core/data/deg_kair_utils/utils_matconvnet.py
    ADDED
    
    | 
         @@ -0,0 +1,197 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            # -*- coding: utf-8 -*-
         
     | 
| 2 | 
         
            +
            import numpy as np
         
     | 
| 3 | 
         
            +
            import torch
         
     | 
| 4 | 
         
            +
            from collections import OrderedDict
         
     | 
| 5 | 
         
            +
             
     | 
| 6 | 
         
            +
            # import scipy.io as io
         
     | 
| 7 | 
         
            +
            import hdf5storage
         
     | 
| 8 | 
         
            +
             
     | 
| 9 | 
         
            +
            """
         
     | 
| 10 | 
         
            +
            # --------------------------------------------
         
     | 
| 11 | 
         
            +
            # Convert matconvnet SimpleNN model into pytorch model
         
     | 
| 12 | 
         
            +
            # --------------------------------------------
         
     | 
| 13 | 
         
            +
            # Kai Zhang ([email protected])
         
     | 
| 14 | 
         
            +
            # https://github.com/cszn
         
     | 
| 15 | 
         
            +
            # 28/Nov/2019
         
     | 
| 16 | 
         
            +
            # --------------------------------------------
         
     | 
| 17 | 
         
            +
            """
         
     | 
| 18 | 
         
            +
             
     | 
| 19 | 
         
            +
             
     | 
| 20 | 
         
            +
            def weights2tensor(x, squeeze=False, in_features=None, out_features=None):
         
     | 
| 21 | 
         
            +
                """Modified version of https://github.com/albanie/pytorch-mcn
         
     | 
| 22 | 
         
            +
                Adjust memory layout and load weights as torch tensor
         
     | 
| 23 | 
         
            +
                Args:
         
     | 
| 24 | 
         
            +
                    x (ndaray): a numpy array, corresponding to a set of network weights
         
     | 
| 25 | 
         
            +
                       stored in column major order
         
     | 
| 26 | 
         
            +
                    squeeze (bool) [False]: whether to squeeze the tensor (i.e. remove
         
     | 
| 27 | 
         
            +
                       singletons from the trailing dimensions. So after converting to
         
     | 
| 28 | 
         
            +
                       pytorch layout (C_out, C_in, H, W), if the shape is (A, B, 1, 1)
         
     | 
| 29 | 
         
            +
                       it will be reshaped to a matrix with shape (A,B).
         
     | 
| 30 | 
         
            +
                    in_features (int :: None): used to reshape weights for a linear block.
         
     | 
| 31 | 
         
            +
                    out_features (int :: None): used to reshape weights for a linear block.
         
     | 
| 32 | 
         
            +
                Returns:
         
     | 
| 33 | 
         
            +
                    torch.tensor: a permuted sets of weights, matching the pytorch layout
         
     | 
| 34 | 
         
            +
                    convention
         
     | 
| 35 | 
         
            +
                """
         
     | 
| 36 | 
         
            +
                if x.ndim == 4:
         
     | 
| 37 | 
         
            +
                    x = x.transpose((3, 2, 0, 1))
         
     | 
| 38 | 
         
            +
            # for FFDNet, pixel-shuffle layer
         
     | 
| 39 | 
         
            +
            #        if x.shape[1]==13:
         
     | 
| 40 | 
         
            +
            #            x=x[:,[0,2,1,3,  4,6,5,7, 8,10,9,11, 12],:,:]
         
     | 
| 41 | 
         
            +
            #        if x.shape[0]==12:   
         
     | 
| 42 | 
         
            +
            #            x=x[[0,2,1,3,  4,6,5,7, 8,10,9,11],:,:,:]
         
     | 
| 43 | 
         
            +
            #        if x.shape[1]==5:
         
     | 
| 44 | 
         
            +
            #            x=x[:,[0,2,1,3,  4],:,:]
         
     | 
| 45 | 
         
            +
            #        if x.shape[0]==4:   
         
     | 
| 46 | 
         
            +
            #            x=x[[0,2,1,3],:,:,:]
         
     | 
| 47 | 
         
            +
            ## for SRMD, pixel-shuffle layer
         
     | 
| 48 | 
         
            +
            #        if x.shape[0]==12:   
         
     | 
| 49 | 
         
            +
            #            x=x[[0,2,1,3,  4,6,5,7, 8,10,9,11],:,:,:]
         
     | 
| 50 | 
         
            +
            #        if x.shape[0]==27:
         
     | 
| 51 | 
         
            +
            #            x=x[[0,3,6,1,4,7,2,5,8, 0+9,3+9,6+9,1+9,4+9,7+9,2+9,5+9,8+9, 0+18,3+18,6+18,1+18,4+18,7+18,2+18,5+18,8+18],:,:,:]
         
     | 
| 52 | 
         
            +
            #        if x.shape[0]==48:   
         
     | 
| 53 | 
         
            +
            #            x=x[[0,4,8,12,1,5,9,13,2,6,10,14,3,7,11,15,  0+16,4+16,8+16,12+16,1+16,5+16,9+16,13+16,2+16,6+16,10+16,14+16,3+16,7+16,11+16,15+16,  0+32,4+32,8+32,12+32,1+32,5+32,9+32,13+32,2+32,6+32,10+32,14+32,3+32,7+32,11+32,15+32],:,:,:]
         
     | 
| 54 | 
         
            +
             
     | 
| 55 | 
         
            +
                elif x.ndim == 3:  # add by Kai
         
     | 
| 56 | 
         
            +
                    x = x[:,:,:,None]
         
     | 
| 57 | 
         
            +
                    x = x.transpose((3, 2, 0, 1))
         
     | 
| 58 | 
         
            +
                elif x.ndim == 2:
         
     | 
| 59 | 
         
            +
                    if x.shape[1] == 1:
         
     | 
| 60 | 
         
            +
                        x = x.flatten()
         
     | 
| 61 | 
         
            +
                if squeeze:
         
     | 
| 62 | 
         
            +
                    if in_features and out_features:
         
     | 
| 63 | 
         
            +
                        x = x.reshape((out_features, in_features))
         
     | 
| 64 | 
         
            +
                    x = np.squeeze(x)
         
     | 
| 65 | 
         
            +
                return torch.from_numpy(np.ascontiguousarray(x))
         
     | 
| 66 | 
         
            +
             
     | 
| 67 | 
         
            +
             
     | 
| 68 | 
         
            +
            def save_model(network, save_path):
         
     | 
| 69 | 
         
            +
                state_dict = network.state_dict()
         
     | 
| 70 | 
         
            +
                for key, param in state_dict.items():
         
     | 
| 71 | 
         
            +
                    state_dict[key] = param.cpu()
         
     | 
| 72 | 
         
            +
                torch.save(state_dict, save_path)
         
     | 
| 73 | 
         
            +
             
     | 
| 74 | 
         
            +
             
     | 
| 75 | 
         
            +
            if __name__ == '__main__':
         
     | 
| 76 | 
         
            +
                
         
     | 
| 77 | 
         
            +
                
         
     | 
| 78 | 
         
            +
            #    from utils import utils_logger
         
     | 
| 79 | 
         
            +
            #    import logging
         
     | 
| 80 | 
         
            +
            #    utils_logger.logger_info('a', 'a.log')
         
     | 
| 81 | 
         
            +
            #    logger = logging.getLogger('a')
         
     | 
| 82 | 
         
            +
            #    
         
     | 
| 83 | 
         
            +
                # mcn = hdf5storage.loadmat('/model_zoo/matfile/FFDNet_Clip_gray.mat')
         
     | 
| 84 | 
         
            +
                mcn = hdf5storage.loadmat('models/modelcolor.mat')
         
     | 
| 85 | 
         
            +
                
         
     | 
| 86 | 
         
            +
                
         
     | 
| 87 | 
         
            +
                #logger.info(mcn['CNNdenoiser'][0][0][0][1][0][0][0][0])
         
     | 
| 88 | 
         
            +
             
     | 
| 89 | 
         
            +
                mat_net = OrderedDict()
         
     | 
| 90 | 
         
            +
                for idx in range(25):
         
     | 
| 91 | 
         
            +
                    mat_net[str(idx)] = OrderedDict()
         
     | 
| 92 | 
         
            +
                    count = -1
         
     | 
| 93 | 
         
            +
                    
         
     | 
| 94 | 
         
            +
                    print(idx)
         
     | 
| 95 | 
         
            +
                    for i in range(13):
         
     | 
| 96 | 
         
            +
                        
         
     | 
| 97 | 
         
            +
                        if mcn['CNNdenoiser'][0][idx][0][i][0][0][0][0] == 'conv':
         
     | 
| 98 | 
         
            +
                            
         
     | 
| 99 | 
         
            +
                            count += 1
         
     | 
| 100 | 
         
            +
                            w = mcn['CNNdenoiser'][0][idx][0][i][0][1][0][0]
         
     | 
| 101 | 
         
            +
                           # print(w.shape)
         
     | 
| 102 | 
         
            +
                            w = weights2tensor(w)
         
     | 
| 103 | 
         
            +
                           # print(w.shape)
         
     | 
| 104 | 
         
            +
                            
         
     | 
| 105 | 
         
            +
                            b = mcn['CNNdenoiser'][0][idx][0][i][0][1][0][1]
         
     | 
| 106 | 
         
            +
                            b = weights2tensor(b)
         
     | 
| 107 | 
         
            +
                            print(b.shape)
         
     | 
| 108 | 
         
            +
             
     | 
| 109 | 
         
            +
                            mat_net[str(idx)]['model.{:d}.weight'.format(count*2)] = w
         
     | 
| 110 | 
         
            +
                            mat_net[str(idx)]['model.{:d}.bias'.format(count*2)] = b
         
     | 
| 111 | 
         
            +
             
     | 
| 112 | 
         
            +
                torch.save(mat_net, 'model_zoo/modelcolor.pth')
         
     | 
| 113 | 
         
            +
               
         
     | 
| 114 | 
         
            +
             
     | 
| 115 | 
         
            +
             
     | 
| 116 | 
         
            +
            #    from models.network_dncnn import IRCNN as net
         
     | 
| 117 | 
         
            +
            #    network = net(in_nc=3, out_nc=3, nc=64)
         
     | 
| 118 | 
         
            +
            #    state_dict = network.state_dict()
         
     | 
| 119 | 
         
            +
            #
         
     | 
| 120 | 
         
            +
            #    #show_kv(state_dict)
         
     | 
| 121 | 
         
            +
            #
         
     | 
| 122 | 
         
            +
            #    for i in range(len(mcn['net'][0][0][0])):
         
     | 
| 123 | 
         
            +
            #        print(mcn['net'][0][0][0][i][0][0][0][0])
         
     | 
| 124 | 
         
            +
            #
         
     | 
| 125 | 
         
            +
            #    count = -1
         
     | 
| 126 | 
         
            +
            #    mat_net = OrderedDict()
         
     | 
| 127 | 
         
            +
            #    for i in range(len(mcn['net'][0][0][0])):
         
     | 
| 128 | 
         
            +
            #        if mcn['net'][0][0][0][i][0][0][0][0] == 'conv':
         
     | 
| 129 | 
         
            +
            #            
         
     | 
| 130 | 
         
            +
            #            count += 1
         
     | 
| 131 | 
         
            +
            #            w = mcn['net'][0][0][0][i][0][1][0][0]
         
     | 
| 132 | 
         
            +
            #            print(w.shape)
         
     | 
| 133 | 
         
            +
            #            w = weights2tensor(w)
         
     | 
| 134 | 
         
            +
            #            print(w.shape)
         
     | 
| 135 | 
         
            +
            #            
         
     | 
| 136 | 
         
            +
            #            b = mcn['net'][0][0][0][i][0][1][0][1]
         
     | 
| 137 | 
         
            +
            #            b = weights2tensor(b)
         
     | 
| 138 | 
         
            +
            #            print(b.shape)
         
     | 
| 139 | 
         
            +
            #            
         
     | 
| 140 | 
         
            +
            #            mat_net['model.{:d}.weight'.format(count*2)] = w
         
     | 
| 141 | 
         
            +
            #            mat_net['model.{:d}.bias'.format(count*2)] = b
         
     | 
| 142 | 
         
            +
            #
         
     | 
| 143 | 
         
            +
            #    torch.save(mat_net, 'E:/pytorch/KAIR_ongoing/model_zoo/ffdnet_gray_clip.pth')
         
     | 
| 144 | 
         
            +
            #    
         
     | 
| 145 | 
         
            +
            #    
         
     | 
| 146 | 
         
            +
            #
         
     | 
| 147 | 
         
            +
            #    crt_net = torch.load('E:/pytorch/KAIR_ongoing/model_zoo/imdn_x4.pth')
         
     | 
| 148 | 
         
            +
            #    def show_kv(net):
         
     | 
| 149 | 
         
            +
            #        for k, v in net.items():
         
     | 
| 150 | 
         
            +
            #            print(k)
         
     | 
| 151 | 
         
            +
            #
         
     | 
| 152 | 
         
            +
            #    show_kv(crt_net)
         
     | 
| 153 | 
         
            +
             
     | 
| 154 | 
         
            +
             
     | 
| 155 | 
         
            +
            #    from models.network_dncnn import DnCNN as net
         
     | 
| 156 | 
         
            +
            #    network = net(in_nc=2, out_nc=1, nc=64, nb=20, act_mode='R')
         
     | 
| 157 | 
         
            +
             
     | 
| 158 | 
         
            +
            #    from models.network_srmd import SRMD as net
         
     | 
| 159 | 
         
            +
            #    #network = net(in_nc=1, out_nc=1, nc=64, nb=15, act_mode='R')
         
     | 
| 160 | 
         
            +
            #    network = net(in_nc=19, out_nc=3, nc=128, nb=12, upscale=4, act_mode='R', upsample_mode='pixelshuffle')
         
     | 
| 161 | 
         
            +
            #    
         
     | 
| 162 | 
         
            +
            #    from models.network_rrdb import RRDB as net
         
     | 
| 163 | 
         
            +
            #    network = net(in_nc=3, out_nc=3, nc=64, nb=23, gc=32, upscale=4, act_mode='L', upsample_mode='upconv')
         
     | 
| 164 | 
         
            +
            #    
         
     | 
| 165 | 
         
            +
            #    state_dict = network.state_dict()
         
     | 
| 166 | 
         
            +
            #    for key, param in state_dict.items():
         
     | 
| 167 | 
         
            +
            #        print(key)
         
     | 
| 168 | 
         
            +
            #    from models.network_imdn import IMDN as net
         
     | 
| 169 | 
         
            +
            #    network = net(in_nc=3, out_nc=3, nc=64, nb=8, upscale=4, act_mode='L', upsample_mode='pixelshuffle')
         
     | 
| 170 | 
         
            +
            #    state_dict = network.state_dict()
         
     | 
| 171 | 
         
            +
            #    mat_net = OrderedDict()
         
     | 
| 172 | 
         
            +
            #    for ((key, param),(key2, param2)) in zip(state_dict.items(), crt_net.items()):
         
     | 
| 173 | 
         
            +
            #        mat_net[key] = param2
         
     | 
| 174 | 
         
            +
            #    torch.save(mat_net, 'model_zoo/imdn_x4_1.pth') 
         
     | 
| 175 | 
         
            +
            #        
         
     | 
| 176 | 
         
            +
             
     | 
| 177 | 
         
            +
            #    net_old = torch.load('net_old.pth')
         
     | 
| 178 | 
         
            +
            #    def show_kv(net):
         
     | 
| 179 | 
         
            +
            #        for k, v in net.items():
         
     | 
| 180 | 
         
            +
            #            print(k)
         
     | 
| 181 | 
         
            +
            #
         
     | 
| 182 | 
         
            +
            #    show_kv(net_old)
         
     | 
| 183 | 
         
            +
            #    from models.network_dpsr import MSRResNet_prior as net
         
     | 
| 184 | 
         
            +
            #    model = net(in_nc=4, out_nc=3, nc=96, nb=16, upscale=4, act_mode='R', upsample_mode='pixelshuffle')
         
     | 
| 185 | 
         
            +
            #    state_dict = network.state_dict()
         
     | 
| 186 | 
         
            +
            #    net_new = OrderedDict()
         
     | 
| 187 | 
         
            +
            #    for ((key, param),(key_old, param_old)) in zip(state_dict.items(), net_old.items()):
         
     | 
| 188 | 
         
            +
            #        net_new[key] = param_old
         
     | 
| 189 | 
         
            +
            #    torch.save(net_new, 'net_new.pth') 
         
     | 
| 190 | 
         
            +
             
     | 
| 191 | 
         
            +
             
     | 
| 192 | 
         
            +
               # print(key)
         
     | 
| 193 | 
         
            +
                  #  print(param.size())
         
     | 
| 194 | 
         
            +
             
     | 
| 195 | 
         
            +
             
     | 
| 196 | 
         
            +
             
     | 
| 197 | 
         
            +
                # run utils/utils_matconvnet.py
         
     | 
    	
        core/data/deg_kair_utils/utils_model.py
    ADDED
    
    | 
         @@ -0,0 +1,330 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            # -*- coding: utf-8 -*-
         
     | 
| 2 | 
         
            +
            import numpy as np
         
     | 
| 3 | 
         
            +
            import torch
         
     | 
| 4 | 
         
            +
            from utils import utils_image as util
         
     | 
| 5 | 
         
            +
            import re
         
     | 
| 6 | 
         
            +
            import glob
         
     | 
| 7 | 
         
            +
            import os
         
     | 
| 8 | 
         
            +
             
     | 
| 9 | 
         
            +
             
     | 
| 10 | 
         
            +
            '''
         
     | 
| 11 | 
         
            +
            # --------------------------------------------
         
     | 
| 12 | 
         
            +
            # Model
         
     | 
| 13 | 
         
            +
            # --------------------------------------------
         
     | 
| 14 | 
         
            +
            # Kai Zhang (github: https://github.com/cszn)
         
     | 
| 15 | 
         
            +
            # 03/Mar/2019
         
     | 
| 16 | 
         
            +
            # --------------------------------------------
         
     | 
| 17 | 
         
            +
            '''
         
     | 
| 18 | 
         
            +
             
     | 
| 19 | 
         
            +
             
     | 
| 20 | 
         
            +
            def find_last_checkpoint(save_dir, net_type='G', pretrained_path=None):
         
     | 
| 21 | 
         
            +
                """
         
     | 
| 22 | 
         
            +
                # ---------------------------------------
         
     | 
| 23 | 
         
            +
                # Kai Zhang (github: https://github.com/cszn)
         
     | 
| 24 | 
         
            +
                # 03/Mar/2019
         
     | 
| 25 | 
         
            +
                # ---------------------------------------
         
     | 
| 26 | 
         
            +
                Args:
         
     | 
| 27 | 
         
            +
                    save_dir: model folder
         
     | 
| 28 | 
         
            +
                    net_type: 'G' or 'D' or 'optimizerG' or 'optimizerD'
         
     | 
| 29 | 
         
            +
                    pretrained_path: pretrained model path. If save_dir does not have any model, load from pretrained_path
         
     | 
| 30 | 
         
            +
             
     | 
| 31 | 
         
            +
                Return:
         
     | 
| 32 | 
         
            +
                    init_iter: iteration number
         
     | 
| 33 | 
         
            +
                    init_path: model path
         
     | 
| 34 | 
         
            +
                # ---------------------------------------
         
     | 
| 35 | 
         
            +
                """
         
     | 
| 36 | 
         
            +
             
     | 
| 37 | 
         
            +
                file_list = glob.glob(os.path.join(save_dir, '*_{}.pth'.format(net_type)))
         
     | 
| 38 | 
         
            +
                if file_list:
         
     | 
| 39 | 
         
            +
                    iter_exist = []
         
     | 
| 40 | 
         
            +
                    for file_ in file_list:
         
     | 
| 41 | 
         
            +
                        iter_current = re.findall(r"(\d+)_{}.pth".format(net_type), file_)
         
     | 
| 42 | 
         
            +
                        iter_exist.append(int(iter_current[0]))
         
     | 
| 43 | 
         
            +
                    init_iter = max(iter_exist)
         
     | 
| 44 | 
         
            +
                    init_path = os.path.join(save_dir, '{}_{}.pth'.format(init_iter, net_type))
         
     | 
| 45 | 
         
            +
                else:
         
     | 
| 46 | 
         
            +
                    init_iter = 0
         
     | 
| 47 | 
         
            +
                    init_path = pretrained_path
         
     | 
| 48 | 
         
            +
                return init_iter, init_path
         
     | 
| 49 | 
         
            +
             
     | 
| 50 | 
         
            +
             
     | 
| 51 | 
         
            +
            def test_mode(model, L, mode=0, refield=32, min_size=256, sf=1, modulo=1):
         
     | 
| 52 | 
         
            +
                '''
         
     | 
| 53 | 
         
            +
                # ---------------------------------------
         
     | 
| 54 | 
         
            +
                # Kai Zhang (github: https://github.com/cszn)
         
     | 
| 55 | 
         
            +
                # 03/Mar/2019
         
     | 
| 56 | 
         
            +
                # ---------------------------------------
         
     | 
| 57 | 
         
            +
                Args:
         
     | 
| 58 | 
         
            +
                    model: trained model
         
     | 
| 59 | 
         
            +
                    L: input Low-quality image
         
     | 
| 60 | 
         
            +
                    mode:
         
     | 
| 61 | 
         
            +
                        (0) normal: test(model, L)
         
     | 
| 62 | 
         
            +
                        (1) pad: test_pad(model, L, modulo=16)
         
     | 
| 63 | 
         
            +
                        (2) split: test_split(model, L, refield=32, min_size=256, sf=1, modulo=1)
         
     | 
| 64 | 
         
            +
                        (3) x8: test_x8(model, L, modulo=1) ^_^
         
     | 
| 65 | 
         
            +
                        (4) split and x8: test_split_x8(model, L, refield=32, min_size=256, sf=1, modulo=1)
         
     | 
| 66 | 
         
            +
                    refield: effective receptive filed of the network, 32 is enough
         
     | 
| 67 | 
         
            +
                        useful when split, i.e., mode=2, 4
         
     | 
| 68 | 
         
            +
                    min_size: min_sizeXmin_size image, e.g., 256X256 image
         
     | 
| 69 | 
         
            +
                        useful when split, i.e., mode=2, 4
         
     | 
| 70 | 
         
            +
                    sf: scale factor for super-resolution, otherwise 1
         
     | 
| 71 | 
         
            +
                    modulo: 1 if split
         
     | 
| 72 | 
         
            +
                        useful when pad, i.e., mode=1
         
     | 
| 73 | 
         
            +
             
     | 
| 74 | 
         
            +
                Returns:
         
     | 
| 75 | 
         
            +
                    E: estimated image
         
     | 
| 76 | 
         
            +
                # ---------------------------------------
         
     | 
| 77 | 
         
            +
                '''
         
     | 
| 78 | 
         
            +
                if mode == 0:
         
     | 
| 79 | 
         
            +
                    E = test(model, L)
         
     | 
| 80 | 
         
            +
                elif mode == 1:
         
     | 
| 81 | 
         
            +
                    E = test_pad(model, L, modulo, sf)
         
     | 
| 82 | 
         
            +
                elif mode == 2:
         
     | 
| 83 | 
         
            +
                    E = test_split(model, L, refield, min_size, sf, modulo)
         
     | 
| 84 | 
         
            +
                elif mode == 3:
         
     | 
| 85 | 
         
            +
                    E = test_x8(model, L, modulo, sf)
         
     | 
| 86 | 
         
            +
                elif mode == 4:
         
     | 
| 87 | 
         
            +
                    E = test_split_x8(model, L, refield, min_size, sf, modulo)
         
     | 
| 88 | 
         
            +
                return E
         
     | 
| 89 | 
         
            +
             
     | 
| 90 | 
         
            +
             
     | 
| 91 | 
         
            +
            '''
         
     | 
| 92 | 
         
            +
            # --------------------------------------------
         
     | 
| 93 | 
         
            +
            # normal (0)
         
     | 
| 94 | 
         
            +
            # --------------------------------------------
         
     | 
| 95 | 
         
            +
            '''
         
     | 
| 96 | 
         
            +
             
     | 
| 97 | 
         
            +
             
     | 
| 98 | 
         
            +
            def test(model, L):
         
     | 
| 99 | 
         
            +
                E = model(L)
         
     | 
| 100 | 
         
            +
                return E
         
     | 
| 101 | 
         
            +
             
     | 
| 102 | 
         
            +
             
     | 
| 103 | 
         
            +
            '''
         
     | 
| 104 | 
         
            +
            # --------------------------------------------
         
     | 
| 105 | 
         
            +
            # pad (1)
         
     | 
| 106 | 
         
            +
            # --------------------------------------------
         
     | 
| 107 | 
         
            +
            '''
         
     | 
| 108 | 
         
            +
             
     | 
| 109 | 
         
            +
             
     | 
| 110 | 
         
            +
            def test_pad(model, L, modulo=16, sf=1):
         
     | 
| 111 | 
         
            +
                h, w = L.size()[-2:]
         
     | 
| 112 | 
         
            +
                paddingBottom = int(np.ceil(h/modulo)*modulo-h)
         
     | 
| 113 | 
         
            +
                paddingRight = int(np.ceil(w/modulo)*modulo-w)
         
     | 
| 114 | 
         
            +
                L = torch.nn.ReplicationPad2d((0, paddingRight, 0, paddingBottom))(L)
         
     | 
| 115 | 
         
            +
                E = model(L)
         
     | 
| 116 | 
         
            +
                E = E[..., :h*sf, :w*sf]
         
     | 
| 117 | 
         
            +
                return E
         
     | 
| 118 | 
         
            +
             
     | 
| 119 | 
         
            +
             
     | 
| 120 | 
         
            +
            '''
         
     | 
| 121 | 
         
            +
            # --------------------------------------------
         
     | 
| 122 | 
         
            +
            # split (function)
         
     | 
| 123 | 
         
            +
            # --------------------------------------------
         
     | 
| 124 | 
         
            +
            '''
         
     | 
| 125 | 
         
            +
             
     | 
| 126 | 
         
            +
             
     | 
| 127 | 
         
            +
            def test_split_fn(model, L, refield=32, min_size=256, sf=1, modulo=1):
         
     | 
| 128 | 
         
            +
                """
         
     | 
| 129 | 
         
            +
                Args:
         
     | 
| 130 | 
         
            +
                    model: trained model
         
     | 
| 131 | 
         
            +
                    L: input Low-quality image
         
     | 
| 132 | 
         
            +
                    refield: effective receptive filed of the network, 32 is enough
         
     | 
| 133 | 
         
            +
                    min_size: min_sizeXmin_size image, e.g., 256X256 image
         
     | 
| 134 | 
         
            +
                    sf: scale factor for super-resolution, otherwise 1
         
     | 
| 135 | 
         
            +
                    modulo: 1 if split
         
     | 
| 136 | 
         
            +
             
     | 
| 137 | 
         
            +
                Returns:
         
     | 
| 138 | 
         
            +
                    E: estimated result
         
     | 
| 139 | 
         
            +
                """
         
     | 
| 140 | 
         
            +
                h, w = L.size()[-2:]
         
     | 
| 141 | 
         
            +
                if h*w <= min_size**2:
         
     | 
| 142 | 
         
            +
                    L = torch.nn.ReplicationPad2d((0, int(np.ceil(w/modulo)*modulo-w), 0, int(np.ceil(h/modulo)*modulo-h)))(L)
         
     | 
| 143 | 
         
            +
                    E = model(L)
         
     | 
| 144 | 
         
            +
                    E = E[..., :h*sf, :w*sf]
         
     | 
| 145 | 
         
            +
                else:
         
     | 
| 146 | 
         
            +
                    top = slice(0, (h//2//refield+1)*refield)
         
     | 
| 147 | 
         
            +
                    bottom = slice(h - (h//2//refield+1)*refield, h)
         
     | 
| 148 | 
         
            +
                    left = slice(0, (w//2//refield+1)*refield)
         
     | 
| 149 | 
         
            +
                    right = slice(w - (w//2//refield+1)*refield, w)
         
     | 
| 150 | 
         
            +
                    Ls = [L[..., top, left], L[..., top, right], L[..., bottom, left], L[..., bottom, right]]
         
     | 
| 151 | 
         
            +
             
     | 
| 152 | 
         
            +
                    if h * w <= 4*(min_size**2):
         
     | 
| 153 | 
         
            +
                        Es = [model(Ls[i]) for i in range(4)]
         
     | 
| 154 | 
         
            +
                    else:
         
     | 
| 155 | 
         
            +
                        Es = [test_split_fn(model, Ls[i], refield=refield, min_size=min_size, sf=sf, modulo=modulo) for i in range(4)]
         
     | 
| 156 | 
         
            +
             
     | 
| 157 | 
         
            +
                    b, c = Es[0].size()[:2]
         
     | 
| 158 | 
         
            +
                    E = torch.zeros(b, c, sf * h, sf * w).type_as(L)
         
     | 
| 159 | 
         
            +
             
     | 
| 160 | 
         
            +
                    E[..., :h//2*sf, :w//2*sf] = Es[0][..., :h//2*sf, :w//2*sf]
         
     | 
| 161 | 
         
            +
                    E[..., :h//2*sf, w//2*sf:w*sf] = Es[1][..., :h//2*sf, (-w + w//2)*sf:]
         
     | 
| 162 | 
         
            +
                    E[..., h//2*sf:h*sf, :w//2*sf] = Es[2][..., (-h + h//2)*sf:, :w//2*sf]
         
     | 
| 163 | 
         
            +
                    E[..., h//2*sf:h*sf, w//2*sf:w*sf] = Es[3][..., (-h + h//2)*sf:, (-w + w//2)*sf:]
         
     | 
| 164 | 
         
            +
                return E
         
     | 
| 165 | 
         
            +
             
     | 
| 166 | 
         
            +
             
     | 
| 167 | 
         
            +
            '''
         
     | 
| 168 | 
         
            +
            # --------------------------------------------
         
     | 
| 169 | 
         
            +
            # split (2)
         
     | 
| 170 | 
         
            +
            # --------------------------------------------
         
     | 
| 171 | 
         
            +
            '''
         
     | 
| 172 | 
         
            +
             
     | 
| 173 | 
         
            +
             
     | 
| 174 | 
         
            +
            def test_split(model, L, refield=32, min_size=256, sf=1, modulo=1):
         
     | 
| 175 | 
         
            +
                E = test_split_fn(model, L, refield=refield, min_size=min_size, sf=sf, modulo=modulo)
         
     | 
| 176 | 
         
            +
                return E
         
     | 
| 177 | 
         
            +
             
     | 
| 178 | 
         
            +
             
     | 
| 179 | 
         
            +
            '''
         
     | 
| 180 | 
         
            +
            # --------------------------------------------
         
     | 
| 181 | 
         
            +
            # x8 (3)
         
     | 
| 182 | 
         
            +
            # --------------------------------------------
         
     | 
| 183 | 
         
            +
            '''
         
     | 
| 184 | 
         
            +
             
     | 
| 185 | 
         
            +
             
     | 
| 186 | 
         
            +
            def test_x8(model, L, modulo=1, sf=1):
         
     | 
| 187 | 
         
            +
                E_list = [test_pad(model, util.augment_img_tensor4(L, mode=i), modulo=modulo, sf=sf) for i in range(8)]
         
     | 
| 188 | 
         
            +
                for i in range(len(E_list)):
         
     | 
| 189 | 
         
            +
                    if i == 3 or i == 5:
         
     | 
| 190 | 
         
            +
                        E_list[i] = util.augment_img_tensor4(E_list[i], mode=8 - i)
         
     | 
| 191 | 
         
            +
                    else:
         
     | 
| 192 | 
         
            +
                        E_list[i] = util.augment_img_tensor4(E_list[i], mode=i)
         
     | 
| 193 | 
         
            +
                output_cat = torch.stack(E_list, dim=0)
         
     | 
| 194 | 
         
            +
                E = output_cat.mean(dim=0, keepdim=False)
         
     | 
| 195 | 
         
            +
                return E
         
     | 
| 196 | 
         
            +
             
     | 
| 197 | 
         
            +
             
     | 
| 198 | 
         
            +
            '''
         
     | 
| 199 | 
         
            +
            # --------------------------------------------
         
     | 
| 200 | 
         
            +
            # split and x8 (4)
         
     | 
| 201 | 
         
            +
            # --------------------------------------------
         
     | 
| 202 | 
         
            +
            '''
         
     | 
| 203 | 
         
            +
             
     | 
| 204 | 
         
            +
             
     | 
| 205 | 
         
            +
            def test_split_x8(model, L, refield=32, min_size=256, sf=1, modulo=1):
         
     | 
| 206 | 
         
            +
                E_list = [test_split_fn(model, util.augment_img_tensor4(L, mode=i), refield=refield, min_size=min_size, sf=sf, modulo=modulo) for i in range(8)]
         
     | 
| 207 | 
         
            +
                for k, i in enumerate(range(len(E_list))):
         
     | 
| 208 | 
         
            +
                    if i==3 or i==5:
         
     | 
| 209 | 
         
            +
                        E_list[k] = util.augment_img_tensor4(E_list[k], mode=8-i)
         
     | 
| 210 | 
         
            +
                    else:
         
     | 
| 211 | 
         
            +
                        E_list[k] = util.augment_img_tensor4(E_list[k], mode=i)
         
     | 
| 212 | 
         
            +
                output_cat = torch.stack(E_list, dim=0)
         
     | 
| 213 | 
         
            +
                E = output_cat.mean(dim=0, keepdim=False)
         
     | 
| 214 | 
         
            +
                return E
         
     | 
| 215 | 
         
            +
             
     | 
| 216 | 
         
            +
             
     | 
| 217 | 
         
            +
            '''
         
     | 
| 218 | 
         
            +
            # ^_^-^_^-^_^-^_^-^_^-^_^-^_^-^_^-^_^-^_^-^_^-
         
     | 
| 219 | 
         
            +
            # _^_^-^_^-^_^-^_^-^_^-^_^-^_^-^_^-^_^-^_^-^_^
         
     | 
| 220 | 
         
            +
            # ^_^-^_^-^_^-^_^-^_^-^_^-^_^-^_^-^_^-^_^-^_^-
         
     | 
| 221 | 
         
            +
            '''
         
     | 
| 222 | 
         
            +
             
     | 
| 223 | 
         
            +
             
     | 
| 224 | 
         
            +
            '''
         
     | 
| 225 | 
         
            +
            # --------------------------------------------
         
     | 
| 226 | 
         
            +
            # print
         
     | 
| 227 | 
         
            +
            # --------------------------------------------
         
     | 
| 228 | 
         
            +
            '''
         
     | 
| 229 | 
         
            +
             
     | 
| 230 | 
         
            +
             
     | 
| 231 | 
         
            +
            # --------------------------------------------
         
     | 
| 232 | 
         
            +
            # print model
         
     | 
| 233 | 
         
            +
            # --------------------------------------------
         
     | 
| 234 | 
         
            +
            def print_model(model):
         
     | 
| 235 | 
         
            +
                msg = describe_model(model)
         
     | 
| 236 | 
         
            +
                print(msg)
         
     | 
| 237 | 
         
            +
             
     | 
| 238 | 
         
            +
             
     | 
| 239 | 
         
            +
            # --------------------------------------------
         
     | 
| 240 | 
         
            +
            # print params
         
     | 
| 241 | 
         
            +
            # --------------------------------------------
         
     | 
| 242 | 
         
            +
            def print_params(model):
         
     | 
| 243 | 
         
            +
                msg = describe_params(model)
         
     | 
| 244 | 
         
            +
                print(msg)
         
     | 
| 245 | 
         
            +
             
     | 
| 246 | 
         
            +
             
     | 
| 247 | 
         
            +
            '''
         
     | 
| 248 | 
         
            +
            # --------------------------------------------
         
     | 
| 249 | 
         
            +
            # information
         
     | 
| 250 | 
         
            +
            # --------------------------------------------
         
     | 
| 251 | 
         
            +
            '''
         
     | 
| 252 | 
         
            +
             
     | 
| 253 | 
         
            +
             
     | 
| 254 | 
         
            +
            # --------------------------------------------
         
     | 
| 255 | 
         
            +
            # model inforation
         
     | 
| 256 | 
         
            +
            # --------------------------------------------
         
     | 
| 257 | 
         
            +
            def info_model(model):
         
     | 
| 258 | 
         
            +
                msg = describe_model(model)
         
     | 
| 259 | 
         
            +
                return msg
         
     | 
| 260 | 
         
            +
             
     | 
| 261 | 
         
            +
             
     | 
| 262 | 
         
            +
            # --------------------------------------------
         
     | 
| 263 | 
         
            +
            # params inforation
         
     | 
| 264 | 
         
            +
            # --------------------------------------------
         
     | 
| 265 | 
         
            +
            def info_params(model):
         
     | 
| 266 | 
         
            +
                msg = describe_params(model)
         
     | 
| 267 | 
         
            +
                return msg
         
     | 
| 268 | 
         
            +
             
     | 
| 269 | 
         
            +
             
     | 
| 270 | 
         
            +
            '''
         
     | 
| 271 | 
         
            +
            # --------------------------------------------
         
     | 
| 272 | 
         
            +
            # description
         
     | 
| 273 | 
         
            +
            # --------------------------------------------
         
     | 
| 274 | 
         
            +
            '''
         
     | 
| 275 | 
         
            +
             
     | 
| 276 | 
         
            +
             
     | 
| 277 | 
         
            +
            # --------------------------------------------
         
     | 
| 278 | 
         
            +
            # model name and total number of parameters
         
     | 
| 279 | 
         
            +
            # --------------------------------------------
         
     | 
| 280 | 
         
            +
            def describe_model(model):
         
     | 
| 281 | 
         
            +
                if isinstance(model, torch.nn.DataParallel):
         
     | 
| 282 | 
         
            +
                    model = model.module
         
     | 
| 283 | 
         
            +
                msg = '\n'
         
     | 
| 284 | 
         
            +
                msg += 'models name: {}'.format(model.__class__.__name__) + '\n'
         
     | 
| 285 | 
         
            +
                msg += 'Params number: {}'.format(sum(map(lambda x: x.numel(), model.parameters()))) + '\n'
         
     | 
| 286 | 
         
            +
                msg += 'Net structure:\n{}'.format(str(model)) + '\n'
         
     | 
| 287 | 
         
            +
                return msg
         
     | 
| 288 | 
         
            +
             
     | 
| 289 | 
         
            +
             
     | 
| 290 | 
         
            +
            # --------------------------------------------
         
     | 
| 291 | 
         
            +
            # parameters description
         
     | 
| 292 | 
         
            +
            # --------------------------------------------
         
     | 
| 293 | 
         
            +
            def describe_params(model):
         
     | 
| 294 | 
         
            +
                if isinstance(model, torch.nn.DataParallel):
         
     | 
| 295 | 
         
            +
                    model = model.module
         
     | 
| 296 | 
         
            +
                msg = '\n'
         
     | 
| 297 | 
         
            +
                msg += ' | {:^6s} | {:^6s} | {:^6s} | {:^6s} || {:<20s}'.format('mean', 'min', 'max', 'std', 'shape', 'param_name') + '\n'
         
     | 
| 298 | 
         
            +
                for name, param in model.state_dict().items():
         
     | 
| 299 | 
         
            +
                    if not 'num_batches_tracked' in name:
         
     | 
| 300 | 
         
            +
                        v = param.data.clone().float()
         
     | 
| 301 | 
         
            +
                        msg += ' | {:>6.3f} | {:>6.3f} | {:>6.3f} | {:>6.3f} | {} || {:s}'.format(v.mean(), v.min(), v.max(), v.std(), v.shape, name) + '\n'
         
     | 
| 302 | 
         
            +
                return msg
         
     | 
| 303 | 
         
            +
             
     | 
| 304 | 
         
            +
             
     | 
| 305 | 
         
            +
            if __name__ == '__main__':
         
     | 
| 306 | 
         
            +
             
     | 
| 307 | 
         
            +
                class Net(torch.nn.Module):
         
     | 
| 308 | 
         
            +
                    def __init__(self, in_channels=3, out_channels=3):
         
     | 
| 309 | 
         
            +
                        super(Net, self).__init__()
         
     | 
| 310 | 
         
            +
                        self.conv = torch.nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=3, padding=1)
         
     | 
| 311 | 
         
            +
             
     | 
| 312 | 
         
            +
                    def forward(self, x):
         
     | 
| 313 | 
         
            +
                        x = self.conv(x)
         
     | 
| 314 | 
         
            +
                        return x
         
     | 
| 315 | 
         
            +
             
     | 
| 316 | 
         
            +
                start = torch.cuda.Event(enable_timing=True)
         
     | 
| 317 | 
         
            +
                end = torch.cuda.Event(enable_timing=True)
         
     | 
| 318 | 
         
            +
             
     | 
| 319 | 
         
            +
                model = Net()
         
     | 
| 320 | 
         
            +
                model = model.eval()
         
     | 
| 321 | 
         
            +
                print_model(model)
         
     | 
| 322 | 
         
            +
                print_params(model)
         
     | 
| 323 | 
         
            +
                x = torch.randn((2,3,401,401))
         
     | 
| 324 | 
         
            +
                torch.cuda.empty_cache()
         
     | 
| 325 | 
         
            +
                with torch.no_grad():
         
     | 
| 326 | 
         
            +
                    for mode in range(5):
         
     | 
| 327 | 
         
            +
                        y = test_mode(model, x, mode, refield=32, min_size=256, sf=1, modulo=1)
         
     | 
| 328 | 
         
            +
                        print(y.shape)
         
     | 
| 329 | 
         
            +
             
     | 
| 330 | 
         
            +
                # run utils/utils_model.py
         
     | 
    	
        core/data/deg_kair_utils/utils_modelsummary.py
    ADDED
    
    | 
         @@ -0,0 +1,485 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            import torch.nn as nn
         
     | 
| 2 | 
         
            +
            import torch
         
     | 
| 3 | 
         
            +
            import numpy as np
         
     | 
| 4 | 
         
            +
             
     | 
| 5 | 
         
            +
            '''
         
     | 
| 6 | 
         
            +
            ---- 1) FLOPs: floating point operations
         
     | 
| 7 | 
         
            +
            ---- 2) #Activations: the number of elements of all ‘Conv2d’ outputs
         
     | 
| 8 | 
         
            +
            ---- 3) #Conv2d: the number of ‘Conv2d’ layers
         
     | 
| 9 | 
         
            +
            # --------------------------------------------
         
     | 
| 10 | 
         
            +
            # Kai Zhang (github: https://github.com/cszn)
         
     | 
| 11 | 
         
            +
            # 21/July/2020
         
     | 
| 12 | 
         
            +
            # --------------------------------------------
         
     | 
| 13 | 
         
            +
            # Reference
         
     | 
| 14 | 
         
            +
            https://github.com/sovrasov/flops-counter.pytorch.git
         
     | 
| 15 | 
         
            +
             
     | 
| 16 | 
         
            +
            # If you use this code, please consider the following citation:
         
     | 
| 17 | 
         
            +
             
     | 
| 18 | 
         
            +
            @inproceedings{zhang2020aim, % 
         
     | 
| 19 | 
         
            +
              title={AIM 2020 Challenge on Efficient Super-Resolution: Methods and Results},
         
     | 
| 20 | 
         
            +
              author={Kai Zhang and Martin Danelljan and Yawei Li and Radu Timofte and others},
         
     | 
| 21 | 
         
            +
              booktitle={European Conference on Computer Vision Workshops},
         
     | 
| 22 | 
         
            +
              year={2020}
         
     | 
| 23 | 
         
            +
            }
         
     | 
| 24 | 
         
            +
            # --------------------------------------------
         
     | 
| 25 | 
         
            +
            '''
         
     | 
| 26 | 
         
            +
             
     | 
| 27 | 
         
            +
            def get_model_flops(model, input_res, print_per_layer_stat=True,
         
     | 
| 28 | 
         
            +
                                          input_constructor=None):
         
     | 
| 29 | 
         
            +
                assert type(input_res) is tuple, 'Please provide the size of the input image.'
         
     | 
| 30 | 
         
            +
                assert len(input_res) >= 3, 'Input image should have 3 dimensions.'
         
     | 
| 31 | 
         
            +
                flops_model = add_flops_counting_methods(model)
         
     | 
| 32 | 
         
            +
                flops_model.eval().start_flops_count()
         
     | 
| 33 | 
         
            +
                if input_constructor:
         
     | 
| 34 | 
         
            +
                    input = input_constructor(input_res)
         
     | 
| 35 | 
         
            +
                    _ = flops_model(**input)
         
     | 
| 36 | 
         
            +
                else:
         
     | 
| 37 | 
         
            +
                    device = list(flops_model.parameters())[-1].device
         
     | 
| 38 | 
         
            +
                    batch = torch.FloatTensor(1, *input_res).to(device)
         
     | 
| 39 | 
         
            +
                    _ = flops_model(batch)
         
     | 
| 40 | 
         
            +
             
     | 
| 41 | 
         
            +
                if print_per_layer_stat:
         
     | 
| 42 | 
         
            +
                    print_model_with_flops(flops_model)
         
     | 
| 43 | 
         
            +
                flops_count = flops_model.compute_average_flops_cost()
         
     | 
| 44 | 
         
            +
                flops_model.stop_flops_count()
         
     | 
| 45 | 
         
            +
             
     | 
| 46 | 
         
            +
                return flops_count
         
     | 
| 47 | 
         
            +
             
     | 
| 48 | 
         
            +
            def get_model_activation(model, input_res, input_constructor=None):
         
     | 
| 49 | 
         
            +
                assert type(input_res) is tuple, 'Please provide the size of the input image.'
         
     | 
| 50 | 
         
            +
                assert len(input_res) >= 3, 'Input image should have 3 dimensions.'
         
     | 
| 51 | 
         
            +
                activation_model = add_activation_counting_methods(model)
         
     | 
| 52 | 
         
            +
                activation_model.eval().start_activation_count()
         
     | 
| 53 | 
         
            +
                if input_constructor:
         
     | 
| 54 | 
         
            +
                    input = input_constructor(input_res)
         
     | 
| 55 | 
         
            +
                    _ = activation_model(**input)
         
     | 
| 56 | 
         
            +
                else:
         
     | 
| 57 | 
         
            +
                    device = list(activation_model.parameters())[-1].device
         
     | 
| 58 | 
         
            +
                    batch = torch.FloatTensor(1, *input_res).to(device)
         
     | 
| 59 | 
         
            +
                    _ = activation_model(batch)
         
     | 
| 60 | 
         
            +
             
     | 
| 61 | 
         
            +
                activation_count, num_conv = activation_model.compute_average_activation_cost()
         
     | 
| 62 | 
         
            +
                activation_model.stop_activation_count()
         
     | 
| 63 | 
         
            +
             
     | 
| 64 | 
         
            +
                return activation_count, num_conv
         
     | 
| 65 | 
         
            +
             
     | 
| 66 | 
         
            +
             
     | 
| 67 | 
         
            +
            def get_model_complexity_info(model, input_res, print_per_layer_stat=True, as_strings=True,
         
     | 
| 68 | 
         
            +
                                          input_constructor=None):
         
     | 
| 69 | 
         
            +
                assert type(input_res) is tuple
         
     | 
| 70 | 
         
            +
                assert len(input_res) >= 3
         
     | 
| 71 | 
         
            +
                flops_model = add_flops_counting_methods(model)
         
     | 
| 72 | 
         
            +
                flops_model.eval().start_flops_count()
         
     | 
| 73 | 
         
            +
                if input_constructor:
         
     | 
| 74 | 
         
            +
                    input = input_constructor(input_res)
         
     | 
| 75 | 
         
            +
                    _ = flops_model(**input)
         
     | 
| 76 | 
         
            +
                else:
         
     | 
| 77 | 
         
            +
                    batch = torch.FloatTensor(1, *input_res)
         
     | 
| 78 | 
         
            +
                    _ = flops_model(batch)
         
     | 
| 79 | 
         
            +
             
     | 
| 80 | 
         
            +
                if print_per_layer_stat:
         
     | 
| 81 | 
         
            +
                    print_model_with_flops(flops_model)
         
     | 
| 82 | 
         
            +
                flops_count = flops_model.compute_average_flops_cost()
         
     | 
| 83 | 
         
            +
                params_count = get_model_parameters_number(flops_model)
         
     | 
| 84 | 
         
            +
                flops_model.stop_flops_count()
         
     | 
| 85 | 
         
            +
             
     | 
| 86 | 
         
            +
                if as_strings:
         
     | 
| 87 | 
         
            +
                    return flops_to_string(flops_count), params_to_string(params_count)
         
     | 
| 88 | 
         
            +
             
     | 
| 89 | 
         
            +
                return flops_count, params_count
         
     | 
| 90 | 
         
            +
             
     | 
| 91 | 
         
            +
             
     | 
| 92 | 
         
            +
            def flops_to_string(flops, units='GMac', precision=2):
         
     | 
| 93 | 
         
            +
                if units is None:
         
     | 
| 94 | 
         
            +
                    if flops // 10**9 > 0:
         
     | 
| 95 | 
         
            +
                        return str(round(flops / 10.**9, precision)) + ' GMac'
         
     | 
| 96 | 
         
            +
                    elif flops // 10**6 > 0:
         
     | 
| 97 | 
         
            +
                        return str(round(flops / 10.**6, precision)) + ' MMac'
         
     | 
| 98 | 
         
            +
                    elif flops // 10**3 > 0:
         
     | 
| 99 | 
         
            +
                        return str(round(flops / 10.**3, precision)) + ' KMac'
         
     | 
| 100 | 
         
            +
                    else:
         
     | 
| 101 | 
         
            +
                        return str(flops) + ' Mac'
         
     | 
| 102 | 
         
            +
                else:
         
     | 
| 103 | 
         
            +
                    if units == 'GMac':
         
     | 
| 104 | 
         
            +
                        return str(round(flops / 10.**9, precision)) + ' ' + units
         
     | 
| 105 | 
         
            +
                    elif units == 'MMac':
         
     | 
| 106 | 
         
            +
                        return str(round(flops / 10.**6, precision)) + ' ' + units
         
     | 
| 107 | 
         
            +
                    elif units == 'KMac':
         
     | 
| 108 | 
         
            +
                        return str(round(flops / 10.**3, precision)) + ' ' + units
         
     | 
| 109 | 
         
            +
                    else:
         
     | 
| 110 | 
         
            +
                        return str(flops) + ' Mac'
         
     | 
| 111 | 
         
            +
             
     | 
| 112 | 
         
            +
             
     | 
| 113 | 
         
            +
            def params_to_string(params_num):
         
     | 
| 114 | 
         
            +
                if params_num // 10 ** 6 > 0:
         
     | 
| 115 | 
         
            +
                    return str(round(params_num / 10 ** 6, 2)) + ' M'
         
     | 
| 116 | 
         
            +
                elif params_num // 10 ** 3:
         
     | 
| 117 | 
         
            +
                    return str(round(params_num / 10 ** 3, 2)) + ' k'
         
     | 
| 118 | 
         
            +
                else:
         
     | 
| 119 | 
         
            +
                    return str(params_num)
         
     | 
| 120 | 
         
            +
             
     | 
| 121 | 
         
            +
             
     | 
| 122 | 
         
            +
            def print_model_with_flops(model, units='GMac', precision=3):
         
     | 
| 123 | 
         
            +
                total_flops = model.compute_average_flops_cost()
         
     | 
| 124 | 
         
            +
             
     | 
| 125 | 
         
            +
                def accumulate_flops(self):
         
     | 
| 126 | 
         
            +
                    if is_supported_instance(self):
         
     | 
| 127 | 
         
            +
                        return self.__flops__ / model.__batch_counter__
         
     | 
| 128 | 
         
            +
                    else:
         
     | 
| 129 | 
         
            +
                        sum = 0
         
     | 
| 130 | 
         
            +
                        for m in self.children():
         
     | 
| 131 | 
         
            +
                            sum += m.accumulate_flops()
         
     | 
| 132 | 
         
            +
                        return sum
         
     | 
| 133 | 
         
            +
             
     | 
| 134 | 
         
            +
                def flops_repr(self):
         
     | 
| 135 | 
         
            +
                    accumulated_flops_cost = self.accumulate_flops()
         
     | 
| 136 | 
         
            +
                    return ', '.join([flops_to_string(accumulated_flops_cost, units=units, precision=precision),
         
     | 
| 137 | 
         
            +
                                      '{:.3%} MACs'.format(accumulated_flops_cost / total_flops),
         
     | 
| 138 | 
         
            +
                                      self.original_extra_repr()])
         
     | 
| 139 | 
         
            +
             
     | 
| 140 | 
         
            +
                def add_extra_repr(m):
         
     | 
| 141 | 
         
            +
                    m.accumulate_flops = accumulate_flops.__get__(m)
         
     | 
| 142 | 
         
            +
                    flops_extra_repr = flops_repr.__get__(m)
         
     | 
| 143 | 
         
            +
                    if m.extra_repr != flops_extra_repr:
         
     | 
| 144 | 
         
            +
                        m.original_extra_repr = m.extra_repr
         
     | 
| 145 | 
         
            +
                        m.extra_repr = flops_extra_repr
         
     | 
| 146 | 
         
            +
                        assert m.extra_repr != m.original_extra_repr
         
     | 
| 147 | 
         
            +
             
     | 
| 148 | 
         
            +
                def del_extra_repr(m):
         
     | 
| 149 | 
         
            +
                    if hasattr(m, 'original_extra_repr'):
         
     | 
| 150 | 
         
            +
                        m.extra_repr = m.original_extra_repr
         
     | 
| 151 | 
         
            +
                        del m.original_extra_repr
         
     | 
| 152 | 
         
            +
                    if hasattr(m, 'accumulate_flops'):
         
     | 
| 153 | 
         
            +
                        del m.accumulate_flops
         
     | 
| 154 | 
         
            +
             
     | 
| 155 | 
         
            +
                model.apply(add_extra_repr)
         
     | 
| 156 | 
         
            +
                print(model)
         
     | 
| 157 | 
         
            +
                model.apply(del_extra_repr)
         
     | 
| 158 | 
         
            +
             
     | 
| 159 | 
         
            +
             
     | 
| 160 | 
         
            +
            def get_model_parameters_number(model):
         
     | 
| 161 | 
         
            +
                params_num = sum(p.numel() for p in model.parameters() if p.requires_grad)
         
     | 
| 162 | 
         
            +
                return params_num
         
     | 
| 163 | 
         
            +
             
     | 
| 164 | 
         
            +
             
     | 
| 165 | 
         
            +
            def add_flops_counting_methods(net_main_module):
         
     | 
| 166 | 
         
            +
                # adding additional methods to the existing module object,
         
     | 
| 167 | 
         
            +
                # this is done this way so that each function has access to self object
         
     | 
| 168 | 
         
            +
                # embed()
         
     | 
| 169 | 
         
            +
                net_main_module.start_flops_count = start_flops_count.__get__(net_main_module)
         
     | 
| 170 | 
         
            +
                net_main_module.stop_flops_count = stop_flops_count.__get__(net_main_module)
         
     | 
| 171 | 
         
            +
                net_main_module.reset_flops_count = reset_flops_count.__get__(net_main_module)
         
     | 
| 172 | 
         
            +
                net_main_module.compute_average_flops_cost = compute_average_flops_cost.__get__(net_main_module)
         
     | 
| 173 | 
         
            +
             
     | 
| 174 | 
         
            +
                net_main_module.reset_flops_count()
         
     | 
| 175 | 
         
            +
                return net_main_module
         
     | 
| 176 | 
         
            +
             
     | 
| 177 | 
         
            +
             
     | 
| 178 | 
         
            +
            def compute_average_flops_cost(self):
         
     | 
| 179 | 
         
            +
                """
         
     | 
| 180 | 
         
            +
                A method that will be available after add_flops_counting_methods() is called
         
     | 
| 181 | 
         
            +
                on a desired net object.
         
     | 
| 182 | 
         
            +
             
     | 
| 183 | 
         
            +
                Returns current mean flops consumption per image.
         
     | 
| 184 | 
         
            +
             
     | 
| 185 | 
         
            +
                """
         
     | 
| 186 | 
         
            +
             
     | 
| 187 | 
         
            +
                flops_sum = 0
         
     | 
| 188 | 
         
            +
                for module in self.modules():
         
     | 
| 189 | 
         
            +
                    if is_supported_instance(module):
         
     | 
| 190 | 
         
            +
                        flops_sum += module.__flops__
         
     | 
| 191 | 
         
            +
             
     | 
| 192 | 
         
            +
                return flops_sum
         
     | 
| 193 | 
         
            +
             
     | 
| 194 | 
         
            +
             
     | 
| 195 | 
         
            +
            def start_flops_count(self):
         
     | 
| 196 | 
         
            +
                """
         
     | 
| 197 | 
         
            +
                A method that will be available after add_flops_counting_methods() is called
         
     | 
| 198 | 
         
            +
                on a desired net object.
         
     | 
| 199 | 
         
            +
             
     | 
| 200 | 
         
            +
                Activates the computation of mean flops consumption per image.
         
     | 
| 201 | 
         
            +
                Call it before you run the network.
         
     | 
| 202 | 
         
            +
             
     | 
| 203 | 
         
            +
                """
         
     | 
| 204 | 
         
            +
                self.apply(add_flops_counter_hook_function)
         
     | 
| 205 | 
         
            +
             
     | 
| 206 | 
         
            +
             
     | 
| 207 | 
         
            +
            def stop_flops_count(self):
         
     | 
| 208 | 
         
            +
                """
         
     | 
| 209 | 
         
            +
                A method that will be available after add_flops_counting_methods() is called
         
     | 
| 210 | 
         
            +
                on a desired net object.
         
     | 
| 211 | 
         
            +
             
     | 
| 212 | 
         
            +
                Stops computing the mean flops consumption per image.
         
     | 
| 213 | 
         
            +
                Call whenever you want to pause the computation.
         
     | 
| 214 | 
         
            +
             
     | 
| 215 | 
         
            +
                """
         
     | 
| 216 | 
         
            +
                self.apply(remove_flops_counter_hook_function)
         
     | 
| 217 | 
         
            +
             
     | 
| 218 | 
         
            +
             
     | 
| 219 | 
         
            +
            def reset_flops_count(self):
         
     | 
| 220 | 
         
            +
                """
         
     | 
| 221 | 
         
            +
                A method that will be available after add_flops_counting_methods() is called
         
     | 
| 222 | 
         
            +
                on a desired net object.
         
     | 
| 223 | 
         
            +
             
     | 
| 224 | 
         
            +
                Resets statistics computed so far.
         
     | 
| 225 | 
         
            +
             
     | 
| 226 | 
         
            +
                """
         
     | 
| 227 | 
         
            +
                self.apply(add_flops_counter_variable_or_reset)
         
     | 
| 228 | 
         
            +
             
     | 
| 229 | 
         
            +
             
     | 
| 230 | 
         
            +
            def add_flops_counter_hook_function(module):
         
     | 
| 231 | 
         
            +
                if is_supported_instance(module):
         
     | 
| 232 | 
         
            +
                    if hasattr(module, '__flops_handle__'):
         
     | 
| 233 | 
         
            +
                        return
         
     | 
| 234 | 
         
            +
             
     | 
| 235 | 
         
            +
                    if isinstance(module, (nn.Conv2d, nn.Conv3d, nn.ConvTranspose2d)):
         
     | 
| 236 | 
         
            +
                        handle = module.register_forward_hook(conv_flops_counter_hook)
         
     | 
| 237 | 
         
            +
                    elif isinstance(module, (nn.ReLU, nn.PReLU, nn.ELU, nn.LeakyReLU, nn.ReLU6)):
         
     | 
| 238 | 
         
            +
                        handle = module.register_forward_hook(relu_flops_counter_hook)
         
     | 
| 239 | 
         
            +
                    elif isinstance(module, nn.Linear):
         
     | 
| 240 | 
         
            +
                        handle = module.register_forward_hook(linear_flops_counter_hook)
         
     | 
| 241 | 
         
            +
                    elif isinstance(module, (nn.BatchNorm2d)):
         
     | 
| 242 | 
         
            +
                        handle = module.register_forward_hook(bn_flops_counter_hook)
         
     | 
| 243 | 
         
            +
                    else:
         
     | 
| 244 | 
         
            +
                        handle = module.register_forward_hook(empty_flops_counter_hook)
         
     | 
| 245 | 
         
            +
                    module.__flops_handle__ = handle
         
     | 
| 246 | 
         
            +
             
     | 
| 247 | 
         
            +
             
     | 
| 248 | 
         
            +
            def remove_flops_counter_hook_function(module):
         
     | 
| 249 | 
         
            +
                if is_supported_instance(module):
         
     | 
| 250 | 
         
            +
                    if hasattr(module, '__flops_handle__'):
         
     | 
| 251 | 
         
            +
                        module.__flops_handle__.remove()
         
     | 
| 252 | 
         
            +
                        del module.__flops_handle__
         
     | 
| 253 | 
         
            +
             
     | 
| 254 | 
         
            +
             
     | 
| 255 | 
         
            +
            def add_flops_counter_variable_or_reset(module):
         
     | 
| 256 | 
         
            +
                if is_supported_instance(module):
         
     | 
| 257 | 
         
            +
                    module.__flops__ = 0
         
     | 
| 258 | 
         
            +
             
     | 
| 259 | 
         
            +
             
     | 
| 260 | 
         
            +
            # ---- Internal functions
         
     | 
| 261 | 
         
            +
            def is_supported_instance(module):
         
     | 
| 262 | 
         
            +
                if isinstance(module,
         
     | 
| 263 | 
         
            +
                              (
         
     | 
| 264 | 
         
            +
                                      nn.Conv2d, nn.ConvTranspose2d,
         
     | 
| 265 | 
         
            +
                                      nn.BatchNorm2d,
         
     | 
| 266 | 
         
            +
                                      nn.Linear,
         
     | 
| 267 | 
         
            +
                                      nn.ReLU, nn.PReLU, nn.ELU, nn.LeakyReLU, nn.ReLU6,
         
     | 
| 268 | 
         
            +
                              )):
         
     | 
| 269 | 
         
            +
                    return True
         
     | 
| 270 | 
         
            +
             
     | 
| 271 | 
         
            +
                return False
         
     | 
| 272 | 
         
            +
             
     | 
| 273 | 
         
            +
             
     | 
| 274 | 
         
            +
            def conv_flops_counter_hook(conv_module, input, output):
         
     | 
| 275 | 
         
            +
                # Can have multiple inputs, getting the first one
         
     | 
| 276 | 
         
            +
                # input = input[0]
         
     | 
| 277 | 
         
            +
             
     | 
| 278 | 
         
            +
                batch_size = output.shape[0]
         
     | 
| 279 | 
         
            +
                output_dims = list(output.shape[2:])
         
     | 
| 280 | 
         
            +
             
     | 
| 281 | 
         
            +
                kernel_dims = list(conv_module.kernel_size)
         
     | 
| 282 | 
         
            +
                in_channels = conv_module.in_channels
         
     | 
| 283 | 
         
            +
                out_channels = conv_module.out_channels
         
     | 
| 284 | 
         
            +
                groups = conv_module.groups
         
     | 
| 285 | 
         
            +
             
     | 
| 286 | 
         
            +
                filters_per_channel = out_channels // groups
         
     | 
| 287 | 
         
            +
                conv_per_position_flops = np.prod(kernel_dims) * in_channels * filters_per_channel
         
     | 
| 288 | 
         
            +
             
     | 
| 289 | 
         
            +
                active_elements_count = batch_size * np.prod(output_dims)
         
     | 
| 290 | 
         
            +
                overall_conv_flops = int(conv_per_position_flops) * int(active_elements_count)
         
     | 
| 291 | 
         
            +
             
     | 
| 292 | 
         
            +
                # overall_flops = overall_conv_flops
         
     | 
| 293 | 
         
            +
             
     | 
| 294 | 
         
            +
                conv_module.__flops__ += int(overall_conv_flops)
         
     | 
| 295 | 
         
            +
                # conv_module.__output_dims__ = output_dims
         
     | 
| 296 | 
         
            +
             
     | 
| 297 | 
         
            +
             
     | 
| 298 | 
         
            +
            def relu_flops_counter_hook(module, input, output):
         
     | 
| 299 | 
         
            +
                active_elements_count = output.numel()
         
     | 
| 300 | 
         
            +
                module.__flops__ += int(active_elements_count)
         
     | 
| 301 | 
         
            +
                # print(module.__flops__, id(module))
         
     | 
| 302 | 
         
            +
                # print(module)
         
     | 
| 303 | 
         
            +
             
     | 
| 304 | 
         
            +
             
     | 
| 305 | 
         
            +
            def linear_flops_counter_hook(module, input, output):
         
     | 
| 306 | 
         
            +
                input = input[0]
         
     | 
| 307 | 
         
            +
                if len(input.shape) == 1:
         
     | 
| 308 | 
         
            +
                    batch_size = 1
         
     | 
| 309 | 
         
            +
                    module.__flops__ += int(batch_size * input.shape[0] * output.shape[0])
         
     | 
| 310 | 
         
            +
                else:
         
     | 
| 311 | 
         
            +
                    batch_size = input.shape[0]
         
     | 
| 312 | 
         
            +
                    module.__flops__ += int(batch_size * input.shape[1] * output.shape[1])
         
     | 
| 313 | 
         
            +
             
     | 
| 314 | 
         
            +
             
     | 
| 315 | 
         
            +
            def bn_flops_counter_hook(module, input, output):
         
     | 
| 316 | 
         
            +
                # input = input[0]
         
     | 
| 317 | 
         
            +
                # TODO: need to check here
         
     | 
| 318 | 
         
            +
                # batch_flops = np.prod(input.shape)
         
     | 
| 319 | 
         
            +
                # if module.affine:
         
     | 
| 320 | 
         
            +
                #     batch_flops *= 2
         
     | 
| 321 | 
         
            +
                # module.__flops__ += int(batch_flops)
         
     | 
| 322 | 
         
            +
                batch = output.shape[0]
         
     | 
| 323 | 
         
            +
                output_dims = output.shape[2:]
         
     | 
| 324 | 
         
            +
                channels = module.num_features
         
     | 
| 325 | 
         
            +
                batch_flops = batch * channels * np.prod(output_dims)
         
     | 
| 326 | 
         
            +
                if module.affine:
         
     | 
| 327 | 
         
            +
                    batch_flops *= 2
         
     | 
| 328 | 
         
            +
                module.__flops__ += int(batch_flops)
         
     | 
| 329 | 
         
            +
             
     | 
| 330 | 
         
            +
             
     | 
| 331 | 
         
            +
            # ---- Count the number of convolutional layers and the activation
         
     | 
| 332 | 
         
            +
            def add_activation_counting_methods(net_main_module):
         
     | 
| 333 | 
         
            +
                # adding additional methods to the existing module object,
         
     | 
| 334 | 
         
            +
                # this is done this way so that each function has access to self object
         
     | 
| 335 | 
         
            +
                # embed()
         
     | 
| 336 | 
         
            +
                net_main_module.start_activation_count = start_activation_count.__get__(net_main_module)
         
     | 
| 337 | 
         
            +
                net_main_module.stop_activation_count = stop_activation_count.__get__(net_main_module)
         
     | 
| 338 | 
         
            +
                net_main_module.reset_activation_count = reset_activation_count.__get__(net_main_module)
         
     | 
| 339 | 
         
            +
                net_main_module.compute_average_activation_cost = compute_average_activation_cost.__get__(net_main_module)
         
     | 
| 340 | 
         
            +
             
     | 
| 341 | 
         
            +
                net_main_module.reset_activation_count()
         
     | 
| 342 | 
         
            +
                return net_main_module
         
     | 
| 343 | 
         
            +
             
     | 
| 344 | 
         
            +
             
     | 
| 345 | 
         
            +
            def compute_average_activation_cost(self):
         
     | 
| 346 | 
         
            +
                """
         
     | 
| 347 | 
         
            +
                A method that will be available after add_activation_counting_methods() is called
         
     | 
| 348 | 
         
            +
                on a desired net object.
         
     | 
| 349 | 
         
            +
             
     | 
| 350 | 
         
            +
                Returns current mean activation consumption per image.
         
     | 
| 351 | 
         
            +
             
     | 
| 352 | 
         
            +
                """
         
     | 
| 353 | 
         
            +
             
     | 
| 354 | 
         
            +
                activation_sum = 0
         
     | 
| 355 | 
         
            +
                num_conv = 0
         
     | 
| 356 | 
         
            +
                for module in self.modules():
         
     | 
| 357 | 
         
            +
                    if is_supported_instance_for_activation(module):
         
     | 
| 358 | 
         
            +
                        activation_sum += module.__activation__
         
     | 
| 359 | 
         
            +
                        num_conv += module.__num_conv__
         
     | 
| 360 | 
         
            +
                return activation_sum, num_conv
         
     | 
| 361 | 
         
            +
             
     | 
| 362 | 
         
            +
             
     | 
| 363 | 
         
            +
            def start_activation_count(self):
         
     | 
| 364 | 
         
            +
                """
         
     | 
| 365 | 
         
            +
                A method that will be available after add_activation_counting_methods() is called
         
     | 
| 366 | 
         
            +
                on a desired net object.
         
     | 
| 367 | 
         
            +
             
     | 
| 368 | 
         
            +
                Activates the computation of mean activation consumption per image.
         
     | 
| 369 | 
         
            +
                Call it before you run the network.
         
     | 
| 370 | 
         
            +
             
     | 
| 371 | 
         
            +
                """
         
     | 
| 372 | 
         
            +
                self.apply(add_activation_counter_hook_function)
         
     | 
| 373 | 
         
            +
             
     | 
| 374 | 
         
            +
             
     | 
| 375 | 
         
            +
            def stop_activation_count(self):
         
     | 
| 376 | 
         
            +
                """
         
     | 
| 377 | 
         
            +
                A method that will be available after add_activation_counting_methods() is called
         
     | 
| 378 | 
         
            +
                on a desired net object.
         
     | 
| 379 | 
         
            +
             
     | 
| 380 | 
         
            +
                Stops computing the mean activation consumption per image.
         
     | 
| 381 | 
         
            +
                Call whenever you want to pause the computation.
         
     | 
| 382 | 
         
            +
             
     | 
| 383 | 
         
            +
                """
         
     | 
| 384 | 
         
            +
                self.apply(remove_activation_counter_hook_function)
         
     | 
| 385 | 
         
            +
             
     | 
| 386 | 
         
            +
             
     | 
| 387 | 
         
            +
            def reset_activation_count(self):
         
     | 
| 388 | 
         
            +
                """
         
     | 
| 389 | 
         
            +
                A method that will be available after add_activation_counting_methods() is called
         
     | 
| 390 | 
         
            +
                on a desired net object.
         
     | 
| 391 | 
         
            +
             
     | 
| 392 | 
         
            +
                Resets statistics computed so far.
         
     | 
| 393 | 
         
            +
             
     | 
| 394 | 
         
            +
                """
         
     | 
| 395 | 
         
            +
                self.apply(add_activation_counter_variable_or_reset)
         
     | 
| 396 | 
         
            +
             
     | 
| 397 | 
         
            +
             
     | 
| 398 | 
         
            +
            def add_activation_counter_hook_function(module):
         
     | 
| 399 | 
         
            +
                if is_supported_instance_for_activation(module):
         
     | 
| 400 | 
         
            +
                    if hasattr(module, '__activation_handle__'):
         
     | 
| 401 | 
         
            +
                        return
         
     | 
| 402 | 
         
            +
             
     | 
| 403 | 
         
            +
                    if isinstance(module, (nn.Conv2d, nn.ConvTranspose2d)):
         
     | 
| 404 | 
         
            +
                        handle = module.register_forward_hook(conv_activation_counter_hook)
         
     | 
| 405 | 
         
            +
                        module.__activation_handle__ = handle
         
     | 
| 406 | 
         
            +
             
     | 
| 407 | 
         
            +
             
     | 
| 408 | 
         
            +
            def remove_activation_counter_hook_function(module):
         
     | 
| 409 | 
         
            +
                if is_supported_instance_for_activation(module):
         
     | 
| 410 | 
         
            +
                    if hasattr(module, '__activation_handle__'):
         
     | 
| 411 | 
         
            +
                        module.__activation_handle__.remove()
         
     | 
| 412 | 
         
            +
                        del module.__activation_handle__
         
     | 
| 413 | 
         
            +
             
     | 
| 414 | 
         
            +
             
     | 
| 415 | 
         
            +
            def add_activation_counter_variable_or_reset(module):
         
     | 
| 416 | 
         
            +
                if is_supported_instance_for_activation(module):
         
     | 
| 417 | 
         
            +
                    module.__activation__ = 0
         
     | 
| 418 | 
         
            +
                    module.__num_conv__ = 0
         
     | 
| 419 | 
         
            +
             
     | 
| 420 | 
         
            +
             
     | 
| 421 | 
         
            +
            def is_supported_instance_for_activation(module):
         
     | 
| 422 | 
         
            +
                if isinstance(module,
         
     | 
| 423 | 
         
            +
                              (
         
     | 
| 424 | 
         
            +
                                      nn.Conv2d, nn.ConvTranspose2d,
         
     | 
| 425 | 
         
            +
                              )):
         
     | 
| 426 | 
         
            +
                    return True
         
     | 
| 427 | 
         
            +
             
     | 
| 428 | 
         
            +
                return False
         
     | 
| 429 | 
         
            +
             
     | 
| 430 | 
         
            +
            def conv_activation_counter_hook(module, input, output):
         
     | 
| 431 | 
         
            +
                """
         
     | 
| 432 | 
         
            +
                Calculate the activations in the convolutional operation.
         
     | 
| 433 | 
         
            +
                Reference: Ilija Radosavovic, Raj Prateek Kosaraju, Ross Girshick, Kaiming He, Piotr Dollár, Designing Network Design Spaces.
         
     | 
| 434 | 
         
            +
                :param module:
         
     | 
| 435 | 
         
            +
                :param input:
         
     | 
| 436 | 
         
            +
                :param output:
         
     | 
| 437 | 
         
            +
                :return:
         
     | 
| 438 | 
         
            +
                """
         
     | 
| 439 | 
         
            +
                module.__activation__ += output.numel()
         
     | 
| 440 | 
         
            +
                module.__num_conv__ += 1
         
     | 
| 441 | 
         
            +
             
     | 
| 442 | 
         
            +
             
     | 
| 443 | 
         
            +
            def empty_flops_counter_hook(module, input, output):
         
     | 
| 444 | 
         
            +
                module.__flops__ += 0
         
     | 
| 445 | 
         
            +
             
     | 
| 446 | 
         
            +
             
     | 
| 447 | 
         
            +
            def upsample_flops_counter_hook(module, input, output):
         
     | 
| 448 | 
         
            +
                output_size = output[0]
         
     | 
| 449 | 
         
            +
                batch_size = output_size.shape[0]
         
     | 
| 450 | 
         
            +
                output_elements_count = batch_size
         
     | 
| 451 | 
         
            +
                for val in output_size.shape[1:]:
         
     | 
| 452 | 
         
            +
                    output_elements_count *= val
         
     | 
| 453 | 
         
            +
                module.__flops__ += int(output_elements_count)
         
     | 
| 454 | 
         
            +
             
     | 
| 455 | 
         
            +
             
     | 
| 456 | 
         
            +
            def pool_flops_counter_hook(module, input, output):
         
     | 
| 457 | 
         
            +
                input = input[0]
         
     | 
| 458 | 
         
            +
                module.__flops__ += int(np.prod(input.shape))
         
     | 
| 459 | 
         
            +
             
     | 
| 460 | 
         
            +
             
     | 
| 461 | 
         
            +
            def dconv_flops_counter_hook(dconv_module, input, output):
         
     | 
| 462 | 
         
            +
                input = input[0]
         
     | 
| 463 | 
         
            +
             
     | 
| 464 | 
         
            +
                batch_size = input.shape[0]
         
     | 
| 465 | 
         
            +
                output_dims = list(output.shape[2:])
         
     | 
| 466 | 
         
            +
             
     | 
| 467 | 
         
            +
                m_channels, in_channels, kernel_dim1, _, = dconv_module.weight.shape
         
     | 
| 468 | 
         
            +
                out_channels, _, kernel_dim2, _, = dconv_module.projection.shape
         
     | 
| 469 | 
         
            +
                # groups = dconv_module.groups
         
     | 
| 470 | 
         
            +
             
     | 
| 471 | 
         
            +
                # filters_per_channel = out_channels // groups
         
     | 
| 472 | 
         
            +
                conv_per_position_flops1 = kernel_dim1 ** 2 * in_channels * m_channels
         
     | 
| 473 | 
         
            +
                conv_per_position_flops2 = kernel_dim2 ** 2 * out_channels * m_channels
         
     | 
| 474 | 
         
            +
                active_elements_count = batch_size * np.prod(output_dims)
         
     | 
| 475 | 
         
            +
             
     | 
| 476 | 
         
            +
                overall_conv_flops = (conv_per_position_flops1 + conv_per_position_flops2) * active_elements_count
         
     | 
| 477 | 
         
            +
                overall_flops = overall_conv_flops
         
     | 
| 478 | 
         
            +
             
     | 
| 479 | 
         
            +
                dconv_module.__flops__ += int(overall_flops)
         
     | 
| 480 | 
         
            +
                # dconv_module.__output_dims__ = output_dims
         
     | 
| 481 | 
         
            +
             
     | 
| 482 | 
         
            +
             
     | 
| 483 | 
         
            +
             
     | 
| 484 | 
         
            +
             
     | 
| 485 | 
         
            +
             
     | 
    	
        core/data/deg_kair_utils/utils_option.py
    ADDED
    
    | 
         @@ -0,0 +1,255 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            import os
         
     | 
| 2 | 
         
            +
            from collections import OrderedDict
         
     | 
| 3 | 
         
            +
            from datetime import datetime
         
     | 
| 4 | 
         
            +
            import json
         
     | 
| 5 | 
         
            +
            import re
         
     | 
| 6 | 
         
            +
            import glob
         
     | 
| 7 | 
         
            +
             
     | 
| 8 | 
         
            +
             
     | 
| 9 | 
         
            +
            '''
         
     | 
| 10 | 
         
            +
            # --------------------------------------------
         
     | 
| 11 | 
         
            +
            # Kai Zhang (github: https://github.com/cszn)
         
     | 
| 12 | 
         
            +
            # 03/Mar/2019
         
     | 
| 13 | 
         
            +
            # --------------------------------------------
         
     | 
| 14 | 
         
            +
            # https://github.com/xinntao/BasicSR
         
     | 
| 15 | 
         
            +
            # --------------------------------------------
         
     | 
| 16 | 
         
            +
            '''
         
     | 
| 17 | 
         
            +
             
     | 
| 18 | 
         
            +
             
     | 
| 19 | 
         
            +
            def get_timestamp():
         
     | 
| 20 | 
         
            +
                return datetime.now().strftime('_%y%m%d_%H%M%S')
         
     | 
| 21 | 
         
            +
             
     | 
| 22 | 
         
            +
             
     | 
| 23 | 
         
            +
            def parse(opt_path, is_train=True):
         
     | 
| 24 | 
         
            +
             
     | 
| 25 | 
         
            +
                # ----------------------------------------
         
     | 
| 26 | 
         
            +
                # remove comments starting with '//'
         
     | 
| 27 | 
         
            +
                # ----------------------------------------
         
     | 
| 28 | 
         
            +
                json_str = ''
         
     | 
| 29 | 
         
            +
                with open(opt_path, 'r') as f:
         
     | 
| 30 | 
         
            +
                    for line in f:
         
     | 
| 31 | 
         
            +
                        line = line.split('//')[0] + '\n'
         
     | 
| 32 | 
         
            +
                        json_str += line
         
     | 
| 33 | 
         
            +
             
     | 
| 34 | 
         
            +
                # ----------------------------------------
         
     | 
| 35 | 
         
            +
                # initialize opt
         
     | 
| 36 | 
         
            +
                # ----------------------------------------
         
     | 
| 37 | 
         
            +
                opt = json.loads(json_str, object_pairs_hook=OrderedDict)
         
     | 
| 38 | 
         
            +
             
     | 
| 39 | 
         
            +
                opt['opt_path'] = opt_path
         
     | 
| 40 | 
         
            +
                opt['is_train'] = is_train
         
     | 
| 41 | 
         
            +
             
     | 
| 42 | 
         
            +
                # ----------------------------------------
         
     | 
| 43 | 
         
            +
                # set default
         
     | 
| 44 | 
         
            +
                # ----------------------------------------
         
     | 
| 45 | 
         
            +
                if 'merge_bn' not in opt:
         
     | 
| 46 | 
         
            +
                    opt['merge_bn'] = False
         
     | 
| 47 | 
         
            +
                    opt['merge_bn_startpoint'] = -1
         
     | 
| 48 | 
         
            +
             
     | 
| 49 | 
         
            +
                if 'scale' not in opt:
         
     | 
| 50 | 
         
            +
                    opt['scale'] = 1
         
     | 
| 51 | 
         
            +
             
     | 
| 52 | 
         
            +
                # ----------------------------------------
         
     | 
| 53 | 
         
            +
                # datasets
         
     | 
| 54 | 
         
            +
                # ----------------------------------------
         
     | 
| 55 | 
         
            +
                for phase, dataset in opt['datasets'].items():
         
     | 
| 56 | 
         
            +
                    phase = phase.split('_')[0]
         
     | 
| 57 | 
         
            +
                    dataset['phase'] = phase
         
     | 
| 58 | 
         
            +
                    dataset['scale'] = opt['scale']  # broadcast
         
     | 
| 59 | 
         
            +
                    dataset['n_channels'] = opt['n_channels']  # broadcast
         
     | 
| 60 | 
         
            +
                    if 'dataroot_H' in dataset and dataset['dataroot_H'] is not None:
         
     | 
| 61 | 
         
            +
                        dataset['dataroot_H'] = os.path.expanduser(dataset['dataroot_H'])
         
     | 
| 62 | 
         
            +
                    if 'dataroot_L' in dataset and dataset['dataroot_L'] is not None:
         
     | 
| 63 | 
         
            +
                        dataset['dataroot_L'] = os.path.expanduser(dataset['dataroot_L'])
         
     | 
| 64 | 
         
            +
             
     | 
| 65 | 
         
            +
                # ----------------------------------------
         
     | 
| 66 | 
         
            +
                # path
         
     | 
| 67 | 
         
            +
                # ----------------------------------------
         
     | 
| 68 | 
         
            +
                for key, path in opt['path'].items():
         
     | 
| 69 | 
         
            +
                    if path and key in opt['path']:
         
     | 
| 70 | 
         
            +
                        opt['path'][key] = os.path.expanduser(path)
         
     | 
| 71 | 
         
            +
             
     | 
| 72 | 
         
            +
                path_task = os.path.join(opt['path']['root'], opt['task'])
         
     | 
| 73 | 
         
            +
                opt['path']['task'] = path_task
         
     | 
| 74 | 
         
            +
                opt['path']['log'] = path_task
         
     | 
| 75 | 
         
            +
                opt['path']['options'] = os.path.join(path_task, 'options')
         
     | 
| 76 | 
         
            +
             
     | 
| 77 | 
         
            +
                if is_train:
         
     | 
| 78 | 
         
            +
                    opt['path']['models'] = os.path.join(path_task, 'models')
         
     | 
| 79 | 
         
            +
                    opt['path']['images'] = os.path.join(path_task, 'images')
         
     | 
| 80 | 
         
            +
                else:  # test
         
     | 
| 81 | 
         
            +
                    opt['path']['images'] = os.path.join(path_task, 'test_images')
         
     | 
| 82 | 
         
            +
             
     | 
| 83 | 
         
            +
                # ----------------------------------------
         
     | 
| 84 | 
         
            +
                # network
         
     | 
| 85 | 
         
            +
                # ----------------------------------------
         
     | 
| 86 | 
         
            +
                opt['netG']['scale'] = opt['scale'] if 'scale' in opt else 1
         
     | 
| 87 | 
         
            +
             
     | 
| 88 | 
         
            +
                # ----------------------------------------
         
     | 
| 89 | 
         
            +
                # GPU devices
         
     | 
| 90 | 
         
            +
                # ----------------------------------------
         
     | 
| 91 | 
         
            +
                gpu_list = ','.join(str(x) for x in opt['gpu_ids'])
         
     | 
| 92 | 
         
            +
                os.environ['CUDA_VISIBLE_DEVICES'] = gpu_list
         
     | 
| 93 | 
         
            +
                print('export CUDA_VISIBLE_DEVICES=' + gpu_list)
         
     | 
| 94 | 
         
            +
             
     | 
| 95 | 
         
            +
                # ----------------------------------------
         
     | 
| 96 | 
         
            +
                # default setting for distributeddataparallel
         
     | 
| 97 | 
         
            +
                # ----------------------------------------
         
     | 
| 98 | 
         
            +
                if 'find_unused_parameters' not in opt:
         
     | 
| 99 | 
         
            +
                    opt['find_unused_parameters'] = True
         
     | 
| 100 | 
         
            +
                if 'use_static_graph' not in opt:
         
     | 
| 101 | 
         
            +
                    opt['use_static_graph'] = False
         
     | 
| 102 | 
         
            +
                if 'dist' not in opt:
         
     | 
| 103 | 
         
            +
                    opt['dist'] = False
         
     | 
| 104 | 
         
            +
                opt['num_gpu'] = len(opt['gpu_ids'])
         
     | 
| 105 | 
         
            +
                print('number of GPUs is: ' + str(opt['num_gpu']))
         
     | 
| 106 | 
         
            +
             
     | 
| 107 | 
         
            +
                # ----------------------------------------
         
     | 
| 108 | 
         
            +
                # default setting for perceptual loss
         
     | 
| 109 | 
         
            +
                # ----------------------------------------
         
     | 
| 110 | 
         
            +
                if 'F_feature_layer' not in opt['train']:
         
     | 
| 111 | 
         
            +
                    opt['train']['F_feature_layer'] = 34  # 25; [2,7,16,25,34]
         
     | 
| 112 | 
         
            +
                if 'F_weights' not in opt['train']:
         
     | 
| 113 | 
         
            +
                    opt['train']['F_weights'] = 1.0  # 1.0; [0.1,0.1,1.0,1.0,1.0]
         
     | 
| 114 | 
         
            +
                if 'F_lossfn_type' not in opt['train']:
         
     | 
| 115 | 
         
            +
                    opt['train']['F_lossfn_type'] = 'l1'
         
     | 
| 116 | 
         
            +
                if 'F_use_input_norm' not in opt['train']:
         
     | 
| 117 | 
         
            +
                    opt['train']['F_use_input_norm'] = True
         
     | 
| 118 | 
         
            +
                if 'F_use_range_norm' not in opt['train']:
         
     | 
| 119 | 
         
            +
                    opt['train']['F_use_range_norm'] = False
         
     | 
| 120 | 
         
            +
             
     | 
| 121 | 
         
            +
                # ----------------------------------------
         
     | 
| 122 | 
         
            +
                # default setting for optimizer
         
     | 
| 123 | 
         
            +
                # ----------------------------------------
         
     | 
| 124 | 
         
            +
                if 'G_optimizer_type' not in opt['train']:
         
     | 
| 125 | 
         
            +
                    opt['train']['G_optimizer_type'] = "adam"
         
     | 
| 126 | 
         
            +
                if 'G_optimizer_betas' not in opt['train']:
         
     | 
| 127 | 
         
            +
                    opt['train']['G_optimizer_betas'] = [0.9,0.999]
         
     | 
| 128 | 
         
            +
                if 'G_scheduler_restart_weights' not in opt['train']:
         
     | 
| 129 | 
         
            +
                    opt['train']['G_scheduler_restart_weights'] = 1
         
     | 
| 130 | 
         
            +
                if 'G_optimizer_wd' not in opt['train']:
         
     | 
| 131 | 
         
            +
                    opt['train']['G_optimizer_wd'] = 0
         
     | 
| 132 | 
         
            +
                if 'G_optimizer_reuse' not in opt['train']:
         
     | 
| 133 | 
         
            +
                    opt['train']['G_optimizer_reuse'] = False
         
     | 
| 134 | 
         
            +
                if 'netD' in opt and 'D_optimizer_reuse' not in opt['train']:
         
     | 
| 135 | 
         
            +
                    opt['train']['D_optimizer_reuse'] = False
         
     | 
| 136 | 
         
            +
             
     | 
| 137 | 
         
            +
                # ----------------------------------------
         
     | 
| 138 | 
         
            +
                # default setting of strict for model loading
         
     | 
| 139 | 
         
            +
                # ----------------------------------------
         
     | 
| 140 | 
         
            +
                if 'G_param_strict' not in opt['train']:
         
     | 
| 141 | 
         
            +
                    opt['train']['G_param_strict'] = True
         
     | 
| 142 | 
         
            +
                if 'netD' in opt and 'D_param_strict' not in opt['path']:
         
     | 
| 143 | 
         
            +
                    opt['train']['D_param_strict'] = True
         
     | 
| 144 | 
         
            +
                if 'E_param_strict' not in opt['path']:
         
     | 
| 145 | 
         
            +
                    opt['train']['E_param_strict'] = True
         
     | 
| 146 | 
         
            +
             
     | 
| 147 | 
         
            +
                # ----------------------------------------
         
     | 
| 148 | 
         
            +
                # Exponential Moving Average
         
     | 
| 149 | 
         
            +
                # ----------------------------------------
         
     | 
| 150 | 
         
            +
                if 'E_decay' not in opt['train']:
         
     | 
| 151 | 
         
            +
                    opt['train']['E_decay'] = 0
         
     | 
| 152 | 
         
            +
             
     | 
| 153 | 
         
            +
                # ----------------------------------------
         
     | 
| 154 | 
         
            +
                # default setting for discriminator
         
     | 
| 155 | 
         
            +
                # ----------------------------------------
         
     | 
| 156 | 
         
            +
                if 'netD' in opt:
         
     | 
| 157 | 
         
            +
                    if 'net_type' not in opt['netD']:
         
     | 
| 158 | 
         
            +
                        opt['netD']['net_type'] = 'discriminator_patchgan'  # discriminator_unet
         
     | 
| 159 | 
         
            +
                    if 'in_nc' not in opt['netD']:
         
     | 
| 160 | 
         
            +
                        opt['netD']['in_nc'] = 3
         
     | 
| 161 | 
         
            +
                    if 'base_nc' not in opt['netD']:
         
     | 
| 162 | 
         
            +
                        opt['netD']['base_nc'] = 64
         
     | 
| 163 | 
         
            +
                    if 'n_layers' not in opt['netD']:
         
     | 
| 164 | 
         
            +
                        opt['netD']['n_layers'] = 3
         
     | 
| 165 | 
         
            +
                    if 'norm_type' not in opt['netD']:
         
     | 
| 166 | 
         
            +
                        opt['netD']['norm_type'] = 'spectral'
         
     | 
| 167 | 
         
            +
             
     | 
| 168 | 
         
            +
             
     | 
| 169 | 
         
            +
                return opt
         
     | 
| 170 | 
         
            +
             
     | 
| 171 | 
         
            +
             
     | 
| 172 | 
         
            +
            def find_last_checkpoint(save_dir, net_type='G', pretrained_path=None):
         
     | 
| 173 | 
         
            +
                """
         
     | 
| 174 | 
         
            +
                Args: 
         
     | 
| 175 | 
         
            +
                    save_dir: model folder
         
     | 
| 176 | 
         
            +
                    net_type: 'G' or 'D' or 'optimizerG' or 'optimizerD'
         
     | 
| 177 | 
         
            +
                    pretrained_path: pretrained model path. If save_dir does not have any model, load from pretrained_path
         
     | 
| 178 | 
         
            +
             
     | 
| 179 | 
         
            +
                Return:
         
     | 
| 180 | 
         
            +
                    init_iter: iteration number
         
     | 
| 181 | 
         
            +
                    init_path: model path
         
     | 
| 182 | 
         
            +
                """
         
     | 
| 183 | 
         
            +
                file_list = glob.glob(os.path.join(save_dir, '*_{}.pth'.format(net_type)))
         
     | 
| 184 | 
         
            +
                if file_list:
         
     | 
| 185 | 
         
            +
                    iter_exist = []
         
     | 
| 186 | 
         
            +
                    for file_ in file_list:
         
     | 
| 187 | 
         
            +
                        iter_current = re.findall(r"(\d+)_{}.pth".format(net_type), file_)
         
     | 
| 188 | 
         
            +
                        iter_exist.append(int(iter_current[0]))
         
     | 
| 189 | 
         
            +
                    init_iter = max(iter_exist)
         
     | 
| 190 | 
         
            +
                    init_path = os.path.join(save_dir, '{}_{}.pth'.format(init_iter, net_type))
         
     | 
| 191 | 
         
            +
                else:
         
     | 
| 192 | 
         
            +
                    init_iter = 0
         
     | 
| 193 | 
         
            +
                    init_path = pretrained_path
         
     | 
| 194 | 
         
            +
                return init_iter, init_path
         
     | 
| 195 | 
         
            +
             
     | 
| 196 | 
         
            +
             
     | 
| 197 | 
         
            +
            '''
         
     | 
| 198 | 
         
            +
            # --------------------------------------------
         
     | 
| 199 | 
         
            +
            # convert the opt into json file
         
     | 
| 200 | 
         
            +
            # --------------------------------------------
         
     | 
| 201 | 
         
            +
            '''
         
     | 
| 202 | 
         
            +
             
     | 
| 203 | 
         
            +
             
     | 
| 204 | 
         
            +
            def save(opt):
         
     | 
| 205 | 
         
            +
                opt_path = opt['opt_path']
         
     | 
| 206 | 
         
            +
                opt_path_copy = opt['path']['options']
         
     | 
| 207 | 
         
            +
                dirname, filename_ext = os.path.split(opt_path)
         
     | 
| 208 | 
         
            +
                filename, ext = os.path.splitext(filename_ext)
         
     | 
| 209 | 
         
            +
                dump_path = os.path.join(opt_path_copy, filename+get_timestamp()+ext)
         
     | 
| 210 | 
         
            +
                with open(dump_path, 'w') as dump_file:
         
     | 
| 211 | 
         
            +
                    json.dump(opt, dump_file, indent=2)
         
     | 
| 212 | 
         
            +
             
     | 
| 213 | 
         
            +
             
     | 
| 214 | 
         
            +
            '''
         
     | 
| 215 | 
         
            +
            # --------------------------------------------
         
     | 
| 216 | 
         
            +
            # dict to string for logger
         
     | 
| 217 | 
         
            +
            # --------------------------------------------
         
     | 
| 218 | 
         
            +
            '''
         
     | 
| 219 | 
         
            +
             
     | 
| 220 | 
         
            +
             
     | 
| 221 | 
         
            +
            def dict2str(opt, indent_l=1):
         
     | 
| 222 | 
         
            +
                msg = ''
         
     | 
| 223 | 
         
            +
                for k, v in opt.items():
         
     | 
| 224 | 
         
            +
                    if isinstance(v, dict):
         
     | 
| 225 | 
         
            +
                        msg += ' ' * (indent_l * 2) + k + ':[\n'
         
     | 
| 226 | 
         
            +
                        msg += dict2str(v, indent_l + 1)
         
     | 
| 227 | 
         
            +
                        msg += ' ' * (indent_l * 2) + ']\n'
         
     | 
| 228 | 
         
            +
                    else:
         
     | 
| 229 | 
         
            +
                        msg += ' ' * (indent_l * 2) + k + ': ' + str(v) + '\n'
         
     | 
| 230 | 
         
            +
                return msg
         
     | 
| 231 | 
         
            +
             
     | 
| 232 | 
         
            +
             
     | 
| 233 | 
         
            +
            '''
         
     | 
| 234 | 
         
            +
            # --------------------------------------------
         
     | 
| 235 | 
         
            +
            # convert OrderedDict to NoneDict,
         
     | 
| 236 | 
         
            +
            # return None for missing key
         
     | 
| 237 | 
         
            +
            # --------------------------------------------
         
     | 
| 238 | 
         
            +
            '''
         
     | 
| 239 | 
         
            +
             
     | 
| 240 | 
         
            +
             
     | 
| 241 | 
         
            +
            def dict_to_nonedict(opt):
         
     | 
| 242 | 
         
            +
                if isinstance(opt, dict):
         
     | 
| 243 | 
         
            +
                    new_opt = dict()
         
     | 
| 244 | 
         
            +
                    for key, sub_opt in opt.items():
         
     | 
| 245 | 
         
            +
                        new_opt[key] = dict_to_nonedict(sub_opt)
         
     | 
| 246 | 
         
            +
                    return NoneDict(**new_opt)
         
     | 
| 247 | 
         
            +
                elif isinstance(opt, list):
         
     | 
| 248 | 
         
            +
                    return [dict_to_nonedict(sub_opt) for sub_opt in opt]
         
     | 
| 249 | 
         
            +
                else:
         
     | 
| 250 | 
         
            +
                    return opt
         
     | 
| 251 | 
         
            +
             
     | 
| 252 | 
         
            +
             
     | 
| 253 | 
         
            +
            class NoneDict(dict):
         
     | 
| 254 | 
         
            +
                def __missing__(self, key):
         
     | 
| 255 | 
         
            +
                    return None
         
     | 
    	
        core/data/deg_kair_utils/utils_params.py
    ADDED
    
    | 
         @@ -0,0 +1,135 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            import torch
         
     | 
| 2 | 
         
            +
             
     | 
| 3 | 
         
            +
            import torchvision
         
     | 
| 4 | 
         
            +
             
     | 
| 5 | 
         
            +
            from models import basicblock as B
         
     | 
| 6 | 
         
            +
             
     | 
| 7 | 
         
            +
            def show_kv(net):
         
     | 
| 8 | 
         
            +
                for k, v in net.items():
         
     | 
| 9 | 
         
            +
                    print(k)
         
     | 
| 10 | 
         
            +
             
     | 
| 11 | 
         
            +
            # should run train debug mode first to get an initial model
         
     | 
| 12 | 
         
            +
            #crt_net = torch.load('../../experiments/debug_SRResNet_bicx4_in3nf64nb16/models/8_G.pth')
         
     | 
| 13 | 
         
            +
            #
         
     | 
| 14 | 
         
            +
            #for k, v in crt_net.items():
         
     | 
| 15 | 
         
            +
            #    print(k)
         
     | 
| 16 | 
         
            +
            #for k, v in crt_net.items():
         
     | 
| 17 | 
         
            +
            #    if k in pretrained_net:
         
     | 
| 18 | 
         
            +
            #        crt_net[k] = pretrained_net[k]
         
     | 
| 19 | 
         
            +
            #        print('replace ... ', k)
         
     | 
| 20 | 
         
            +
             
     | 
| 21 | 
         
            +
            # x2 -> x4
         
     | 
| 22 | 
         
            +
            #crt_net['model.5.weight'] = pretrained_net['model.2.weight']
         
     | 
| 23 | 
         
            +
            #crt_net['model.5.bias'] = pretrained_net['model.2.bias']
         
     | 
| 24 | 
         
            +
            #crt_net['model.8.weight'] = pretrained_net['model.5.weight']
         
     | 
| 25 | 
         
            +
            #crt_net['model.8.bias'] = pretrained_net['model.5.bias']
         
     | 
| 26 | 
         
            +
            #crt_net['model.10.weight'] = pretrained_net['model.7.weight']
         
     | 
| 27 | 
         
            +
            #crt_net['model.10.bias'] = pretrained_net['model.7.bias']
         
     | 
| 28 | 
         
            +
            #torch.save(crt_net, '../pretrained_tmp.pth')
         
     | 
| 29 | 
         
            +
             
     | 
| 30 | 
         
            +
            # x2 -> x3
         
     | 
| 31 | 
         
            +
            '''
         
     | 
| 32 | 
         
            +
            in_filter = pretrained_net['model.2.weight'] # 256, 64, 3, 3
         
     | 
| 33 | 
         
            +
            new_filter = torch.Tensor(576, 64, 3, 3)
         
     | 
| 34 | 
         
            +
            new_filter[0:256, :, :, :] = in_filter
         
     | 
| 35 | 
         
            +
            new_filter[256:512, :, :, :] = in_filter
         
     | 
| 36 | 
         
            +
            new_filter[512:, :, :, :] = in_filter[0:576-512, :, :, :]
         
     | 
| 37 | 
         
            +
            crt_net['model.2.weight'] = new_filter
         
     | 
| 38 | 
         
            +
             
     | 
| 39 | 
         
            +
            in_bias = pretrained_net['model.2.bias']  # 256, 64, 3, 3
         
     | 
| 40 | 
         
            +
            new_bias = torch.Tensor(576)
         
     | 
| 41 | 
         
            +
            new_bias[0:256] = in_bias
         
     | 
| 42 | 
         
            +
            new_bias[256:512] = in_bias
         
     | 
| 43 | 
         
            +
            new_bias[512:] = in_bias[0:576 - 512]
         
     | 
| 44 | 
         
            +
            crt_net['model.2.bias'] = new_bias
         
     | 
| 45 | 
         
            +
             
     | 
| 46 | 
         
            +
            torch.save(crt_net, '../pretrained_tmp.pth')
         
     | 
| 47 | 
         
            +
            '''
         
     | 
| 48 | 
         
            +
             
     | 
| 49 | 
         
            +
            # x2 -> x8
         
     | 
| 50 | 
         
            +
            '''
         
     | 
| 51 | 
         
            +
            crt_net['model.5.weight'] = pretrained_net['model.2.weight']
         
     | 
| 52 | 
         
            +
            crt_net['model.5.bias'] = pretrained_net['model.2.bias']
         
     | 
| 53 | 
         
            +
            crt_net['model.8.weight'] = pretrained_net['model.2.weight']
         
     | 
| 54 | 
         
            +
            crt_net['model.8.bias'] = pretrained_net['model.2.bias']
         
     | 
| 55 | 
         
            +
            crt_net['model.11.weight'] = pretrained_net['model.5.weight']
         
     | 
| 56 | 
         
            +
            crt_net['model.11.bias'] = pretrained_net['model.5.bias']
         
     | 
| 57 | 
         
            +
            crt_net['model.13.weight'] = pretrained_net['model.7.weight']
         
     | 
| 58 | 
         
            +
            crt_net['model.13.bias'] = pretrained_net['model.7.bias']
         
     | 
| 59 | 
         
            +
            torch.save(crt_net, '../pretrained_tmp.pth')
         
     | 
| 60 | 
         
            +
            '''
         
     | 
| 61 | 
         
            +
             
     | 
| 62 | 
         
            +
            # x3/4/8 RGB -> Y
         
     | 
| 63 | 
         
            +
             
     | 
| 64 | 
         
            +
            def rgb2gray_net(net, only_input=True):
         
     | 
| 65 | 
         
            +
             
     | 
| 66 | 
         
            +
                if only_input:
         
     | 
| 67 | 
         
            +
                    in_filter = net['0.weight']
         
     | 
| 68 | 
         
            +
                    in_new_filter = in_filter[:,0,:,:]*0.2989 + in_filter[:,1,:,:]*0.587 + in_filter[:,2,:,:]*0.114
         
     | 
| 69 | 
         
            +
                    in_new_filter.unsqueeze_(1)
         
     | 
| 70 | 
         
            +
                    net['0.weight'] = in_new_filter
         
     | 
| 71 | 
         
            +
             
     | 
| 72 | 
         
            +
            #    out_filter = pretrained_net['model.13.weight']
         
     | 
| 73 | 
         
            +
            #    out_new_filter = out_filter[0, :, :, :] * 0.2989 + out_filter[1, :, :, :] * 0.587 + \
         
     | 
| 74 | 
         
            +
            #        out_filter[2, :, :, :] * 0.114
         
     | 
| 75 | 
         
            +
            #    out_new_filter.unsqueeze_(0)
         
     | 
| 76 | 
         
            +
            #    crt_net['model.13.weight'] = out_new_filter
         
     | 
| 77 | 
         
            +
            #    out_bias = pretrained_net['model.13.bias']
         
     | 
| 78 | 
         
            +
            #    out_new_bias = out_bias[0] * 0.2989 + out_bias[1] * 0.587 + out_bias[2] * 0.114
         
     | 
| 79 | 
         
            +
            #    out_new_bias = torch.Tensor(1).fill_(out_new_bias)
         
     | 
| 80 | 
         
            +
            #    crt_net['model.13.bias'] = out_new_bias
         
     | 
| 81 | 
         
            +
             
     | 
| 82 | 
         
            +
            #    torch.save(crt_net, '../pretrained_tmp.pth')
         
     | 
| 83 | 
         
            +
             
     | 
| 84 | 
         
            +
                return net
         
     | 
| 85 | 
         
            +
             
     | 
| 86 | 
         
            +
             
     | 
| 87 | 
         
            +
             
     | 
| 88 | 
         
            +
            if __name__ == '__main__':
         
     | 
| 89 | 
         
            +
                
         
     | 
| 90 | 
         
            +
                net = torchvision.models.vgg19(pretrained=True)
         
     | 
| 91 | 
         
            +
                for k,v in net.features.named_parameters():
         
     | 
| 92 | 
         
            +
                    if k=='0.weight':
         
     | 
| 93 | 
         
            +
                        in_new_filter = v[:,0,:,:]*0.2989 + v[:,1,:,:]*0.587 + v[:,2,:,:]*0.114
         
     | 
| 94 | 
         
            +
                        in_new_filter.unsqueeze_(1)
         
     | 
| 95 | 
         
            +
                        v = in_new_filter
         
     | 
| 96 | 
         
            +
                        print(v.shape)
         
     | 
| 97 | 
         
            +
                        print(v[0,0,0,0])
         
     | 
| 98 | 
         
            +
                    if k=='0.bias':
         
     | 
| 99 | 
         
            +
                        in_new_bias = v
         
     | 
| 100 | 
         
            +
                        print(v[0])
         
     | 
| 101 | 
         
            +
             
     | 
| 102 | 
         
            +
                print(net.features[0])
         
     | 
| 103 | 
         
            +
             
     | 
| 104 | 
         
            +
                net.features[0] = B.conv(1, 64, mode='C') 
         
     | 
| 105 | 
         
            +
             
     | 
| 106 | 
         
            +
                print(net.features[0])
         
     | 
| 107 | 
         
            +
                net.features[0].weight.data=in_new_filter
         
     | 
| 108 | 
         
            +
                net.features[0].bias.data=in_new_bias
         
     | 
| 109 | 
         
            +
             
     | 
| 110 | 
         
            +
                for k,v in net.features.named_parameters():
         
     | 
| 111 | 
         
            +
                    if k=='0.weight':
         
     | 
| 112 | 
         
            +
                        print(v[0,0,0,0])
         
     | 
| 113 | 
         
            +
                    if k=='0.bias':
         
     | 
| 114 | 
         
            +
                        print(v[0])
         
     | 
| 115 | 
         
            +
             
     | 
| 116 | 
         
            +
                # transfer parameters of old model to new one
         
     | 
| 117 | 
         
            +
                model_old = torch.load(model_path)
         
     | 
| 118 | 
         
            +
                state_dict = model.state_dict()
         
     | 
| 119 | 
         
            +
                for ((key, param),(key2, param2)) in zip(model_old.items(), state_dict.items()):
         
     | 
| 120 | 
         
            +
                    state_dict[key2] = param
         
     | 
| 121 | 
         
            +
                    print([key, key2])
         
     | 
| 122 | 
         
            +
                   # print([param.size(), param2.size()])
         
     | 
| 123 | 
         
            +
                torch.save(state_dict, 'model_new.pth') 
         
     | 
| 124 | 
         
            +
             
     | 
| 125 | 
         
            +
             
     | 
| 126 | 
         
            +
               # rgb2gray_net(net)
         
     | 
| 127 | 
         
            +
             
     | 
| 128 | 
         
            +
             
     | 
| 129 | 
         
            +
             
     | 
| 130 | 
         
            +
             
     | 
| 131 | 
         
            +
             
     | 
| 132 | 
         
            +
             
     | 
| 133 | 
         
            +
             
     | 
| 134 | 
         
            +
             
     | 
| 135 | 
         
            +
             
     | 
    	
        core/data/deg_kair_utils/utils_receptivefield.py
    ADDED
    
    | 
         @@ -0,0 +1,62 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            # -*- coding: utf-8 -*-
         
     | 
| 2 | 
         
            +
             
     | 
| 3 | 
         
            +
            # online calculation: https://fomoro.com/research/article/receptive-field-calculator#
         
     | 
| 4 | 
         
            +
             
     | 
| 5 | 
         
            +
            # [filter size, stride, padding]
         
     | 
| 6 | 
         
            +
            #Assume the two dimensions are the same
         
     | 
| 7 | 
         
            +
            #Each kernel requires the following parameters:
         
     | 
| 8 | 
         
            +
            # - k_i: kernel size
         
     | 
| 9 | 
         
            +
            # - s_i: stride
         
     | 
| 10 | 
         
            +
            # - p_i: padding (if padding is uneven, right padding will higher than left padding; "SAME" option in tensorflow)
         
     | 
| 11 | 
         
            +
            # 
         
     | 
| 12 | 
         
            +
            #Each layer i requires the following parameters to be fully represented: 
         
     | 
| 13 | 
         
            +
            # - n_i: number of feature (data layer has n_1 = imagesize )
         
     | 
| 14 | 
         
            +
            # - j_i: distance (projected to image pixel distance) between center of two adjacent features
         
     | 
| 15 | 
         
            +
            # - r_i: receptive field of a feature in layer i
         
     | 
| 16 | 
         
            +
            # - start_i: position of the first feature's receptive field in layer i (idx start from 0, negative means the center fall into padding)
         
     | 
| 17 | 
         
            +
             
     | 
| 18 | 
         
            +
            import math
         
     | 
| 19 | 
         
            +
             
     | 
| 20 | 
         
            +
            def outFromIn(conv, layerIn):
         
     | 
| 21 | 
         
            +
                n_in = layerIn[0]
         
     | 
| 22 | 
         
            +
                j_in = layerIn[1]
         
     | 
| 23 | 
         
            +
                r_in = layerIn[2]
         
     | 
| 24 | 
         
            +
                start_in = layerIn[3]
         
     | 
| 25 | 
         
            +
                k = conv[0]
         
     | 
| 26 | 
         
            +
                s = conv[1]
         
     | 
| 27 | 
         
            +
                p = conv[2]
         
     | 
| 28 | 
         
            +
              
         
     | 
| 29 | 
         
            +
                n_out = math.floor((n_in - k + 2*p)/s) + 1
         
     | 
| 30 | 
         
            +
                actualP = (n_out-1)*s - n_in + k 
         
     | 
| 31 | 
         
            +
                pR = math.ceil(actualP/2)
         
     | 
| 32 | 
         
            +
                pL = math.floor(actualP/2)
         
     | 
| 33 | 
         
            +
              
         
     | 
| 34 | 
         
            +
                j_out = j_in * s
         
     | 
| 35 | 
         
            +
                r_out = r_in + (k - 1)*j_in
         
     | 
| 36 | 
         
            +
                start_out = start_in + ((k-1)/2 - pL)*j_in
         
     | 
| 37 | 
         
            +
                return n_out, j_out, r_out, start_out
         
     | 
| 38 | 
         
            +
             
     | 
| 39 | 
         
            +
            def printLayer(layer, layer_name):
         
     | 
| 40 | 
         
            +
                print(layer_name + ":")
         
     | 
| 41 | 
         
            +
                print(" n features: %s  jump: %s  receptive size: %s  start: %s " % (layer[0], layer[1], layer[2], layer[3]))
         
     | 
| 42 | 
         
            +
             
     | 
| 43 | 
         
            +
             
     | 
| 44 | 
         
            +
             
     | 
| 45 | 
         
            +
            layerInfos = []
         
     | 
| 46 | 
         
            +
            if __name__ == '__main__':
         
     | 
| 47 | 
         
            +
             
     | 
| 48 | 
         
            +
                convnet =   [[3,1,1],[3,1,1],[3,1,1],[4,2,1],[2,2,0],[3,1,1]]
         
     | 
| 49 | 
         
            +
                layer_names = ['conv1','conv2','conv3','conv4','conv5','conv6','conv7','conv8','conv9','conv10','conv11','conv12']
         
     | 
| 50 | 
         
            +
                imsize = 128
         
     | 
| 51 | 
         
            +
             
     | 
| 52 | 
         
            +
                print ("-------Net summary------")
         
     | 
| 53 | 
         
            +
                currentLayer = [imsize, 1, 1, 0.5]
         
     | 
| 54 | 
         
            +
                printLayer(currentLayer, "input image")
         
     | 
| 55 | 
         
            +
                for i in range(len(convnet)):
         
     | 
| 56 | 
         
            +
                    currentLayer = outFromIn(convnet[i], currentLayer)
         
     | 
| 57 | 
         
            +
                    layerInfos.append(currentLayer)
         
     | 
| 58 | 
         
            +
                    printLayer(currentLayer, layer_names[i])
         
     | 
| 59 | 
         
            +
             
     | 
| 60 | 
         
            +
             
     | 
| 61 | 
         
            +
            # run utils/utils_receptivefield.py
         
     | 
| 62 | 
         
            +
                
         
     | 
    	
        core/data/deg_kair_utils/utils_regularizers.py
    ADDED
    
    | 
         @@ -0,0 +1,104 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            import torch
         
     | 
| 2 | 
         
            +
            import torch.nn as nn
         
     | 
| 3 | 
         
            +
             
     | 
| 4 | 
         
            +
             
     | 
| 5 | 
         
            +
            '''
         
     | 
| 6 | 
         
            +
            # --------------------------------------------
         
     | 
| 7 | 
         
            +
            # Kai Zhang (github: https://github.com/cszn)
         
     | 
| 8 | 
         
            +
            # 03/Mar/2019
         
     | 
| 9 | 
         
            +
            # --------------------------------------------
         
     | 
| 10 | 
         
            +
            '''
         
     | 
| 11 | 
         
            +
             
     | 
| 12 | 
         
            +
             
     | 
| 13 | 
         
            +
            # --------------------------------------------
         
     | 
| 14 | 
         
            +
            # SVD Orthogonal Regularization
         
     | 
| 15 | 
         
            +
            # --------------------------------------------
         
     | 
| 16 | 
         
            +
            def regularizer_orth(m):
         
     | 
| 17 | 
         
            +
                """
         
     | 
| 18 | 
         
            +
                # ----------------------------------------
         
     | 
| 19 | 
         
            +
                # SVD Orthogonal Regularization
         
     | 
| 20 | 
         
            +
                # ----------------------------------------
         
     | 
| 21 | 
         
            +
                # Applies regularization to the training by performing the
         
     | 
| 22 | 
         
            +
                # orthogonalization technique described in the paper
         
     | 
| 23 | 
         
            +
                # This function is to be called by the torch.nn.Module.apply() method,
         
     | 
| 24 | 
         
            +
                # which applies svd_orthogonalization() to every layer of the model.
         
     | 
| 25 | 
         
            +
                # usage: net.apply(regularizer_orth)
         
     | 
| 26 | 
         
            +
                # ----------------------------------------
         
     | 
| 27 | 
         
            +
                """
         
     | 
| 28 | 
         
            +
                classname = m.__class__.__name__
         
     | 
| 29 | 
         
            +
                if classname.find('Conv') != -1:
         
     | 
| 30 | 
         
            +
                    w = m.weight.data.clone()
         
     | 
| 31 | 
         
            +
                    c_out, c_in, f1, f2 = w.size()
         
     | 
| 32 | 
         
            +
                    # dtype = m.weight.data.type()
         
     | 
| 33 | 
         
            +
                    w = w.permute(2, 3, 1, 0).contiguous().view(f1*f2*c_in, c_out)
         
     | 
| 34 | 
         
            +
                    # self.netG.apply(svd_orthogonalization)
         
     | 
| 35 | 
         
            +
                    u, s, v = torch.svd(w)
         
     | 
| 36 | 
         
            +
                    s[s > 1.5] = s[s > 1.5] - 1e-4
         
     | 
| 37 | 
         
            +
                    s[s < 0.5] = s[s < 0.5] + 1e-4
         
     | 
| 38 | 
         
            +
                    w = torch.mm(torch.mm(u, torch.diag(s)), v.t())
         
     | 
| 39 | 
         
            +
                    m.weight.data = w.view(f1, f2, c_in, c_out).permute(3, 2, 0, 1)  # .type(dtype)
         
     | 
| 40 | 
         
            +
                else:
         
     | 
| 41 | 
         
            +
                    pass
         
     | 
| 42 | 
         
            +
             
     | 
| 43 | 
         
            +
             
     | 
| 44 | 
         
            +
            # --------------------------------------------
         
     | 
| 45 | 
         
            +
            # SVD Orthogonal Regularization
         
     | 
| 46 | 
         
            +
            # --------------------------------------------
         
     | 
| 47 | 
         
            +
            def regularizer_orth2(m):
         
     | 
| 48 | 
         
            +
                """
         
     | 
| 49 | 
         
            +
                # ----------------------------------------
         
     | 
| 50 | 
         
            +
                # Applies regularization to the training by performing the
         
     | 
| 51 | 
         
            +
                # orthogonalization technique described in the paper
         
     | 
| 52 | 
         
            +
                # This function is to be called by the torch.nn.Module.apply() method,
         
     | 
| 53 | 
         
            +
                # which applies svd_orthogonalization() to every layer of the model.
         
     | 
| 54 | 
         
            +
                # usage: net.apply(regularizer_orth2)
         
     | 
| 55 | 
         
            +
                # ----------------------------------------
         
     | 
| 56 | 
         
            +
                """
         
     | 
| 57 | 
         
            +
                classname = m.__class__.__name__
         
     | 
| 58 | 
         
            +
                if classname.find('Conv') != -1:
         
     | 
| 59 | 
         
            +
                    w = m.weight.data.clone()
         
     | 
| 60 | 
         
            +
                    c_out, c_in, f1, f2 = w.size()
         
     | 
| 61 | 
         
            +
                    # dtype = m.weight.data.type()
         
     | 
| 62 | 
         
            +
                    w = w.permute(2, 3, 1, 0).contiguous().view(f1*f2*c_in, c_out)
         
     | 
| 63 | 
         
            +
                    u, s, v = torch.svd(w)
         
     | 
| 64 | 
         
            +
                    s_mean = s.mean()
         
     | 
| 65 | 
         
            +
                    s[s > 1.5*s_mean] = s[s > 1.5*s_mean] - 1e-4
         
     | 
| 66 | 
         
            +
                    s[s < 0.5*s_mean] = s[s < 0.5*s_mean] + 1e-4
         
     | 
| 67 | 
         
            +
                    w = torch.mm(torch.mm(u, torch.diag(s)), v.t())
         
     | 
| 68 | 
         
            +
                    m.weight.data = w.view(f1, f2, c_in, c_out).permute(3, 2, 0, 1)  # .type(dtype)
         
     | 
| 69 | 
         
            +
                else:
         
     | 
| 70 | 
         
            +
                    pass
         
     | 
| 71 | 
         
            +
             
     | 
| 72 | 
         
            +
             
     | 
| 73 | 
         
            +
             
     | 
| 74 | 
         
            +
            def regularizer_clip(m):
         
     | 
| 75 | 
         
            +
                """
         
     | 
| 76 | 
         
            +
                # ----------------------------------------
         
     | 
| 77 | 
         
            +
                # usage: net.apply(regularizer_clip)
         
     | 
| 78 | 
         
            +
                # ----------------------------------------
         
     | 
| 79 | 
         
            +
                """
         
     | 
| 80 | 
         
            +
                eps = 1e-4
         
     | 
| 81 | 
         
            +
                c_min = -1.5
         
     | 
| 82 | 
         
            +
                c_max = 1.5
         
     | 
| 83 | 
         
            +
             
     | 
| 84 | 
         
            +
                classname = m.__class__.__name__
         
     | 
| 85 | 
         
            +
                if classname.find('Conv') != -1 or classname.find('Linear') != -1:
         
     | 
| 86 | 
         
            +
                    w = m.weight.data.clone()
         
     | 
| 87 | 
         
            +
                    w[w > c_max] -= eps
         
     | 
| 88 | 
         
            +
                    w[w < c_min] += eps
         
     | 
| 89 | 
         
            +
                    m.weight.data = w
         
     | 
| 90 | 
         
            +
             
     | 
| 91 | 
         
            +
                    if m.bias is not None:
         
     | 
| 92 | 
         
            +
                        b = m.bias.data.clone()
         
     | 
| 93 | 
         
            +
                        b[b > c_max] -= eps
         
     | 
| 94 | 
         
            +
                        b[b < c_min] += eps
         
     | 
| 95 | 
         
            +
                        m.bias.data = b
         
     | 
| 96 | 
         
            +
             
     | 
| 97 | 
         
            +
            #    elif classname.find('BatchNorm2d') != -1:
         
     | 
| 98 | 
         
            +
            #
         
     | 
| 99 | 
         
            +
            #       rv = m.running_var.data.clone()
         
     | 
| 100 | 
         
            +
            #       rm = m.running_mean.data.clone()
         
     | 
| 101 | 
         
            +
            #
         
     | 
| 102 | 
         
            +
            #        if m.affine:
         
     | 
| 103 | 
         
            +
            #            m.weight.data
         
     | 
| 104 | 
         
            +
            #            m.bias.data
         
     | 
    	
        core/data/deg_kair_utils/utils_sisr.py
    ADDED
    
    | 
         @@ -0,0 +1,848 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            # -*- coding: utf-8 -*-
         
     | 
| 2 | 
         
            +
            from utils import utils_image as util
         
     | 
| 3 | 
         
            +
            import random
         
     | 
| 4 | 
         
            +
             
     | 
| 5 | 
         
            +
            import scipy
         
     | 
| 6 | 
         
            +
            import scipy.stats as ss
         
     | 
| 7 | 
         
            +
            import scipy.io as io
         
     | 
| 8 | 
         
            +
            from scipy import ndimage
         
     | 
| 9 | 
         
            +
            from scipy.interpolate import interp2d
         
     | 
| 10 | 
         
            +
             
     | 
| 11 | 
         
            +
            import numpy as np
         
     | 
| 12 | 
         
            +
            import torch
         
     | 
| 13 | 
         
            +
             
     | 
| 14 | 
         
            +
             
     | 
| 15 | 
         
            +
            """
         
     | 
| 16 | 
         
            +
            # --------------------------------------------
         
     | 
| 17 | 
         
            +
            # Super-Resolution
         
     | 
| 18 | 
         
            +
            # --------------------------------------------
         
     | 
| 19 | 
         
            +
            #
         
     | 
| 20 | 
         
            +
            # Kai Zhang ([email protected])
         
     | 
| 21 | 
         
            +
            # https://github.com/cszn
         
     | 
| 22 | 
         
            +
            # modified by Kai Zhang (github: https://github.com/cszn)
         
     | 
| 23 | 
         
            +
            # 03/03/2020
         
     | 
| 24 | 
         
            +
            # --------------------------------------------
         
     | 
| 25 | 
         
            +
            """
         
     | 
| 26 | 
         
            +
             
     | 
| 27 | 
         
            +
             
     | 
| 28 | 
         
            +
            """
         
     | 
| 29 | 
         
            +
            # --------------------------------------------
         
     | 
| 30 | 
         
            +
            # anisotropic Gaussian kernels
         
     | 
| 31 | 
         
            +
            # --------------------------------------------
         
     | 
| 32 | 
         
            +
            """
         
     | 
| 33 | 
         
            +
             
     | 
| 34 | 
         
            +
             
     | 
| 35 | 
         
            +
            def anisotropic_Gaussian(ksize=15, theta=np.pi, l1=6, l2=6):
         
     | 
| 36 | 
         
            +
                """ generate an anisotropic Gaussian kernel
         
     | 
| 37 | 
         
            +
                Args:
         
     | 
| 38 | 
         
            +
                    ksize : e.g., 15, kernel size
         
     | 
| 39 | 
         
            +
                    theta : [0,  pi], rotation angle range
         
     | 
| 40 | 
         
            +
                    l1    : [0.1,50], scaling of eigenvalues
         
     | 
| 41 | 
         
            +
                    l2    : [0.1,l1], scaling of eigenvalues
         
     | 
| 42 | 
         
            +
                    If l1 = l2, will get an isotropic Gaussian kernel.
         
     | 
| 43 | 
         
            +
                Returns:
         
     | 
| 44 | 
         
            +
                    k     : kernel
         
     | 
| 45 | 
         
            +
                """
         
     | 
| 46 | 
         
            +
             
     | 
| 47 | 
         
            +
                v = np.dot(np.array([[np.cos(theta), -np.sin(theta)], [np.sin(theta), np.cos(theta)]]), np.array([1., 0.]))
         
     | 
| 48 | 
         
            +
                V = np.array([[v[0], v[1]], [v[1], -v[0]]])
         
     | 
| 49 | 
         
            +
                D = np.array([[l1, 0], [0, l2]])
         
     | 
| 50 | 
         
            +
                Sigma = np.dot(np.dot(V, D), np.linalg.inv(V))
         
     | 
| 51 | 
         
            +
                k = gm_blur_kernel(mean=[0, 0], cov=Sigma, size=ksize)
         
     | 
| 52 | 
         
            +
             
     | 
| 53 | 
         
            +
                return k
         
     | 
| 54 | 
         
            +
             
     | 
| 55 | 
         
            +
             
     | 
| 56 | 
         
            +
            def gm_blur_kernel(mean, cov, size=15):
         
     | 
| 57 | 
         
            +
                center = size / 2.0 + 0.5
         
     | 
| 58 | 
         
            +
                k = np.zeros([size, size])
         
     | 
| 59 | 
         
            +
                for y in range(size):
         
     | 
| 60 | 
         
            +
                    for x in range(size):
         
     | 
| 61 | 
         
            +
                        cy = y - center + 1
         
     | 
| 62 | 
         
            +
                        cx = x - center + 1
         
     | 
| 63 | 
         
            +
                        k[y, x] = ss.multivariate_normal.pdf([cx, cy], mean=mean, cov=cov)
         
     | 
| 64 | 
         
            +
             
     | 
| 65 | 
         
            +
                k = k / np.sum(k)
         
     | 
| 66 | 
         
            +
                return k
         
     | 
| 67 | 
         
            +
             
     | 
| 68 | 
         
            +
             
     | 
| 69 | 
         
            +
            """
         
     | 
| 70 | 
         
            +
            # --------------------------------------------
         
     | 
| 71 | 
         
            +
            # calculate PCA projection matrix
         
     | 
| 72 | 
         
            +
            # --------------------------------------------
         
     | 
| 73 | 
         
            +
            """
         
     | 
| 74 | 
         
            +
             
     | 
| 75 | 
         
            +
             
     | 
| 76 | 
         
            +
            def get_pca_matrix(x, dim_pca=15):
         
     | 
| 77 | 
         
            +
                """
         
     | 
| 78 | 
         
            +
                Args:
         
     | 
| 79 | 
         
            +
                    x: 225x10000 matrix
         
     | 
| 80 | 
         
            +
                    dim_pca: 15
         
     | 
| 81 | 
         
            +
                Returns:
         
     | 
| 82 | 
         
            +
                    pca_matrix: 15x225
         
     | 
| 83 | 
         
            +
                """
         
     | 
| 84 | 
         
            +
                C = np.dot(x, x.T)
         
     | 
| 85 | 
         
            +
                w, v = scipy.linalg.eigh(C)
         
     | 
| 86 | 
         
            +
                pca_matrix = v[:, -dim_pca:].T
         
     | 
| 87 | 
         
            +
             
     | 
| 88 | 
         
            +
                return pca_matrix
         
     | 
| 89 | 
         
            +
             
     | 
| 90 | 
         
            +
             
     | 
| 91 | 
         
            +
            def show_pca(x):
         
     | 
| 92 | 
         
            +
                """
         
     | 
| 93 | 
         
            +
                x: PCA projection matrix, e.g., 15x225
         
     | 
| 94 | 
         
            +
                """
         
     | 
| 95 | 
         
            +
                for i in range(x.shape[0]):
         
     | 
| 96 | 
         
            +
                    xc = np.reshape(x[i, :], (int(np.sqrt(x.shape[1])), -1), order="F")
         
     | 
| 97 | 
         
            +
                    util.surf(xc)
         
     | 
| 98 | 
         
            +
             
     | 
| 99 | 
         
            +
             
     | 
| 100 | 
         
            +
            def cal_pca_matrix(path='PCA_matrix.mat', ksize=15, l_max=12.0, dim_pca=15, num_samples=500):
         
     | 
| 101 | 
         
            +
                kernels = np.zeros([ksize*ksize, num_samples], dtype=np.float32)
         
     | 
| 102 | 
         
            +
                for i in range(num_samples):
         
     | 
| 103 | 
         
            +
             
     | 
| 104 | 
         
            +
                    theta = np.pi*np.random.rand(1)
         
     | 
| 105 | 
         
            +
                    l1    = 0.1+l_max*np.random.rand(1)
         
     | 
| 106 | 
         
            +
                    l2    = 0.1+(l1-0.1)*np.random.rand(1)
         
     | 
| 107 | 
         
            +
             
     | 
| 108 | 
         
            +
                    k = anisotropic_Gaussian(ksize=ksize, theta=theta[0], l1=l1[0], l2=l2[0])
         
     | 
| 109 | 
         
            +
             
     | 
| 110 | 
         
            +
                    # util.imshow(k)
         
     | 
| 111 | 
         
            +
             
     | 
| 112 | 
         
            +
                    kernels[:, i] = np.reshape(k, (-1), order="F")  # k.flatten(order='F')
         
     | 
| 113 | 
         
            +
             
     | 
| 114 | 
         
            +
                # io.savemat('k.mat', {'k': kernels})
         
     | 
| 115 | 
         
            +
             
     | 
| 116 | 
         
            +
                pca_matrix = get_pca_matrix(kernels, dim_pca=dim_pca)
         
     | 
| 117 | 
         
            +
             
     | 
| 118 | 
         
            +
                io.savemat(path, {'p': pca_matrix})
         
     | 
| 119 | 
         
            +
             
     | 
| 120 | 
         
            +
                return pca_matrix
         
     | 
| 121 | 
         
            +
             
     | 
| 122 | 
         
            +
             
     | 
| 123 | 
         
            +
            """
         
     | 
| 124 | 
         
            +
            # --------------------------------------------
         
     | 
| 125 | 
         
            +
            # shifted anisotropic Gaussian kernels
         
     | 
| 126 | 
         
            +
            # --------------------------------------------
         
     | 
| 127 | 
         
            +
            """
         
     | 
| 128 | 
         
            +
             
     | 
| 129 | 
         
            +
             
     | 
| 130 | 
         
            +
            def shifted_anisotropic_Gaussian(k_size=np.array([15, 15]), scale_factor=np.array([4, 4]), min_var=0.6, max_var=10., noise_level=0):
         
     | 
| 131 | 
         
            +
                """"
         
     | 
| 132 | 
         
            +
                # modified version of https://github.com/assafshocher/BlindSR_dataset_generator
         
     | 
| 133 | 
         
            +
                # Kai Zhang
         
     | 
| 134 | 
         
            +
                # min_var = 0.175 * sf  # variance of the gaussian kernel will be sampled between min_var and max_var
         
     | 
| 135 | 
         
            +
                # max_var = 2.5 * sf
         
     | 
| 136 | 
         
            +
                """
         
     | 
| 137 | 
         
            +
                # Set random eigen-vals (lambdas) and angle (theta) for COV matrix
         
     | 
| 138 | 
         
            +
                lambda_1 = min_var + np.random.rand() * (max_var - min_var)
         
     | 
| 139 | 
         
            +
                lambda_2 = min_var + np.random.rand() * (max_var - min_var)
         
     | 
| 140 | 
         
            +
                theta = np.random.rand() * np.pi  # random theta
         
     | 
| 141 | 
         
            +
                noise = -noise_level + np.random.rand(*k_size) * noise_level * 2
         
     | 
| 142 | 
         
            +
             
     | 
| 143 | 
         
            +
                # Set COV matrix using Lambdas and Theta
         
     | 
| 144 | 
         
            +
                LAMBDA = np.diag([lambda_1, lambda_2])
         
     | 
| 145 | 
         
            +
                Q = np.array([[np.cos(theta), -np.sin(theta)],
         
     | 
| 146 | 
         
            +
                              [np.sin(theta), np.cos(theta)]])
         
     | 
| 147 | 
         
            +
                SIGMA = Q @ LAMBDA @ Q.T
         
     | 
| 148 | 
         
            +
                INV_SIGMA = np.linalg.inv(SIGMA)[None, None, :, :]
         
     | 
| 149 | 
         
            +
             
     | 
| 150 | 
         
            +
                # Set expectation position (shifting kernel for aligned image)
         
     | 
| 151 | 
         
            +
                MU = k_size // 2 - 0.5*(scale_factor - 1) # - 0.5 * (scale_factor - k_size % 2)
         
     | 
| 152 | 
         
            +
                MU = MU[None, None, :, None]
         
     | 
| 153 | 
         
            +
             
     | 
| 154 | 
         
            +
                # Create meshgrid for Gaussian
         
     | 
| 155 | 
         
            +
                [X,Y] = np.meshgrid(range(k_size[0]), range(k_size[1]))
         
     | 
| 156 | 
         
            +
                Z = np.stack([X, Y], 2)[:, :, :, None]
         
     | 
| 157 | 
         
            +
             
     | 
| 158 | 
         
            +
                # Calcualte Gaussian for every pixel of the kernel
         
     | 
| 159 | 
         
            +
                ZZ = Z-MU
         
     | 
| 160 | 
         
            +
                ZZ_t = ZZ.transpose(0,1,3,2)
         
     | 
| 161 | 
         
            +
                raw_kernel = np.exp(-0.5 * np.squeeze(ZZ_t @ INV_SIGMA @ ZZ)) * (1 + noise)
         
     | 
| 162 | 
         
            +
             
     | 
| 163 | 
         
            +
                # shift the kernel so it will be centered
         
     | 
| 164 | 
         
            +
                #raw_kernel_centered = kernel_shift(raw_kernel, scale_factor)
         
     | 
| 165 | 
         
            +
             
     | 
| 166 | 
         
            +
                # Normalize the kernel and return
         
     | 
| 167 | 
         
            +
                #kernel = raw_kernel_centered / np.sum(raw_kernel_centered)
         
     | 
| 168 | 
         
            +
                kernel = raw_kernel / np.sum(raw_kernel)
         
     | 
| 169 | 
         
            +
                return kernel
         
     | 
| 170 | 
         
            +
             
     | 
| 171 | 
         
            +
             
     | 
| 172 | 
         
            +
            def gen_kernel(k_size=np.array([25, 25]), scale_factor=np.array([4, 4]), min_var=0.6, max_var=12., noise_level=0):
         
     | 
| 173 | 
         
            +
                """"
         
     | 
| 174 | 
         
            +
                # modified version of https://github.com/assafshocher/BlindSR_dataset_generator
         
     | 
| 175 | 
         
            +
                # Kai Zhang
         
     | 
| 176 | 
         
            +
                # min_var = 0.175 * sf  # variance of the gaussian kernel will be sampled between min_var and max_var
         
     | 
| 177 | 
         
            +
                # max_var = 2.5 * sf
         
     | 
| 178 | 
         
            +
                """
         
     | 
| 179 | 
         
            +
                sf = random.choice([1, 2, 3, 4])
         
     | 
| 180 | 
         
            +
                scale_factor = np.array([sf, sf])
         
     | 
| 181 | 
         
            +
                # Set random eigen-vals (lambdas) and angle (theta) for COV matrix
         
     | 
| 182 | 
         
            +
                lambda_1 = min_var + np.random.rand() * (max_var - min_var)
         
     | 
| 183 | 
         
            +
                lambda_2 = min_var + np.random.rand() * (max_var - min_var)
         
     | 
| 184 | 
         
            +
                theta = np.random.rand() * np.pi  # random theta
         
     | 
| 185 | 
         
            +
                noise = 0#-noise_level + np.random.rand(*k_size) * noise_level * 2
         
     | 
| 186 | 
         
            +
             
     | 
| 187 | 
         
            +
                # Set COV matrix using Lambdas and Theta
         
     | 
| 188 | 
         
            +
                LAMBDA = np.diag([lambda_1, lambda_2])
         
     | 
| 189 | 
         
            +
                Q = np.array([[np.cos(theta), -np.sin(theta)],
         
     | 
| 190 | 
         
            +
                              [np.sin(theta), np.cos(theta)]])
         
     | 
| 191 | 
         
            +
                SIGMA = Q @ LAMBDA @ Q.T
         
     | 
| 192 | 
         
            +
                INV_SIGMA = np.linalg.inv(SIGMA)[None, None, :, :]
         
     | 
| 193 | 
         
            +
             
     | 
| 194 | 
         
            +
                # Set expectation position (shifting kernel for aligned image)
         
     | 
| 195 | 
         
            +
                MU = k_size // 2 - 0.5*(scale_factor - 1) # - 0.5 * (scale_factor - k_size % 2)
         
     | 
| 196 | 
         
            +
                MU = MU[None, None, :, None]
         
     | 
| 197 | 
         
            +
             
     | 
| 198 | 
         
            +
                # Create meshgrid for Gaussian
         
     | 
| 199 | 
         
            +
                [X,Y] = np.meshgrid(range(k_size[0]), range(k_size[1]))
         
     | 
| 200 | 
         
            +
                Z = np.stack([X, Y], 2)[:, :, :, None]
         
     | 
| 201 | 
         
            +
             
     | 
| 202 | 
         
            +
                # Calcualte Gaussian for every pixel of the kernel
         
     | 
| 203 | 
         
            +
                ZZ = Z-MU
         
     | 
| 204 | 
         
            +
                ZZ_t = ZZ.transpose(0,1,3,2)
         
     | 
| 205 | 
         
            +
                raw_kernel = np.exp(-0.5 * np.squeeze(ZZ_t @ INV_SIGMA @ ZZ)) * (1 + noise)
         
     | 
| 206 | 
         
            +
             
     | 
| 207 | 
         
            +
                # shift the kernel so it will be centered
         
     | 
| 208 | 
         
            +
                #raw_kernel_centered = kernel_shift(raw_kernel, scale_factor)
         
     | 
| 209 | 
         
            +
             
     | 
| 210 | 
         
            +
                # Normalize the kernel and return
         
     | 
| 211 | 
         
            +
                #kernel = raw_kernel_centered / np.sum(raw_kernel_centered)
         
     | 
| 212 | 
         
            +
                kernel = raw_kernel / np.sum(raw_kernel)
         
     | 
| 213 | 
         
            +
                return kernel
         
     | 
| 214 | 
         
            +
             
     | 
| 215 | 
         
            +
             
     | 
| 216 | 
         
            +
            """
         
     | 
| 217 | 
         
            +
            # --------------------------------------------
         
     | 
| 218 | 
         
            +
            # degradation models
         
     | 
| 219 | 
         
            +
            # --------------------------------------------
         
     | 
| 220 | 
         
            +
            """
         
     | 
| 221 | 
         
            +
             
     | 
| 222 | 
         
            +
             
     | 
| 223 | 
         
            +
            def bicubic_degradation(x, sf=3):
         
     | 
| 224 | 
         
            +
                '''
         
     | 
| 225 | 
         
            +
                Args:
         
     | 
| 226 | 
         
            +
                    x: HxWxC image, [0, 1]
         
     | 
| 227 | 
         
            +
                    sf: down-scale factor
         
     | 
| 228 | 
         
            +
                Return:
         
     | 
| 229 | 
         
            +
                    bicubicly downsampled LR image
         
     | 
| 230 | 
         
            +
                '''
         
     | 
| 231 | 
         
            +
                x = util.imresize_np(x, scale=1/sf)
         
     | 
| 232 | 
         
            +
                return x
         
     | 
| 233 | 
         
            +
             
     | 
| 234 | 
         
            +
             
     | 
| 235 | 
         
            +
            def srmd_degradation(x, k, sf=3):
         
     | 
| 236 | 
         
            +
                ''' blur + bicubic downsampling
         
     | 
| 237 | 
         
            +
                Args:
         
     | 
| 238 | 
         
            +
                    x: HxWxC image, [0, 1]
         
     | 
| 239 | 
         
            +
                    k: hxw, double
         
     | 
| 240 | 
         
            +
                    sf: down-scale factor
         
     | 
| 241 | 
         
            +
                Return:
         
     | 
| 242 | 
         
            +
                    downsampled LR image
         
     | 
| 243 | 
         
            +
                Reference:
         
     | 
| 244 | 
         
            +
                    @inproceedings{zhang2018learning,
         
     | 
| 245 | 
         
            +
                      title={Learning a single convolutional super-resolution network for multiple degradations},
         
     | 
| 246 | 
         
            +
                      author={Zhang, Kai and Zuo, Wangmeng and Zhang, Lei},
         
     | 
| 247 | 
         
            +
                      booktitle={IEEE Conference on Computer Vision and Pattern Recognition},
         
     | 
| 248 | 
         
            +
                      pages={3262--3271},
         
     | 
| 249 | 
         
            +
                      year={2018}
         
     | 
| 250 | 
         
            +
                    }
         
     | 
| 251 | 
         
            +
                '''
         
     | 
| 252 | 
         
            +
                x = ndimage.filters.convolve(x, np.expand_dims(k, axis=2), mode='wrap')  # 'nearest' | 'mirror'
         
     | 
| 253 | 
         
            +
                x = bicubic_degradation(x, sf=sf)
         
     | 
| 254 | 
         
            +
                return x
         
     | 
| 255 | 
         
            +
             
     | 
| 256 | 
         
            +
             
     | 
| 257 | 
         
            +
            def dpsr_degradation(x, k, sf=3):
         
     | 
| 258 | 
         
            +
             
     | 
| 259 | 
         
            +
                ''' bicubic downsampling + blur
         
     | 
| 260 | 
         
            +
                Args:
         
     | 
| 261 | 
         
            +
                    x: HxWxC image, [0, 1]
         
     | 
| 262 | 
         
            +
                    k: hxw, double
         
     | 
| 263 | 
         
            +
                    sf: down-scale factor
         
     | 
| 264 | 
         
            +
                Return:
         
     | 
| 265 | 
         
            +
                    downsampled LR image
         
     | 
| 266 | 
         
            +
                Reference:
         
     | 
| 267 | 
         
            +
                    @inproceedings{zhang2019deep,
         
     | 
| 268 | 
         
            +
                      title={Deep Plug-and-Play Super-Resolution for Arbitrary Blur Kernels},
         
     | 
| 269 | 
         
            +
                      author={Zhang, Kai and Zuo, Wangmeng and Zhang, Lei},
         
     | 
| 270 | 
         
            +
                      booktitle={IEEE Conference on Computer Vision and Pattern Recognition},
         
     | 
| 271 | 
         
            +
                      pages={1671--1681},
         
     | 
| 272 | 
         
            +
                      year={2019}
         
     | 
| 273 | 
         
            +
                    }
         
     | 
| 274 | 
         
            +
                '''
         
     | 
| 275 | 
         
            +
                x = bicubic_degradation(x, sf=sf)
         
     | 
| 276 | 
         
            +
                x = ndimage.filters.convolve(x, np.expand_dims(k, axis=2), mode='wrap')
         
     | 
| 277 | 
         
            +
                return x
         
     | 
| 278 | 
         
            +
             
     | 
| 279 | 
         
            +
             
     | 
| 280 | 
         
            +
            def classical_degradation(x, k, sf=3):
         
     | 
| 281 | 
         
            +
                ''' blur + downsampling
         
     | 
| 282 | 
         
            +
             
     | 
| 283 | 
         
            +
                Args:
         
     | 
| 284 | 
         
            +
                    x: HxWxC image, [0, 1]/[0, 255]
         
     | 
| 285 | 
         
            +
                    k: hxw, double
         
     | 
| 286 | 
         
            +
                    sf: down-scale factor
         
     | 
| 287 | 
         
            +
             
     | 
| 288 | 
         
            +
                Return:
         
     | 
| 289 | 
         
            +
                    downsampled LR image
         
     | 
| 290 | 
         
            +
                '''
         
     | 
| 291 | 
         
            +
                x = ndimage.filters.convolve(x, np.expand_dims(k, axis=2), mode='wrap')
         
     | 
| 292 | 
         
            +
                #x = filters.correlate(x, np.expand_dims(np.flip(k), axis=2))
         
     | 
| 293 | 
         
            +
                st = 0
         
     | 
| 294 | 
         
            +
                return x[st::sf, st::sf, ...]
         
     | 
| 295 | 
         
            +
             
     | 
| 296 | 
         
            +
             
     | 
| 297 | 
         
            +
            def modcrop_np(img, sf):
         
     | 
| 298 | 
         
            +
                '''
         
     | 
| 299 | 
         
            +
                Args:
         
     | 
| 300 | 
         
            +
                    img: numpy image, WxH or WxHxC
         
     | 
| 301 | 
         
            +
                    sf: scale factor
         
     | 
| 302 | 
         
            +
                Return:
         
     | 
| 303 | 
         
            +
                    cropped image
         
     | 
| 304 | 
         
            +
                '''
         
     | 
| 305 | 
         
            +
                w, h = img.shape[:2]
         
     | 
| 306 | 
         
            +
                im = np.copy(img)
         
     | 
| 307 | 
         
            +
                return im[:w - w % sf, :h - h % sf, ...]
         
     | 
| 308 | 
         
            +
             
     | 
| 309 | 
         
            +
             
     | 
| 310 | 
         
            +
            '''
         
     | 
| 311 | 
         
            +
            # =================
         
     | 
| 312 | 
         
            +
            # Numpy
         
     | 
| 313 | 
         
            +
            # =================
         
     | 
| 314 | 
         
            +
            '''
         
     | 
| 315 | 
         
            +
             
     | 
| 316 | 
         
            +
             
     | 
| 317 | 
         
            +
            def shift_pixel(x, sf, upper_left=True):
         
     | 
| 318 | 
         
            +
                """shift pixel for super-resolution with different scale factors
         
     | 
| 319 | 
         
            +
                Args:
         
     | 
| 320 | 
         
            +
                    x: WxHxC or WxH, image or kernel
         
     | 
| 321 | 
         
            +
                    sf: scale factor
         
     | 
| 322 | 
         
            +
                    upper_left: shift direction
         
     | 
| 323 | 
         
            +
                """
         
     | 
| 324 | 
         
            +
                h, w = x.shape[:2]
         
     | 
| 325 | 
         
            +
                shift = (sf-1)*0.5
         
     | 
| 326 | 
         
            +
                xv, yv = np.arange(0, w, 1.0), np.arange(0, h, 1.0)
         
     | 
| 327 | 
         
            +
                if upper_left:
         
     | 
| 328 | 
         
            +
                    x1 = xv + shift
         
     | 
| 329 | 
         
            +
                    y1 = yv + shift
         
     | 
| 330 | 
         
            +
                else:
         
     | 
| 331 | 
         
            +
                    x1 = xv - shift
         
     | 
| 332 | 
         
            +
                    y1 = yv - shift
         
     | 
| 333 | 
         
            +
             
     | 
| 334 | 
         
            +
                x1 = np.clip(x1, 0, w-1)
         
     | 
| 335 | 
         
            +
                y1 = np.clip(y1, 0, h-1)
         
     | 
| 336 | 
         
            +
             
     | 
| 337 | 
         
            +
                if x.ndim == 2:
         
     | 
| 338 | 
         
            +
                    x = interp2d(xv, yv, x)(x1, y1)
         
     | 
| 339 | 
         
            +
                if x.ndim == 3:
         
     | 
| 340 | 
         
            +
                    for i in range(x.shape[-1]):
         
     | 
| 341 | 
         
            +
                        x[:, :, i] = interp2d(xv, yv, x[:, :, i])(x1, y1)
         
     | 
| 342 | 
         
            +
             
     | 
| 343 | 
         
            +
                return x
         
     | 
| 344 | 
         
            +
             
     | 
| 345 | 
         
            +
             
     | 
| 346 | 
         
            +
            '''
         
     | 
| 347 | 
         
            +
            # =================
         
     | 
| 348 | 
         
            +
            # pytorch
         
     | 
| 349 | 
         
            +
            # =================
         
     | 
| 350 | 
         
            +
            '''
         
     | 
| 351 | 
         
            +
             
     | 
| 352 | 
         
            +
             
     | 
| 353 | 
         
            +
            def splits(a, sf):
         
     | 
| 354 | 
         
            +
                '''
         
     | 
| 355 | 
         
            +
                a: tensor NxCxWxHx2
         
     | 
| 356 | 
         
            +
                sf: scale factor
         
     | 
| 357 | 
         
            +
                out: tensor NxCx(W/sf)x(H/sf)x2x(sf^2)
         
     | 
| 358 | 
         
            +
                '''
         
     | 
| 359 | 
         
            +
                b = torch.stack(torch.chunk(a, sf, dim=2), dim=5)
         
     | 
| 360 | 
         
            +
                b = torch.cat(torch.chunk(b, sf, dim=3), dim=5)
         
     | 
| 361 | 
         
            +
                return b
         
     | 
| 362 | 
         
            +
             
     | 
| 363 | 
         
            +
             
     | 
| 364 | 
         
            +
            def c2c(x):
         
     | 
| 365 | 
         
            +
                return torch.from_numpy(np.stack([np.float32(x.real), np.float32(x.imag)], axis=-1))
         
     | 
| 366 | 
         
            +
             
     | 
| 367 | 
         
            +
             
     | 
| 368 | 
         
            +
            def r2c(x):
         
     | 
| 369 | 
         
            +
                return torch.stack([x, torch.zeros_like(x)], -1)
         
     | 
| 370 | 
         
            +
             
     | 
| 371 | 
         
            +
             
     | 
| 372 | 
         
            +
            def cdiv(x, y):
         
     | 
| 373 | 
         
            +
                a, b = x[..., 0], x[..., 1]
         
     | 
| 374 | 
         
            +
                c, d = y[..., 0], y[..., 1]
         
     | 
| 375 | 
         
            +
                cd2 = c**2 + d**2
         
     | 
| 376 | 
         
            +
                return torch.stack([(a*c+b*d)/cd2, (b*c-a*d)/cd2], -1)
         
     | 
| 377 | 
         
            +
             
     | 
| 378 | 
         
            +
             
     | 
| 379 | 
         
            +
            def csum(x, y):
         
     | 
| 380 | 
         
            +
                return torch.stack([x[..., 0] + y, x[..., 1]], -1)
         
     | 
| 381 | 
         
            +
             
     | 
| 382 | 
         
            +
             
     | 
| 383 | 
         
            +
            def cabs(x):
         
     | 
| 384 | 
         
            +
                return torch.pow(x[..., 0]**2+x[..., 1]**2, 0.5)
         
     | 
| 385 | 
         
            +
             
     | 
| 386 | 
         
            +
             
     | 
| 387 | 
         
            +
            def cmul(t1, t2):
         
     | 
| 388 | 
         
            +
                '''
         
     | 
| 389 | 
         
            +
                complex multiplication
         
     | 
| 390 | 
         
            +
                t1: NxCxHxWx2
         
     | 
| 391 | 
         
            +
                output: NxCxHxWx2
         
     | 
| 392 | 
         
            +
                '''
         
     | 
| 393 | 
         
            +
                real1, imag1 = t1[..., 0], t1[..., 1]
         
     | 
| 394 | 
         
            +
                real2, imag2 = t2[..., 0], t2[..., 1]
         
     | 
| 395 | 
         
            +
                return torch.stack([real1 * real2 - imag1 * imag2, real1 * imag2 + imag1 * real2], dim=-1)
         
     | 
| 396 | 
         
            +
             
     | 
| 397 | 
         
            +
             
     | 
| 398 | 
         
            +
            def cconj(t, inplace=False):
         
     | 
| 399 | 
         
            +
                '''
         
     | 
| 400 | 
         
            +
                # complex's conjugation
         
     | 
| 401 | 
         
            +
                t: NxCxHxWx2
         
     | 
| 402 | 
         
            +
                output: NxCxHxWx2
         
     | 
| 403 | 
         
            +
                '''
         
     | 
| 404 | 
         
            +
                c = t.clone() if not inplace else t
         
     | 
| 405 | 
         
            +
                c[..., 1] *= -1
         
     | 
| 406 | 
         
            +
                return c
         
     | 
| 407 | 
         
            +
             
     | 
| 408 | 
         
            +
             
     | 
| 409 | 
         
            +
            def rfft(t):
         
     | 
| 410 | 
         
            +
                return torch.rfft(t, 2, onesided=False)
         
     | 
| 411 | 
         
            +
             
     | 
| 412 | 
         
            +
             
     | 
| 413 | 
         
            +
            def irfft(t):
         
     | 
| 414 | 
         
            +
                return torch.irfft(t, 2, onesided=False)
         
     | 
| 415 | 
         
            +
             
     | 
| 416 | 
         
            +
             
     | 
| 417 | 
         
            +
            def fft(t):
         
     | 
| 418 | 
         
            +
                return torch.fft(t, 2)
         
     | 
| 419 | 
         
            +
             
     | 
| 420 | 
         
            +
             
     | 
| 421 | 
         
            +
            def ifft(t):
         
     | 
| 422 | 
         
            +
                return torch.ifft(t, 2)
         
     | 
| 423 | 
         
            +
             
     | 
| 424 | 
         
            +
             
     | 
| 425 | 
         
            +
            def p2o(psf, shape):
         
     | 
| 426 | 
         
            +
                '''
         
     | 
| 427 | 
         
            +
                Args:
         
     | 
| 428 | 
         
            +
                    psf: NxCxhxw
         
     | 
| 429 | 
         
            +
                    shape: [H,W]
         
     | 
| 430 | 
         
            +
             
     | 
| 431 | 
         
            +
                Returns:
         
     | 
| 432 | 
         
            +
                    otf: NxCxHxWx2
         
     | 
| 433 | 
         
            +
                '''
         
     | 
| 434 | 
         
            +
                otf = torch.zeros(psf.shape[:-2] + shape).type_as(psf)
         
     | 
| 435 | 
         
            +
                otf[...,:psf.shape[2],:psf.shape[3]].copy_(psf)
         
     | 
| 436 | 
         
            +
                for axis, axis_size in enumerate(psf.shape[2:]):
         
     | 
| 437 | 
         
            +
                    otf = torch.roll(otf, -int(axis_size / 2), dims=axis+2)
         
     | 
| 438 | 
         
            +
                otf = torch.rfft(otf, 2, onesided=False)
         
     | 
| 439 | 
         
            +
                n_ops = torch.sum(torch.tensor(psf.shape).type_as(psf) * torch.log2(torch.tensor(psf.shape).type_as(psf)))
         
     | 
| 440 | 
         
            +
                otf[...,1][torch.abs(otf[...,1])<n_ops*2.22e-16] = torch.tensor(0).type_as(psf)
         
     | 
| 441 | 
         
            +
                return otf
         
     | 
| 442 | 
         
            +
             
     | 
| 443 | 
         
            +
             
     | 
| 444 | 
         
            +
            '''
         
     | 
| 445 | 
         
            +
            # =================
         
     | 
| 446 | 
         
            +
            PyTorch
         
     | 
| 447 | 
         
            +
            # =================
         
     | 
| 448 | 
         
            +
            '''
         
     | 
| 449 | 
         
            +
             
     | 
| 450 | 
         
            +
            def INVLS_pytorch(FB, FBC, F2B, FR, tau, sf=2):
         
     | 
| 451 | 
         
            +
                '''
         
     | 
| 452 | 
         
            +
                FB: NxCxWxHx2
         
     | 
| 453 | 
         
            +
                F2B: NxCxWxHx2
         
     | 
| 454 | 
         
            +
             
     | 
| 455 | 
         
            +
                x1 = FB.*FR;
         
     | 
| 456 | 
         
            +
                FBR = BlockMM(nr,nc,Nb,m,x1);
         
     | 
| 457 | 
         
            +
                invW = BlockMM(nr,nc,Nb,m,F2B);
         
     | 
| 458 | 
         
            +
                invWBR = FBR./(invW + tau*Nb);
         
     | 
| 459 | 
         
            +
                fun = @(block_struct) block_struct.data.*invWBR;
         
     | 
| 460 | 
         
            +
                FCBinvWBR = blockproc(FBC,[nr,nc],fun);
         
     | 
| 461 | 
         
            +
                FX = (FR-FCBinvWBR)/tau;
         
     | 
| 462 | 
         
            +
                Xest = real(ifft2(FX));
         
     | 
| 463 | 
         
            +
                '''
         
     | 
| 464 | 
         
            +
                x1 = cmul(FB, FR)
         
     | 
| 465 | 
         
            +
                FBR = torch.mean(splits(x1, sf), dim=-1, keepdim=False)
         
     | 
| 466 | 
         
            +
                invW = torch.mean(splits(F2B, sf), dim=-1, keepdim=False)
         
     | 
| 467 | 
         
            +
                invWBR = cdiv(FBR, csum(invW, tau))
         
     | 
| 468 | 
         
            +
                FCBinvWBR = cmul(FBC, invWBR.repeat(1,1,sf,sf,1))
         
     | 
| 469 | 
         
            +
                FX = (FR-FCBinvWBR)/tau
         
     | 
| 470 | 
         
            +
                Xest = torch.irfft(FX, 2, onesided=False)
         
     | 
| 471 | 
         
            +
                return Xest
         
     | 
| 472 | 
         
            +
             
     | 
| 473 | 
         
            +
             
     | 
| 474 | 
         
            +
            def real2complex(x):
         
     | 
| 475 | 
         
            +
                return torch.stack([x, torch.zeros_like(x)], -1)
         
     | 
| 476 | 
         
            +
             
     | 
| 477 | 
         
            +
             
     | 
| 478 | 
         
            +
            def modcrop(img, sf):
         
     | 
| 479 | 
         
            +
                '''
         
     | 
| 480 | 
         
            +
                img: tensor image, NxCxWxH or CxWxH or WxH
         
     | 
| 481 | 
         
            +
                sf: scale factor
         
     | 
| 482 | 
         
            +
                '''
         
     | 
| 483 | 
         
            +
                w, h = img.shape[-2:]
         
     | 
| 484 | 
         
            +
                im = img.clone()
         
     | 
| 485 | 
         
            +
                return im[..., :w - w % sf, :h - h % sf]
         
     | 
| 486 | 
         
            +
             
     | 
| 487 | 
         
            +
             
     | 
| 488 | 
         
            +
            def upsample(x, sf=3, center=False):
         
     | 
| 489 | 
         
            +
                '''
         
     | 
| 490 | 
         
            +
                x: tensor image, NxCxWxH
         
     | 
| 491 | 
         
            +
                '''
         
     | 
| 492 | 
         
            +
                st = (sf-1)//2 if center else 0
         
     | 
| 493 | 
         
            +
                z = torch.zeros((x.shape[0], x.shape[1], x.shape[2]*sf, x.shape[3]*sf)).type_as(x)
         
     | 
| 494 | 
         
            +
                z[..., st::sf, st::sf].copy_(x)
         
     | 
| 495 | 
         
            +
                return z
         
     | 
| 496 | 
         
            +
             
     | 
| 497 | 
         
            +
             
     | 
| 498 | 
         
            +
            def downsample(x, sf=3, center=False):
         
     | 
| 499 | 
         
            +
                st = (sf-1)//2 if center else 0
         
     | 
| 500 | 
         
            +
                return x[..., st::sf, st::sf]
         
     | 
| 501 | 
         
            +
             
     | 
| 502 | 
         
            +
             
     | 
| 503 | 
         
            +
            def circular_pad(x, pad):
         
     | 
| 504 | 
         
            +
                '''
         
     | 
| 505 | 
         
            +
                # x[N, 1, W, H] -> x[N, 1, W + 2 pad, H + 2 pad] (pariodic padding)
         
     | 
| 506 | 
         
            +
                '''
         
     | 
| 507 | 
         
            +
                x = torch.cat([x, x[:, :, 0:pad, :]], dim=2)
         
     | 
| 508 | 
         
            +
                x = torch.cat([x, x[:, :, :, 0:pad]], dim=3)
         
     | 
| 509 | 
         
            +
                x = torch.cat([x[:, :, -2 * pad:-pad, :], x], dim=2)
         
     | 
| 510 | 
         
            +
                x = torch.cat([x[:, :, :, -2 * pad:-pad], x], dim=3)
         
     | 
| 511 | 
         
            +
                return x
         
     | 
| 512 | 
         
            +
             
     | 
| 513 | 
         
            +
             
     | 
| 514 | 
         
            +
            def pad_circular(input, padding):
         
     | 
| 515 | 
         
            +
                # type: (Tensor, List[int]) -> Tensor
         
     | 
| 516 | 
         
            +
                """
         
     | 
| 517 | 
         
            +
                Arguments
         
     | 
| 518 | 
         
            +
                :param input: tensor of shape :math:`(N, C_{\text{in}}, H, [W, D]))`
         
     | 
| 519 | 
         
            +
                :param padding: (tuple): m-elem tuple where m is the degree of convolution
         
     | 
| 520 | 
         
            +
                Returns
         
     | 
| 521 | 
         
            +
                :return: tensor of shape :math:`(N, C_{\text{in}}, [D + 2 * padding[0],
         
     | 
| 522 | 
         
            +
                                                 H + 2 * padding[1]], W + 2 * padding[2]))`
         
     | 
| 523 | 
         
            +
                """
         
     | 
| 524 | 
         
            +
                offset = 3
         
     | 
| 525 | 
         
            +
                for dimension in range(input.dim() - offset + 1):
         
     | 
| 526 | 
         
            +
                    input = dim_pad_circular(input, padding[dimension], dimension + offset)
         
     | 
| 527 | 
         
            +
                return input
         
     | 
| 528 | 
         
            +
             
     | 
| 529 | 
         
            +
             
     | 
| 530 | 
         
            +
            def dim_pad_circular(input, padding, dimension):
         
     | 
| 531 | 
         
            +
                # type: (Tensor, int, int) -> Tensor
         
     | 
| 532 | 
         
            +
                input = torch.cat([input, input[[slice(None)] * (dimension - 1) +
         
     | 
| 533 | 
         
            +
                                  [slice(0, padding)]]], dim=dimension - 1)
         
     | 
| 534 | 
         
            +
                input = torch.cat([input[[slice(None)] * (dimension - 1) +
         
     | 
| 535 | 
         
            +
                                  [slice(-2 * padding, -padding)]], input], dim=dimension - 1)
         
     | 
| 536 | 
         
            +
                return input
         
     | 
| 537 | 
         
            +
             
     | 
| 538 | 
         
            +
             
     | 
| 539 | 
         
            +
            def imfilter(x, k):
         
     | 
| 540 | 
         
            +
                '''
         
     | 
| 541 | 
         
            +
                x: image, NxcxHxW
         
     | 
| 542 | 
         
            +
                k: kernel, cx1xhxw
         
     | 
| 543 | 
         
            +
                '''
         
     | 
| 544 | 
         
            +
                x = pad_circular(x, padding=((k.shape[-2]-1)//2, (k.shape[-1]-1)//2))
         
     | 
| 545 | 
         
            +
                x = torch.nn.functional.conv2d(x, k, groups=x.shape[1])
         
     | 
| 546 | 
         
            +
                return x
         
     | 
| 547 | 
         
            +
             
     | 
| 548 | 
         
            +
             
     | 
| 549 | 
         
            +
            def G(x, k, sf=3, center=False):
         
     | 
| 550 | 
         
            +
                '''
         
     | 
| 551 | 
         
            +
                x: image, NxcxHxW
         
     | 
| 552 | 
         
            +
                k: kernel, cx1xhxw
         
     | 
| 553 | 
         
            +
                sf: scale factor
         
     | 
| 554 | 
         
            +
                center: the first one or the moddle one
         
     | 
| 555 | 
         
            +
             
     | 
| 556 | 
         
            +
                Matlab function:
         
     | 
| 557 | 
         
            +
                tmp = imfilter(x,h,'circular');
         
     | 
| 558 | 
         
            +
                y = downsample2(tmp,K);
         
     | 
| 559 | 
         
            +
                '''
         
     | 
| 560 | 
         
            +
                x = downsample(imfilter(x, k), sf=sf, center=center)
         
     | 
| 561 | 
         
            +
                return x
         
     | 
| 562 | 
         
            +
             
     | 
| 563 | 
         
            +
             
     | 
| 564 | 
         
            +
            def Gt(x, k, sf=3, center=False):
         
     | 
| 565 | 
         
            +
                '''
         
     | 
| 566 | 
         
            +
                x: image, NxcxHxW
         
     | 
| 567 | 
         
            +
                k: kernel, cx1xhxw
         
     | 
| 568 | 
         
            +
                sf: scale factor
         
     | 
| 569 | 
         
            +
                center: the first one or the moddle one
         
     | 
| 570 | 
         
            +
             
     | 
| 571 | 
         
            +
                Matlab function:
         
     | 
| 572 | 
         
            +
                tmp = upsample2(x,K);
         
     | 
| 573 | 
         
            +
                y = imfilter(tmp,h,'circular');
         
     | 
| 574 | 
         
            +
                '''
         
     | 
| 575 | 
         
            +
                x = imfilter(upsample(x, sf=sf, center=center), k)
         
     | 
| 576 | 
         
            +
                return x
         
     | 
| 577 | 
         
            +
             
     | 
| 578 | 
         
            +
             
     | 
| 579 | 
         
            +
            def interpolation_down(x, sf, center=False):
         
     | 
| 580 | 
         
            +
                mask = torch.zeros_like(x)
         
     | 
| 581 | 
         
            +
                if center:
         
     | 
| 582 | 
         
            +
                    start = torch.tensor((sf-1)//2)
         
     | 
| 583 | 
         
            +
                    mask[..., start::sf, start::sf] = torch.tensor(1).type_as(x)
         
     | 
| 584 | 
         
            +
                    LR = x[..., start::sf, start::sf]
         
     | 
| 585 | 
         
            +
                else:
         
     | 
| 586 | 
         
            +
                    mask[..., ::sf, ::sf] = torch.tensor(1).type_as(x)
         
     | 
| 587 | 
         
            +
                    LR = x[..., ::sf, ::sf]
         
     | 
| 588 | 
         
            +
                y = x.mul(mask)
         
     | 
| 589 | 
         
            +
             
     | 
| 590 | 
         
            +
                return LR, y, mask
         
     | 
| 591 | 
         
            +
             
     | 
| 592 | 
         
            +
             
     | 
| 593 | 
         
            +
            '''
         
     | 
| 594 | 
         
            +
            # =================
         
     | 
| 595 | 
         
            +
            Numpy
         
     | 
| 596 | 
         
            +
            # =================
         
     | 
| 597 | 
         
            +
            '''
         
     | 
| 598 | 
         
            +
             
     | 
| 599 | 
         
            +
             
     | 
| 600 | 
         
            +
            def blockproc(im, blocksize, fun):
         
     | 
| 601 | 
         
            +
                xblocks = np.split(im, range(blocksize[0], im.shape[0], blocksize[0]), axis=0)
         
     | 
| 602 | 
         
            +
                xblocks_proc = []
         
     | 
| 603 | 
         
            +
                for xb in xblocks:
         
     | 
| 604 | 
         
            +
                    yblocks = np.split(xb, range(blocksize[1], im.shape[1], blocksize[1]), axis=1)
         
     | 
| 605 | 
         
            +
                    yblocks_proc = []
         
     | 
| 606 | 
         
            +
                    for yb in yblocks:
         
     | 
| 607 | 
         
            +
                        yb_proc = fun(yb)
         
     | 
| 608 | 
         
            +
                        yblocks_proc.append(yb_proc)
         
     | 
| 609 | 
         
            +
                    xblocks_proc.append(np.concatenate(yblocks_proc, axis=1))
         
     | 
| 610 | 
         
            +
             
     | 
| 611 | 
         
            +
                proc = np.concatenate(xblocks_proc, axis=0)
         
     | 
| 612 | 
         
            +
             
     | 
| 613 | 
         
            +
                return proc
         
     | 
| 614 | 
         
            +
             
     | 
| 615 | 
         
            +
             
     | 
| 616 | 
         
            +
            def fun_reshape(a):
         
     | 
| 617 | 
         
            +
                return np.reshape(a, (-1,1,a.shape[-1]), order='F')
         
     | 
| 618 | 
         
            +
             
     | 
| 619 | 
         
            +
             
     | 
| 620 | 
         
            +
            def fun_mul(a, b):
         
     | 
| 621 | 
         
            +
                return a*b
         
     | 
| 622 | 
         
            +
             
     | 
| 623 | 
         
            +
             
     | 
| 624 | 
         
            +
            def BlockMM(nr, nc, Nb, m, x1):
         
     | 
| 625 | 
         
            +
                '''
         
     | 
| 626 | 
         
            +
                myfun = @(block_struct) reshape(block_struct.data,m,1);
         
     | 
| 627 | 
         
            +
                x1 = blockproc(x1,[nr nc],myfun);
         
     | 
| 628 | 
         
            +
                x1 = reshape(x1,m,Nb);
         
     | 
| 629 | 
         
            +
                x1 = sum(x1,2);
         
     | 
| 630 | 
         
            +
                x = reshape(x1,nr,nc);
         
     | 
| 631 | 
         
            +
                '''
         
     | 
| 632 | 
         
            +
                fun = fun_reshape
         
     | 
| 633 | 
         
            +
                x1 = blockproc(x1, blocksize=(nr, nc), fun=fun)
         
     | 
| 634 | 
         
            +
                x1 = np.reshape(x1, (m, Nb, x1.shape[-1]), order='F')
         
     | 
| 635 | 
         
            +
                x1 = np.sum(x1, 1)
         
     | 
| 636 | 
         
            +
                x = np.reshape(x1, (nr, nc, x1.shape[-1]), order='F')
         
     | 
| 637 | 
         
            +
                return x
         
     | 
| 638 | 
         
            +
             
     | 
| 639 | 
         
            +
             
     | 
| 640 | 
         
            +
            def INVLS(FB, FBC, F2B, FR, tau, Nb, nr, nc, m):
         
     | 
| 641 | 
         
            +
                '''
         
     | 
| 642 | 
         
            +
                x1 = FB.*FR;
         
     | 
| 643 | 
         
            +
                FBR = BlockMM(nr,nc,Nb,m,x1);
         
     | 
| 644 | 
         
            +
                invW = BlockMM(nr,nc,Nb,m,F2B);
         
     | 
| 645 | 
         
            +
                invWBR = FBR./(invW + tau*Nb);
         
     | 
| 646 | 
         
            +
                fun = @(block_struct) block_struct.data.*invWBR;
         
     | 
| 647 | 
         
            +
                FCBinvWBR = blockproc(FBC,[nr,nc],fun);
         
     | 
| 648 | 
         
            +
                FX = (FR-FCBinvWBR)/tau;
         
     | 
| 649 | 
         
            +
                Xest = real(ifft2(FX));
         
     | 
| 650 | 
         
            +
                '''
         
     | 
| 651 | 
         
            +
                x1 = FB*FR
         
     | 
| 652 | 
         
            +
                FBR = BlockMM(nr, nc, Nb, m, x1)
         
     | 
| 653 | 
         
            +
                invW = BlockMM(nr, nc, Nb, m, F2B)
         
     | 
| 654 | 
         
            +
                invWBR = FBR/(invW + tau*Nb)
         
     | 
| 655 | 
         
            +
                FCBinvWBR = blockproc(FBC, [nr, nc], lambda im: fun_mul(im, invWBR))
         
     | 
| 656 | 
         
            +
                FX = (FR-FCBinvWBR)/tau
         
     | 
| 657 | 
         
            +
                Xest = np.real(np.fft.ifft2(FX, axes=(0, 1)))
         
     | 
| 658 | 
         
            +
                return Xest
         
     | 
| 659 | 
         
            +
             
     | 
| 660 | 
         
            +
             
     | 
| 661 | 
         
            +
            def psf2otf(psf, shape=None):
         
     | 
| 662 | 
         
            +
                """
         
     | 
| 663 | 
         
            +
                Convert point-spread function to optical transfer function.
         
     | 
| 664 | 
         
            +
                Compute the Fast Fourier Transform (FFT) of the point-spread
         
     | 
| 665 | 
         
            +
                function (PSF) array and creates the optical transfer function (OTF)
         
     | 
| 666 | 
         
            +
                array that is not influenced by the PSF off-centering.
         
     | 
| 667 | 
         
            +
                By default, the OTF array is the same size as the PSF array.
         
     | 
| 668 | 
         
            +
                To ensure that the OTF is not altered due to PSF off-centering, PSF2OTF
         
     | 
| 669 | 
         
            +
                post-pads the PSF array (down or to the right) with zeros to match
         
     | 
| 670 | 
         
            +
                dimensions specified in OUTSIZE, then circularly shifts the values of
         
     | 
| 671 | 
         
            +
                the PSF array up (or to the left) until the central pixel reaches (1,1)
         
     | 
| 672 | 
         
            +
                position.
         
     | 
| 673 | 
         
            +
                Parameters
         
     | 
| 674 | 
         
            +
                ----------
         
     | 
| 675 | 
         
            +
                psf : `numpy.ndarray`
         
     | 
| 676 | 
         
            +
                    PSF array
         
     | 
| 677 | 
         
            +
                shape : int
         
     | 
| 678 | 
         
            +
                    Output shape of the OTF array
         
     | 
| 679 | 
         
            +
                Returns
         
     | 
| 680 | 
         
            +
                -------
         
     | 
| 681 | 
         
            +
                otf : `numpy.ndarray`
         
     | 
| 682 | 
         
            +
                    OTF array
         
     | 
| 683 | 
         
            +
                Notes
         
     | 
| 684 | 
         
            +
                -----
         
     | 
| 685 | 
         
            +
                Adapted from MATLAB psf2otf function
         
     | 
| 686 | 
         
            +
                """
         
     | 
| 687 | 
         
            +
                if type(shape) == type(None):
         
     | 
| 688 | 
         
            +
                    shape = psf.shape
         
     | 
| 689 | 
         
            +
                shape = np.array(shape)
         
     | 
| 690 | 
         
            +
                if np.all(psf == 0):
         
     | 
| 691 | 
         
            +
                    # return np.zeros_like(psf)
         
     | 
| 692 | 
         
            +
                    return np.zeros(shape)
         
     | 
| 693 | 
         
            +
                if len(psf.shape) == 1:
         
     | 
| 694 | 
         
            +
                    psf = psf.reshape((1, psf.shape[0]))
         
     | 
| 695 | 
         
            +
                inshape = psf.shape
         
     | 
| 696 | 
         
            +
                psf = zero_pad(psf, shape, position='corner')
         
     | 
| 697 | 
         
            +
                for axis, axis_size in enumerate(inshape):
         
     | 
| 698 | 
         
            +
                    psf = np.roll(psf, -int(axis_size / 2), axis=axis)
         
     | 
| 699 | 
         
            +
                # Compute the OTF
         
     | 
| 700 | 
         
            +
                otf = np.fft.fft2(psf, axes=(0, 1))
         
     | 
| 701 | 
         
            +
                # Estimate the rough number of operations involved in the FFT
         
     | 
| 702 | 
         
            +
                # and discard the PSF imaginary part if within roundoff error
         
     | 
| 703 | 
         
            +
                # roundoff error  = machine epsilon = sys.float_info.epsilon
         
     | 
| 704 | 
         
            +
                # or np.finfo().eps
         
     | 
| 705 | 
         
            +
                n_ops = np.sum(psf.size * np.log2(psf.shape))
         
     | 
| 706 | 
         
            +
                otf = np.real_if_close(otf, tol=n_ops)
         
     | 
| 707 | 
         
            +
                return otf
         
     | 
| 708 | 
         
            +
             
     | 
| 709 | 
         
            +
             
     | 
| 710 | 
         
            +
            def zero_pad(image, shape, position='corner'):
         
     | 
| 711 | 
         
            +
                """
         
     | 
| 712 | 
         
            +
                Extends image to a certain size with zeros
         
     | 
| 713 | 
         
            +
                Parameters
         
     | 
| 714 | 
         
            +
                ----------
         
     | 
| 715 | 
         
            +
                image: real 2d `numpy.ndarray`
         
     | 
| 716 | 
         
            +
                    Input image
         
     | 
| 717 | 
         
            +
                shape: tuple of int
         
     | 
| 718 | 
         
            +
                    Desired output shape of the image
         
     | 
| 719 | 
         
            +
                position : str, optional
         
     | 
| 720 | 
         
            +
                    The position of the input image in the output one:
         
     | 
| 721 | 
         
            +
                        * 'corner'
         
     | 
| 722 | 
         
            +
                            top-left corner (default)
         
     | 
| 723 | 
         
            +
                        * 'center'
         
     | 
| 724 | 
         
            +
                            centered
         
     | 
| 725 | 
         
            +
                Returns
         
     | 
| 726 | 
         
            +
                -------
         
     | 
| 727 | 
         
            +
                padded_img: real `numpy.ndarray`
         
     | 
| 728 | 
         
            +
                    The zero-padded image
         
     | 
| 729 | 
         
            +
                """
         
     | 
| 730 | 
         
            +
                shape = np.asarray(shape, dtype=int)
         
     | 
| 731 | 
         
            +
                imshape = np.asarray(image.shape, dtype=int)
         
     | 
| 732 | 
         
            +
                if np.alltrue(imshape == shape):
         
     | 
| 733 | 
         
            +
                    return image
         
     | 
| 734 | 
         
            +
                if np.any(shape <= 0):
         
     | 
| 735 | 
         
            +
                    raise ValueError("ZERO_PAD: null or negative shape given")
         
     | 
| 736 | 
         
            +
                dshape = shape - imshape
         
     | 
| 737 | 
         
            +
                if np.any(dshape < 0):
         
     | 
| 738 | 
         
            +
                    raise ValueError("ZERO_PAD: target size smaller than source one")
         
     | 
| 739 | 
         
            +
                pad_img = np.zeros(shape, dtype=image.dtype)
         
     | 
| 740 | 
         
            +
                idx, idy = np.indices(imshape)
         
     | 
| 741 | 
         
            +
                if position == 'center':
         
     | 
| 742 | 
         
            +
                    if np.any(dshape % 2 != 0):
         
     | 
| 743 | 
         
            +
                        raise ValueError("ZERO_PAD: source and target shapes "
         
     | 
| 744 | 
         
            +
                                         "have different parity.")
         
     | 
| 745 | 
         
            +
                    offx, offy = dshape // 2
         
     | 
| 746 | 
         
            +
                else:
         
     | 
| 747 | 
         
            +
                    offx, offy = (0, 0)
         
     | 
| 748 | 
         
            +
                pad_img[idx + offx, idy + offy] = image
         
     | 
| 749 | 
         
            +
                return pad_img
         
     | 
| 750 | 
         
            +
             
     | 
| 751 | 
         
            +
             
     | 
| 752 | 
         
            +
            def upsample_np(x, sf=3, center=False):
         
     | 
| 753 | 
         
            +
                st = (sf-1)//2 if center else 0
         
     | 
| 754 | 
         
            +
                z = np.zeros((x.shape[0]*sf, x.shape[1]*sf, x.shape[2]))
         
     | 
| 755 | 
         
            +
                z[st::sf, st::sf, ...] = x
         
     | 
| 756 | 
         
            +
                return z
         
     | 
| 757 | 
         
            +
             
     | 
| 758 | 
         
            +
             
     | 
| 759 | 
         
            +
            def downsample_np(x, sf=3, center=False):
         
     | 
| 760 | 
         
            +
                st = (sf-1)//2 if center else 0
         
     | 
| 761 | 
         
            +
                return x[st::sf, st::sf, ...]
         
     | 
| 762 | 
         
            +
             
     | 
| 763 | 
         
            +
             
     | 
| 764 | 
         
            +
            def imfilter_np(x, k):
         
     | 
| 765 | 
         
            +
                '''
         
     | 
| 766 | 
         
            +
                x: image, NxcxHxW
         
     | 
| 767 | 
         
            +
                k: kernel, cx1xhxw
         
     | 
| 768 | 
         
            +
                '''
         
     | 
| 769 | 
         
            +
                x = ndimage.filters.convolve(x, np.expand_dims(k, axis=2), mode='wrap')
         
     | 
| 770 | 
         
            +
                return x
         
     | 
| 771 | 
         
            +
             
     | 
| 772 | 
         
            +
             
     | 
| 773 | 
         
            +
            def G_np(x, k, sf=3, center=False):
         
     | 
| 774 | 
         
            +
                '''
         
     | 
| 775 | 
         
            +
                x: image, NxcxHxW
         
     | 
| 776 | 
         
            +
                k: kernel, cx1xhxw
         
     | 
| 777 | 
         
            +
             
     | 
| 778 | 
         
            +
                Matlab function:
         
     | 
| 779 | 
         
            +
                tmp = imfilter(x,h,'circular');
         
     | 
| 780 | 
         
            +
                y = downsample2(tmp,K);
         
     | 
| 781 | 
         
            +
                '''
         
     | 
| 782 | 
         
            +
                x = downsample_np(imfilter_np(x, k), sf=sf, center=center)
         
     | 
| 783 | 
         
            +
                return x
         
     | 
| 784 | 
         
            +
             
     | 
| 785 | 
         
            +
             
     | 
| 786 | 
         
            +
            def Gt_np(x, k, sf=3, center=False):
         
     | 
| 787 | 
         
            +
                '''
         
     | 
| 788 | 
         
            +
                x: image, NxcxHxW
         
     | 
| 789 | 
         
            +
                k: kernel, cx1xhxw
         
     | 
| 790 | 
         
            +
             
     | 
| 791 | 
         
            +
                Matlab function:
         
     | 
| 792 | 
         
            +
                tmp = upsample2(x,K);
         
     | 
| 793 | 
         
            +
                y = imfilter(tmp,h,'circular');
         
     | 
| 794 | 
         
            +
                '''
         
     | 
| 795 | 
         
            +
                x = imfilter_np(upsample_np(x, sf=sf, center=center), k)
         
     | 
| 796 | 
         
            +
                return x
         
     | 
| 797 | 
         
            +
             
     | 
| 798 | 
         
            +
             
     | 
| 799 | 
         
            +
            if __name__ == '__main__':
         
     | 
| 800 | 
         
            +
                img = util.imread_uint('test.bmp', 3)
         
     | 
| 801 | 
         
            +
             
     | 
| 802 | 
         
            +
                img = util.uint2single(img)
         
     | 
| 803 | 
         
            +
                k = anisotropic_Gaussian(ksize=15, theta=np.pi, l1=6, l2=6)
         
     | 
| 804 | 
         
            +
                util.imshow(k*10)
         
     | 
| 805 | 
         
            +
             
     | 
| 806 | 
         
            +
             
     | 
| 807 | 
         
            +
                for sf in [2, 3, 4]:
         
     | 
| 808 | 
         
            +
             
     | 
| 809 | 
         
            +
                    # modcrop
         
     | 
| 810 | 
         
            +
                    img = modcrop_np(img, sf=sf)
         
     | 
| 811 | 
         
            +
             
     | 
| 812 | 
         
            +
                    # 1) bicubic degradation
         
     | 
| 813 | 
         
            +
                    img_b = bicubic_degradation(img, sf=sf)
         
     | 
| 814 | 
         
            +
                    print(img_b.shape)
         
     | 
| 815 | 
         
            +
             
     | 
| 816 | 
         
            +
                    # 2) srmd degradation
         
     | 
| 817 | 
         
            +
                    img_s = srmd_degradation(img, k, sf=sf)
         
     | 
| 818 | 
         
            +
                    print(img_s.shape)
         
     | 
| 819 | 
         
            +
             
     | 
| 820 | 
         
            +
                    # 3) dpsr degradation
         
     | 
| 821 | 
         
            +
                    img_d = dpsr_degradation(img, k, sf=sf)
         
     | 
| 822 | 
         
            +
                    print(img_d.shape)
         
     | 
| 823 | 
         
            +
             
     | 
| 824 | 
         
            +
                    # 4) classical degradation
         
     | 
| 825 | 
         
            +
                    img_d = classical_degradation(img, k, sf=sf)
         
     | 
| 826 | 
         
            +
                    print(img_d.shape)
         
     | 
| 827 | 
         
            +
             
     | 
| 828 | 
         
            +
                k = anisotropic_Gaussian(ksize=7, theta=0.25*np.pi, l1=0.01, l2=0.01)
         
     | 
| 829 | 
         
            +
                #print(k)
         
     | 
| 830 | 
         
            +
            #    util.imshow(k*10)
         
     | 
| 831 | 
         
            +
             
     | 
| 832 | 
         
            +
                k = shifted_anisotropic_Gaussian(k_size=np.array([15, 15]), scale_factor=np.array([4, 4]), min_var=0.8, max_var=10.8, noise_level=0.0)
         
     | 
| 833 | 
         
            +
            #    util.imshow(k*10)
         
     | 
| 834 | 
         
            +
             
     | 
| 835 | 
         
            +
             
     | 
| 836 | 
         
            +
                # PCA
         
     | 
| 837 | 
         
            +
            #    pca_matrix = cal_pca_matrix(ksize=15, l_max=10.0, dim_pca=15, num_samples=12500)
         
     | 
| 838 | 
         
            +
            #    print(pca_matrix.shape)
         
     | 
| 839 | 
         
            +
            #    show_pca(pca_matrix)
         
     | 
| 840 | 
         
            +
                # run utils/utils_sisr.py
         
     | 
| 841 | 
         
            +
                # run utils_sisr.py
         
     | 
| 842 | 
         
            +
                
         
     | 
| 843 | 
         
            +
                
         
     | 
| 844 | 
         
            +
                
         
     | 
| 845 | 
         
            +
                
         
     | 
| 846 | 
         
            +
                
         
     | 
| 847 | 
         
            +
                
         
     | 
| 848 | 
         
            +
                
         
     | 
    	
        core/data/deg_kair_utils/utils_video.py
    ADDED
    
    | 
         @@ -0,0 +1,493 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            import os
         
     | 
| 2 | 
         
            +
            import cv2
         
     | 
| 3 | 
         
            +
            import numpy as np
         
     | 
| 4 | 
         
            +
            import torch
         
     | 
| 5 | 
         
            +
            import random
         
     | 
| 6 | 
         
            +
            from os import path as osp
         
     | 
| 7 | 
         
            +
            from torch.nn import functional as F
         
     | 
| 8 | 
         
            +
            from abc import ABCMeta, abstractmethod
         
     | 
| 9 | 
         
            +
             
     | 
| 10 | 
         
            +
             
     | 
| 11 | 
         
            +
            def scandir(dir_path, suffix=None, recursive=False, full_path=False):
         
     | 
| 12 | 
         
            +
                """Scan a directory to find the interested files.
         
     | 
| 13 | 
         
            +
             
     | 
| 14 | 
         
            +
                Args:
         
     | 
| 15 | 
         
            +
                    dir_path (str): Path of the directory.
         
     | 
| 16 | 
         
            +
                    suffix (str | tuple(str), optional): File suffix that we are
         
     | 
| 17 | 
         
            +
                        interested in. Default: None.
         
     | 
| 18 | 
         
            +
                    recursive (bool, optional): If set to True, recursively scan the
         
     | 
| 19 | 
         
            +
                        directory. Default: False.
         
     | 
| 20 | 
         
            +
                    full_path (bool, optional): If set to True, include the dir_path.
         
     | 
| 21 | 
         
            +
                        Default: False.
         
     | 
| 22 | 
         
            +
             
     | 
| 23 | 
         
            +
                Returns:
         
     | 
| 24 | 
         
            +
                    A generator for all the interested files with relative paths.
         
     | 
| 25 | 
         
            +
                """
         
     | 
| 26 | 
         
            +
             
     | 
| 27 | 
         
            +
                if (suffix is not None) and not isinstance(suffix, (str, tuple)):
         
     | 
| 28 | 
         
            +
                    raise TypeError('"suffix" must be a string or tuple of strings')
         
     | 
| 29 | 
         
            +
             
     | 
| 30 | 
         
            +
                root = dir_path
         
     | 
| 31 | 
         
            +
             
     | 
| 32 | 
         
            +
                def _scandir(dir_path, suffix, recursive):
         
     | 
| 33 | 
         
            +
                    for entry in os.scandir(dir_path):
         
     | 
| 34 | 
         
            +
                        if not entry.name.startswith('.') and entry.is_file():
         
     | 
| 35 | 
         
            +
                            if full_path:
         
     | 
| 36 | 
         
            +
                                return_path = entry.path
         
     | 
| 37 | 
         
            +
                            else:
         
     | 
| 38 | 
         
            +
                                return_path = osp.relpath(entry.path, root)
         
     | 
| 39 | 
         
            +
             
     | 
| 40 | 
         
            +
                            if suffix is None:
         
     | 
| 41 | 
         
            +
                                yield return_path
         
     | 
| 42 | 
         
            +
                            elif return_path.endswith(suffix):
         
     | 
| 43 | 
         
            +
                                yield return_path
         
     | 
| 44 | 
         
            +
                        else:
         
     | 
| 45 | 
         
            +
                            if recursive:
         
     | 
| 46 | 
         
            +
                                yield from _scandir(entry.path, suffix=suffix, recursive=recursive)
         
     | 
| 47 | 
         
            +
                            else:
         
     | 
| 48 | 
         
            +
                                continue
         
     | 
| 49 | 
         
            +
             
     | 
| 50 | 
         
            +
                return _scandir(dir_path, suffix=suffix, recursive=recursive)
         
     | 
| 51 | 
         
            +
             
     | 
| 52 | 
         
            +
             
     | 
| 53 | 
         
            +
            def read_img_seq(path, require_mod_crop=False, scale=1, return_imgname=False):
         
     | 
| 54 | 
         
            +
                """Read a sequence of images from a given folder path.
         
     | 
| 55 | 
         
            +
             
     | 
| 56 | 
         
            +
                Args:
         
     | 
| 57 | 
         
            +
                    path (list[str] | str): List of image paths or image folder path.
         
     | 
| 58 | 
         
            +
                    require_mod_crop (bool): Require mod crop for each image.
         
     | 
| 59 | 
         
            +
                        Default: False.
         
     | 
| 60 | 
         
            +
                    scale (int): Scale factor for mod_crop. Default: 1.
         
     | 
| 61 | 
         
            +
                    return_imgname(bool): Whether return image names. Default False.
         
     | 
| 62 | 
         
            +
             
     | 
| 63 | 
         
            +
                Returns:
         
     | 
| 64 | 
         
            +
                    Tensor: size (t, c, h, w), RGB, [0, 1].
         
     | 
| 65 | 
         
            +
                    list[str]: Returned image name list.
         
     | 
| 66 | 
         
            +
                """
         
     | 
| 67 | 
         
            +
                if isinstance(path, list):
         
     | 
| 68 | 
         
            +
                    img_paths = path
         
     | 
| 69 | 
         
            +
                else:
         
     | 
| 70 | 
         
            +
                    img_paths = sorted(list(scandir(path, full_path=True)))
         
     | 
| 71 | 
         
            +
                imgs = [cv2.imread(v).astype(np.float32) / 255. for v in img_paths]
         
     | 
| 72 | 
         
            +
             
     | 
| 73 | 
         
            +
                if require_mod_crop:
         
     | 
| 74 | 
         
            +
                    imgs = [mod_crop(img, scale) for img in imgs]
         
     | 
| 75 | 
         
            +
                imgs = img2tensor(imgs, bgr2rgb=True, float32=True)
         
     | 
| 76 | 
         
            +
                imgs = torch.stack(imgs, dim=0)
         
     | 
| 77 | 
         
            +
             
     | 
| 78 | 
         
            +
                if return_imgname:
         
     | 
| 79 | 
         
            +
                    imgnames = [osp.splitext(osp.basename(path))[0] for path in img_paths]
         
     | 
| 80 | 
         
            +
                    return imgs, imgnames
         
     | 
| 81 | 
         
            +
                else:
         
     | 
| 82 | 
         
            +
                    return imgs
         
     | 
| 83 | 
         
            +
             
     | 
| 84 | 
         
            +
             
     | 
| 85 | 
         
            +
            def img2tensor(imgs, bgr2rgb=True, float32=True):
         
     | 
| 86 | 
         
            +
                """Numpy array to tensor.
         
     | 
| 87 | 
         
            +
             
     | 
| 88 | 
         
            +
                Args:
         
     | 
| 89 | 
         
            +
                    imgs (list[ndarray] | ndarray): Input images.
         
     | 
| 90 | 
         
            +
                    bgr2rgb (bool): Whether to change bgr to rgb.
         
     | 
| 91 | 
         
            +
                    float32 (bool): Whether to change to float32.
         
     | 
| 92 | 
         
            +
             
     | 
| 93 | 
         
            +
                Returns:
         
     | 
| 94 | 
         
            +
                    list[tensor] | tensor: Tensor images. If returned results only have
         
     | 
| 95 | 
         
            +
                        one element, just return tensor.
         
     | 
| 96 | 
         
            +
                """
         
     | 
| 97 | 
         
            +
             
     | 
| 98 | 
         
            +
                def _totensor(img, bgr2rgb, float32):
         
     | 
| 99 | 
         
            +
                    if img.shape[2] == 3 and bgr2rgb:
         
     | 
| 100 | 
         
            +
                        if img.dtype == 'float64':
         
     | 
| 101 | 
         
            +
                            img = img.astype('float32')
         
     | 
| 102 | 
         
            +
                        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
         
     | 
| 103 | 
         
            +
                    img = torch.from_numpy(img.transpose(2, 0, 1))
         
     | 
| 104 | 
         
            +
                    if float32:
         
     | 
| 105 | 
         
            +
                        img = img.float()
         
     | 
| 106 | 
         
            +
                    return img
         
     | 
| 107 | 
         
            +
             
     | 
| 108 | 
         
            +
                if isinstance(imgs, list):
         
     | 
| 109 | 
         
            +
                    return [_totensor(img, bgr2rgb, float32) for img in imgs]
         
     | 
| 110 | 
         
            +
                else:
         
     | 
| 111 | 
         
            +
                    return _totensor(imgs, bgr2rgb, float32)
         
     | 
| 112 | 
         
            +
             
     | 
| 113 | 
         
            +
             
     | 
| 114 | 
         
            +
            def tensor2img(tensor, rgb2bgr=True, out_type=np.uint8, min_max=(0, 1)):
         
     | 
| 115 | 
         
            +
                """Convert torch Tensors into image numpy arrays.
         
     | 
| 116 | 
         
            +
             
     | 
| 117 | 
         
            +
                After clamping to [min, max], values will be normalized to [0, 1].
         
     | 
| 118 | 
         
            +
             
     | 
| 119 | 
         
            +
                Args:
         
     | 
| 120 | 
         
            +
                    tensor (Tensor or list[Tensor]): Accept shapes:
         
     | 
| 121 | 
         
            +
                        1) 4D mini-batch Tensor of shape (B x 3/1 x H x W);
         
     | 
| 122 | 
         
            +
                        2) 3D Tensor of shape (3/1 x H x W);
         
     | 
| 123 | 
         
            +
                        3) 2D Tensor of shape (H x W).
         
     | 
| 124 | 
         
            +
                        Tensor channel should be in RGB order.
         
     | 
| 125 | 
         
            +
                    rgb2bgr (bool): Whether to change rgb to bgr.
         
     | 
| 126 | 
         
            +
                    out_type (numpy type): output types. If ``np.uint8``, transform outputs
         
     | 
| 127 | 
         
            +
                        to uint8 type with range [0, 255]; otherwise, float type with
         
     | 
| 128 | 
         
            +
                        range [0, 1]. Default: ``np.uint8``.
         
     | 
| 129 | 
         
            +
                    min_max (tuple[int]): min and max values for clamp.
         
     | 
| 130 | 
         
            +
             
     | 
| 131 | 
         
            +
                Returns:
         
     | 
| 132 | 
         
            +
                    (Tensor or list): 3D ndarray of shape (H x W x C) OR 2D ndarray of
         
     | 
| 133 | 
         
            +
                    shape (H x W). The channel order is BGR.
         
     | 
| 134 | 
         
            +
                """
         
     | 
| 135 | 
         
            +
                if not (torch.is_tensor(tensor) or (isinstance(tensor, list) and all(torch.is_tensor(t) for t in tensor))):
         
     | 
| 136 | 
         
            +
                    raise TypeError(f'tensor or list of tensors expected, got {type(tensor)}')
         
     | 
| 137 | 
         
            +
             
     | 
| 138 | 
         
            +
                if torch.is_tensor(tensor):
         
     | 
| 139 | 
         
            +
                    tensor = [tensor]
         
     | 
| 140 | 
         
            +
                result = []
         
     | 
| 141 | 
         
            +
                for _tensor in tensor:
         
     | 
| 142 | 
         
            +
                    _tensor = _tensor.squeeze(0).float().detach().cpu().clamp_(*min_max)
         
     | 
| 143 | 
         
            +
                    _tensor = (_tensor - min_max[0]) / (min_max[1] - min_max[0])
         
     | 
| 144 | 
         
            +
             
     | 
| 145 | 
         
            +
                    n_dim = _tensor.dim()
         
     | 
| 146 | 
         
            +
                    if n_dim == 4:
         
     | 
| 147 | 
         
            +
                        img_np = make_grid(_tensor, nrow=int(math.sqrt(_tensor.size(0))), normalize=False).numpy()
         
     | 
| 148 | 
         
            +
                        img_np = img_np.transpose(1, 2, 0)
         
     | 
| 149 | 
         
            +
                        if rgb2bgr:
         
     | 
| 150 | 
         
            +
                            img_np = cv2.cvtColor(img_np, cv2.COLOR_RGB2BGR)
         
     | 
| 151 | 
         
            +
                    elif n_dim == 3:
         
     | 
| 152 | 
         
            +
                        img_np = _tensor.numpy()
         
     | 
| 153 | 
         
            +
                        img_np = img_np.transpose(1, 2, 0)
         
     | 
| 154 | 
         
            +
                        if img_np.shape[2] == 1:  # gray image
         
     | 
| 155 | 
         
            +
                            img_np = np.squeeze(img_np, axis=2)
         
     | 
| 156 | 
         
            +
                        else:
         
     | 
| 157 | 
         
            +
                            if rgb2bgr:
         
     | 
| 158 | 
         
            +
                                img_np = cv2.cvtColor(img_np, cv2.COLOR_RGB2BGR)
         
     | 
| 159 | 
         
            +
                    elif n_dim == 2:
         
     | 
| 160 | 
         
            +
                        img_np = _tensor.numpy()
         
     | 
| 161 | 
         
            +
                    else:
         
     | 
| 162 | 
         
            +
                        raise TypeError(f'Only support 4D, 3D or 2D tensor. But received with dimension: {n_dim}')
         
     | 
| 163 | 
         
            +
                    if out_type == np.uint8:
         
     | 
| 164 | 
         
            +
                        # Unlike MATLAB, numpy.unit8() WILL NOT round by default.
         
     | 
| 165 | 
         
            +
                        img_np = (img_np * 255.0).round()
         
     | 
| 166 | 
         
            +
                    img_np = img_np.astype(out_type)
         
     | 
| 167 | 
         
            +
                    result.append(img_np)
         
     | 
| 168 | 
         
            +
                if len(result) == 1:
         
     | 
| 169 | 
         
            +
                    result = result[0]
         
     | 
| 170 | 
         
            +
                return result
         
     | 
| 171 | 
         
            +
             
     | 
| 172 | 
         
            +
             
     | 
| 173 | 
         
            +
            def augment(imgs, hflip=True, rotation=True, flows=None, return_status=False):
         
     | 
| 174 | 
         
            +
                """Augment: horizontal flips OR rotate (0, 90, 180, 270 degrees).
         
     | 
| 175 | 
         
            +
             
     | 
| 176 | 
         
            +
                We use vertical flip and transpose for rotation implementation.
         
     | 
| 177 | 
         
            +
                All the images in the list use the same augmentation.
         
     | 
| 178 | 
         
            +
             
     | 
| 179 | 
         
            +
                Args:
         
     | 
| 180 | 
         
            +
                    imgs (list[ndarray] | ndarray): Images to be augmented. If the input
         
     | 
| 181 | 
         
            +
                        is an ndarray, it will be transformed to a list.
         
     | 
| 182 | 
         
            +
                    hflip (bool): Horizontal flip. Default: True.
         
     | 
| 183 | 
         
            +
                    rotation (bool): Ratotation. Default: True.
         
     | 
| 184 | 
         
            +
                    flows (list[ndarray]: Flows to be augmented. If the input is an
         
     | 
| 185 | 
         
            +
                        ndarray, it will be transformed to a list.
         
     | 
| 186 | 
         
            +
                        Dimension is (h, w, 2). Default: None.
         
     | 
| 187 | 
         
            +
                    return_status (bool): Return the status of flip and rotation.
         
     | 
| 188 | 
         
            +
                        Default: False.
         
     | 
| 189 | 
         
            +
             
     | 
| 190 | 
         
            +
                Returns:
         
     | 
| 191 | 
         
            +
                    list[ndarray] | ndarray: Augmented images and flows. If returned
         
     | 
| 192 | 
         
            +
                        results only have one element, just return ndarray.
         
     | 
| 193 | 
         
            +
             
     | 
| 194 | 
         
            +
                """
         
     | 
| 195 | 
         
            +
                hflip = hflip and random.random() < 0.5
         
     | 
| 196 | 
         
            +
                vflip = rotation and random.random() < 0.5
         
     | 
| 197 | 
         
            +
                rot90 = rotation and random.random() < 0.5
         
     | 
| 198 | 
         
            +
             
     | 
| 199 | 
         
            +
                def _augment(img):
         
     | 
| 200 | 
         
            +
                    if hflip:  # horizontal
         
     | 
| 201 | 
         
            +
                        cv2.flip(img, 1, img)
         
     | 
| 202 | 
         
            +
                    if vflip:  # vertical
         
     | 
| 203 | 
         
            +
                        cv2.flip(img, 0, img)
         
     | 
| 204 | 
         
            +
                    if rot90:
         
     | 
| 205 | 
         
            +
                        img = img.transpose(1, 0, 2)
         
     | 
| 206 | 
         
            +
                    return img
         
     | 
| 207 | 
         
            +
             
     | 
| 208 | 
         
            +
                def _augment_flow(flow):
         
     | 
| 209 | 
         
            +
                    if hflip:  # horizontal
         
     | 
| 210 | 
         
            +
                        cv2.flip(flow, 1, flow)
         
     | 
| 211 | 
         
            +
                        flow[:, :, 0] *= -1
         
     | 
| 212 | 
         
            +
                    if vflip:  # vertical
         
     | 
| 213 | 
         
            +
                        cv2.flip(flow, 0, flow)
         
     | 
| 214 | 
         
            +
                        flow[:, :, 1] *= -1
         
     | 
| 215 | 
         
            +
                    if rot90:
         
     | 
| 216 | 
         
            +
                        flow = flow.transpose(1, 0, 2)
         
     | 
| 217 | 
         
            +
                        flow = flow[:, :, [1, 0]]
         
     | 
| 218 | 
         
            +
                    return flow
         
     | 
| 219 | 
         
            +
             
     | 
| 220 | 
         
            +
                if not isinstance(imgs, list):
         
     | 
| 221 | 
         
            +
                    imgs = [imgs]
         
     | 
| 222 | 
         
            +
                imgs = [_augment(img) for img in imgs]
         
     | 
| 223 | 
         
            +
                if len(imgs) == 1:
         
     | 
| 224 | 
         
            +
                    imgs = imgs[0]
         
     | 
| 225 | 
         
            +
             
     | 
| 226 | 
         
            +
                if flows is not None:
         
     | 
| 227 | 
         
            +
                    if not isinstance(flows, list):
         
     | 
| 228 | 
         
            +
                        flows = [flows]
         
     | 
| 229 | 
         
            +
                    flows = [_augment_flow(flow) for flow in flows]
         
     | 
| 230 | 
         
            +
                    if len(flows) == 1:
         
     | 
| 231 | 
         
            +
                        flows = flows[0]
         
     | 
| 232 | 
         
            +
                    return imgs, flows
         
     | 
| 233 | 
         
            +
                else:
         
     | 
| 234 | 
         
            +
                    if return_status:
         
     | 
| 235 | 
         
            +
                        return imgs, (hflip, vflip, rot90)
         
     | 
| 236 | 
         
            +
                    else:
         
     | 
| 237 | 
         
            +
                        return imgs
         
     | 
| 238 | 
         
            +
             
     | 
| 239 | 
         
            +
             
     | 
| 240 | 
         
            +
            def paired_random_crop(img_gts, img_lqs, gt_patch_size, scale, gt_path=None):
         
     | 
| 241 | 
         
            +
                """Paired random crop. Support Numpy array and Tensor inputs.
         
     | 
| 242 | 
         
            +
             
     | 
| 243 | 
         
            +
                It crops lists of lq and gt images with corresponding locations.
         
     | 
| 244 | 
         
            +
             
     | 
| 245 | 
         
            +
                Args:
         
     | 
| 246 | 
         
            +
                    img_gts (list[ndarray] | ndarray | list[Tensor] | Tensor): GT images. Note that all images
         
     | 
| 247 | 
         
            +
                        should have the same shape. If the input is an ndarray, it will
         
     | 
| 248 | 
         
            +
                        be transformed to a list containing itself.
         
     | 
| 249 | 
         
            +
                    img_lqs (list[ndarray] | ndarray): LQ images. Note that all images
         
     | 
| 250 | 
         
            +
                        should have the same shape. If the input is an ndarray, it will
         
     | 
| 251 | 
         
            +
                        be transformed to a list containing itself.
         
     | 
| 252 | 
         
            +
                    gt_patch_size (int): GT patch size.
         
     | 
| 253 | 
         
            +
                    scale (int): Scale factor.
         
     | 
| 254 | 
         
            +
                    gt_path (str): Path to ground-truth. Default: None.
         
     | 
| 255 | 
         
            +
             
     | 
| 256 | 
         
            +
                Returns:
         
     | 
| 257 | 
         
            +
                    list[ndarray] | ndarray: GT images and LQ images. If returned results
         
     | 
| 258 | 
         
            +
                        only have one element, just return ndarray.
         
     | 
| 259 | 
         
            +
                """
         
     | 
| 260 | 
         
            +
             
     | 
| 261 | 
         
            +
                if not isinstance(img_gts, list):
         
     | 
| 262 | 
         
            +
                    img_gts = [img_gts]
         
     | 
| 263 | 
         
            +
                if not isinstance(img_lqs, list):
         
     | 
| 264 | 
         
            +
                    img_lqs = [img_lqs]
         
     | 
| 265 | 
         
            +
             
     | 
| 266 | 
         
            +
                # determine input type: Numpy array or Tensor
         
     | 
| 267 | 
         
            +
                input_type = 'Tensor' if torch.is_tensor(img_gts[0]) else 'Numpy'
         
     | 
| 268 | 
         
            +
             
     | 
| 269 | 
         
            +
                if input_type == 'Tensor':
         
     | 
| 270 | 
         
            +
                    h_lq, w_lq = img_lqs[0].size()[-2:]
         
     | 
| 271 | 
         
            +
                    h_gt, w_gt = img_gts[0].size()[-2:]
         
     | 
| 272 | 
         
            +
                else:
         
     | 
| 273 | 
         
            +
                    h_lq, w_lq = img_lqs[0].shape[0:2]
         
     | 
| 274 | 
         
            +
                    h_gt, w_gt = img_gts[0].shape[0:2]
         
     | 
| 275 | 
         
            +
                lq_patch_size = gt_patch_size // scale
         
     | 
| 276 | 
         
            +
             
     | 
| 277 | 
         
            +
                if h_gt != h_lq * scale or w_gt != w_lq * scale:
         
     | 
| 278 | 
         
            +
                    raise ValueError(f'Scale mismatches. GT ({h_gt}, {w_gt}) is not {scale}x ',
         
     | 
| 279 | 
         
            +
                                     f'multiplication of LQ ({h_lq}, {w_lq}).')
         
     | 
| 280 | 
         
            +
                if h_lq < lq_patch_size or w_lq < lq_patch_size:
         
     | 
| 281 | 
         
            +
                    raise ValueError(f'LQ ({h_lq}, {w_lq}) is smaller than patch size '
         
     | 
| 282 | 
         
            +
                                     f'({lq_patch_size}, {lq_patch_size}). '
         
     | 
| 283 | 
         
            +
                                     f'Please remove {gt_path}.')
         
     | 
| 284 | 
         
            +
             
     | 
| 285 | 
         
            +
                # randomly choose top and left coordinates for lq patch
         
     | 
| 286 | 
         
            +
                top = random.randint(0, h_lq - lq_patch_size)
         
     | 
| 287 | 
         
            +
                left = random.randint(0, w_lq - lq_patch_size)
         
     | 
| 288 | 
         
            +
             
     | 
| 289 | 
         
            +
                # crop lq patch
         
     | 
| 290 | 
         
            +
                if input_type == 'Tensor':
         
     | 
| 291 | 
         
            +
                    img_lqs = [v[:, :, top:top + lq_patch_size, left:left + lq_patch_size] for v in img_lqs]
         
     | 
| 292 | 
         
            +
                else:
         
     | 
| 293 | 
         
            +
                    img_lqs = [v[top:top + lq_patch_size, left:left + lq_patch_size, ...] for v in img_lqs]
         
     | 
| 294 | 
         
            +
             
     | 
| 295 | 
         
            +
                # crop corresponding gt patch
         
     | 
| 296 | 
         
            +
                top_gt, left_gt = int(top * scale), int(left * scale)
         
     | 
| 297 | 
         
            +
                if input_type == 'Tensor':
         
     | 
| 298 | 
         
            +
                    img_gts = [v[:, :, top_gt:top_gt + gt_patch_size, left_gt:left_gt + gt_patch_size] for v in img_gts]
         
     | 
| 299 | 
         
            +
                else:
         
     | 
| 300 | 
         
            +
                    img_gts = [v[top_gt:top_gt + gt_patch_size, left_gt:left_gt + gt_patch_size, ...] for v in img_gts]
         
     | 
| 301 | 
         
            +
                if len(img_gts) == 1:
         
     | 
| 302 | 
         
            +
                    img_gts = img_gts[0]
         
     | 
| 303 | 
         
            +
                if len(img_lqs) == 1:
         
     | 
| 304 | 
         
            +
                    img_lqs = img_lqs[0]
         
     | 
| 305 | 
         
            +
                return img_gts, img_lqs
         
     | 
| 306 | 
         
            +
             
     | 
| 307 | 
         
            +
             
     | 
| 308 | 
         
            +
            # Modified from https://github.com/open-mmlab/mmcv/blob/master/mmcv/fileio/file_client.py  # noqa: E501
         
     | 
| 309 | 
         
            +
            class BaseStorageBackend(metaclass=ABCMeta):
         
     | 
| 310 | 
         
            +
                """Abstract class of storage backends.
         
     | 
| 311 | 
         
            +
             
     | 
| 312 | 
         
            +
                All backends need to implement two apis: ``get()`` and ``get_text()``.
         
     | 
| 313 | 
         
            +
                ``get()`` reads the file as a byte stream and ``get_text()`` reads the file
         
     | 
| 314 | 
         
            +
                as texts.
         
     | 
| 315 | 
         
            +
                """
         
     | 
| 316 | 
         
            +
             
     | 
| 317 | 
         
            +
                @abstractmethod
         
     | 
| 318 | 
         
            +
                def get(self, filepath):
         
     | 
| 319 | 
         
            +
                    pass
         
     | 
| 320 | 
         
            +
             
     | 
| 321 | 
         
            +
                @abstractmethod
         
     | 
| 322 | 
         
            +
                def get_text(self, filepath):
         
     | 
| 323 | 
         
            +
                    pass
         
     | 
| 324 | 
         
            +
             
     | 
| 325 | 
         
            +
             
     | 
| 326 | 
         
            +
            class MemcachedBackend(BaseStorageBackend):
         
     | 
| 327 | 
         
            +
                """Memcached storage backend.
         
     | 
| 328 | 
         
            +
             
     | 
| 329 | 
         
            +
                Attributes:
         
     | 
| 330 | 
         
            +
                    server_list_cfg (str): Config file for memcached server list.
         
     | 
| 331 | 
         
            +
                    client_cfg (str): Config file for memcached client.
         
     | 
| 332 | 
         
            +
                    sys_path (str | None): Additional path to be appended to `sys.path`.
         
     | 
| 333 | 
         
            +
                        Default: None.
         
     | 
| 334 | 
         
            +
                """
         
     | 
| 335 | 
         
            +
             
     | 
| 336 | 
         
            +
                def __init__(self, server_list_cfg, client_cfg, sys_path=None):
         
     | 
| 337 | 
         
            +
                    if sys_path is not None:
         
     | 
| 338 | 
         
            +
                        import sys
         
     | 
| 339 | 
         
            +
                        sys.path.append(sys_path)
         
     | 
| 340 | 
         
            +
                    try:
         
     | 
| 341 | 
         
            +
                        import mc
         
     | 
| 342 | 
         
            +
                    except ImportError:
         
     | 
| 343 | 
         
            +
                        raise ImportError('Please install memcached to enable MemcachedBackend.')
         
     | 
| 344 | 
         
            +
             
     | 
| 345 | 
         
            +
                    self.server_list_cfg = server_list_cfg
         
     | 
| 346 | 
         
            +
                    self.client_cfg = client_cfg
         
     | 
| 347 | 
         
            +
                    self._client = mc.MemcachedClient.GetInstance(self.server_list_cfg, self.client_cfg)
         
     | 
| 348 | 
         
            +
                    # mc.pyvector servers as a point which points to a memory cache
         
     | 
| 349 | 
         
            +
                    self._mc_buffer = mc.pyvector()
         
     | 
| 350 | 
         
            +
             
     | 
| 351 | 
         
            +
                def get(self, filepath):
         
     | 
| 352 | 
         
            +
                    filepath = str(filepath)
         
     | 
| 353 | 
         
            +
                    import mc
         
     | 
| 354 | 
         
            +
                    self._client.Get(filepath, self._mc_buffer)
         
     | 
| 355 | 
         
            +
                    value_buf = mc.ConvertBuffer(self._mc_buffer)
         
     | 
| 356 | 
         
            +
                    return value_buf
         
     | 
| 357 | 
         
            +
             
     | 
| 358 | 
         
            +
                def get_text(self, filepath):
         
     | 
| 359 | 
         
            +
                    raise NotImplementedError
         
     | 
| 360 | 
         
            +
             
     | 
| 361 | 
         
            +
             
     | 
| 362 | 
         
            +
            class HardDiskBackend(BaseStorageBackend):
         
     | 
| 363 | 
         
            +
                """Raw hard disks storage backend."""
         
     | 
| 364 | 
         
            +
             
     | 
| 365 | 
         
            +
                def get(self, filepath):
         
     | 
| 366 | 
         
            +
                    filepath = str(filepath)
         
     | 
| 367 | 
         
            +
                    with open(filepath, 'rb') as f:
         
     | 
| 368 | 
         
            +
                        value_buf = f.read()
         
     | 
| 369 | 
         
            +
                    return value_buf
         
     | 
| 370 | 
         
            +
             
     | 
| 371 | 
         
            +
                def get_text(self, filepath):
         
     | 
| 372 | 
         
            +
                    filepath = str(filepath)
         
     | 
| 373 | 
         
            +
                    with open(filepath, 'r') as f:
         
     | 
| 374 | 
         
            +
                        value_buf = f.read()
         
     | 
| 375 | 
         
            +
                    return value_buf
         
     | 
| 376 | 
         
            +
             
     | 
| 377 | 
         
            +
             
     | 
| 378 | 
         
            +
            class LmdbBackend(BaseStorageBackend):
         
     | 
| 379 | 
         
            +
                """Lmdb storage backend.
         
     | 
| 380 | 
         
            +
             
     | 
| 381 | 
         
            +
                Args:
         
     | 
| 382 | 
         
            +
                    db_paths (str | list[str]): Lmdb database paths.
         
     | 
| 383 | 
         
            +
                    client_keys (str | list[str]): Lmdb client keys. Default: 'default'.
         
     | 
| 384 | 
         
            +
                    readonly (bool, optional): Lmdb environment parameter. If True,
         
     | 
| 385 | 
         
            +
                        disallow any write operations. Default: True.
         
     | 
| 386 | 
         
            +
                    lock (bool, optional): Lmdb environment parameter. If False, when
         
     | 
| 387 | 
         
            +
                        concurrent access occurs, do not lock the database. Default: False.
         
     | 
| 388 | 
         
            +
                    readahead (bool, optional): Lmdb environment parameter. If False,
         
     | 
| 389 | 
         
            +
                        disable the OS filesystem readahead mechanism, which may improve
         
     | 
| 390 | 
         
            +
                        random read performance when a database is larger than RAM.
         
     | 
| 391 | 
         
            +
                        Default: False.
         
     | 
| 392 | 
         
            +
             
     | 
| 393 | 
         
            +
                Attributes:
         
     | 
| 394 | 
         
            +
                    db_paths (list): Lmdb database path.
         
     | 
| 395 | 
         
            +
                    _client (list): A list of several lmdb envs.
         
     | 
| 396 | 
         
            +
                """
         
     | 
| 397 | 
         
            +
             
     | 
| 398 | 
         
            +
                def __init__(self, db_paths, client_keys='default', readonly=True, lock=False, readahead=False, **kwargs):
         
     | 
| 399 | 
         
            +
                    try:
         
     | 
| 400 | 
         
            +
                        import lmdb
         
     | 
| 401 | 
         
            +
                    except ImportError:
         
     | 
| 402 | 
         
            +
                        raise ImportError('Please install lmdb to enable LmdbBackend.')
         
     | 
| 403 | 
         
            +
             
     | 
| 404 | 
         
            +
                    if isinstance(client_keys, str):
         
     | 
| 405 | 
         
            +
                        client_keys = [client_keys]
         
     | 
| 406 | 
         
            +
             
     | 
| 407 | 
         
            +
                    if isinstance(db_paths, list):
         
     | 
| 408 | 
         
            +
                        self.db_paths = [str(v) for v in db_paths]
         
     | 
| 409 | 
         
            +
                    elif isinstance(db_paths, str):
         
     | 
| 410 | 
         
            +
                        self.db_paths = [str(db_paths)]
         
     | 
| 411 | 
         
            +
                    assert len(client_keys) == len(self.db_paths), ('client_keys and db_paths should have the same length, '
         
     | 
| 412 | 
         
            +
                                                                    f'but received {len(client_keys)} and {len(self.db_paths)}.')
         
     | 
| 413 | 
         
            +
             
     | 
| 414 | 
         
            +
                    self._client = {}
         
     | 
| 415 | 
         
            +
                    for client, path in zip(client_keys, self.db_paths):
         
     | 
| 416 | 
         
            +
                        self._client[client] = lmdb.open(path, readonly=readonly, lock=lock, readahead=readahead, **kwargs)
         
     | 
| 417 | 
         
            +
             
     | 
| 418 | 
         
            +
                def get(self, filepath, client_key):
         
     | 
| 419 | 
         
            +
                    """Get values according to the filepath from one lmdb named client_key.
         
     | 
| 420 | 
         
            +
             
     | 
| 421 | 
         
            +
                    Args:
         
     | 
| 422 | 
         
            +
                        filepath (str | obj:`Path`): Here, filepath is the lmdb key.
         
     | 
| 423 | 
         
            +
                        client_key (str): Used for distinguishing different lmdb envs.
         
     | 
| 424 | 
         
            +
                    """
         
     | 
| 425 | 
         
            +
                    filepath = str(filepath)
         
     | 
| 426 | 
         
            +
                    assert client_key in self._client, (f'client_key {client_key} is not ' 'in lmdb clients.')
         
     | 
| 427 | 
         
            +
                    client = self._client[client_key]
         
     | 
| 428 | 
         
            +
                    with client.begin(write=False) as txn:
         
     | 
| 429 | 
         
            +
                        value_buf = txn.get(filepath.encode('ascii'))
         
     | 
| 430 | 
         
            +
                    return value_buf
         
     | 
| 431 | 
         
            +
             
     | 
| 432 | 
         
            +
                def get_text(self, filepath):
         
     | 
| 433 | 
         
            +
                    raise NotImplementedError
         
     | 
| 434 | 
         
            +
             
     | 
| 435 | 
         
            +
             
     | 
| 436 | 
         
            +
            class FileClient(object):
         
     | 
| 437 | 
         
            +
                """A general file client to access files in different backend.
         
     | 
| 438 | 
         
            +
             
     | 
| 439 | 
         
            +
                The client loads a file or text in a specified backend from its path
         
     | 
| 440 | 
         
            +
                and return it as a binary file. it can also register other backend
         
     | 
| 441 | 
         
            +
                accessor with a given name and backend class.
         
     | 
| 442 | 
         
            +
             
     | 
| 443 | 
         
            +
                Attributes:
         
     | 
| 444 | 
         
            +
                    backend (str): The storage backend type. Options are "disk",
         
     | 
| 445 | 
         
            +
                        "memcached" and "lmdb".
         
     | 
| 446 | 
         
            +
                    client (:obj:`BaseStorageBackend`): The backend object.
         
     | 
| 447 | 
         
            +
                """
         
     | 
| 448 | 
         
            +
             
     | 
| 449 | 
         
            +
                _backends = {
         
     | 
| 450 | 
         
            +
                    'disk': HardDiskBackend,
         
     | 
| 451 | 
         
            +
                    'memcached': MemcachedBackend,
         
     | 
| 452 | 
         
            +
                    'lmdb': LmdbBackend,
         
     | 
| 453 | 
         
            +
                }
         
     | 
| 454 | 
         
            +
             
     | 
| 455 | 
         
            +
                def __init__(self, backend='disk', **kwargs):
         
     | 
| 456 | 
         
            +
                    if backend not in self._backends:
         
     | 
| 457 | 
         
            +
                        raise ValueError(f'Backend {backend} is not supported. Currently supported ones'
         
     | 
| 458 | 
         
            +
                                         f' are {list(self._backends.keys())}')
         
     | 
| 459 | 
         
            +
                    self.backend = backend
         
     | 
| 460 | 
         
            +
                    self.client = self._backends[backend](**kwargs)
         
     | 
| 461 | 
         
            +
             
     | 
| 462 | 
         
            +
                def get(self, filepath, client_key='default'):
         
     | 
| 463 | 
         
            +
                    # client_key is used only for lmdb, where different fileclients have
         
     | 
| 464 | 
         
            +
                    # different lmdb environments.
         
     | 
| 465 | 
         
            +
                    if self.backend == 'lmdb':
         
     | 
| 466 | 
         
            +
                        return self.client.get(filepath, client_key)
         
     | 
| 467 | 
         
            +
                    else:
         
     | 
| 468 | 
         
            +
                        return self.client.get(filepath)
         
     | 
| 469 | 
         
            +
             
     | 
| 470 | 
         
            +
                def get_text(self, filepath):
         
     | 
| 471 | 
         
            +
                    return self.client.get_text(filepath)
         
     | 
| 472 | 
         
            +
             
     | 
| 473 | 
         
            +
             
     | 
| 474 | 
         
            +
            def imfrombytes(content, flag='color', float32=False):
         
     | 
| 475 | 
         
            +
                """Read an image from bytes.
         
     | 
| 476 | 
         
            +
             
     | 
| 477 | 
         
            +
                Args:
         
     | 
| 478 | 
         
            +
                    content (bytes): Image bytes got from files or other streams.
         
     | 
| 479 | 
         
            +
                    flag (str): Flags specifying the color type of a loaded image,
         
     | 
| 480 | 
         
            +
                        candidates are `color`, `grayscale` and `unchanged`.
         
     | 
| 481 | 
         
            +
                    float32 (bool): Whether to change to float32., If True, will also norm
         
     | 
| 482 | 
         
            +
                        to [0, 1]. Default: False.
         
     | 
| 483 | 
         
            +
             
     | 
| 484 | 
         
            +
                Returns:
         
     | 
| 485 | 
         
            +
                    ndarray: Loaded image array.
         
     | 
| 486 | 
         
            +
                """
         
     | 
| 487 | 
         
            +
                img_np = np.frombuffer(content, np.uint8)
         
     | 
| 488 | 
         
            +
                imread_flags = {'color': cv2.IMREAD_COLOR, 'grayscale': cv2.IMREAD_GRAYSCALE, 'unchanged': cv2.IMREAD_UNCHANGED}
         
     | 
| 489 | 
         
            +
                img = cv2.imdecode(img_np, imread_flags[flag])
         
     | 
| 490 | 
         
            +
                if float32:
         
     | 
| 491 | 
         
            +
                    img = img.astype(np.float32) / 255.
         
     | 
| 492 | 
         
            +
                return img
         
     | 
| 493 | 
         
            +
             
     | 
    	
        core/data/deg_kair_utils/utils_videoio.py
    ADDED
    
    | 
         @@ -0,0 +1,555 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            import os
         
     | 
| 2 | 
         
            +
            import cv2
         
     | 
| 3 | 
         
            +
            import numpy as np
         
     | 
| 4 | 
         
            +
            import torch
         
     | 
| 5 | 
         
            +
            import random
         
     | 
| 6 | 
         
            +
            from os import path as osp
         
     | 
| 7 | 
         
            +
            from torchvision.utils import make_grid
         
     | 
| 8 | 
         
            +
            import sys
         
     | 
| 9 | 
         
            +
            from pathlib import Path
         
     | 
| 10 | 
         
            +
            import six
         
     | 
| 11 | 
         
            +
            from collections import OrderedDict
         
     | 
| 12 | 
         
            +
            import math
         
     | 
| 13 | 
         
            +
            import glob
         
     | 
| 14 | 
         
            +
            import av
         
     | 
| 15 | 
         
            +
            import io
         
     | 
| 16 | 
         
            +
            from cv2 import (CAP_PROP_FOURCC, CAP_PROP_FPS, CAP_PROP_FRAME_COUNT,
         
     | 
| 17 | 
         
            +
                             CAP_PROP_FRAME_HEIGHT, CAP_PROP_FRAME_WIDTH,
         
     | 
| 18 | 
         
            +
                             CAP_PROP_POS_FRAMES, VideoWriter_fourcc)
         
     | 
| 19 | 
         
            +
             
     | 
| 20 | 
         
            +
            if sys.version_info <= (3, 3):
         
     | 
| 21 | 
         
            +
                FileNotFoundError = IOError
         
     | 
| 22 | 
         
            +
            else:
         
     | 
| 23 | 
         
            +
                FileNotFoundError = FileNotFoundError
         
     | 
| 24 | 
         
            +
             
     | 
| 25 | 
         
            +
             
     | 
| 26 | 
         
            +
            def is_str(x):
         
     | 
| 27 | 
         
            +
                """Whether the input is an string instance."""
         
     | 
| 28 | 
         
            +
                return isinstance(x, six.string_types)
         
     | 
| 29 | 
         
            +
             
     | 
| 30 | 
         
            +
             
     | 
| 31 | 
         
            +
            def is_filepath(x):
         
     | 
| 32 | 
         
            +
                return is_str(x) or isinstance(x, Path)
         
     | 
| 33 | 
         
            +
             
     | 
| 34 | 
         
            +
             
     | 
| 35 | 
         
            +
            def fopen(filepath, *args, **kwargs):
         
     | 
| 36 | 
         
            +
                if is_str(filepath):
         
     | 
| 37 | 
         
            +
                    return open(filepath, *args, **kwargs)
         
     | 
| 38 | 
         
            +
                elif isinstance(filepath, Path):
         
     | 
| 39 | 
         
            +
                    return filepath.open(*args, **kwargs)
         
     | 
| 40 | 
         
            +
                raise ValueError('`filepath` should be a string or a Path')
         
     | 
| 41 | 
         
            +
             
     | 
| 42 | 
         
            +
             
     | 
| 43 | 
         
            +
            def check_file_exist(filename, msg_tmpl='file "{}" does not exist'):
         
     | 
| 44 | 
         
            +
                if not osp.isfile(filename):
         
     | 
| 45 | 
         
            +
                    raise FileNotFoundError(msg_tmpl.format(filename))
         
     | 
| 46 | 
         
            +
             
     | 
| 47 | 
         
            +
             
     | 
| 48 | 
         
            +
            def mkdir_or_exist(dir_name, mode=0o777):
         
     | 
| 49 | 
         
            +
                if dir_name == '':
         
     | 
| 50 | 
         
            +
                    return
         
     | 
| 51 | 
         
            +
                dir_name = osp.expanduser(dir_name)
         
     | 
| 52 | 
         
            +
                os.makedirs(dir_name, mode=mode, exist_ok=True)
         
     | 
| 53 | 
         
            +
             
     | 
| 54 | 
         
            +
             
     | 
| 55 | 
         
            +
            def symlink(src, dst, overwrite=True, **kwargs):
         
     | 
| 56 | 
         
            +
                if os.path.lexists(dst) and overwrite:
         
     | 
| 57 | 
         
            +
                    os.remove(dst)
         
     | 
| 58 | 
         
            +
                os.symlink(src, dst, **kwargs)
         
     | 
| 59 | 
         
            +
             
     | 
| 60 | 
         
            +
             
     | 
| 61 | 
         
            +
            def scandir(dir_path, suffix=None, recursive=False, case_sensitive=True):
         
     | 
| 62 | 
         
            +
                """Scan a directory to find the interested files.
         
     | 
| 63 | 
         
            +
                Args:
         
     | 
| 64 | 
         
            +
                    dir_path (str | :obj:`Path`): Path of the directory.
         
     | 
| 65 | 
         
            +
                    suffix (str | tuple(str), optional): File suffix that we are
         
     | 
| 66 | 
         
            +
                        interested in. Default: None.
         
     | 
| 67 | 
         
            +
                    recursive (bool, optional): If set to True, recursively scan the
         
     | 
| 68 | 
         
            +
                        directory. Default: False.
         
     | 
| 69 | 
         
            +
                    case_sensitive (bool, optional) : If set to False, ignore the case of
         
     | 
| 70 | 
         
            +
                        suffix. Default: True.
         
     | 
| 71 | 
         
            +
                Returns:
         
     | 
| 72 | 
         
            +
                    A generator for all the interested files with relative paths.
         
     | 
| 73 | 
         
            +
                """
         
     | 
| 74 | 
         
            +
                if isinstance(dir_path, (str, Path)):
         
     | 
| 75 | 
         
            +
                    dir_path = str(dir_path)
         
     | 
| 76 | 
         
            +
                else:
         
     | 
| 77 | 
         
            +
                    raise TypeError('"dir_path" must be a string or Path object')
         
     | 
| 78 | 
         
            +
             
     | 
| 79 | 
         
            +
                if (suffix is not None) and not isinstance(suffix, (str, tuple)):
         
     | 
| 80 | 
         
            +
                    raise TypeError('"suffix" must be a string or tuple of strings')
         
     | 
| 81 | 
         
            +
             
     | 
| 82 | 
         
            +
                if suffix is not None and not case_sensitive:
         
     | 
| 83 | 
         
            +
                    suffix = suffix.lower() if isinstance(suffix, str) else tuple(
         
     | 
| 84 | 
         
            +
                        item.lower() for item in suffix)
         
     | 
| 85 | 
         
            +
             
     | 
| 86 | 
         
            +
                root = dir_path
         
     | 
| 87 | 
         
            +
             
     | 
| 88 | 
         
            +
                def _scandir(dir_path, suffix, recursive, case_sensitive):
         
     | 
| 89 | 
         
            +
                    for entry in os.scandir(dir_path):
         
     | 
| 90 | 
         
            +
                        if not entry.name.startswith('.') and entry.is_file():
         
     | 
| 91 | 
         
            +
                            rel_path = osp.relpath(entry.path, root)
         
     | 
| 92 | 
         
            +
                            _rel_path = rel_path if case_sensitive else rel_path.lower()
         
     | 
| 93 | 
         
            +
                            if suffix is None or _rel_path.endswith(suffix):
         
     | 
| 94 | 
         
            +
                                yield rel_path
         
     | 
| 95 | 
         
            +
                        elif recursive and os.path.isdir(entry.path):
         
     | 
| 96 | 
         
            +
                            # scan recursively if entry.path is a directory
         
     | 
| 97 | 
         
            +
                            yield from _scandir(entry.path, suffix, recursive,
         
     | 
| 98 | 
         
            +
                                                case_sensitive)
         
     | 
| 99 | 
         
            +
             
     | 
| 100 | 
         
            +
                return _scandir(dir_path, suffix, recursive, case_sensitive)
         
     | 
| 101 | 
         
            +
             
     | 
| 102 | 
         
            +
             
     | 
| 103 | 
         
            +
            class Cache:
         
     | 
| 104 | 
         
            +
             
     | 
| 105 | 
         
            +
                def __init__(self, capacity):
         
     | 
| 106 | 
         
            +
                    self._cache = OrderedDict()
         
     | 
| 107 | 
         
            +
                    self._capacity = int(capacity)
         
     | 
| 108 | 
         
            +
                    if capacity <= 0:
         
     | 
| 109 | 
         
            +
                        raise ValueError('capacity must be a positive integer')
         
     | 
| 110 | 
         
            +
             
     | 
| 111 | 
         
            +
                @property
         
     | 
| 112 | 
         
            +
                def capacity(self):
         
     | 
| 113 | 
         
            +
                    return self._capacity
         
     | 
| 114 | 
         
            +
             
     | 
| 115 | 
         
            +
                @property
         
     | 
| 116 | 
         
            +
                def size(self):
         
     | 
| 117 | 
         
            +
                    return len(self._cache)
         
     | 
| 118 | 
         
            +
             
     | 
| 119 | 
         
            +
                def put(self, key, val):
         
     | 
| 120 | 
         
            +
                    if key in self._cache:
         
     | 
| 121 | 
         
            +
                        return
         
     | 
| 122 | 
         
            +
                    if len(self._cache) >= self.capacity:
         
     | 
| 123 | 
         
            +
                        self._cache.popitem(last=False)
         
     | 
| 124 | 
         
            +
                    self._cache[key] = val
         
     | 
| 125 | 
         
            +
             
     | 
| 126 | 
         
            +
                def get(self, key, default=None):
         
     | 
| 127 | 
         
            +
                    val = self._cache[key] if key in self._cache else default
         
     | 
| 128 | 
         
            +
                    return val
         
     | 
| 129 | 
         
            +
             
     | 
| 130 | 
         
            +
             
     | 
| 131 | 
         
            +
            class VideoReader:
         
     | 
| 132 | 
         
            +
                """Video class with similar usage to a list object.
         
     | 
| 133 | 
         
            +
             
     | 
| 134 | 
         
            +
                This video warpper class provides convenient apis to access frames.
         
     | 
| 135 | 
         
            +
                There exists an issue of OpenCV's VideoCapture class that jumping to a
         
     | 
| 136 | 
         
            +
                certain frame may be inaccurate. It is fixed in this class by checking
         
     | 
| 137 | 
         
            +
                the position after jumping each time.
         
     | 
| 138 | 
         
            +
                Cache is used when decoding videos. So if the same frame is visited for
         
     | 
| 139 | 
         
            +
                the second time, there is no need to decode again if it is stored in the
         
     | 
| 140 | 
         
            +
                cache.
         
     | 
| 141 | 
         
            +
             
     | 
| 142 | 
         
            +
                """
         
     | 
| 143 | 
         
            +
             
     | 
| 144 | 
         
            +
                def __init__(self, filename, cache_capacity=10):
         
     | 
| 145 | 
         
            +
                    # Check whether the video path is a url
         
     | 
| 146 | 
         
            +
                    if not filename.startswith(('https://', 'http://')):
         
     | 
| 147 | 
         
            +
                        check_file_exist(filename, 'Video file not found: ' + filename)
         
     | 
| 148 | 
         
            +
                    self._vcap = cv2.VideoCapture(filename)
         
     | 
| 149 | 
         
            +
                    assert cache_capacity > 0
         
     | 
| 150 | 
         
            +
                    self._cache = Cache(cache_capacity)
         
     | 
| 151 | 
         
            +
                    self._position = 0
         
     | 
| 152 | 
         
            +
                    # get basic info
         
     | 
| 153 | 
         
            +
                    self._width = int(self._vcap.get(CAP_PROP_FRAME_WIDTH))
         
     | 
| 154 | 
         
            +
                    self._height = int(self._vcap.get(CAP_PROP_FRAME_HEIGHT))
         
     | 
| 155 | 
         
            +
                    self._fps = self._vcap.get(CAP_PROP_FPS)
         
     | 
| 156 | 
         
            +
                    self._frame_cnt = int(self._vcap.get(CAP_PROP_FRAME_COUNT))
         
     | 
| 157 | 
         
            +
                    self._fourcc = self._vcap.get(CAP_PROP_FOURCC)
         
     | 
| 158 | 
         
            +
             
     | 
| 159 | 
         
            +
                @property
         
     | 
| 160 | 
         
            +
                def vcap(self):
         
     | 
| 161 | 
         
            +
                    """:obj:`cv2.VideoCapture`: The raw VideoCapture object."""
         
     | 
| 162 | 
         
            +
                    return self._vcap
         
     | 
| 163 | 
         
            +
             
     | 
| 164 | 
         
            +
                @property
         
     | 
| 165 | 
         
            +
                def opened(self):
         
     | 
| 166 | 
         
            +
                    """bool: Indicate whether the video is opened."""
         
     | 
| 167 | 
         
            +
                    return self._vcap.isOpened()
         
     | 
| 168 | 
         
            +
             
     | 
| 169 | 
         
            +
                @property
         
     | 
| 170 | 
         
            +
                def width(self):
         
     | 
| 171 | 
         
            +
                    """int: Width of video frames."""
         
     | 
| 172 | 
         
            +
                    return self._width
         
     | 
| 173 | 
         
            +
             
     | 
| 174 | 
         
            +
                @property
         
     | 
| 175 | 
         
            +
                def height(self):
         
     | 
| 176 | 
         
            +
                    """int: Height of video frames."""
         
     | 
| 177 | 
         
            +
                    return self._height
         
     | 
| 178 | 
         
            +
             
     | 
| 179 | 
         
            +
                @property
         
     | 
| 180 | 
         
            +
                def resolution(self):
         
     | 
| 181 | 
         
            +
                    """tuple: Video resolution (width, height)."""
         
     | 
| 182 | 
         
            +
                    return (self._width, self._height)
         
     | 
| 183 | 
         
            +
             
     | 
| 184 | 
         
            +
                @property
         
     | 
| 185 | 
         
            +
                def fps(self):
         
     | 
| 186 | 
         
            +
                    """float: FPS of the video."""
         
     | 
| 187 | 
         
            +
                    return self._fps
         
     | 
| 188 | 
         
            +
             
     | 
| 189 | 
         
            +
                @property
         
     | 
| 190 | 
         
            +
                def frame_cnt(self):
         
     | 
| 191 | 
         
            +
                    """int: Total frames of the video."""
         
     | 
| 192 | 
         
            +
                    return self._frame_cnt
         
     | 
| 193 | 
         
            +
             
     | 
| 194 | 
         
            +
                @property
         
     | 
| 195 | 
         
            +
                def fourcc(self):
         
     | 
| 196 | 
         
            +
                    """str: "Four character code" of the video."""
         
     | 
| 197 | 
         
            +
                    return self._fourcc
         
     | 
| 198 | 
         
            +
             
     | 
| 199 | 
         
            +
                @property
         
     | 
| 200 | 
         
            +
                def position(self):
         
     | 
| 201 | 
         
            +
                    """int: Current cursor position, indicating frame decoded."""
         
     | 
| 202 | 
         
            +
                    return self._position
         
     | 
| 203 | 
         
            +
             
     | 
| 204 | 
         
            +
                def _get_real_position(self):
         
     | 
| 205 | 
         
            +
                    return int(round(self._vcap.get(CAP_PROP_POS_FRAMES)))
         
     | 
| 206 | 
         
            +
             
     | 
| 207 | 
         
            +
                def _set_real_position(self, frame_id):
         
     | 
| 208 | 
         
            +
                    self._vcap.set(CAP_PROP_POS_FRAMES, frame_id)
         
     | 
| 209 | 
         
            +
                    pos = self._get_real_position()
         
     | 
| 210 | 
         
            +
                    for _ in range(frame_id - pos):
         
     | 
| 211 | 
         
            +
                        self._vcap.read()
         
     | 
| 212 | 
         
            +
                    self._position = frame_id
         
     | 
| 213 | 
         
            +
             
     | 
| 214 | 
         
            +
                def read(self):
         
     | 
| 215 | 
         
            +
                    """Read the next frame.
         
     | 
| 216 | 
         
            +
             
     | 
| 217 | 
         
            +
                    If the next frame have been decoded before and in the cache, then
         
     | 
| 218 | 
         
            +
                    return it directly, otherwise decode, cache and return it.
         
     | 
| 219 | 
         
            +
             
     | 
| 220 | 
         
            +
                    Returns:
         
     | 
| 221 | 
         
            +
                        ndarray or None: Return the frame if successful, otherwise None.
         
     | 
| 222 | 
         
            +
                    """
         
     | 
| 223 | 
         
            +
                    # pos = self._position
         
     | 
| 224 | 
         
            +
                    if self._cache:
         
     | 
| 225 | 
         
            +
                        img = self._cache.get(self._position)
         
     | 
| 226 | 
         
            +
                        if img is not None:
         
     | 
| 227 | 
         
            +
                            ret = True
         
     | 
| 228 | 
         
            +
                        else:
         
     | 
| 229 | 
         
            +
                            if self._position != self._get_real_position():
         
     | 
| 230 | 
         
            +
                                self._set_real_position(self._position)
         
     | 
| 231 | 
         
            +
                            ret, img = self._vcap.read()
         
     | 
| 232 | 
         
            +
                            if ret:
         
     | 
| 233 | 
         
            +
                                self._cache.put(self._position, img)
         
     | 
| 234 | 
         
            +
                    else:
         
     | 
| 235 | 
         
            +
                        ret, img = self._vcap.read()
         
     | 
| 236 | 
         
            +
                    if ret:
         
     | 
| 237 | 
         
            +
                        self._position += 1
         
     | 
| 238 | 
         
            +
                    return img
         
     | 
| 239 | 
         
            +
             
     | 
| 240 | 
         
            +
                def get_frame(self, frame_id):
         
     | 
| 241 | 
         
            +
                    """Get frame by index.
         
     | 
| 242 | 
         
            +
             
     | 
| 243 | 
         
            +
                    Args:
         
     | 
| 244 | 
         
            +
                        frame_id (int): Index of the expected frame, 0-based.
         
     | 
| 245 | 
         
            +
             
     | 
| 246 | 
         
            +
                    Returns:
         
     | 
| 247 | 
         
            +
                        ndarray or None: Return the frame if successful, otherwise None.
         
     | 
| 248 | 
         
            +
                    """
         
     | 
| 249 | 
         
            +
                    if frame_id < 0 or frame_id >= self._frame_cnt:
         
     | 
| 250 | 
         
            +
                        raise IndexError(
         
     | 
| 251 | 
         
            +
                            f'"frame_id" must be between 0 and {self._frame_cnt - 1}')
         
     | 
| 252 | 
         
            +
                    if frame_id == self._position:
         
     | 
| 253 | 
         
            +
                        return self.read()
         
     | 
| 254 | 
         
            +
                    if self._cache:
         
     | 
| 255 | 
         
            +
                        img = self._cache.get(frame_id)
         
     | 
| 256 | 
         
            +
                        if img is not None:
         
     | 
| 257 | 
         
            +
                            self._position = frame_id + 1
         
     | 
| 258 | 
         
            +
                            return img
         
     | 
| 259 | 
         
            +
                    self._set_real_position(frame_id)
         
     | 
| 260 | 
         
            +
                    ret, img = self._vcap.read()
         
     | 
| 261 | 
         
            +
                    if ret:
         
     | 
| 262 | 
         
            +
                        if self._cache:
         
     | 
| 263 | 
         
            +
                            self._cache.put(self._position, img)
         
     | 
| 264 | 
         
            +
                        self._position += 1
         
     | 
| 265 | 
         
            +
                    return img
         
     | 
| 266 | 
         
            +
             
     | 
| 267 | 
         
            +
                def current_frame(self):
         
     | 
| 268 | 
         
            +
                    """Get the current frame (frame that is just visited).
         
     | 
| 269 | 
         
            +
             
     | 
| 270 | 
         
            +
                    Returns:
         
     | 
| 271 | 
         
            +
                        ndarray or None: If the video is fresh, return None, otherwise
         
     | 
| 272 | 
         
            +
                        return the frame.
         
     | 
| 273 | 
         
            +
                    """
         
     | 
| 274 | 
         
            +
                    if self._position == 0:
         
     | 
| 275 | 
         
            +
                        return None
         
     | 
| 276 | 
         
            +
                    return self._cache.get(self._position - 1)
         
     | 
| 277 | 
         
            +
             
     | 
| 278 | 
         
            +
                def cvt2frames(self,
         
     | 
| 279 | 
         
            +
                               frame_dir,
         
     | 
| 280 | 
         
            +
                               file_start=0,
         
     | 
| 281 | 
         
            +
                               filename_tmpl='{:06d}.jpg',
         
     | 
| 282 | 
         
            +
                               start=0,
         
     | 
| 283 | 
         
            +
                               max_num=0,
         
     | 
| 284 | 
         
            +
                               show_progress=False):
         
     | 
| 285 | 
         
            +
                    """Convert a video to frame images.
         
     | 
| 286 | 
         
            +
             
     | 
| 287 | 
         
            +
                    Args:
         
     | 
| 288 | 
         
            +
                        frame_dir (str): Output directory to store all the frame images.
         
     | 
| 289 | 
         
            +
                        file_start (int): Filenames will start from the specified number.
         
     | 
| 290 | 
         
            +
                        filename_tmpl (str): Filename template with the index as the
         
     | 
| 291 | 
         
            +
                            placeholder.
         
     | 
| 292 | 
         
            +
                        start (int): The starting frame index.
         
     | 
| 293 | 
         
            +
                        max_num (int): Maximum number of frames to be written.
         
     | 
| 294 | 
         
            +
                        show_progress (bool): Whether to show a progress bar.
         
     | 
| 295 | 
         
            +
                    """
         
     | 
| 296 | 
         
            +
                    mkdir_or_exist(frame_dir)
         
     | 
| 297 | 
         
            +
                    if max_num == 0:
         
     | 
| 298 | 
         
            +
                        task_num = self.frame_cnt - start
         
     | 
| 299 | 
         
            +
                    else:
         
     | 
| 300 | 
         
            +
                        task_num = min(self.frame_cnt - start, max_num)
         
     | 
| 301 | 
         
            +
                    if task_num <= 0:
         
     | 
| 302 | 
         
            +
                        raise ValueError('start must be less than total frame number')
         
     | 
| 303 | 
         
            +
                    if start > 0:
         
     | 
| 304 | 
         
            +
                        self._set_real_position(start)
         
     | 
| 305 | 
         
            +
             
     | 
| 306 | 
         
            +
                    def write_frame(file_idx):
         
     | 
| 307 | 
         
            +
                        img = self.read()
         
     | 
| 308 | 
         
            +
                        if img is None:
         
     | 
| 309 | 
         
            +
                            return
         
     | 
| 310 | 
         
            +
                        filename = osp.join(frame_dir, filename_tmpl.format(file_idx))
         
     | 
| 311 | 
         
            +
                        cv2.imwrite(filename, img)
         
     | 
| 312 | 
         
            +
             
     | 
| 313 | 
         
            +
                    if show_progress:
         
     | 
| 314 | 
         
            +
                        pass
         
     | 
| 315 | 
         
            +
                        #track_progress(write_frame, range(file_start,file_start + task_num))
         
     | 
| 316 | 
         
            +
                    else:
         
     | 
| 317 | 
         
            +
                        for i in range(task_num):
         
     | 
| 318 | 
         
            +
                            write_frame(file_start + i)
         
     | 
| 319 | 
         
            +
             
     | 
| 320 | 
         
            +
                def __len__(self):
         
     | 
| 321 | 
         
            +
                    return self.frame_cnt
         
     | 
| 322 | 
         
            +
             
     | 
| 323 | 
         
            +
                def __getitem__(self, index):
         
     | 
| 324 | 
         
            +
                    if isinstance(index, slice):
         
     | 
| 325 | 
         
            +
                        return [
         
     | 
| 326 | 
         
            +
                            self.get_frame(i)
         
     | 
| 327 | 
         
            +
                            for i in range(*index.indices(self.frame_cnt))
         
     | 
| 328 | 
         
            +
                        ]
         
     | 
| 329 | 
         
            +
                    # support negative indexing
         
     | 
| 330 | 
         
            +
                    if index < 0:
         
     | 
| 331 | 
         
            +
                        index += self.frame_cnt
         
     | 
| 332 | 
         
            +
                        if index < 0:
         
     | 
| 333 | 
         
            +
                            raise IndexError('index out of range')
         
     | 
| 334 | 
         
            +
                    return self.get_frame(index)
         
     | 
| 335 | 
         
            +
             
     | 
| 336 | 
         
            +
                def __iter__(self):
         
     | 
| 337 | 
         
            +
                    self._set_real_position(0)
         
     | 
| 338 | 
         
            +
                    return self
         
     | 
| 339 | 
         
            +
             
     | 
| 340 | 
         
            +
                def __next__(self):
         
     | 
| 341 | 
         
            +
                    img = self.read()
         
     | 
| 342 | 
         
            +
                    if img is not None:
         
     | 
| 343 | 
         
            +
                        return img
         
     | 
| 344 | 
         
            +
                    else:
         
     | 
| 345 | 
         
            +
                        raise StopIteration
         
     | 
| 346 | 
         
            +
             
     | 
| 347 | 
         
            +
                next = __next__
         
     | 
| 348 | 
         
            +
             
     | 
| 349 | 
         
            +
                def __enter__(self):
         
     | 
| 350 | 
         
            +
                    return self
         
     | 
| 351 | 
         
            +
             
     | 
| 352 | 
         
            +
                def __exit__(self, exc_type, exc_value, traceback):
         
     | 
| 353 | 
         
            +
                    self._vcap.release()
         
     | 
| 354 | 
         
            +
             
     | 
| 355 | 
         
            +
             
     | 
| 356 | 
         
            +
            def frames2video(frame_dir,
         
     | 
| 357 | 
         
            +
                             video_file,
         
     | 
| 358 | 
         
            +
                             fps=30,
         
     | 
| 359 | 
         
            +
                             fourcc='XVID',
         
     | 
| 360 | 
         
            +
                             filename_tmpl='{:06d}.jpg',
         
     | 
| 361 | 
         
            +
                             start=0,
         
     | 
| 362 | 
         
            +
                             end=0,
         
     | 
| 363 | 
         
            +
                             show_progress=False):
         
     | 
| 364 | 
         
            +
                """Read the frame images from a directory and join them as a video.
         
     | 
| 365 | 
         
            +
             
     | 
| 366 | 
         
            +
                Args:
         
     | 
| 367 | 
         
            +
                    frame_dir (str): The directory containing video frames.
         
     | 
| 368 | 
         
            +
                    video_file (str): Output filename.
         
     | 
| 369 | 
         
            +
                    fps (float): FPS of the output video.
         
     | 
| 370 | 
         
            +
                    fourcc (str): Fourcc of the output video, this should be compatible
         
     | 
| 371 | 
         
            +
                        with the output file type.
         
     | 
| 372 | 
         
            +
                    filename_tmpl (str): Filename template with the index as the variable.
         
     | 
| 373 | 
         
            +
                    start (int): Starting frame index.
         
     | 
| 374 | 
         
            +
                    end (int): Ending frame index.
         
     | 
| 375 | 
         
            +
                    show_progress (bool): Whether to show a progress bar.
         
     | 
| 376 | 
         
            +
                """
         
     | 
| 377 | 
         
            +
                if end == 0:
         
     | 
| 378 | 
         
            +
                    ext = filename_tmpl.split('.')[-1]
         
     | 
| 379 | 
         
            +
                    end = len([name for name in scandir(frame_dir, ext)])
         
     | 
| 380 | 
         
            +
                first_file = osp.join(frame_dir, filename_tmpl.format(start))
         
     | 
| 381 | 
         
            +
                check_file_exist(first_file, 'The start frame not found: ' + first_file)
         
     | 
| 382 | 
         
            +
                img = cv2.imread(first_file)
         
     | 
| 383 | 
         
            +
                height, width = img.shape[:2]
         
     | 
| 384 | 
         
            +
                resolution = (width, height)
         
     | 
| 385 | 
         
            +
                vwriter = cv2.VideoWriter(video_file, VideoWriter_fourcc(*fourcc), fps,
         
     | 
| 386 | 
         
            +
                                          resolution)
         
     | 
| 387 | 
         
            +
             
     | 
| 388 | 
         
            +
                def write_frame(file_idx):
         
     | 
| 389 | 
         
            +
                    filename = osp.join(frame_dir, filename_tmpl.format(file_idx))
         
     | 
| 390 | 
         
            +
                    img = cv2.imread(filename)
         
     | 
| 391 | 
         
            +
                    vwriter.write(img)
         
     | 
| 392 | 
         
            +
             
     | 
| 393 | 
         
            +
                if show_progress:
         
     | 
| 394 | 
         
            +
                    pass
         
     | 
| 395 | 
         
            +
                    # track_progress(write_frame, range(start, end))
         
     | 
| 396 | 
         
            +
                else:
         
     | 
| 397 | 
         
            +
                    for i in range(start, end):
         
     | 
| 398 | 
         
            +
                        write_frame(i)
         
     | 
| 399 | 
         
            +
                vwriter.release()
         
     | 
| 400 | 
         
            +
             
     | 
| 401 | 
         
            +
             
     | 
| 402 | 
         
            +
            def video2images(video_path, output_dir):
         
     | 
| 403 | 
         
            +
                vidcap = cv2.VideoCapture(video_path)
         
     | 
| 404 | 
         
            +
                in_fps = vidcap.get(cv2.CAP_PROP_FPS)
         
     | 
| 405 | 
         
            +
                print('video fps:', in_fps)
         
     | 
| 406 | 
         
            +
                if not os.path.isdir(output_dir):
         
     | 
| 407 | 
         
            +
                    os.makedirs(output_dir)
         
     | 
| 408 | 
         
            +
                loaded, frame = vidcap.read()
         
     | 
| 409 | 
         
            +
                total_frames = int(vidcap.get(cv2.CAP_PROP_FRAME_COUNT))
         
     | 
| 410 | 
         
            +
                print(f'number of total frames is: {total_frames:06}')
         
     | 
| 411 | 
         
            +
                for i_frame in range(total_frames):
         
     | 
| 412 | 
         
            +
                    if i_frame % 100 == 0:
         
     | 
| 413 | 
         
            +
                        print(f'{i_frame:06} / {total_frames:06}')
         
     | 
| 414 | 
         
            +
                    frame_name = os.path.join(output_dir, f'{i_frame:06}' + '.png')
         
     | 
| 415 | 
         
            +
                    cv2.imwrite(frame_name, frame)
         
     | 
| 416 | 
         
            +
                    loaded, frame = vidcap.read()
         
     | 
| 417 | 
         
            +
             
     | 
| 418 | 
         
            +
             
     | 
| 419 | 
         
            +
            def images2video(image_dir, video_path, fps=24, image_ext='png'):
         
     | 
| 420 | 
         
            +
                '''
         
     | 
| 421 | 
         
            +
                #codec = cv2.VideoWriter_fourcc(*'XVID')
         
     | 
| 422 | 
         
            +
                #codec = cv2.VideoWriter_fourcc('A','V','C','1')
         
     | 
| 423 | 
         
            +
                #codec = cv2.VideoWriter_fourcc('Y','U','V','1')
         
     | 
| 424 | 
         
            +
                #codec = cv2.VideoWriter_fourcc('P','I','M','1')
         
     | 
| 425 | 
         
            +
                #codec = cv2.VideoWriter_fourcc('M','J','P','G')
         
     | 
| 426 | 
         
            +
                codec = cv2.VideoWriter_fourcc('M','P','4','2')
         
     | 
| 427 | 
         
            +
                #codec = cv2.VideoWriter_fourcc('D','I','V','3')
         
     | 
| 428 | 
         
            +
                #codec =  cv2.VideoWriter_fourcc('D','I','V','X')
         
     | 
| 429 | 
         
            +
                #codec = cv2.VideoWriter_fourcc('U','2','6','3')
         
     | 
| 430 | 
         
            +
                #codec = cv2.VideoWriter_fourcc('I','2','6','3')
         
     | 
| 431 | 
         
            +
                #codec = cv2.VideoWriter_fourcc('F','L','V','1')
         
     | 
| 432 | 
         
            +
                #codec = cv2.VideoWriter_fourcc('H','2','6','4')
         
     | 
| 433 | 
         
            +
                #codec = cv2.VideoWriter_fourcc('A','Y','U','V')
         
     | 
| 434 | 
         
            +
                #codec = cv2.VideoWriter_fourcc('I','U','Y','V')
         
     | 
| 435 | 
         
            +
                编码器常用的几种:
         
     | 
| 436 | 
         
            +
                cv2.VideoWriter_fourcc("I", "4", "2", "0") 
         
     | 
| 437 | 
         
            +
                    压缩的yuv颜色编码器,4:2:0色彩度子采样 兼容性好,产生很大的视频 avi
         
     | 
| 438 | 
         
            +
                cv2.VideoWriter_fourcc("P", I", "M", "1")
         
     | 
| 439 | 
         
            +
                    采用mpeg-1编码,文件为avi
         
     | 
| 440 | 
         
            +
                cv2.VideoWriter_fourcc("X", "V", "T", "D")
         
     | 
| 441 | 
         
            +
                    采用mpeg-4编码,得到视频大小平均 拓展名avi
         
     | 
| 442 | 
         
            +
                cv2.VideoWriter_fourcc("T", "H", "E", "O")
         
     | 
| 443 | 
         
            +
                    Ogg Vorbis, 拓展名为ogv
         
     | 
| 444 | 
         
            +
                cv2.VideoWriter_fourcc("F", "L", "V", "1")
         
     | 
| 445 | 
         
            +
                    FLASH视频,拓展名为.flv
         
     | 
| 446 | 
         
            +
                '''
         
     | 
| 447 | 
         
            +
                image_files = sorted(glob.glob(os.path.join(image_dir, '*.{}'.format(image_ext))))
         
     | 
| 448 | 
         
            +
                print(len(image_files))
         
     | 
| 449 | 
         
            +
                height, width, _ = cv2.imread(image_files[0]).shape
         
     | 
| 450 | 
         
            +
                out_fourcc = cv2.VideoWriter_fourcc('M', 'J', 'P', 'G')  # cv2.VideoWriter_fourcc(*'MP4V')
         
     | 
| 451 | 
         
            +
                out_video = cv2.VideoWriter(video_path, out_fourcc, fps, (width, height))
         
     | 
| 452 | 
         
            +
             
     | 
| 453 | 
         
            +
                for image_file in image_files:
         
     | 
| 454 | 
         
            +
                    img = cv2.imread(image_file)
         
     | 
| 455 | 
         
            +
                    img = cv2.resize(img, (width, height), interpolation=3)
         
     | 
| 456 | 
         
            +
                    out_video.write(img)
         
     | 
| 457 | 
         
            +
                out_video.release()
         
     | 
| 458 | 
         
            +
             
     | 
| 459 | 
         
            +
             
     | 
| 460 | 
         
            +
            def add_video_compression(imgs):
         
     | 
| 461 | 
         
            +
                codec_type = ['libx264', 'h264', 'mpeg4']
         
     | 
| 462 | 
         
            +
                codec_prob = [1 / 3., 1 / 3., 1 / 3.]
         
     | 
| 463 | 
         
            +
                codec = random.choices(codec_type, codec_prob)[0]
         
     | 
| 464 | 
         
            +
                # codec = 'mpeg4'
         
     | 
| 465 | 
         
            +
                bitrate = [1e4, 1e5]
         
     | 
| 466 | 
         
            +
                bitrate = np.random.randint(bitrate[0], bitrate[1] + 1)
         
     | 
| 467 | 
         
            +
             
     | 
| 468 | 
         
            +
                buf = io.BytesIO()
         
     | 
| 469 | 
         
            +
                with av.open(buf, 'w', 'mp4') as container:
         
     | 
| 470 | 
         
            +
                    stream = container.add_stream(codec, rate=1)
         
     | 
| 471 | 
         
            +
                    stream.height = imgs[0].shape[0]
         
     | 
| 472 | 
         
            +
                    stream.width = imgs[0].shape[1]
         
     | 
| 473 | 
         
            +
                    stream.pix_fmt = 'yuv420p'
         
     | 
| 474 | 
         
            +
                    stream.bit_rate = bitrate
         
     | 
| 475 | 
         
            +
                    
         
     | 
| 476 | 
         
            +
                    for img in imgs:
         
     | 
| 477 | 
         
            +
                        img = np.uint8((img.clip(0, 1)*255.).round())
         
     | 
| 478 | 
         
            +
                        frame = av.VideoFrame.from_ndarray(img, format='rgb24')
         
     | 
| 479 | 
         
            +
                        frame.pict_type = 'NONE'
         
     | 
| 480 | 
         
            +
                        # pdb.set_trace()
         
     | 
| 481 | 
         
            +
                        for packet in stream.encode(frame):
         
     | 
| 482 | 
         
            +
                            container.mux(packet)
         
     | 
| 483 | 
         
            +
             
     | 
| 484 | 
         
            +
                    # Flush stream
         
     | 
| 485 | 
         
            +
                    for packet in stream.encode():
         
     | 
| 486 | 
         
            +
                        container.mux(packet)
         
     | 
| 487 | 
         
            +
             
     | 
| 488 | 
         
            +
                outputs = []
         
     | 
| 489 | 
         
            +
                with av.open(buf, 'r', 'mp4') as container:
         
     | 
| 490 | 
         
            +
                    if container.streams.video:
         
     | 
| 491 | 
         
            +
                        for frame in container.decode(**{'video': 0}):
         
     | 
| 492 | 
         
            +
                            outputs.append(
         
     | 
| 493 | 
         
            +
                                frame.to_rgb().to_ndarray().astype(np.float32) / 255.)
         
     | 
| 494 | 
         
            +
             
     | 
| 495 | 
         
            +
                #outputs = np.stack(outputs, axis=0)
         
     | 
| 496 | 
         
            +
                return outputs
         
     | 
| 497 | 
         
            +
             
     | 
| 498 | 
         
            +
             
     | 
| 499 | 
         
            +
            if __name__ == '__main__':
         
     | 
| 500 | 
         
            +
             
     | 
| 501 | 
         
            +
                # -----------------------------------
         
     | 
| 502 | 
         
            +
                # test VideoReader(filename, cache_capacity=10)
         
     | 
| 503 | 
         
            +
                # -----------------------------------
         
     | 
| 504 | 
         
            +
            #    video_reader = VideoReader('utils/test.mp4')
         
     | 
| 505 | 
         
            +
            #    from utils import utils_image as util
         
     | 
| 506 | 
         
            +
            #    inputs = []
         
     | 
| 507 | 
         
            +
            #    for frame in video_reader:
         
     | 
| 508 | 
         
            +
            #        print(frame.dtype)
         
     | 
| 509 | 
         
            +
            #        util.imshow(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
         
     | 
| 510 | 
         
            +
            #        #util.imshow(np.flip(frame, axis=2))
         
     | 
| 511 | 
         
            +
             
     | 
| 512 | 
         
            +
                # -----------------------------------
         
     | 
| 513 | 
         
            +
                # test video2images(video_path, output_dir)
         
     | 
| 514 | 
         
            +
                # -----------------------------------
         
     | 
| 515 | 
         
            +
            #    video2images('utils/test.mp4', 'frames')
         
     | 
| 516 | 
         
            +
             
     | 
| 517 | 
         
            +
                # -----------------------------------
         
     | 
| 518 | 
         
            +
                # test images2video(image_dir, video_path, fps=24, image_ext='png')
         
     | 
| 519 | 
         
            +
                # -----------------------------------
         
     | 
| 520 | 
         
            +
            #    images2video('frames', 'video_02.mp4', fps=30, image_ext='png')
         
     | 
| 521 | 
         
            +
             
     | 
| 522 | 
         
            +
             
     | 
| 523 | 
         
            +
                # -----------------------------------
         
     | 
| 524 | 
         
            +
                # test frames2video(frame_dir, video_file, fps=30, fourcc='XVID', filename_tmpl='{:06d}.png')
         
     | 
| 525 | 
         
            +
                # -----------------------------------
         
     | 
| 526 | 
         
            +
            #    frames2video('frames', 'video_01.mp4', filename_tmpl='{:06d}.png')
         
     | 
| 527 | 
         
            +
             
     | 
| 528 | 
         
            +
             
     | 
| 529 | 
         
            +
                # -----------------------------------
         
     | 
| 530 | 
         
            +
                # test add_video_compression(imgs)
         
     | 
| 531 | 
         
            +
                # -----------------------------------
         
     | 
| 532 | 
         
            +
            #    imgs = []
         
     | 
| 533 | 
         
            +
            #    image_ext = 'png'
         
     | 
| 534 | 
         
            +
            #    frames = 'frames'
         
     | 
| 535 | 
         
            +
            #    from utils import utils_image as util
         
     | 
| 536 | 
         
            +
            #    image_files = sorted(glob.glob(os.path.join(frames, '*.{}'.format(image_ext))))
         
     | 
| 537 | 
         
            +
            #    for i, image_file in enumerate(image_files):
         
     | 
| 538 | 
         
            +
            #        if i < 7:
         
     | 
| 539 | 
         
            +
            #            img = util.imread_uint(image_file, 3)
         
     | 
| 540 | 
         
            +
            #            img = util.uint2single(img)
         
     | 
| 541 | 
         
            +
            #            imgs.append(img)
         
     | 
| 542 | 
         
            +
            #
         
     | 
| 543 | 
         
            +
            #    results = add_video_compression(imgs)
         
     | 
| 544 | 
         
            +
            #    for i, img in enumerate(results):
         
     | 
| 545 | 
         
            +
            #        util.imshow(util.single2uint(img))
         
     | 
| 546 | 
         
            +
            #        util.imsave(util.single2uint(img),f'{i:05}.png')
         
     | 
| 547 | 
         
            +
             
     | 
| 548 | 
         
            +
                # run utils/utils_video.py
         
     | 
| 549 | 
         
            +
             
     | 
| 550 | 
         
            +
             
     | 
| 551 | 
         
            +
             
     | 
| 552 | 
         
            +
             
     | 
| 553 | 
         
            +
             
     | 
| 554 | 
         
            +
             
     | 
| 555 | 
         
            +
             
     | 
    	
        core/scripts/__init__.py
    ADDED
    
    | 
         
            File without changes
         
     | 
    	
        core/scripts/cli.py
    ADDED
    
    | 
         @@ -0,0 +1,41 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            import sys
         
     | 
| 2 | 
         
            +
            import argparse
         
     | 
| 3 | 
         
            +
            from .. import WarpCore
         
     | 
| 4 | 
         
            +
            from .. import templates
         
     | 
| 5 | 
         
            +
             
     | 
| 6 | 
         
            +
             
     | 
| 7 | 
         
            +
            def template_init(args):
         
     | 
| 8 | 
         
            +
                return ''''
         
     | 
| 9 | 
         
            +
             
     | 
| 10 | 
         
            +
             
     | 
| 11 | 
         
            +
                '''.strip()
         
     | 
| 12 | 
         
            +
             
     | 
| 13 | 
         
            +
             
     | 
| 14 | 
         
            +
            def init_template(args):
         
     | 
| 15 | 
         
            +
                parser = argparse.ArgumentParser(description='WarpCore template init tool')
         
     | 
| 16 | 
         
            +
                parser.add_argument('-t', '--template', type=str, default='WarpCore')
         
     | 
| 17 | 
         
            +
                args = parser.parse_args(args)
         
     | 
| 18 | 
         
            +
             
     | 
| 19 | 
         
            +
                if args.template == 'WarpCore':
         
     | 
| 20 | 
         
            +
                    template_cls = WarpCore
         
     | 
| 21 | 
         
            +
                else:
         
     | 
| 22 | 
         
            +
                    try:
         
     | 
| 23 | 
         
            +
                        template_cls = __import__(args.template)
         
     | 
| 24 | 
         
            +
                    except ModuleNotFoundError:
         
     | 
| 25 | 
         
            +
                        template_cls = getattr(templates, args.template)
         
     | 
| 26 | 
         
            +
                print(template_cls)
         
     | 
| 27 | 
         
            +
             
     | 
| 28 | 
         
            +
             
     | 
| 29 | 
         
            +
            def main():
         
     | 
| 30 | 
         
            +
                if len(sys.argv) < 2:
         
     | 
| 31 | 
         
            +
                    print('Usage: core <command>')
         
     | 
| 32 | 
         
            +
                    sys.exit(1)
         
     | 
| 33 | 
         
            +
                if sys.argv[1] == 'init':
         
     | 
| 34 | 
         
            +
                    init_template(sys.argv[2:])
         
     | 
| 35 | 
         
            +
                else:
         
     | 
| 36 | 
         
            +
                    print('Unknown command')
         
     | 
| 37 | 
         
            +
                    sys.exit(1)
         
     | 
| 38 | 
         
            +
             
     | 
| 39 | 
         
            +
             
     | 
| 40 | 
         
            +
            if __name__ == '__main__':
         
     | 
| 41 | 
         
            +
                main()
         
     | 
    	
        core/templates/__init__.py
    ADDED
    
    | 
         @@ -0,0 +1 @@ 
     | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            from .diffusion import DiffusionCore
         
     | 
    	
        core/templates/diffusion.py
    ADDED
    
    | 
         @@ -0,0 +1,236 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            from .. import WarpCore
         
     | 
| 2 | 
         
            +
            from ..utils import EXPECTED, EXPECTED_TRAIN, update_weights_ema, create_folder_if_necessary
         
     | 
| 3 | 
         
            +
            from abc import abstractmethod
         
     | 
| 4 | 
         
            +
            from dataclasses import dataclass
         
     | 
| 5 | 
         
            +
            import torch
         
     | 
| 6 | 
         
            +
            from torch import nn
         
     | 
| 7 | 
         
            +
            from torch.utils.data import DataLoader
         
     | 
| 8 | 
         
            +
            from gdf import GDF
         
     | 
| 9 | 
         
            +
            import numpy as np
         
     | 
| 10 | 
         
            +
            from tqdm import tqdm
         
     | 
| 11 | 
         
            +
            import wandb
         
     | 
| 12 | 
         
            +
             
     | 
| 13 | 
         
            +
            import webdataset as wds
         
     | 
| 14 | 
         
            +
            from webdataset.handlers import warn_and_continue
         
     | 
| 15 | 
         
            +
            from torch.distributed import barrier
         
     | 
| 16 | 
         
            +
            from enum import Enum
         
     | 
| 17 | 
         
            +
             
     | 
| 18 | 
         
            +
            class TargetReparametrization(Enum):
         
     | 
| 19 | 
         
            +
                EPSILON = 'epsilon'
         
     | 
| 20 | 
         
            +
                X0 = 'x0'
         
     | 
| 21 | 
         
            +
             
     | 
| 22 | 
         
            +
            class DiffusionCore(WarpCore):
         
     | 
| 23 | 
         
            +
                @dataclass(frozen=True)
         
     | 
| 24 | 
         
            +
                class Config(WarpCore.Config):
         
     | 
| 25 | 
         
            +
                    # TRAINING PARAMS
         
     | 
| 26 | 
         
            +
                    lr: float = EXPECTED_TRAIN
         
     | 
| 27 | 
         
            +
                    grad_accum_steps: int = EXPECTED_TRAIN
         
     | 
| 28 | 
         
            +
                    batch_size: int = EXPECTED_TRAIN
         
     | 
| 29 | 
         
            +
                    updates: int = EXPECTED_TRAIN
         
     | 
| 30 | 
         
            +
                    warmup_updates: int = EXPECTED_TRAIN
         
     | 
| 31 | 
         
            +
                    save_every: int = 500
         
     | 
| 32 | 
         
            +
                    backup_every: int = 20000
         
     | 
| 33 | 
         
            +
                    use_fsdp: bool = True
         
     | 
| 34 | 
         
            +
             
     | 
| 35 | 
         
            +
                    # EMA UPDATE
         
     | 
| 36 | 
         
            +
                    ema_start_iters: int = None
         
     | 
| 37 | 
         
            +
                    ema_iters: int = None
         
     | 
| 38 | 
         
            +
                    ema_beta: float = None
         
     | 
| 39 | 
         
            +
             
     | 
| 40 | 
         
            +
                    # GDF setting
         
     | 
| 41 | 
         
            +
                    gdf_target_reparametrization: TargetReparametrization = None # epsilon or x0
         
     | 
| 42 | 
         
            +
                
         
     | 
| 43 | 
         
            +
                @dataclass() # not frozen, means that fields are mutable. Doesn't support EXPECTED
         
     | 
| 44 | 
         
            +
                class Info(WarpCore.Info):
         
     | 
| 45 | 
         
            +
                    ema_loss: float = None
         
     | 
| 46 | 
         
            +
             
     | 
| 47 | 
         
            +
                @dataclass(frozen=True)
         
     | 
| 48 | 
         
            +
                class Models(WarpCore.Models):
         
     | 
| 49 | 
         
            +
                    generator : nn.Module = EXPECTED
         
     | 
| 50 | 
         
            +
                    generator_ema : nn.Module = None # optional
         
     | 
| 51 | 
         
            +
             
     | 
| 52 | 
         
            +
                @dataclass(frozen=True)
         
     | 
| 53 | 
         
            +
                class Optimizers(WarpCore.Optimizers):
         
     | 
| 54 | 
         
            +
                    generator : any = EXPECTED
         
     | 
| 55 | 
         
            +
             
     | 
| 56 | 
         
            +
                @dataclass(frozen=True)
         
     | 
| 57 | 
         
            +
                class Schedulers(WarpCore.Schedulers):
         
     | 
| 58 | 
         
            +
                    generator: any = None
         
     | 
| 59 | 
         
            +
             
     | 
| 60 | 
         
            +
                @dataclass(frozen=True)
         
     | 
| 61 | 
         
            +
                class Extras(WarpCore.Extras):
         
     | 
| 62 | 
         
            +
                    gdf: GDF = EXPECTED
         
     | 
| 63 | 
         
            +
                    sampling_configs: dict = EXPECTED
         
     | 
| 64 | 
         
            +
             
     | 
| 65 | 
         
            +
                # --------------------------------------------
         
     | 
| 66 | 
         
            +
                info: Info
         
     | 
| 67 | 
         
            +
                config: Config
         
     | 
| 68 | 
         
            +
             
     | 
| 69 | 
         
            +
                @abstractmethod
         
     | 
| 70 | 
         
            +
                def encode_latents(self, batch: dict, models: Models, extras: Extras) -> torch.Tensor:
         
     | 
| 71 | 
         
            +
                    raise NotImplementedError("This method needs to be overriden")
         
     | 
| 72 | 
         
            +
             
     | 
| 73 | 
         
            +
                @abstractmethod
         
     | 
| 74 | 
         
            +
                def decode_latents(self, latents: torch.Tensor, batch: dict, models: Models, extras: Extras) -> torch.Tensor:
         
     | 
| 75 | 
         
            +
                    raise NotImplementedError("This method needs to be overriden")
         
     | 
| 76 | 
         
            +
             
     | 
| 77 | 
         
            +
                @abstractmethod
         
     | 
| 78 | 
         
            +
                def get_conditions(self, batch: dict, models: Models, extras: Extras, is_eval=False, is_unconditional=False):
         
     | 
| 79 | 
         
            +
                    raise NotImplementedError("This method needs to be overriden")
         
     | 
| 80 | 
         
            +
             
     | 
| 81 | 
         
            +
                @abstractmethod
         
     | 
| 82 | 
         
            +
                def webdataset_path(self, extras: Extras):
         
     | 
| 83 | 
         
            +
                    raise NotImplementedError("This method needs to be overriden")
         
     | 
| 84 | 
         
            +
             
     | 
| 85 | 
         
            +
                @abstractmethod
         
     | 
| 86 | 
         
            +
                def webdataset_filters(self, extras: Extras):
         
     | 
| 87 | 
         
            +
                    raise NotImplementedError("This method needs to be overriden")
         
     | 
| 88 | 
         
            +
                
         
     | 
| 89 | 
         
            +
                @abstractmethod
         
     | 
| 90 | 
         
            +
                def webdataset_preprocessors(self, extras: Extras):
         
     | 
| 91 | 
         
            +
                    raise NotImplementedError("This method needs to be overriden")
         
     | 
| 92 | 
         
            +
             
     | 
| 93 | 
         
            +
                @abstractmethod
         
     | 
| 94 | 
         
            +
                def sample(self, models: Models, data: WarpCore.Data, extras: Extras):
         
     | 
| 95 | 
         
            +
                    raise NotImplementedError("This method needs to be overriden")
         
     | 
| 96 | 
         
            +
                # -------------
         
     | 
| 97 | 
         
            +
             
     | 
| 98 | 
         
            +
                def setup_data(self, extras: Extras) -> WarpCore.Data:
         
     | 
| 99 | 
         
            +
                    # SETUP DATASET
         
     | 
| 100 | 
         
            +
                    dataset_path = self.webdataset_path(extras)
         
     | 
| 101 | 
         
            +
                    preprocessors = self.webdataset_preprocessors(extras)
         
     | 
| 102 | 
         
            +
                    filters = self.webdataset_filters(extras)
         
     | 
| 103 | 
         
            +
             
     | 
| 104 | 
         
            +
                    handler = warn_and_continue # None
         
     | 
| 105 | 
         
            +
                    # handler = None
         
     | 
| 106 | 
         
            +
                    dataset = wds.WebDataset(
         
     | 
| 107 | 
         
            +
                        dataset_path, resampled=True, handler=handler
         
     | 
| 108 | 
         
            +
                    ).select(filters).shuffle(690, handler=handler).decode(
         
     | 
| 109 | 
         
            +
                        "pilrgb", handler=handler
         
     | 
| 110 | 
         
            +
                    ).to_tuple(
         
     | 
| 111 | 
         
            +
                        *[p[0] for p in preprocessors], handler=handler
         
     | 
| 112 | 
         
            +
                    ).map_tuple(
         
     | 
| 113 | 
         
            +
                        *[p[1] for p in preprocessors], handler=handler
         
     | 
| 114 | 
         
            +
                    ).map(lambda x: {p[2]:x[i] for i, p in enumerate(preprocessors)})
         
     | 
| 115 | 
         
            +
             
     | 
| 116 | 
         
            +
                    # SETUP DATALOADER
         
     | 
| 117 | 
         
            +
                    real_batch_size = self.config.batch_size//(self.world_size*self.config.grad_accum_steps)
         
     | 
| 118 | 
         
            +
                    dataloader = DataLoader(
         
     | 
| 119 | 
         
            +
                        dataset, batch_size=real_batch_size, num_workers=8, pin_memory=True
         
     | 
| 120 | 
         
            +
                    )
         
     | 
| 121 | 
         
            +
             
     | 
| 122 | 
         
            +
                    return self.Data(dataset=dataset, dataloader=dataloader, iterator=iter(dataloader))
         
     | 
| 123 | 
         
            +
             
     | 
| 124 | 
         
            +
                def forward_pass(self, data: WarpCore.Data, extras: Extras, models: Models):
         
     | 
| 125 | 
         
            +
                    batch = next(data.iterator)
         
     | 
| 126 | 
         
            +
             
     | 
| 127 | 
         
            +
                    with torch.no_grad():
         
     | 
| 128 | 
         
            +
                        conditions = self.get_conditions(batch, models, extras)
         
     | 
| 129 | 
         
            +
                        latents = self.encode_latents(batch, models, extras)
         
     | 
| 130 | 
         
            +
                        noised, noise, target, logSNR, noise_cond, loss_weight = extras.gdf.diffuse(latents, shift=1, loss_shift=1)
         
     | 
| 131 | 
         
            +
             
     | 
| 132 | 
         
            +
                    # FORWARD PASS
         
     | 
| 133 | 
         
            +
                    with torch.cuda.amp.autocast(dtype=torch.bfloat16):
         
     | 
| 134 | 
         
            +
                        pred = models.generator(noised, noise_cond, **conditions)
         
     | 
| 135 | 
         
            +
                        if self.config.gdf_target_reparametrization == TargetReparametrization.EPSILON:
         
     | 
| 136 | 
         
            +
                            pred = extras.gdf.undiffuse(noised, logSNR, pred)[1] # transform whatever prediction to epsilon to use in the loss
         
     | 
| 137 | 
         
            +
                            target = noise
         
     | 
| 138 | 
         
            +
                        elif self.config.gdf_target_reparametrization == TargetReparametrization.X0:
         
     | 
| 139 | 
         
            +
                            pred = extras.gdf.undiffuse(noised, logSNR, pred)[0] # transform whatever prediction to x0 to use in the loss
         
     | 
| 140 | 
         
            +
                            target = latents
         
     | 
| 141 | 
         
            +
                        loss = nn.functional.mse_loss(pred, target, reduction='none').mean(dim=[1, 2, 3])
         
     | 
| 142 | 
         
            +
                        loss_adjusted = (loss * loss_weight).mean() / self.config.grad_accum_steps
         
     | 
| 143 | 
         
            +
             
     | 
| 144 | 
         
            +
                    return loss, loss_adjusted
         
     | 
| 145 | 
         
            +
                
         
     | 
| 146 | 
         
            +
                def train(self, data: WarpCore.Data, extras: Extras, models: Models, optimizers: Optimizers, schedulers: Schedulers):
         
     | 
| 147 | 
         
            +
                    start_iter = self.info.iter+1
         
     | 
| 148 | 
         
            +
                    max_iters = self.config.updates * self.config.grad_accum_steps
         
     | 
| 149 | 
         
            +
                    if self.is_main_node:
         
     | 
| 150 | 
         
            +
                        print(f"STARTING AT STEP: {start_iter}/{max_iters}")
         
     | 
| 151 | 
         
            +
             
     | 
| 152 | 
         
            +
                    pbar = tqdm(range(start_iter, max_iters+1)) if self.is_main_node else range(start_iter, max_iters+1) # <--- DDP
         
     | 
| 153 | 
         
            +
                    models.generator.train()
         
     | 
| 154 | 
         
            +
                    for i in pbar:
         
     | 
| 155 | 
         
            +
                        # FORWARD PASS
         
     | 
| 156 | 
         
            +
                        loss, loss_adjusted = self.forward_pass(data, extras, models)
         
     | 
| 157 | 
         
            +
             
     | 
| 158 | 
         
            +
                        # BACKWARD PASS
         
     | 
| 159 | 
         
            +
                        if i % self.config.grad_accum_steps == 0 or i == max_iters:
         
     | 
| 160 | 
         
            +
                            loss_adjusted.backward()
         
     | 
| 161 | 
         
            +
                            grad_norm = nn.utils.clip_grad_norm_(models.generator.parameters(), 1.0)
         
     | 
| 162 | 
         
            +
                            optimizers_dict = optimizers.to_dict()
         
     | 
| 163 | 
         
            +
                            for k in optimizers_dict:
         
     | 
| 164 | 
         
            +
                                optimizers_dict[k].step()
         
     | 
| 165 | 
         
            +
                            schedulers_dict = schedulers.to_dict()
         
     | 
| 166 | 
         
            +
                            for k in schedulers_dict:
         
     | 
| 167 | 
         
            +
                                schedulers_dict[k].step()
         
     | 
| 168 | 
         
            +
                            models.generator.zero_grad(set_to_none=True)
         
     | 
| 169 | 
         
            +
                            self.info.total_steps += 1
         
     | 
| 170 | 
         
            +
                        else:
         
     | 
| 171 | 
         
            +
                            with models.generator.no_sync():
         
     | 
| 172 | 
         
            +
                                loss_adjusted.backward()
         
     | 
| 173 | 
         
            +
                        self.info.iter = i
         
     | 
| 174 | 
         
            +
             
     | 
| 175 | 
         
            +
                        # UPDATE EMA
         
     | 
| 176 | 
         
            +
                        if models.generator_ema is not None and i % self.config.ema_iters == 0:
         
     | 
| 177 | 
         
            +
                            update_weights_ema(
         
     | 
| 178 | 
         
            +
                                models.generator_ema, models.generator,
         
     | 
| 179 | 
         
            +
                                beta=(self.config.ema_beta if i > self.config.ema_start_iters else 0)
         
     | 
| 180 | 
         
            +
                            )
         
     | 
| 181 | 
         
            +
             
     | 
| 182 | 
         
            +
                        # UPDATE LOSS METRICS
         
     | 
| 183 | 
         
            +
                        self.info.ema_loss = loss.mean().item() if self.info.ema_loss is None else self.info.ema_loss * 0.99 + loss.mean().item() * 0.01
         
     | 
| 184 | 
         
            +
             
     | 
| 185 | 
         
            +
                        if self.is_main_node and self.config.wandb_project is not None and np.isnan(loss.mean().item()) or np.isnan(grad_norm.item()):
         
     | 
| 186 | 
         
            +
                            wandb.alert(
         
     | 
| 187 | 
         
            +
                                title=f"NaN value encountered in training run {self.info.wandb_run_id}", 
         
     | 
| 188 | 
         
            +
                                text=f"Loss {loss.mean().item()} - Grad Norm {grad_norm.item()}. Run {self.info.wandb_run_id}",
         
     | 
| 189 | 
         
            +
                                wait_duration=60*30
         
     | 
| 190 | 
         
            +
                            )
         
     | 
| 191 | 
         
            +
             
     | 
| 192 | 
         
            +
                        if self.is_main_node:
         
     | 
| 193 | 
         
            +
                            logs = {
         
     | 
| 194 | 
         
            +
                                'loss': self.info.ema_loss, 
         
     | 
| 195 | 
         
            +
                                'raw_loss': loss.mean().item(),
         
     | 
| 196 | 
         
            +
                                'grad_norm': grad_norm.item(),
         
     | 
| 197 | 
         
            +
                                'lr': optimizers.generator.param_groups[0]['lr'],
         
     | 
| 198 | 
         
            +
                                'total_steps': self.info.total_steps,
         
     | 
| 199 | 
         
            +
                            }
         
     | 
| 200 | 
         
            +
             
     | 
| 201 | 
         
            +
                            pbar.set_postfix(logs)
         
     | 
| 202 | 
         
            +
                            if self.config.wandb_project is not None:
         
     | 
| 203 | 
         
            +
                                wandb.log(logs)
         
     | 
| 204 | 
         
            +
             
     | 
| 205 | 
         
            +
                        if i == 1 or i % (self.config.save_every*self.config.grad_accum_steps) == 0 or i == max_iters:
         
     | 
| 206 | 
         
            +
                            # SAVE AND CHECKPOINT STUFF
         
     | 
| 207 | 
         
            +
                            if np.isnan(loss.mean().item()):
         
     | 
| 208 | 
         
            +
                                if self.is_main_node and self.config.wandb_project is not None:
         
     | 
| 209 | 
         
            +
                                    tqdm.write("Skipping sampling & checkpoint because the loss is NaN")
         
     | 
| 210 | 
         
            +
                                    wandb.alert(title=f"Skipping sampling & checkpoint for training run {self.config.run_id}", text=f"Skipping sampling & checkpoint at {self.info.total_steps} for training run {self.info.wandb_run_id} iters because loss is NaN")
         
     | 
| 211 | 
         
            +
                            else:
         
     | 
| 212 | 
         
            +
                                self.save_checkpoints(models, optimizers)
         
     | 
| 213 | 
         
            +
                                if self.is_main_node:
         
     | 
| 214 | 
         
            +
                                    create_folder_if_necessary(f'{self.config.output_path}/{self.config.experiment_id}/')
         
     | 
| 215 | 
         
            +
                                self.sample(models, data, extras)
         
     | 
| 216 | 
         
            +
             
     | 
| 217 | 
         
            +
                def models_to_save(self):
         
     | 
| 218 | 
         
            +
                    return ['generator', 'generator_ema']
         
     | 
| 219 | 
         
            +
             
     | 
| 220 | 
         
            +
                def save_checkpoints(self, models: Models, optimizers: Optimizers, suffix=None):
         
     | 
| 221 | 
         
            +
                    barrier()
         
     | 
| 222 | 
         
            +
                    suffix = '' if suffix is None else suffix
         
     | 
| 223 | 
         
            +
                    self.save_info(self.info, suffix=suffix)
         
     | 
| 224 | 
         
            +
                    models_dict = models.to_dict()
         
     | 
| 225 | 
         
            +
                    optimizers_dict = optimizers.to_dict()
         
     | 
| 226 | 
         
            +
                    for key in self.models_to_save():
         
     | 
| 227 | 
         
            +
                        model = models_dict[key]
         
     | 
| 228 | 
         
            +
                        if model is not None:
         
     | 
| 229 | 
         
            +
                            self.save_model(model, f"{key}{suffix}", is_fsdp=self.config.use_fsdp)
         
     | 
| 230 | 
         
            +
                    for key in optimizers_dict:
         
     | 
| 231 | 
         
            +
                        optimizer = optimizers_dict[key]
         
     | 
| 232 | 
         
            +
                        if optimizer is not None:
         
     | 
| 233 | 
         
            +
                            self.save_optimizer(optimizer, f'{key}_optim{suffix}', fsdp_model=models.generator if self.config.use_fsdp else None)
         
     | 
| 234 | 
         
            +
                    if suffix == '' and self.info.total_steps > 1 and self.info.total_steps % self.config.backup_every == 0:
         
     | 
| 235 | 
         
            +
                        self.save_checkpoints(models, optimizers, suffix=f"_{self.info.total_steps//1000}k")
         
     | 
| 236 | 
         
            +
                    torch.cuda.empty_cache()
         
     | 
    	
        core/utils/__init__.py
    ADDED
    
    | 
         @@ -0,0 +1,9 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            from .base_dto import Base, nested_dto, EXPECTED, EXPECTED_TRAIN
         
     | 
| 2 | 
         
            +
            from .save_and_load import create_folder_if_necessary, safe_save, load_or_fail
         
     | 
| 3 | 
         
            +
             
     | 
| 4 | 
         
            +
            # MOVE IT SOMERWHERE ELSE
         
     | 
| 5 | 
         
            +
            def update_weights_ema(tgt_model, src_model, beta=0.999):
         
     | 
| 6 | 
         
            +
                for self_params, src_params in zip(tgt_model.parameters(), src_model.parameters()):
         
     | 
| 7 | 
         
            +
                    self_params.data = self_params.data * beta + src_params.data.clone().to(self_params.device) * (1-beta)
         
     | 
| 8 | 
         
            +
                for self_buffers, src_buffers in zip(tgt_model.buffers(), src_model.buffers()):
         
     | 
| 9 | 
         
            +
                    self_buffers.data = self_buffers.data * beta + src_buffers.data.clone().to(self_buffers.device) * (1-beta)
         
     | 
    	
        core/utils/__pycache__/__init__.cpython-310.pyc
    ADDED
    
    | 
         Binary file (763 Bytes). View file 
     | 
| 
         | 
    	
        core/utils/__pycache__/__init__.cpython-39.pyc
    ADDED
    
    | 
         Binary file (804 Bytes). View file 
     | 
| 
         | 
    	
        core/utils/__pycache__/base_dto.cpython-310.pyc
    ADDED
    
    | 
         Binary file (3.09 kB). View file 
     | 
| 
         | 
    	
        core/utils/__pycache__/base_dto.cpython-39.pyc
    ADDED
    
    | 
         Binary file (3.11 kB). View file 
     | 
| 
         | 
    	
        core/utils/__pycache__/save_and_load.cpython-310.pyc
    ADDED
    
    | 
         Binary file (2.19 kB). View file 
     | 
| 
         | 
    	
        core/utils/__pycache__/save_and_load.cpython-39.pyc
    ADDED
    
    | 
         Binary file (2.2 kB). View file 
     | 
| 
         | 
    	
        core/utils/base_dto.py
    ADDED
    
    | 
         @@ -0,0 +1,56 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            import dataclasses
         
     | 
| 2 | 
         
            +
            from dataclasses import dataclass, _MISSING_TYPE
         
     | 
| 3 | 
         
            +
            from munch import Munch
         
     | 
| 4 | 
         
            +
             
     | 
| 5 | 
         
            +
            EXPECTED = "___REQUIRED___"
         
     | 
| 6 | 
         
            +
            EXPECTED_TRAIN = "___REQUIRED_TRAIN___"
         
     | 
| 7 | 
         
            +
             
     | 
| 8 | 
         
            +
            # pylint: disable=invalid-field-call
         
     | 
| 9 | 
         
            +
            def nested_dto(x, raw=False):
         
     | 
| 10 | 
         
            +
                return dataclasses.field(default_factory=lambda: x if raw else Munch.fromDict(x))
         
     | 
| 11 | 
         
            +
             
     | 
| 12 | 
         
            +
            @dataclass(frozen=True)
         
     | 
| 13 | 
         
            +
            class Base:
         
     | 
| 14 | 
         
            +
                training: bool = None
         
     | 
| 15 | 
         
            +
                def __new__(cls, **kwargs):
         
     | 
| 16 | 
         
            +
                    training = kwargs.get('training', True)
         
     | 
| 17 | 
         
            +
                    setteable_fields = cls.setteable_fields(**kwargs)
         
     | 
| 18 | 
         
            +
                    mandatory_fields = cls.mandatory_fields(**kwargs)
         
     | 
| 19 | 
         
            +
                    invalid_kwargs = [
         
     | 
| 20 | 
         
            +
                        {k: v} for k, v in kwargs.items() if k not in setteable_fields or v == EXPECTED or (v == EXPECTED_TRAIN and training is not False)
         
     | 
| 21 | 
         
            +
                    ]
         
     | 
| 22 | 
         
            +
                    print(mandatory_fields)
         
     | 
| 23 | 
         
            +
                    assert (
         
     | 
| 24 | 
         
            +
                        len(invalid_kwargs) == 0
         
     | 
| 25 | 
         
            +
                    ), f"Invalid fields detected when initializing this DTO: {invalid_kwargs}.\nDeclare this field and set it to None or EXPECTED in order to make it setteable."
         
     | 
| 26 | 
         
            +
                    missing_kwargs = [f for f in mandatory_fields if f not in kwargs]
         
     | 
| 27 | 
         
            +
                    assert (
         
     | 
| 28 | 
         
            +
                        len(missing_kwargs) == 0
         
     | 
| 29 | 
         
            +
                    ), f"Required fields missing initializing this DTO: {missing_kwargs}."
         
     | 
| 30 | 
         
            +
                    return object.__new__(cls)
         
     | 
| 31 | 
         
            +
             
     | 
| 32 | 
         
            +
             
     | 
| 33 | 
         
            +
                @classmethod
         
     | 
| 34 | 
         
            +
                def setteable_fields(cls, **kwargs):
         
     | 
| 35 | 
         
            +
                    return [f.name for f in dataclasses.fields(cls) if f.default is None or isinstance(f.default, _MISSING_TYPE) or f.default == EXPECTED or f.default == EXPECTED_TRAIN]
         
     | 
| 36 | 
         
            +
             
     | 
| 37 | 
         
            +
                @classmethod
         
     | 
| 38 | 
         
            +
                def mandatory_fields(cls, **kwargs):
         
     | 
| 39 | 
         
            +
                    training = kwargs.get('training', True)
         
     | 
| 40 | 
         
            +
                    return [f.name for f in dataclasses.fields(cls) if isinstance(f.default, _MISSING_TYPE) and isinstance(f.default_factory, _MISSING_TYPE) or f.default == EXPECTED or (f.default == EXPECTED_TRAIN and training is not False)]
         
     | 
| 41 | 
         
            +
             
     | 
| 42 | 
         
            +
                @classmethod
         
     | 
| 43 | 
         
            +
                def from_dict(cls, kwargs):
         
     | 
| 44 | 
         
            +
                    for k in kwargs:
         
     | 
| 45 | 
         
            +
                        if isinstance(kwargs[k], (dict, list, tuple)):
         
     | 
| 46 | 
         
            +
                            kwargs[k] = Munch.fromDict(kwargs[k])
         
     | 
| 47 | 
         
            +
                    return cls(**kwargs)
         
     | 
| 48 | 
         
            +
             
     | 
| 49 | 
         
            +
                def to_dict(self):
         
     | 
| 50 | 
         
            +
                    # selfdict = dataclasses.asdict(self) # needs to pickle stuff, doesn't support some more complex classes
         
     | 
| 51 | 
         
            +
                    selfdict = {}
         
     | 
| 52 | 
         
            +
                    for k in dataclasses.fields(self):
         
     | 
| 53 | 
         
            +
                        selfdict[k.name] = getattr(self, k.name)
         
     | 
| 54 | 
         
            +
                        if isinstance(selfdict[k.name], Munch):
         
     | 
| 55 | 
         
            +
                            selfdict[k.name] = selfdict[k.name].toDict()
         
     | 
| 56 | 
         
            +
                    return selfdict
         
     | 
    	
        core/utils/save_and_load.py
    ADDED
    
    | 
         @@ -0,0 +1,59 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            import os
         
     | 
| 2 | 
         
            +
            import torch
         
     | 
| 3 | 
         
            +
            import json
         
     | 
| 4 | 
         
            +
            from pathlib import Path
         
     | 
| 5 | 
         
            +
            import safetensors
         
     | 
| 6 | 
         
            +
            import wandb
         
     | 
| 7 | 
         
            +
             
     | 
| 8 | 
         
            +
             
     | 
| 9 | 
         
            +
            def create_folder_if_necessary(path):
         
     | 
| 10 | 
         
            +
                path = "/".join(path.split("/")[:-1])
         
     | 
| 11 | 
         
            +
                Path(path).mkdir(parents=True, exist_ok=True)
         
     | 
| 12 | 
         
            +
             
     | 
| 13 | 
         
            +
             
     | 
| 14 | 
         
            +
            def safe_save(ckpt, path):
         
     | 
| 15 | 
         
            +
                try:
         
     | 
| 16 | 
         
            +
                    os.remove(f"{path}.bak")
         
     | 
| 17 | 
         
            +
                except OSError:
         
     | 
| 18 | 
         
            +
                    pass
         
     | 
| 19 | 
         
            +
                try:
         
     | 
| 20 | 
         
            +
                    os.rename(path, f"{path}.bak")
         
     | 
| 21 | 
         
            +
                except OSError:
         
     | 
| 22 | 
         
            +
                    pass
         
     | 
| 23 | 
         
            +
                if path.endswith(".pt") or path.endswith(".ckpt"):
         
     | 
| 24 | 
         
            +
                    torch.save(ckpt, path)
         
     | 
| 25 | 
         
            +
                elif path.endswith(".json"):
         
     | 
| 26 | 
         
            +
                    with open(path, "w", encoding="utf-8") as f:
         
     | 
| 27 | 
         
            +
                        json.dump(ckpt, f, indent=4)
         
     | 
| 28 | 
         
            +
                elif path.endswith(".safetensors"):
         
     | 
| 29 | 
         
            +
                    safetensors.torch.save_file(ckpt, path)
         
     | 
| 30 | 
         
            +
                else:
         
     | 
| 31 | 
         
            +
                    raise ValueError(f"File extension not supported: {path}")
         
     | 
| 32 | 
         
            +
             
     | 
| 33 | 
         
            +
             
     | 
| 34 | 
         
            +
            def load_or_fail(path, wandb_run_id=None):
         
     | 
| 35 | 
         
            +
                accepted_extensions = [".pt", ".ckpt", ".json", ".safetensors"]
         
     | 
| 36 | 
         
            +
                try:
         
     | 
| 37 | 
         
            +
                    assert any(
         
     | 
| 38 | 
         
            +
                        [path.endswith(ext) for ext in accepted_extensions]
         
     | 
| 39 | 
         
            +
                    ), f"Automatic loading not supported for this extension: {path}"
         
     | 
| 40 | 
         
            +
                    if not os.path.exists(path):
         
     | 
| 41 | 
         
            +
                        checkpoint = None
         
     | 
| 42 | 
         
            +
                    elif path.endswith(".pt") or path.endswith(".ckpt"):
         
     | 
| 43 | 
         
            +
                        checkpoint = torch.load(path, map_location="cpu")
         
     | 
| 44 | 
         
            +
                    elif path.endswith(".json"):
         
     | 
| 45 | 
         
            +
                        with open(path, "r", encoding="utf-8") as f:
         
     | 
| 46 | 
         
            +
                            checkpoint = json.load(f)
         
     | 
| 47 | 
         
            +
                    elif path.endswith(".safetensors"):
         
     | 
| 48 | 
         
            +
                        checkpoint = {}
         
     | 
| 49 | 
         
            +
                        with safetensors.safe_open(path, framework="pt", device="cpu") as f:
         
     | 
| 50 | 
         
            +
                            for key in f.keys():
         
     | 
| 51 | 
         
            +
                                checkpoint[key] = f.get_tensor(key)
         
     | 
| 52 | 
         
            +
                    return checkpoint
         
     | 
| 53 | 
         
            +
                except Exception as e:
         
     | 
| 54 | 
         
            +
                    if wandb_run_id is not None:
         
     | 
| 55 | 
         
            +
                        wandb.alert(
         
     | 
| 56 | 
         
            +
                            title=f"Corrupt checkpoint for run {wandb_run_id}",
         
     | 
| 57 | 
         
            +
                            text=f"Training {wandb_run_id} tried to load checkpoint {path} and failed",
         
     | 
| 58 | 
         
            +
                        )
         
     | 
| 59 | 
         
            +
                    raise e
         
     |