Spaces:
				
			
			
	
			
			
		Runtime error
		
	
	
	
			
			
	
	
	
	
		
		
		Runtime error
		
	first commit
Browse files- lib/attention.py +1 -1
- lib/ddpm_multi.py +1 -1
- lib/openaimodel.py +1 -1
- lib/util.py +2 -10
- lib/utils.py +117 -0
    	
        lib/attention.py
    CHANGED
    
    | @@ -18,7 +18,7 @@ from torch import nn, einsum | |
| 18 | 
             
            from einops import rearrange, repeat
         | 
| 19 | 
             
            from typing import Optional, Any
         | 
| 20 |  | 
| 21 | 
            -
            from  | 
| 22 |  | 
| 23 | 
             
            try:
         | 
| 24 | 
             
                import xformers
         | 
|  | |
| 18 | 
             
            from einops import rearrange, repeat
         | 
| 19 | 
             
            from typing import Optional, Any
         | 
| 20 |  | 
| 21 | 
            +
            from utils import checkpoint
         | 
| 22 |  | 
| 23 | 
             
            try:
         | 
| 24 | 
             
                import xformers
         | 
    	
        lib/ddpm_multi.py
    CHANGED
    
    | @@ -30,7 +30,7 @@ from torchvision.utils import make_grid | |
| 30 | 
             
            from pytorch_lightning.utilities.distributed import rank_zero_only
         | 
| 31 | 
             
            from omegaconf import ListConfig
         | 
| 32 |  | 
| 33 | 
            -
            from  | 
| 34 | 
             
            from lib.distributions import normal_kl, DiagonalGaussianDistribution
         | 
| 35 | 
             
            from lib.autoencoder import IdentityFirstStage, AutoencoderKL
         | 
| 36 | 
             
            from lib.util import make_beta_schedule, extract_into_tensor, noise_like
         | 
|  | |
| 30 | 
             
            from pytorch_lightning.utilities.distributed import rank_zero_only
         | 
| 31 | 
             
            from omegaconf import ListConfig
         | 
| 32 |  | 
| 33 | 
            +
            from utils import log_txt_as_img, exists, default, ismap, isimage, mean_flat, count_params, instantiate_from_config
         | 
| 34 | 
             
            from lib.distributions import normal_kl, DiagonalGaussianDistribution
         | 
| 35 | 
             
            from lib.autoencoder import IdentityFirstStage, AutoencoderKL
         | 
| 36 | 
             
            from lib.util import make_beta_schedule, extract_into_tensor, noise_like
         | 
    	
        lib/openaimodel.py
    CHANGED
    
    | @@ -26,7 +26,7 @@ from lib.util import ( | |
| 26 | 
             
                timestep_embedding,
         | 
| 27 | 
             
            )
         | 
| 28 | 
             
            from attention import SpatialTransformer
         | 
| 29 | 
            -
            from  | 
| 30 |  | 
| 31 |  | 
| 32 | 
             
            # dummy replace
         | 
|  | |
| 26 | 
             
                timestep_embedding,
         | 
| 27 | 
             
            )
         | 
| 28 | 
             
            from attention import SpatialTransformer
         | 
| 29 | 
            +
            from utils import exists
         | 
| 30 |  | 
| 31 |  | 
| 32 | 
             
            # dummy replace
         | 
    	
        lib/util.py
    CHANGED
    
    | @@ -25,16 +25,8 @@ import torch.nn as nn | |
| 25 | 
             
            import numpy as np
         | 
| 26 | 
             
            from einops import repeat
         | 
| 27 |  | 
| 28 | 
            -
             | 
| 29 | 
            -
             | 
| 30 | 
            -
            def instantiate_from_config(config):
         | 
| 31 | 
            -
                if not "target" in config:
         | 
| 32 | 
            -
                    if config == '__is_first_stage__':
         | 
| 33 | 
            -
                        return None
         | 
| 34 | 
            -
                    elif config == "__is_unconditional__":
         | 
| 35 | 
            -
                        return None
         | 
| 36 | 
            -
                    raise KeyError("Expected key `target` to instantiate.")
         | 
| 37 | 
            -
                return get_obj_from_str(config["target"])(**config.get("params", dict()))
         | 
| 38 |  | 
| 39 | 
             
            def make_beta_schedule(schedule, n_timestep, linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3):
         | 
| 40 | 
             
                if schedule == "linear":
         | 
|  | |
| 25 | 
             
            import numpy as np
         | 
| 26 | 
             
            from einops import repeat
         | 
| 27 |  | 
| 28 | 
            +
            from utils import instantiate_from_config
         | 
| 29 | 
            +
             | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 30 |  | 
| 31 | 
             
            def make_beta_schedule(schedule, n_timestep, linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3):
         | 
| 32 | 
             
                if schedule == "linear":
         | 
    	
        lib/utils.py
    ADDED
    
    | @@ -0,0 +1,117 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            '''
         | 
| 2 | 
            +
             * Copyright (c) 2023 Salesforce, Inc.
         | 
| 3 | 
            +
             * All rights reserved.
         | 
| 4 | 
            +
             * SPDX-License-Identifier: Apache License 2.0
         | 
| 5 | 
            +
             * For full license text, see LICENSE.txt file in the repo root or http://www.apache.org/licenses/
         | 
| 6 | 
            +
             * By Can Qin
         | 
| 7 | 
            +
             * Modified from ControlNet repo: https://github.com/lllyasviel/ControlNet
         | 
| 8 | 
            +
             * Copyright (c) 2023 Lvmin Zhang and Maneesh Agrawala
         | 
| 9 | 
            +
            '''
         | 
| 10 | 
            +
             | 
| 11 | 
            +
            import os
         | 
| 12 | 
            +
            import torch
         | 
| 13 | 
            +
            from omegaconf import OmegaConf
         | 
| 14 | 
            +
            import importlib
         | 
| 15 | 
            +
            import numpy as np
         | 
| 16 | 
            +
             | 
| 17 | 
            +
             | 
| 18 | 
            +
            from inspect import isfunction
         | 
| 19 | 
            +
            from PIL import Image, ImageDraw, ImageFont
         | 
| 20 | 
            +
             | 
| 21 | 
            +
             | 
| 22 | 
            +
            def log_txt_as_img(wh, xc, size=10):
         | 
| 23 | 
            +
                # wh a tuple of (width, height)
         | 
| 24 | 
            +
                # xc a list of captions to plot
         | 
| 25 | 
            +
                b = len(xc)
         | 
| 26 | 
            +
                txts = list()
         | 
| 27 | 
            +
                for bi in range(b):
         | 
| 28 | 
            +
                    txt = Image.new("RGB", wh, color="white")
         | 
| 29 | 
            +
                    draw = ImageDraw.Draw(txt)
         | 
| 30 | 
            +
                    font = ImageFont.truetype('font/DejaVuSans.ttf', size=size)
         | 
| 31 | 
            +
                    nc = int(40 * (wh[0] / 256))
         | 
| 32 | 
            +
                    lines = "\n".join(xc[bi][start:start + nc] for start in range(0, len(xc[bi]), nc))
         | 
| 33 | 
            +
             | 
| 34 | 
            +
                    try:
         | 
| 35 | 
            +
                        draw.text((0, 0), lines, fill="black", font=font)
         | 
| 36 | 
            +
                    except UnicodeEncodeError:
         | 
| 37 | 
            +
                        print("Cant encode string for logging. Skipping.")
         | 
| 38 | 
            +
             | 
| 39 | 
            +
                    txt = np.array(txt).transpose(2, 0, 1) / 127.5 - 1.0
         | 
| 40 | 
            +
                    txts.append(txt)
         | 
| 41 | 
            +
                txts = np.stack(txts)
         | 
| 42 | 
            +
                txts = torch.tensor(txts)
         | 
| 43 | 
            +
                return txts
         | 
| 44 | 
            +
             | 
| 45 | 
            +
             | 
| 46 | 
            +
            def ismap(x):
         | 
| 47 | 
            +
                if not isinstance(x, torch.Tensor):
         | 
| 48 | 
            +
                    return False
         | 
| 49 | 
            +
                return (len(x.shape) == 4) and (x.shape[1] > 3)
         | 
| 50 | 
            +
             | 
| 51 | 
            +
             | 
| 52 | 
            +
            def isimage(x):
         | 
| 53 | 
            +
                if not isinstance(x,torch.Tensor):
         | 
| 54 | 
            +
                    return False
         | 
| 55 | 
            +
                return (len(x.shape) == 4) and (x.shape[1] == 3 or x.shape[1] == 1)
         | 
| 56 | 
            +
             | 
| 57 | 
            +
             | 
| 58 | 
            +
            def exists(x):
         | 
| 59 | 
            +
                return x is not None
         | 
| 60 | 
            +
             | 
| 61 | 
            +
             | 
| 62 | 
            +
            def default(val, d):
         | 
| 63 | 
            +
                if exists(val):
         | 
| 64 | 
            +
                    return val
         | 
| 65 | 
            +
                return d() if isfunction(d) else d
         | 
| 66 | 
            +
             | 
| 67 | 
            +
             | 
| 68 | 
            +
            def mean_flat(tensor):
         | 
| 69 | 
            +
                """
         | 
| 70 | 
            +
                https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/nn.py#L86
         | 
| 71 | 
            +
                Take the mean over all non-batch dimensions.
         | 
| 72 | 
            +
                """
         | 
| 73 | 
            +
                return tensor.mean(dim=list(range(1, len(tensor.shape))))
         | 
| 74 | 
            +
             | 
| 75 | 
            +
            def count_params(model, verbose=False):
         | 
| 76 | 
            +
                total_params = sum(p.numel() for p in model.parameters())
         | 
| 77 | 
            +
                if verbose:
         | 
| 78 | 
            +
                    print(f"{model.__class__.__name__} has {total_params*1.e-6:.2f} M params.")
         | 
| 79 | 
            +
                return total_params
         | 
| 80 | 
            +
             | 
| 81 | 
            +
             | 
| 82 | 
            +
            def get_state_dict(d):
         | 
| 83 | 
            +
                return d.get('state_dict', d)
         | 
| 84 | 
            +
             | 
| 85 | 
            +
             | 
| 86 | 
            +
            def load_state_dict(ckpt_path, location='cpu'):
         | 
| 87 | 
            +
                _, extension = os.path.splitext(ckpt_path)
         | 
| 88 | 
            +
                if extension.lower() == ".safetensors":
         | 
| 89 | 
            +
                    import safetensors.torch
         | 
| 90 | 
            +
                    state_dict = safetensors.torch.load_file(ckpt_path, device=location)
         | 
| 91 | 
            +
                else:
         | 
| 92 | 
            +
                    state_dict = get_state_dict(torch.load(ckpt_path, map_location=torch.device(location)))
         | 
| 93 | 
            +
                state_dict = get_state_dict(state_dict)
         | 
| 94 | 
            +
                print(f'Loaded state_dict from [{ckpt_path}]')
         | 
| 95 | 
            +
                return state_dict
         | 
| 96 | 
            +
             | 
| 97 | 
            +
            def get_obj_from_str(string, reload=False):
         | 
| 98 | 
            +
                module, cls = string.rsplit(".", 1)
         | 
| 99 | 
            +
                if reload:
         | 
| 100 | 
            +
                    module_imp = importlib.import_module(module)
         | 
| 101 | 
            +
                    importlib.reload(module_imp)
         | 
| 102 | 
            +
                return getattr(importlib.import_module(module, package=None), cls)
         | 
| 103 | 
            +
             | 
| 104 | 
            +
            def instantiate_from_config(config):
         | 
| 105 | 
            +
                if not "target" in config:
         | 
| 106 | 
            +
                    if config == '__is_first_stage__':
         | 
| 107 | 
            +
                        return None
         | 
| 108 | 
            +
                    elif config == "__is_unconditional__":
         | 
| 109 | 
            +
                        return None
         | 
| 110 | 
            +
                    raise KeyError("Expected key `target` to instantiate.")
         | 
| 111 | 
            +
                return get_obj_from_str(config["target"])(**config.get("params", dict()))
         | 
| 112 | 
            +
             | 
| 113 | 
            +
            def create_model(config_path):
         | 
| 114 | 
            +
                config = OmegaConf.load(config_path)
         | 
| 115 | 
            +
                model = instantiate_from_config(config.model).cpu()
         | 
| 116 | 
            +
                print(f'Loaded model config from [{config_path}]')
         | 
| 117 | 
            +
                return model
         |