Upload 27 files
Browse filesAdding sgmse folder
- sgmse/backbones/__init__.py +6 -0
- sgmse/backbones/dcunet.py +627 -0
- sgmse/backbones/ncsnpp.py +419 -0
- sgmse/backbones/ncsnpp_48k.py +424 -0
- sgmse/backbones/ncsnpp_utils/layers.py +662 -0
- sgmse/backbones/ncsnpp_utils/layerspp.py +274 -0
- sgmse/backbones/ncsnpp_utils/normalization.py +215 -0
- sgmse/backbones/ncsnpp_utils/op/__init__.py +1 -0
- sgmse/backbones/ncsnpp_utils/op/fused_act.py +97 -0
- sgmse/backbones/ncsnpp_utils/op/fused_bias_act.cpp +21 -0
- sgmse/backbones/ncsnpp_utils/op/fused_bias_act_kernel.cu +99 -0
- sgmse/backbones/ncsnpp_utils/op/upfirdn2d.cpp +23 -0
- sgmse/backbones/ncsnpp_utils/op/upfirdn2d.py +203 -0
- sgmse/backbones/ncsnpp_utils/op/upfirdn2d_kernel.cu +369 -0
- sgmse/backbones/ncsnpp_utils/up_or_down_sampling.py +257 -0
- sgmse/backbones/ncsnpp_utils/utils.py +189 -0
- sgmse/backbones/shared.py +123 -0
- sgmse/data_module.py +236 -0
- sgmse/model.py +253 -0
- sgmse/sampling/__init__.py +143 -0
- sgmse/sampling/correctors.py +96 -0
- sgmse/sampling/predictors.py +76 -0
- sgmse/sdes.py +310 -0
- sgmse/util/inference.py +64 -0
- sgmse/util/other.py +141 -0
- sgmse/util/registry.py +34 -0
- sgmse/util/tensors.py +16 -0
    	
        sgmse/backbones/__init__.py
    ADDED
    
    | @@ -0,0 +1,6 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            from .shared import BackboneRegistry
         | 
| 2 | 
            +
            from .ncsnpp import NCSNpp
         | 
| 3 | 
            +
            from .ncsnpp_48k import NCSNpp_48k
         | 
| 4 | 
            +
            from .dcunet import DCUNet
         | 
| 5 | 
            +
             | 
| 6 | 
            +
            __all__ = ['BackboneRegistry', 'NCSNpp', 'NCSNpp_48k', 'DCUNet']
         | 
    	
        sgmse/backbones/dcunet.py
    ADDED
    
    | @@ -0,0 +1,627 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            from functools import partial
         | 
| 2 | 
            +
            import numpy as np
         | 
| 3 | 
            +
             | 
| 4 | 
            +
            import torch
         | 
| 5 | 
            +
            from torch import nn, Tensor
         | 
| 6 | 
            +
            from torch.nn.modules.batchnorm import _BatchNorm
         | 
| 7 | 
            +
             | 
| 8 | 
            +
            from .shared import BackboneRegistry, ComplexConv2d, ComplexConvTranspose2d, ComplexLinear, \
         | 
| 9 | 
            +
                DiffusionStepEmbedding, GaussianFourierProjection, FeatureMapDense, torch_complex_from_reim
         | 
| 10 | 
            +
             | 
| 11 | 
            +
             | 
| 12 | 
            +
            def get_activation(name):
         | 
| 13 | 
            +
                if name == "silu":
         | 
| 14 | 
            +
                    return nn.SiLU
         | 
| 15 | 
            +
                elif name == "relu":
         | 
| 16 | 
            +
                    return nn.ReLU
         | 
| 17 | 
            +
                elif name == "leaky_relu":
         | 
| 18 | 
            +
                    return nn.LeakyReLU
         | 
| 19 | 
            +
                else:
         | 
| 20 | 
            +
                    raise NotImplementedError(f"Unknown activation: {name}")
         | 
| 21 | 
            +
             | 
| 22 | 
            +
             | 
| 23 | 
            +
            class BatchNorm(_BatchNorm):
         | 
| 24 | 
            +
                def _check_input_dim(self, input):
         | 
| 25 | 
            +
                    if input.dim() < 2 or input.dim() > 4:
         | 
| 26 | 
            +
                        raise ValueError("expected 4D or 3D input (got {}D input)".format(input.dim()))
         | 
| 27 | 
            +
             | 
| 28 | 
            +
             | 
| 29 | 
            +
            class OnReIm(nn.Module):
         | 
| 30 | 
            +
                def __init__(self, module_cls, *args, **kwargs):
         | 
| 31 | 
            +
                    super().__init__()
         | 
| 32 | 
            +
                    self.re_module = module_cls(*args, **kwargs)
         | 
| 33 | 
            +
                    self.im_module = module_cls(*args, **kwargs)
         | 
| 34 | 
            +
             | 
| 35 | 
            +
                def forward(self, x):
         | 
| 36 | 
            +
                    return torch_complex_from_reim(self.re_module(x.real), self.im_module(x.imag))
         | 
| 37 | 
            +
             | 
| 38 | 
            +
             | 
| 39 | 
            +
            # Code for DCUNet largely copied from Danilo's `informedenh` repo, cheers!
         | 
| 40 | 
            +
             | 
| 41 | 
            +
            def unet_decoder_args(encoders, *, skip_connections):
         | 
| 42 | 
            +
                """Get list of decoder arguments for upsampling (right) side of a symmetric u-net,
         | 
| 43 | 
            +
                given the arguments used to construct the encoder.
         | 
| 44 | 
            +
                Args:
         | 
| 45 | 
            +
                    encoders (tuple of length `N` of tuples of (in_chan, out_chan, kernel_size, stride, padding)):
         | 
| 46 | 
            +
                        List of arguments used to construct the encoders
         | 
| 47 | 
            +
                    skip_connections (bool): Whether to include skip connections in the
         | 
| 48 | 
            +
                        calculation of decoder input channels.
         | 
| 49 | 
            +
                Return:
         | 
| 50 | 
            +
                    tuple of length `N` of tuples of (in_chan, out_chan, kernel_size, stride, padding):
         | 
| 51 | 
            +
                        Arguments to be used to construct decoders
         | 
| 52 | 
            +
                """
         | 
| 53 | 
            +
                decoder_args = []
         | 
| 54 | 
            +
                for enc_in_chan, enc_out_chan, enc_kernel_size, enc_stride, enc_padding, enc_dilation in reversed(encoders):
         | 
| 55 | 
            +
                    if skip_connections and decoder_args:
         | 
| 56 | 
            +
                        skip_in_chan = enc_out_chan
         | 
| 57 | 
            +
                    else:
         | 
| 58 | 
            +
                        skip_in_chan = 0
         | 
| 59 | 
            +
                    decoder_args.append(
         | 
| 60 | 
            +
                        (enc_out_chan + skip_in_chan, enc_in_chan, enc_kernel_size, enc_stride, enc_padding, enc_dilation)
         | 
| 61 | 
            +
                    )
         | 
| 62 | 
            +
                return tuple(decoder_args)
         | 
| 63 | 
            +
             | 
| 64 | 
            +
             | 
| 65 | 
            +
            def make_unet_encoder_decoder_args(encoder_args, decoder_args):
         | 
| 66 | 
            +
                encoder_args = tuple(
         | 
| 67 | 
            +
                    (
         | 
| 68 | 
            +
                        in_chan,
         | 
| 69 | 
            +
                        out_chan,
         | 
| 70 | 
            +
                        tuple(kernel_size),
         | 
| 71 | 
            +
                        tuple(stride),
         | 
| 72 | 
            +
                        tuple([n // 2 for n in kernel_size]) if padding == "auto" else tuple(padding),
         | 
| 73 | 
            +
                        tuple(dilation)
         | 
| 74 | 
            +
                    )
         | 
| 75 | 
            +
                    for in_chan, out_chan, kernel_size, stride, padding, dilation in encoder_args
         | 
| 76 | 
            +
                )
         | 
| 77 | 
            +
             | 
| 78 | 
            +
                if decoder_args == "auto":
         | 
| 79 | 
            +
                    decoder_args = unet_decoder_args(
         | 
| 80 | 
            +
                        encoder_args,
         | 
| 81 | 
            +
                        skip_connections=True,
         | 
| 82 | 
            +
                    )
         | 
| 83 | 
            +
                else:
         | 
| 84 | 
            +
                    decoder_args = tuple(
         | 
| 85 | 
            +
                        (
         | 
| 86 | 
            +
                            in_chan,
         | 
| 87 | 
            +
                            out_chan,
         | 
| 88 | 
            +
                            tuple(kernel_size),
         | 
| 89 | 
            +
                            tuple(stride),
         | 
| 90 | 
            +
                            tuple([n // 2 for n in kernel_size]) if padding == "auto" else padding,
         | 
| 91 | 
            +
                            tuple(dilation),
         | 
| 92 | 
            +
                            output_padding,
         | 
| 93 | 
            +
                        )
         | 
| 94 | 
            +
                        for in_chan, out_chan, kernel_size, stride, padding, dilation, output_padding in decoder_args
         | 
| 95 | 
            +
                    )
         | 
| 96 | 
            +
             | 
| 97 | 
            +
                return encoder_args, decoder_args
         | 
| 98 | 
            +
             | 
| 99 | 
            +
             | 
| 100 | 
            +
            DCUNET_ARCHITECTURES = {
         | 
| 101 | 
            +
                "DCUNet-10": make_unet_encoder_decoder_args(
         | 
| 102 | 
            +
                    # Encoders:
         | 
| 103 | 
            +
                    # (in_chan, out_chan, kernel_size, stride, padding, dilation)
         | 
| 104 | 
            +
                    (
         | 
| 105 | 
            +
                        (1, 32,  (7, 5), (2, 2), "auto", (1,1)),
         | 
| 106 | 
            +
                        (32, 64, (7, 5), (2, 2), "auto", (1,1)),
         | 
| 107 | 
            +
                        (64, 64, (5, 3), (2, 2), "auto", (1,1)),
         | 
| 108 | 
            +
                        (64, 64, (5, 3), (2, 2), "auto", (1,1)),
         | 
| 109 | 
            +
                        (64, 64, (5, 3), (2, 1), "auto", (1,1)),
         | 
| 110 | 
            +
                    ),
         | 
| 111 | 
            +
                    # Decoders: automatic inverse
         | 
| 112 | 
            +
                    "auto",
         | 
| 113 | 
            +
                ),
         | 
| 114 | 
            +
                "DCUNet-16": make_unet_encoder_decoder_args(
         | 
| 115 | 
            +
                    # Encoders:
         | 
| 116 | 
            +
                    # (in_chan, out_chan, kernel_size, stride, padding, dilation)
         | 
| 117 | 
            +
                    (
         | 
| 118 | 
            +
                        (1,  32, (7, 5), (2, 2), "auto", (1,1)),
         | 
| 119 | 
            +
                        (32, 32, (7, 5), (2, 1), "auto", (1,1)),
         | 
| 120 | 
            +
                        (32, 64, (7, 5), (2, 2), "auto", (1,1)),
         | 
| 121 | 
            +
                        (64, 64, (5, 3), (2, 1), "auto", (1,1)),
         | 
| 122 | 
            +
                        (64, 64, (5, 3), (2, 2), "auto", (1,1)),
         | 
| 123 | 
            +
                        (64, 64, (5, 3), (2, 1), "auto", (1,1)),
         | 
| 124 | 
            +
                        (64, 64, (5, 3), (2, 2), "auto", (1,1)),
         | 
| 125 | 
            +
                        (64, 64, (5, 3), (2, 1), "auto", (1,1)),
         | 
| 126 | 
            +
                    ),
         | 
| 127 | 
            +
                    # Decoders: automatic inverse
         | 
| 128 | 
            +
                    "auto",
         | 
| 129 | 
            +
                ),
         | 
| 130 | 
            +
                "DCUNet-20": make_unet_encoder_decoder_args(
         | 
| 131 | 
            +
                    # Encoders:
         | 
| 132 | 
            +
                    # (in_chan, out_chan, kernel_size, stride, padding, dilation)
         | 
| 133 | 
            +
                    (
         | 
| 134 | 
            +
                        (1,  32, (7, 1), (1, 1), "auto", (1,1)),
         | 
| 135 | 
            +
                        (32, 32, (1, 7), (1, 1), "auto", (1,1)),
         | 
| 136 | 
            +
                        (32, 64, (7, 5), (2, 2), "auto", (1,1)),
         | 
| 137 | 
            +
                        (64, 64, (7, 5), (2, 1), "auto", (1,1)),
         | 
| 138 | 
            +
                        (64, 64, (5, 3), (2, 2), "auto", (1,1)),
         | 
| 139 | 
            +
                        (64, 64, (5, 3), (2, 1), "auto", (1,1)),
         | 
| 140 | 
            +
                        (64, 64, (5, 3), (2, 2), "auto", (1,1)),
         | 
| 141 | 
            +
                        (64, 64, (5, 3), (2, 1), "auto", (1,1)),
         | 
| 142 | 
            +
                        (64, 64, (5, 3), (2, 2), "auto", (1,1)),
         | 
| 143 | 
            +
                        (64, 90, (5, 3), (2, 1), "auto", (1,1)),
         | 
| 144 | 
            +
                    ),
         | 
| 145 | 
            +
                    # Decoders: automatic inverse
         | 
| 146 | 
            +
                    "auto",
         | 
| 147 | 
            +
                ),
         | 
| 148 | 
            +
                "DilDCUNet-v2": make_unet_encoder_decoder_args(  # architecture used in SGMSE / Interspeech paper
         | 
| 149 | 
            +
                    # Encoders:
         | 
| 150 | 
            +
                    # (in_chan, out_chan, kernel_size, stride, padding, dilation)
         | 
| 151 | 
            +
                    (
         | 
| 152 | 
            +
                        (1,  32,   (4, 4), (1, 1), "auto", (1, 1)),
         | 
| 153 | 
            +
                        (32, 32,   (4, 4), (1, 1), "auto", (1, 1)),
         | 
| 154 | 
            +
                        (32, 32,   (4, 4), (1, 1), "auto", (1, 1)),
         | 
| 155 | 
            +
                        (32, 64,   (4, 4), (2, 1), "auto", (2, 1)),
         | 
| 156 | 
            +
                        (64, 128,  (4, 4), (2, 2), "auto", (4, 1)),
         | 
| 157 | 
            +
                        (128, 256, (4, 4), (2, 2), "auto", (8, 1)),
         | 
| 158 | 
            +
                    ),
         | 
| 159 | 
            +
                    # Decoders: automatic inverse
         | 
| 160 | 
            +
                    "auto",
         | 
| 161 | 
            +
                ),
         | 
| 162 | 
            +
            }
         | 
| 163 | 
            +
             | 
| 164 | 
            +
             | 
| 165 | 
            +
            @BackboneRegistry.register("dcunet")
         | 
| 166 | 
            +
            class DCUNet(nn.Module):
         | 
| 167 | 
            +
                @staticmethod
         | 
| 168 | 
            +
                def add_argparse_args(parser):
         | 
| 169 | 
            +
                    parser.add_argument("--dcunet-architecture", type=str, default="DilDCUNet-v2", choices=DCUNET_ARCHITECTURES.keys(), help="The concrete DCUNet architecture. 'DilDCUNet-v2' by default.")
         | 
| 170 | 
            +
                    parser.add_argument("--dcunet-time-embedding", type=str, choices=("gfp", "ds", "none"), default="gfp", help="Timestep embedding style. 'gfp' (Gaussian Fourier Projections) by default.")
         | 
| 171 | 
            +
                    parser.add_argument("--dcunet-temb-layers-global", type=int, default=1, help="Number of global linear+activation layers for the time embedding. 1 by default.")
         | 
| 172 | 
            +
                    parser.add_argument("--dcunet-temb-layers-local", type=int, default=1, help="Number of local (per-encoder/per-decoder) linear+activation layers for the time embedding. 1 by default.")
         | 
| 173 | 
            +
                    parser.add_argument("--dcunet-temb-activation", type=str, default="silu", help="The (complex) activation to use between all (global&local) time embedding layers.")
         | 
| 174 | 
            +
                    parser.add_argument("--dcunet-time-embedding-complex", action="store_true", help="Use complex-valued timestep embedding. Compatible with 'gfp' and 'ds' embeddings.")
         | 
| 175 | 
            +
                    parser.add_argument("--dcunet-fix-length", type=str, default="pad", choices=("pad", "trim", "none"), help="DCUNet strategy to 'fix' mismatched input timespan. 'pad' by default.")
         | 
| 176 | 
            +
                    parser.add_argument("--dcunet-mask-bound", type=str, choices=("tanh", "sigmoid", "none"), default="none", help="DCUNet output bounding strategy. 'none' by default.")
         | 
| 177 | 
            +
                    parser.add_argument("--dcunet-norm-type", type=str, choices=("bN", "CbN"), default="bN", help="The type of norm to use within each encoder and decoder layer. 'bN' (real/imaginary separate batch norm) by default.")
         | 
| 178 | 
            +
                    parser.add_argument("--dcunet-activation", type=str, choices=("leaky_relu", "relu", "silu"), default="leaky_relu", help="The activation to use within each encoder and decoder layer. 'leaky_relu' by default.")
         | 
| 179 | 
            +
                    return parser
         | 
| 180 | 
            +
             | 
| 181 | 
            +
                def __init__(
         | 
| 182 | 
            +
                    self,
         | 
| 183 | 
            +
                    dcunet_architecture: str = "DilDCUNet-v2",
         | 
| 184 | 
            +
                    dcunet_time_embedding: str = "gfp",
         | 
| 185 | 
            +
                    dcunet_temb_layers_global: int = 2,
         | 
| 186 | 
            +
                    dcunet_temb_layers_local: int = 1,
         | 
| 187 | 
            +
                    dcunet_temb_activation: str = "silu",
         | 
| 188 | 
            +
                    dcunet_time_embedding_complex: bool = False,
         | 
| 189 | 
            +
                    dcunet_fix_length: str = "pad",
         | 
| 190 | 
            +
                    dcunet_mask_bound: str = "none",
         | 
| 191 | 
            +
                    dcunet_norm_type: str = "bN",
         | 
| 192 | 
            +
                    dcunet_activation: str = "relu",
         | 
| 193 | 
            +
                    embed_dim: int = 128,
         | 
| 194 | 
            +
                    **kwargs
         | 
| 195 | 
            +
                ):
         | 
| 196 | 
            +
                    super().__init__()
         | 
| 197 | 
            +
             | 
| 198 | 
            +
                    self.architecture = dcunet_architecture
         | 
| 199 | 
            +
                    self.fix_length_mode = (dcunet_fix_length if dcunet_fix_length != "none" else None)
         | 
| 200 | 
            +
                    self.norm_type = dcunet_norm_type
         | 
| 201 | 
            +
                    self.activation = dcunet_activation
         | 
| 202 | 
            +
                    self.input_channels = 2  # for x_t and y -- note that this is 2 rather than 4, because we directly treat complex channels in this DNN
         | 
| 203 | 
            +
                    self.time_embedding = (dcunet_time_embedding if dcunet_time_embedding != "none" else None)
         | 
| 204 | 
            +
                    self.time_embedding_complex = dcunet_time_embedding_complex
         | 
| 205 | 
            +
                    self.temb_layers_global = dcunet_temb_layers_global
         | 
| 206 | 
            +
                    self.temb_layers_local = dcunet_temb_layers_local
         | 
| 207 | 
            +
                    self.temb_activation = dcunet_temb_activation
         | 
| 208 | 
            +
                    conf_encoders, conf_decoders = DCUNET_ARCHITECTURES[dcunet_architecture]
         | 
| 209 | 
            +
             | 
| 210 | 
            +
                    # Replace `input_channels` in encoders config
         | 
| 211 | 
            +
                    _replaced_input_channels, *rest = conf_encoders[0]
         | 
| 212 | 
            +
                    encoders = ((self.input_channels, *rest), *conf_encoders[1:])
         | 
| 213 | 
            +
                    decoders = conf_decoders
         | 
| 214 | 
            +
                    self.encoders_stride_product = np.prod(
         | 
| 215 | 
            +
                        [enc_stride for _, _, _, enc_stride, _, _ in encoders], axis=0
         | 
| 216 | 
            +
                    )
         | 
| 217 | 
            +
             | 
| 218 | 
            +
                    # Prepare kwargs for encoder and decoder (to potentially be modified before layer instantiation)
         | 
| 219 | 
            +
                    encoder_decoder_kwargs = dict(
         | 
| 220 | 
            +
                        norm_type=self.norm_type, activation=self.activation,
         | 
| 221 | 
            +
                        temb_layers=self.temb_layers_local, temb_activation=self.temb_activation)
         | 
| 222 | 
            +
             | 
| 223 | 
            +
                    # Instantiate (global) time embedding layer
         | 
| 224 | 
            +
                    embed_ops = []
         | 
| 225 | 
            +
                    if self.time_embedding is not None:
         | 
| 226 | 
            +
                        complex_valued = self.time_embedding_complex
         | 
| 227 | 
            +
                        if self.time_embedding == "gfp":
         | 
| 228 | 
            +
                            embed_ops += [GaussianFourierProjection(embed_dim=embed_dim, complex_valued=complex_valued)]
         | 
| 229 | 
            +
                            encoder_decoder_kwargs["embed_dim"] = embed_dim
         | 
| 230 | 
            +
                        elif self.time_embedding == "ds":
         | 
| 231 | 
            +
                            embed_ops += [DiffusionStepEmbedding(embed_dim=embed_dim, complex_valued=complex_valued)]
         | 
| 232 | 
            +
                            encoder_decoder_kwargs["embed_dim"] = embed_dim
         | 
| 233 | 
            +
             | 
| 234 | 
            +
                        if self.time_embedding_complex:
         | 
| 235 | 
            +
                            assert self.time_embedding in ("gfp", "ds"), "Complex timestep embedding only available for gfp and ds"
         | 
| 236 | 
            +
                            encoder_decoder_kwargs["complex_time_embedding"] = True
         | 
| 237 | 
            +
                        for _ in range(self.temb_layers_global):
         | 
| 238 | 
            +
                            embed_ops += [
         | 
| 239 | 
            +
                                ComplexLinear(embed_dim, embed_dim, complex_valued=True),
         | 
| 240 | 
            +
                                OnReIm(get_activation(dcunet_temb_activation))
         | 
| 241 | 
            +
                            ]
         | 
| 242 | 
            +
                    self.embed = nn.Sequential(*embed_ops)
         | 
| 243 | 
            +
             | 
| 244 | 
            +
                    ### Instantiate DCUNet layers ###
         | 
| 245 | 
            +
                    output_layer = ComplexConvTranspose2d(*decoders[-1])
         | 
| 246 | 
            +
                    encoders = [DCUNetComplexEncoderBlock(*args, **encoder_decoder_kwargs) for args in encoders]
         | 
| 247 | 
            +
                    decoders = [DCUNetComplexDecoderBlock(*args, **encoder_decoder_kwargs) for args in decoders[:-1]]
         | 
| 248 | 
            +
             | 
| 249 | 
            +
                    self.mask_bound = (dcunet_mask_bound if dcunet_mask_bound != "none" else None)
         | 
| 250 | 
            +
                    if self.mask_bound is not None:
         | 
| 251 | 
            +
                        raise NotImplementedError("sorry, mask bounding not implemented at the moment")
         | 
| 252 | 
            +
                        # TODO we can't use nn.Sequential since the ComplexConvTranspose2d needs a second `output_size` argument
         | 
| 253 | 
            +
                    #operations = (output_layer, complex_nn.BoundComplexMask(self.mask_bound))
         | 
| 254 | 
            +
                    #output_layer = nn.Sequential(*[x for x in operations if x is not None])
         | 
| 255 | 
            +
             | 
| 256 | 
            +
                    assert len(encoders) == len(decoders) + 1
         | 
| 257 | 
            +
                    self.encoders = nn.ModuleList(encoders)
         | 
| 258 | 
            +
                    self.decoders = nn.ModuleList(decoders)
         | 
| 259 | 
            +
                    self.output_layer = output_layer or nn.Identity()
         | 
| 260 | 
            +
             | 
| 261 | 
            +
                def forward(self, spec, t) -> Tensor:
         | 
| 262 | 
            +
                    """
         | 
| 263 | 
            +
                    Input shape is expected to be $(batch, nfreqs, time)$, with $nfreqs - 1$ divisible
         | 
| 264 | 
            +
                    by $f_0 * f_1 * ... * f_N$ where $f_k$ are the frequency strides of the encoders,
         | 
| 265 | 
            +
                    and $time - 1$ is divisible by $t_0 * t_1 * ... * t_N$ where $t_N$ are the time
         | 
| 266 | 
            +
                    strides of the encoders.
         | 
| 267 | 
            +
                    Args:
         | 
| 268 | 
            +
                        spec (Tensor): complex spectrogram tensor. 1D, 2D or 3D tensor, time last.
         | 
| 269 | 
            +
                    Returns:
         | 
| 270 | 
            +
                        Tensor, of shape (batch, time) or (time).
         | 
| 271 | 
            +
                    """
         | 
| 272 | 
            +
                    # TF-rep shape: (batch, self.input_channels, n_fft, frames)
         | 
| 273 | 
            +
                    # Estimate mask from time-frequency representation.
         | 
| 274 | 
            +
                    x_in = self.fix_input_dims(spec)
         | 
| 275 | 
            +
                    x = x_in
         | 
| 276 | 
            +
                    t_embed = self.embed(t+0j) if self.time_embedding is not None else None
         | 
| 277 | 
            +
             | 
| 278 | 
            +
                    enc_outs = []
         | 
| 279 | 
            +
                    for idx, enc in enumerate(self.encoders):
         | 
| 280 | 
            +
                        x = enc(x, t_embed)
         | 
| 281 | 
            +
                        # UNet skip connection
         | 
| 282 | 
            +
                        enc_outs.append(x)
         | 
| 283 | 
            +
                    for (enc_out, dec) in zip(reversed(enc_outs[:-1]), self.decoders):
         | 
| 284 | 
            +
                        x = dec(x, t_embed, output_size=enc_out.shape)
         | 
| 285 | 
            +
                        x = torch.cat([x, enc_out], dim=1)
         | 
| 286 | 
            +
             | 
| 287 | 
            +
                    output = self.output_layer(x, output_size=x_in.shape)
         | 
| 288 | 
            +
                    # output shape: (batch, 1, n_fft, frames)
         | 
| 289 | 
            +
                    output = self.fix_output_dims(output, spec)
         | 
| 290 | 
            +
                    return output
         | 
| 291 | 
            +
             | 
| 292 | 
            +
                def fix_input_dims(self, x):
         | 
| 293 | 
            +
                    return _fix_dcu_input_dims(
         | 
| 294 | 
            +
                        self.fix_length_mode, x, torch.from_numpy(self.encoders_stride_product)
         | 
| 295 | 
            +
                    )
         | 
| 296 | 
            +
             | 
| 297 | 
            +
                def fix_output_dims(self, out, x):
         | 
| 298 | 
            +
                    return _fix_dcu_output_dims(self.fix_length_mode, out, x)
         | 
| 299 | 
            +
             | 
| 300 | 
            +
             | 
| 301 | 
            +
            def _fix_dcu_input_dims(fix_length_mode, x, encoders_stride_product):
         | 
| 302 | 
            +
                """Pad or trim `x` to a length compatible with DCUNet."""
         | 
| 303 | 
            +
                freq_prod = int(encoders_stride_product[0])
         | 
| 304 | 
            +
                time_prod = int(encoders_stride_product[1])
         | 
| 305 | 
            +
                if (x.shape[2] - 1) % freq_prod:
         | 
| 306 | 
            +
                    raise TypeError(
         | 
| 307 | 
            +
                        f"Input shape must be [batch, ch, freq + 1, time + 1] with freq divisible by "
         | 
| 308 | 
            +
                        f"{freq_prod}, got {x.shape} instead"
         | 
| 309 | 
            +
                    )
         | 
| 310 | 
            +
                time_remainder = (x.shape[3] - 1) % time_prod
         | 
| 311 | 
            +
                if time_remainder:
         | 
| 312 | 
            +
                    if fix_length_mode is None:
         | 
| 313 | 
            +
                        raise TypeError(
         | 
| 314 | 
            +
                            f"Input shape must be [batch, ch, freq + 1, time + 1] with time divisible by "
         | 
| 315 | 
            +
                            f"{time_prod}, got {x.shape} instead. Set the 'fix_length_mode' argument "
         | 
| 316 | 
            +
                            f"in 'DCUNet' to 'pad' or 'trim' to fix shapes automatically."
         | 
| 317 | 
            +
                        )
         | 
| 318 | 
            +
                    elif fix_length_mode == "pad":
         | 
| 319 | 
            +
                        pad_shape = [0, time_prod - time_remainder]
         | 
| 320 | 
            +
                        x = nn.functional.pad(x, pad_shape, mode="constant")
         | 
| 321 | 
            +
                    elif fix_length_mode == "trim":
         | 
| 322 | 
            +
                        pad_shape = [0, -time_remainder]
         | 
| 323 | 
            +
                        x = nn.functional.pad(x, pad_shape, mode="constant")
         | 
| 324 | 
            +
                    else:
         | 
| 325 | 
            +
                        raise ValueError(f"Unknown fix_length mode '{fix_length_mode}'")
         | 
| 326 | 
            +
                return x
         | 
| 327 | 
            +
             | 
| 328 | 
            +
             | 
| 329 | 
            +
            def _fix_dcu_output_dims(fix_length_mode, out, x):
         | 
| 330 | 
            +
                """Fix shape of `out` to the original shape of `x` by padding/cropping."""
         | 
| 331 | 
            +
                inp_len = x.shape[-1]
         | 
| 332 | 
            +
                output_len = out.shape[-1]
         | 
| 333 | 
            +
                return nn.functional.pad(out, [0, inp_len - output_len])
         | 
| 334 | 
            +
             | 
| 335 | 
            +
             | 
| 336 | 
            +
            def _get_norm(norm_type):
         | 
| 337 | 
            +
                if norm_type == "CbN":
         | 
| 338 | 
            +
                    return ComplexBatchNorm
         | 
| 339 | 
            +
                elif norm_type == "bN":
         | 
| 340 | 
            +
                    return partial(OnReIm, BatchNorm)
         | 
| 341 | 
            +
                else:
         | 
| 342 | 
            +
                    raise NotImplementedError(f"Unknown norm type: {norm_type}")
         | 
| 343 | 
            +
             | 
| 344 | 
            +
             | 
| 345 | 
            +
            class DCUNetComplexEncoderBlock(nn.Module):
         | 
| 346 | 
            +
                def __init__(
         | 
| 347 | 
            +
                    self,
         | 
| 348 | 
            +
                    in_chan,
         | 
| 349 | 
            +
                    out_chan,
         | 
| 350 | 
            +
                    kernel_size,
         | 
| 351 | 
            +
                    stride,
         | 
| 352 | 
            +
                    padding,
         | 
| 353 | 
            +
                    dilation,
         | 
| 354 | 
            +
                    norm_type="bN",
         | 
| 355 | 
            +
                    activation="leaky_relu",
         | 
| 356 | 
            +
                    embed_dim=None,
         | 
| 357 | 
            +
                    complex_time_embedding=False,
         | 
| 358 | 
            +
                    temb_layers=1,
         | 
| 359 | 
            +
                    temb_activation="silu"
         | 
| 360 | 
            +
                ):
         | 
| 361 | 
            +
                    super().__init__()
         | 
| 362 | 
            +
             | 
| 363 | 
            +
                    self.in_chan = in_chan
         | 
| 364 | 
            +
                    self.out_chan = out_chan
         | 
| 365 | 
            +
                    self.kernel_size = kernel_size
         | 
| 366 | 
            +
                    self.stride = stride
         | 
| 367 | 
            +
                    self.padding = padding
         | 
| 368 | 
            +
                    self.dilation = dilation
         | 
| 369 | 
            +
                    self.temb_layers = temb_layers
         | 
| 370 | 
            +
                    self.temb_activation = temb_activation
         | 
| 371 | 
            +
                    self.complex_time_embedding = complex_time_embedding
         | 
| 372 | 
            +
             | 
| 373 | 
            +
                    self.conv = ComplexConv2d(
         | 
| 374 | 
            +
                        in_chan, out_chan, kernel_size, stride, padding, bias=norm_type is None, dilation=dilation
         | 
| 375 | 
            +
                    )
         | 
| 376 | 
            +
                    self.norm = _get_norm(norm_type)(out_chan)
         | 
| 377 | 
            +
                    self.activation = OnReIm(get_activation(activation))
         | 
| 378 | 
            +
                    self.embed_dim = embed_dim
         | 
| 379 | 
            +
                    if self.embed_dim is not None:
         | 
| 380 | 
            +
                        ops = []
         | 
| 381 | 
            +
                        for _ in range(max(0, self.temb_layers - 1)):
         | 
| 382 | 
            +
                            ops += [
         | 
| 383 | 
            +
                                ComplexLinear(self.embed_dim, self.embed_dim, complex_valued=True),
         | 
| 384 | 
            +
                                OnReIm(get_activation(self.temb_activation))
         | 
| 385 | 
            +
                            ]
         | 
| 386 | 
            +
                        ops += [
         | 
| 387 | 
            +
                            FeatureMapDense(self.embed_dim, self.out_chan, complex_valued=True),
         | 
| 388 | 
            +
                            OnReIm(get_activation(self.temb_activation))
         | 
| 389 | 
            +
                        ]
         | 
| 390 | 
            +
                        self.embed_layer = nn.Sequential(*ops)
         | 
| 391 | 
            +
             | 
| 392 | 
            +
                def forward(self, x, t_embed):
         | 
| 393 | 
            +
                    y = self.conv(x)
         | 
| 394 | 
            +
                    if self.embed_dim is not None:
         | 
| 395 | 
            +
                        y = y + self.embed_layer(t_embed)
         | 
| 396 | 
            +
                    return self.activation(self.norm(y))
         | 
| 397 | 
            +
             | 
| 398 | 
            +
             | 
| 399 | 
            +
            class DCUNetComplexDecoderBlock(nn.Module):
         | 
| 400 | 
            +
                def __init__(
         | 
| 401 | 
            +
                    self,
         | 
| 402 | 
            +
                    in_chan,
         | 
| 403 | 
            +
                    out_chan,
         | 
| 404 | 
            +
                    kernel_size,
         | 
| 405 | 
            +
                    stride,
         | 
| 406 | 
            +
                    padding,
         | 
| 407 | 
            +
                    dilation,
         | 
| 408 | 
            +
                    output_padding=(0, 0),
         | 
| 409 | 
            +
                    norm_type="bN",
         | 
| 410 | 
            +
                    activation="leaky_relu",
         | 
| 411 | 
            +
                    embed_dim=None,
         | 
| 412 | 
            +
                    temb_layers=1,
         | 
| 413 | 
            +
                    temb_activation='swish',
         | 
| 414 | 
            +
                    complex_time_embedding=False,
         | 
| 415 | 
            +
                ):
         | 
| 416 | 
            +
                    super().__init__()
         | 
| 417 | 
            +
             | 
| 418 | 
            +
                    self.in_chan = in_chan
         | 
| 419 | 
            +
                    self.out_chan = out_chan
         | 
| 420 | 
            +
                    self.kernel_size = kernel_size
         | 
| 421 | 
            +
                    self.stride = stride
         | 
| 422 | 
            +
                    self.padding = padding
         | 
| 423 | 
            +
                    self.dilation = dilation
         | 
| 424 | 
            +
                    self.output_padding = output_padding
         | 
| 425 | 
            +
                    self.complex_time_embedding = complex_time_embedding
         | 
| 426 | 
            +
                    self.temb_layers = temb_layers
         | 
| 427 | 
            +
                    self.temb_activation = temb_activation
         | 
| 428 | 
            +
             | 
| 429 | 
            +
                    self.deconv = ComplexConvTranspose2d(
         | 
| 430 | 
            +
                        in_chan, out_chan, kernel_size, stride, padding, output_padding, dilation=dilation, bias=norm_type is None
         | 
| 431 | 
            +
                    )
         | 
| 432 | 
            +
                    self.norm = _get_norm(norm_type)(out_chan)
         | 
| 433 | 
            +
                    self.activation = OnReIm(get_activation(activation))
         | 
| 434 | 
            +
                    self.embed_dim = embed_dim
         | 
| 435 | 
            +
                    if self.embed_dim is not None:
         | 
| 436 | 
            +
                        ops = []
         | 
| 437 | 
            +
                        for _ in range(max(0, self.temb_layers - 1)):
         | 
| 438 | 
            +
                            ops += [
         | 
| 439 | 
            +
                                ComplexLinear(self.embed_dim, self.embed_dim, complex_valued=True),
         | 
| 440 | 
            +
                                OnReIm(get_activation(self.temb_activation))
         | 
| 441 | 
            +
                            ]
         | 
| 442 | 
            +
                        ops += [
         | 
| 443 | 
            +
                            FeatureMapDense(self.embed_dim, self.out_chan, complex_valued=True),
         | 
| 444 | 
            +
                            OnReIm(get_activation(self.temb_activation))
         | 
| 445 | 
            +
                        ]
         | 
| 446 | 
            +
                        self.embed_layer = nn.Sequential(*ops)
         | 
| 447 | 
            +
             | 
| 448 | 
            +
                def forward(self, x, t_embed, output_size=None):
         | 
| 449 | 
            +
                    y = self.deconv(x, output_size=output_size)
         | 
| 450 | 
            +
                    if self.embed_dim is not None:
         | 
| 451 | 
            +
                        y = y + self.embed_layer(t_embed)
         | 
| 452 | 
            +
                    return self.activation(self.norm(y))
         | 
| 453 | 
            +
             | 
| 454 | 
            +
             | 
| 455 | 
            +
            # From https://github.com/chanil1218/DCUnet.pytorch/blob/2dcdd30804be47a866fde6435cbb7e2f81585213/models/layers/complexnn.py
         | 
| 456 | 
            +
            class ComplexBatchNorm(torch.nn.Module):
         | 
| 457 | 
            +
                def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True, track_running_stats=False):
         | 
| 458 | 
            +
                    super(ComplexBatchNorm, self).__init__()
         | 
| 459 | 
            +
                    self.num_features        = num_features
         | 
| 460 | 
            +
                    self.eps                 = eps
         | 
| 461 | 
            +
                    self.momentum            = momentum
         | 
| 462 | 
            +
                    self.affine              = affine
         | 
| 463 | 
            +
                    self.track_running_stats = track_running_stats
         | 
| 464 | 
            +
                    if self.affine:
         | 
| 465 | 
            +
                        self.Wrr = torch.nn.Parameter(torch.Tensor(num_features))
         | 
| 466 | 
            +
                        self.Wri = torch.nn.Parameter(torch.Tensor(num_features))
         | 
| 467 | 
            +
                        self.Wii = torch.nn.Parameter(torch.Tensor(num_features))
         | 
| 468 | 
            +
                        self.Br  = torch.nn.Parameter(torch.Tensor(num_features))
         | 
| 469 | 
            +
                        self.Bi  = torch.nn.Parameter(torch.Tensor(num_features))
         | 
| 470 | 
            +
                    else:
         | 
| 471 | 
            +
                        self.register_parameter('Wrr', None)
         | 
| 472 | 
            +
                        self.register_parameter('Wri', None)
         | 
| 473 | 
            +
                        self.register_parameter('Wii', None)
         | 
| 474 | 
            +
                        self.register_parameter('Br',  None)
         | 
| 475 | 
            +
                        self.register_parameter('Bi',  None)
         | 
| 476 | 
            +
                    if self.track_running_stats:
         | 
| 477 | 
            +
                        self.register_buffer('RMr',  torch.zeros(num_features))
         | 
| 478 | 
            +
                        self.register_buffer('RMi',  torch.zeros(num_features))
         | 
| 479 | 
            +
                        self.register_buffer('RVrr', torch.ones (num_features))
         | 
| 480 | 
            +
                        self.register_buffer('RVri', torch.zeros(num_features))
         | 
| 481 | 
            +
                        self.register_buffer('RVii', torch.ones (num_features))
         | 
| 482 | 
            +
                        self.register_buffer('num_batches_tracked', torch.tensor(0, dtype=torch.long))
         | 
| 483 | 
            +
                    else:
         | 
| 484 | 
            +
                        self.register_parameter('RMr',                 None)
         | 
| 485 | 
            +
                        self.register_parameter('RMi',                 None)
         | 
| 486 | 
            +
                        self.register_parameter('RVrr',                None)
         | 
| 487 | 
            +
                        self.register_parameter('RVri',                None)
         | 
| 488 | 
            +
                        self.register_parameter('RVii',                None)
         | 
| 489 | 
            +
                        self.register_parameter('num_batches_tracked', None)
         | 
| 490 | 
            +
                    self.reset_parameters()
         | 
| 491 | 
            +
             | 
| 492 | 
            +
                def reset_running_stats(self):
         | 
| 493 | 
            +
                    if self.track_running_stats:
         | 
| 494 | 
            +
                        self.RMr.zero_()
         | 
| 495 | 
            +
                        self.RMi.zero_()
         | 
| 496 | 
            +
                        self.RVrr.fill_(1)
         | 
| 497 | 
            +
                        self.RVri.zero_()
         | 
| 498 | 
            +
                        self.RVii.fill_(1)
         | 
| 499 | 
            +
                        self.num_batches_tracked.zero_()
         | 
| 500 | 
            +
             | 
| 501 | 
            +
                def reset_parameters(self):
         | 
| 502 | 
            +
                    self.reset_running_stats()
         | 
| 503 | 
            +
                    if self.affine:
         | 
| 504 | 
            +
                        self.Br.data.zero_()
         | 
| 505 | 
            +
                        self.Bi.data.zero_()
         | 
| 506 | 
            +
                        self.Wrr.data.fill_(1)
         | 
| 507 | 
            +
                        self.Wri.data.uniform_(-.9, +.9) # W will be positive-definite
         | 
| 508 | 
            +
                        self.Wii.data.fill_(1)
         | 
| 509 | 
            +
             | 
| 510 | 
            +
                def _check_input_dim(self, xr, xi):
         | 
| 511 | 
            +
                    assert(xr.shape == xi.shape)
         | 
| 512 | 
            +
                    assert(xr.size(1) == self.num_features)
         | 
| 513 | 
            +
             | 
| 514 | 
            +
                def forward(self, x):
         | 
| 515 | 
            +
                    xr, xi = x.real, x.imag
         | 
| 516 | 
            +
                    self._check_input_dim(xr, xi)
         | 
| 517 | 
            +
             | 
| 518 | 
            +
                    exponential_average_factor = 0.0
         | 
| 519 | 
            +
             | 
| 520 | 
            +
                    if self.training and self.track_running_stats:
         | 
| 521 | 
            +
                        self.num_batches_tracked += 1
         | 
| 522 | 
            +
                        if self.momentum is None:  # use cumulative moving average
         | 
| 523 | 
            +
                            exponential_average_factor = 1.0 / self.num_batches_tracked.item()
         | 
| 524 | 
            +
                        else:  # use exponential moving average
         | 
| 525 | 
            +
                            exponential_average_factor = self.momentum
         | 
| 526 | 
            +
             | 
| 527 | 
            +
                    #
         | 
| 528 | 
            +
                    # NOTE: The precise meaning of the "training flag" is:
         | 
| 529 | 
            +
                    #       True:  Normalize using batch   statistics, update running statistics
         | 
| 530 | 
            +
                    #              if they are being collected.
         | 
| 531 | 
            +
                    #       False: Normalize using running statistics, ignore batch   statistics.
         | 
| 532 | 
            +
                    #
         | 
| 533 | 
            +
                    training = self.training or not self.track_running_stats
         | 
| 534 | 
            +
                    redux = [i for i in reversed(range(xr.dim())) if i!=1]
         | 
| 535 | 
            +
                    vdim  = [1] * xr.dim()
         | 
| 536 | 
            +
                    vdim[1] = xr.size(1)
         | 
| 537 | 
            +
             | 
| 538 | 
            +
                    #
         | 
| 539 | 
            +
                    # Mean M Computation and Centering
         | 
| 540 | 
            +
                    #
         | 
| 541 | 
            +
                    # Includes running mean update if training and running.
         | 
| 542 | 
            +
                    #
         | 
| 543 | 
            +
                    if training:
         | 
| 544 | 
            +
                        Mr, Mi = xr, xi
         | 
| 545 | 
            +
                        for d in redux:
         | 
| 546 | 
            +
                            Mr = Mr.mean(d, keepdim=True)
         | 
| 547 | 
            +
                            Mi = Mi.mean(d, keepdim=True)
         | 
| 548 | 
            +
                        if self.track_running_stats:
         | 
| 549 | 
            +
                            self.RMr.lerp_(Mr.squeeze(), exponential_average_factor)
         | 
| 550 | 
            +
                            self.RMi.lerp_(Mi.squeeze(), exponential_average_factor)
         | 
| 551 | 
            +
                    else:
         | 
| 552 | 
            +
                        Mr = self.RMr.view(vdim)
         | 
| 553 | 
            +
                        Mi = self.RMi.view(vdim)
         | 
| 554 | 
            +
                    xr, xi = xr-Mr, xi-Mi
         | 
| 555 | 
            +
             | 
| 556 | 
            +
                    #
         | 
| 557 | 
            +
                    # Variance Matrix V Computation
         | 
| 558 | 
            +
                    #
         | 
| 559 | 
            +
                    # Includes epsilon numerical stabilizer/Tikhonov regularizer.
         | 
| 560 | 
            +
                    # Includes running variance update if training and running.
         | 
| 561 | 
            +
                    #
         | 
| 562 | 
            +
                    if training:
         | 
| 563 | 
            +
                        Vrr = xr * xr
         | 
| 564 | 
            +
                        Vri = xr * xi
         | 
| 565 | 
            +
                        Vii = xi * xi
         | 
| 566 | 
            +
                        for d in redux:
         | 
| 567 | 
            +
                            Vrr = Vrr.mean(d, keepdim=True)
         | 
| 568 | 
            +
                            Vri = Vri.mean(d, keepdim=True)
         | 
| 569 | 
            +
                            Vii = Vii.mean(d, keepdim=True)
         | 
| 570 | 
            +
                        if self.track_running_stats:
         | 
| 571 | 
            +
                            self.RVrr.lerp_(Vrr.squeeze(), exponential_average_factor)
         | 
| 572 | 
            +
                            self.RVri.lerp_(Vri.squeeze(), exponential_average_factor)
         | 
| 573 | 
            +
                            self.RVii.lerp_(Vii.squeeze(), exponential_average_factor)
         | 
| 574 | 
            +
                    else:
         | 
| 575 | 
            +
                        Vrr = self.RVrr.view(vdim)
         | 
| 576 | 
            +
                        Vri = self.RVri.view(vdim)
         | 
| 577 | 
            +
                        Vii = self.RVii.view(vdim)
         | 
| 578 | 
            +
                    Vrr   = Vrr + self.eps
         | 
| 579 | 
            +
                    Vri   = Vri
         | 
| 580 | 
            +
                    Vii   = Vii + self.eps
         | 
| 581 | 
            +
             | 
| 582 | 
            +
                    #
         | 
| 583 | 
            +
                    # Matrix Inverse Square Root U = V^-0.5
         | 
| 584 | 
            +
                    #
         | 
| 585 | 
            +
                    # sqrt of a 2x2 matrix,
         | 
| 586 | 
            +
                    # - https://en.wikipedia.org/wiki/Square_root_of_a_2_by_2_matrix
         | 
| 587 | 
            +
                    tau   = Vrr + Vii
         | 
| 588 | 
            +
                    delta = torch.addcmul(Vrr * Vii, Vri, Vri, value=-1)
         | 
| 589 | 
            +
                    s     = delta.sqrt()
         | 
| 590 | 
            +
                    t     = (tau + 2*s).sqrt()
         | 
| 591 | 
            +
             | 
| 592 | 
            +
                    # matrix inverse, http://mathworld.wolfram.com/MatrixInverse.html
         | 
| 593 | 
            +
                    rst   = (s * t).reciprocal()
         | 
| 594 | 
            +
                    Urr   = (s + Vii) * rst
         | 
| 595 | 
            +
                    Uii   = (s + Vrr) * rst
         | 
| 596 | 
            +
                    Uri   = (  - Vri) * rst
         | 
| 597 | 
            +
             | 
| 598 | 
            +
                    #
         | 
| 599 | 
            +
                    # Optionally left-multiply U by affine weights W to produce combined
         | 
| 600 | 
            +
                    # weights Z, left-multiply the inputs by Z, then optionally bias them.
         | 
| 601 | 
            +
                    #
         | 
| 602 | 
            +
                    # y = Zx + B
         | 
| 603 | 
            +
                    # y = WUx + B
         | 
| 604 | 
            +
                    # y = [Wrr Wri][Urr Uri] [xr] + [Br]
         | 
| 605 | 
            +
                    #     [Wir Wii][Uir Uii] [xi]   [Bi]
         | 
| 606 | 
            +
                    #
         | 
| 607 | 
            +
                    if self.affine:
         | 
| 608 | 
            +
                        Wrr, Wri, Wii = self.Wrr.view(vdim), self.Wri.view(vdim), self.Wii.view(vdim)
         | 
| 609 | 
            +
                        Zrr = (Wrr * Urr) + (Wri * Uri)
         | 
| 610 | 
            +
                        Zri = (Wrr * Uri) + (Wri * Uii)
         | 
| 611 | 
            +
                        Zir = (Wri * Urr) + (Wii * Uri)
         | 
| 612 | 
            +
                        Zii = (Wri * Uri) + (Wii * Uii)
         | 
| 613 | 
            +
                    else:
         | 
| 614 | 
            +
                        Zrr, Zri, Zir, Zii = Urr, Uri, Uri, Uii
         | 
| 615 | 
            +
             | 
| 616 | 
            +
                    yr = (Zrr * xr) + (Zri * xi)
         | 
| 617 | 
            +
                    yi = (Zir * xr) + (Zii * xi)
         | 
| 618 | 
            +
             | 
| 619 | 
            +
                    if self.affine:
         | 
| 620 | 
            +
                        yr = yr + self.Br.view(vdim)
         | 
| 621 | 
            +
                        yi = yi + self.Bi.view(vdim)
         | 
| 622 | 
            +
             | 
| 623 | 
            +
                    return torch.view_as_complex(torch.stack([yr, yi], dim=-1))
         | 
| 624 | 
            +
             | 
| 625 | 
            +
                def extra_repr(self):
         | 
| 626 | 
            +
                    return '{num_features}, eps={eps}, momentum={momentum}, affine={affine}, ' \
         | 
| 627 | 
            +
                            'track_running_stats={track_running_stats}'.format(**self.__dict__)
         | 
    	
        sgmse/backbones/ncsnpp.py
    ADDED
    
    | @@ -0,0 +1,419 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            # coding=utf-8
         | 
| 2 | 
            +
            # Copyright 2020 The Google Research Authors.
         | 
| 3 | 
            +
            #
         | 
| 4 | 
            +
            # Licensed under the Apache License, Version 2.0 (the "License");
         | 
| 5 | 
            +
            # you may not use this file except in compliance with the License.
         | 
| 6 | 
            +
            # You may obtain a copy of the License at
         | 
| 7 | 
            +
            #
         | 
| 8 | 
            +
            #     http://www.apache.org/licenses/LICENSE-2.0
         | 
| 9 | 
            +
            #
         | 
| 10 | 
            +
            # Unless required by applicable law or agreed to in writing, software
         | 
| 11 | 
            +
            # distributed under the License is distributed on an "AS IS" BASIS,
         | 
| 12 | 
            +
            # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
         | 
| 13 | 
            +
            # See the License for the specific language governing permissions and
         | 
| 14 | 
            +
            # limitations under the License.
         | 
| 15 | 
            +
             | 
| 16 | 
            +
            # pylint: skip-file
         | 
| 17 | 
            +
             | 
| 18 | 
            +
            from .ncsnpp_utils import layers, layerspp, normalization
         | 
| 19 | 
            +
            import torch.nn as nn
         | 
| 20 | 
            +
            import functools
         | 
| 21 | 
            +
            import torch
         | 
| 22 | 
            +
            import numpy as np
         | 
| 23 | 
            +
             | 
| 24 | 
            +
            from .shared import BackboneRegistry
         | 
| 25 | 
            +
             | 
| 26 | 
            +
            ResnetBlockDDPM = layerspp.ResnetBlockDDPMpp
         | 
| 27 | 
            +
            ResnetBlockBigGAN = layerspp.ResnetBlockBigGANpp
         | 
| 28 | 
            +
            Combine = layerspp.Combine
         | 
| 29 | 
            +
            conv3x3 = layerspp.conv3x3
         | 
| 30 | 
            +
            conv1x1 = layerspp.conv1x1
         | 
| 31 | 
            +
            get_act = layers.get_act
         | 
| 32 | 
            +
            get_normalization = normalization.get_normalization
         | 
| 33 | 
            +
            default_initializer = layers.default_init
         | 
| 34 | 
            +
             | 
| 35 | 
            +
             | 
| 36 | 
            +
            @BackboneRegistry.register("ncsnpp")
         | 
| 37 | 
            +
            class NCSNpp(nn.Module):
         | 
| 38 | 
            +
                """NCSN++ model, adapted from https://github.com/yang-song/score_sde repository"""
         | 
| 39 | 
            +
             | 
| 40 | 
            +
                @staticmethod
         | 
| 41 | 
            +
                def add_argparse_args(parser):
         | 
| 42 | 
            +
                    parser.add_argument("--ch_mult",type=int, nargs='+', default=[1,1,2,2,2,2,2])
         | 
| 43 | 
            +
                    parser.add_argument("--num_res_blocks", type=int, default=2)
         | 
| 44 | 
            +
                    parser.add_argument("--attn_resolutions", type=int, nargs='+', default=[16])
         | 
| 45 | 
            +
                    parser.add_argument("--no-centered", dest="centered", action="store_false", help="The data is not centered [-1, 1]")
         | 
| 46 | 
            +
                    parser.add_argument("--centered", dest="centered", action="store_true", help="The data is centered [-1, 1]")
         | 
| 47 | 
            +
                    parser.set_defaults(centered=True)
         | 
| 48 | 
            +
                    return parser
         | 
| 49 | 
            +
             | 
| 50 | 
            +
                def __init__(self,
         | 
| 51 | 
            +
                    scale_by_sigma = True,
         | 
| 52 | 
            +
                    nonlinearity = 'swish',
         | 
| 53 | 
            +
                    nf = 128,
         | 
| 54 | 
            +
                    ch_mult = (1, 1, 2, 2, 2, 2, 2),
         | 
| 55 | 
            +
                    num_res_blocks = 2,
         | 
| 56 | 
            +
                    attn_resolutions = (16,),
         | 
| 57 | 
            +
                    resamp_with_conv = True,
         | 
| 58 | 
            +
                    conditional = True,
         | 
| 59 | 
            +
                    fir = True,
         | 
| 60 | 
            +
                    fir_kernel = [1, 3, 3, 1],
         | 
| 61 | 
            +
                    skip_rescale = True,
         | 
| 62 | 
            +
                    resblock_type = 'biggan',
         | 
| 63 | 
            +
                    progressive = 'output_skip',
         | 
| 64 | 
            +
                    progressive_input = 'input_skip',
         | 
| 65 | 
            +
                    progressive_combine = 'sum',
         | 
| 66 | 
            +
                    init_scale = 0.,
         | 
| 67 | 
            +
                    fourier_scale = 16,
         | 
| 68 | 
            +
                    image_size = 256,
         | 
| 69 | 
            +
                    embedding_type = 'fourier',
         | 
| 70 | 
            +
                    dropout = .0,
         | 
| 71 | 
            +
                    centered = True,
         | 
| 72 | 
            +
                    **unused_kwargs
         | 
| 73 | 
            +
                ):
         | 
| 74 | 
            +
                    super().__init__()
         | 
| 75 | 
            +
                    self.act = act = get_act(nonlinearity)
         | 
| 76 | 
            +
             | 
| 77 | 
            +
                    self.nf = nf = nf
         | 
| 78 | 
            +
                    ch_mult = ch_mult
         | 
| 79 | 
            +
                    self.num_res_blocks = num_res_blocks = num_res_blocks
         | 
| 80 | 
            +
                    self.attn_resolutions = attn_resolutions = attn_resolutions
         | 
| 81 | 
            +
                    dropout = dropout
         | 
| 82 | 
            +
                    resamp_with_conv = resamp_with_conv
         | 
| 83 | 
            +
                    self.num_resolutions = num_resolutions = len(ch_mult)
         | 
| 84 | 
            +
                    self.all_resolutions = all_resolutions = [image_size // (2 ** i) for i in range(num_resolutions)]
         | 
| 85 | 
            +
             | 
| 86 | 
            +
                    self.conditional = conditional = conditional  # noise-conditional
         | 
| 87 | 
            +
                    self.centered = centered
         | 
| 88 | 
            +
                    self.scale_by_sigma = scale_by_sigma
         | 
| 89 | 
            +
             | 
| 90 | 
            +
                    fir = fir
         | 
| 91 | 
            +
                    fir_kernel = fir_kernel
         | 
| 92 | 
            +
                    self.skip_rescale = skip_rescale = skip_rescale
         | 
| 93 | 
            +
                    self.resblock_type = resblock_type = resblock_type.lower()
         | 
| 94 | 
            +
                    self.progressive = progressive = progressive.lower()
         | 
| 95 | 
            +
                    self.progressive_input = progressive_input = progressive_input.lower()
         | 
| 96 | 
            +
                    self.embedding_type = embedding_type = embedding_type.lower()
         | 
| 97 | 
            +
                    init_scale = init_scale
         | 
| 98 | 
            +
                    assert progressive in ['none', 'output_skip', 'residual']
         | 
| 99 | 
            +
                    assert progressive_input in ['none', 'input_skip', 'residual']
         | 
| 100 | 
            +
                    assert embedding_type in ['fourier', 'positional']
         | 
| 101 | 
            +
                    combine_method = progressive_combine.lower()
         | 
| 102 | 
            +
                    combiner = functools.partial(Combine, method=combine_method)
         | 
| 103 | 
            +
             | 
| 104 | 
            +
                    num_channels = 4  # x.real, x.imag, y.real, y.imag
         | 
| 105 | 
            +
                    self.output_layer = nn.Conv2d(num_channels, 2, 1)
         | 
| 106 | 
            +
             | 
| 107 | 
            +
                    modules = []
         | 
| 108 | 
            +
                    # timestep/noise_level embedding
         | 
| 109 | 
            +
                    if embedding_type == 'fourier':
         | 
| 110 | 
            +
                        # Gaussian Fourier features embeddings.
         | 
| 111 | 
            +
                        modules.append(layerspp.GaussianFourierProjection(
         | 
| 112 | 
            +
                            embedding_size=nf, scale=fourier_scale
         | 
| 113 | 
            +
                        ))
         | 
| 114 | 
            +
                        embed_dim = 2 * nf
         | 
| 115 | 
            +
                    elif embedding_type == 'positional':
         | 
| 116 | 
            +
                        embed_dim = nf
         | 
| 117 | 
            +
                    else:
         | 
| 118 | 
            +
                        raise ValueError(f'embedding type {embedding_type} unknown.')
         | 
| 119 | 
            +
             | 
| 120 | 
            +
                    if conditional:
         | 
| 121 | 
            +
                        modules.append(nn.Linear(embed_dim, nf * 4))
         | 
| 122 | 
            +
                        modules[-1].weight.data = default_initializer()(modules[-1].weight.shape)
         | 
| 123 | 
            +
                        nn.init.zeros_(modules[-1].bias)
         | 
| 124 | 
            +
                        modules.append(nn.Linear(nf * 4, nf * 4))
         | 
| 125 | 
            +
                        modules[-1].weight.data = default_initializer()(modules[-1].weight.shape)
         | 
| 126 | 
            +
                        nn.init.zeros_(modules[-1].bias)
         | 
| 127 | 
            +
             | 
| 128 | 
            +
                    AttnBlock = functools.partial(layerspp.AttnBlockpp,
         | 
| 129 | 
            +
                        init_scale=init_scale, skip_rescale=skip_rescale)
         | 
| 130 | 
            +
             | 
| 131 | 
            +
                    Upsample = functools.partial(layerspp.Upsample,
         | 
| 132 | 
            +
                        with_conv=resamp_with_conv, fir=fir, fir_kernel=fir_kernel)
         | 
| 133 | 
            +
             | 
| 134 | 
            +
                    if progressive == 'output_skip':
         | 
| 135 | 
            +
                        self.pyramid_upsample = layerspp.Upsample(fir=fir, fir_kernel=fir_kernel, with_conv=False)
         | 
| 136 | 
            +
                    elif progressive == 'residual':
         | 
| 137 | 
            +
                        pyramid_upsample = functools.partial(layerspp.Upsample, fir=fir,
         | 
| 138 | 
            +
                            fir_kernel=fir_kernel, with_conv=True)
         | 
| 139 | 
            +
             | 
| 140 | 
            +
                    Downsample = functools.partial(layerspp.Downsample, with_conv=resamp_with_conv, fir=fir, fir_kernel=fir_kernel)
         | 
| 141 | 
            +
             | 
| 142 | 
            +
                    if progressive_input == 'input_skip':
         | 
| 143 | 
            +
                        self.pyramid_downsample = layerspp.Downsample(fir=fir, fir_kernel=fir_kernel, with_conv=False)
         | 
| 144 | 
            +
                    elif progressive_input == 'residual':
         | 
| 145 | 
            +
                        pyramid_downsample = functools.partial(layerspp.Downsample,
         | 
| 146 | 
            +
                            fir=fir, fir_kernel=fir_kernel, with_conv=True)
         | 
| 147 | 
            +
             | 
| 148 | 
            +
                    if resblock_type == 'ddpm':
         | 
| 149 | 
            +
                        ResnetBlock = functools.partial(ResnetBlockDDPM, act=act,
         | 
| 150 | 
            +
                            dropout=dropout, init_scale=init_scale,
         | 
| 151 | 
            +
                            skip_rescale=skip_rescale, temb_dim=nf * 4)
         | 
| 152 | 
            +
             | 
| 153 | 
            +
                    elif resblock_type == 'biggan':
         | 
| 154 | 
            +
                        ResnetBlock = functools.partial(ResnetBlockBigGAN, act=act,
         | 
| 155 | 
            +
                            dropout=dropout, fir=fir, fir_kernel=fir_kernel,
         | 
| 156 | 
            +
                            init_scale=init_scale, skip_rescale=skip_rescale, temb_dim=nf * 4)
         | 
| 157 | 
            +
             | 
| 158 | 
            +
                    else:
         | 
| 159 | 
            +
                        raise ValueError(f'resblock type {resblock_type} unrecognized.')
         | 
| 160 | 
            +
             | 
| 161 | 
            +
                    # Downsampling block
         | 
| 162 | 
            +
             | 
| 163 | 
            +
                    channels = num_channels
         | 
| 164 | 
            +
                    if progressive_input != 'none':
         | 
| 165 | 
            +
                        input_pyramid_ch = channels
         | 
| 166 | 
            +
             | 
| 167 | 
            +
                    modules.append(conv3x3(channels, nf))
         | 
| 168 | 
            +
                    hs_c = [nf]
         | 
| 169 | 
            +
             | 
| 170 | 
            +
                    in_ch = nf
         | 
| 171 | 
            +
                    for i_level in range(num_resolutions):
         | 
| 172 | 
            +
                        # Residual blocks for this resolution
         | 
| 173 | 
            +
                        for i_block in range(num_res_blocks):
         | 
| 174 | 
            +
                            out_ch = nf * ch_mult[i_level]
         | 
| 175 | 
            +
                            modules.append(ResnetBlock(in_ch=in_ch, out_ch=out_ch))
         | 
| 176 | 
            +
                            in_ch = out_ch
         | 
| 177 | 
            +
             | 
| 178 | 
            +
                            if all_resolutions[i_level] in attn_resolutions:
         | 
| 179 | 
            +
                                modules.append(AttnBlock(channels=in_ch))
         | 
| 180 | 
            +
                            hs_c.append(in_ch)
         | 
| 181 | 
            +
             | 
| 182 | 
            +
                        if i_level != num_resolutions - 1:
         | 
| 183 | 
            +
                            if resblock_type == 'ddpm':
         | 
| 184 | 
            +
                                modules.append(Downsample(in_ch=in_ch))
         | 
| 185 | 
            +
                            else:
         | 
| 186 | 
            +
                                modules.append(ResnetBlock(down=True, in_ch=in_ch))
         | 
| 187 | 
            +
             | 
| 188 | 
            +
                            if progressive_input == 'input_skip':
         | 
| 189 | 
            +
                                modules.append(combiner(dim1=input_pyramid_ch, dim2=in_ch))
         | 
| 190 | 
            +
                                if combine_method == 'cat':
         | 
| 191 | 
            +
                                    in_ch *= 2
         | 
| 192 | 
            +
             | 
| 193 | 
            +
                            elif progressive_input == 'residual':
         | 
| 194 | 
            +
                                modules.append(pyramid_downsample(in_ch=input_pyramid_ch, out_ch=in_ch))
         | 
| 195 | 
            +
                                input_pyramid_ch = in_ch
         | 
| 196 | 
            +
             | 
| 197 | 
            +
                            hs_c.append(in_ch)
         | 
| 198 | 
            +
             | 
| 199 | 
            +
                    in_ch = hs_c[-1]
         | 
| 200 | 
            +
                    modules.append(ResnetBlock(in_ch=in_ch))
         | 
| 201 | 
            +
                    modules.append(AttnBlock(channels=in_ch))
         | 
| 202 | 
            +
                    modules.append(ResnetBlock(in_ch=in_ch))
         | 
| 203 | 
            +
             | 
| 204 | 
            +
                    pyramid_ch = 0
         | 
| 205 | 
            +
                    # Upsampling block
         | 
| 206 | 
            +
                    for i_level in reversed(range(num_resolutions)):
         | 
| 207 | 
            +
                        for i_block in range(num_res_blocks + 1):  # +1 blocks in upsampling because of skip connection from combiner (after downsampling)
         | 
| 208 | 
            +
                            out_ch = nf * ch_mult[i_level]
         | 
| 209 | 
            +
                            modules.append(ResnetBlock(in_ch=in_ch + hs_c.pop(), out_ch=out_ch))
         | 
| 210 | 
            +
                            in_ch = out_ch
         | 
| 211 | 
            +
             | 
| 212 | 
            +
                        if all_resolutions[i_level] in attn_resolutions:
         | 
| 213 | 
            +
                            modules.append(AttnBlock(channels=in_ch))
         | 
| 214 | 
            +
             | 
| 215 | 
            +
                        if progressive != 'none':
         | 
| 216 | 
            +
                            if i_level == num_resolutions - 1:
         | 
| 217 | 
            +
                                if progressive == 'output_skip':
         | 
| 218 | 
            +
                                    modules.append(nn.GroupNorm(num_groups=min(in_ch // 4, 32),
         | 
| 219 | 
            +
                                        num_channels=in_ch, eps=1e-6))
         | 
| 220 | 
            +
                                    modules.append(conv3x3(in_ch, channels, init_scale=init_scale))
         | 
| 221 | 
            +
                                    pyramid_ch = channels
         | 
| 222 | 
            +
                                elif progressive == 'residual':
         | 
| 223 | 
            +
                                    modules.append(nn.GroupNorm(num_groups=min(in_ch // 4, 32), num_channels=in_ch, eps=1e-6))
         | 
| 224 | 
            +
                                    modules.append(conv3x3(in_ch, in_ch, bias=True))
         | 
| 225 | 
            +
                                    pyramid_ch = in_ch
         | 
| 226 | 
            +
                                else:
         | 
| 227 | 
            +
                                    raise ValueError(f'{progressive} is not a valid name.')
         | 
| 228 | 
            +
                            else:
         | 
| 229 | 
            +
                                if progressive == 'output_skip':
         | 
| 230 | 
            +
                                    modules.append(nn.GroupNorm(num_groups=min(in_ch // 4, 32),
         | 
| 231 | 
            +
                                        num_channels=in_ch, eps=1e-6))
         | 
| 232 | 
            +
                                    modules.append(conv3x3(in_ch, channels, bias=True, init_scale=init_scale))
         | 
| 233 | 
            +
                                    pyramid_ch = channels
         | 
| 234 | 
            +
                                elif progressive == 'residual':
         | 
| 235 | 
            +
                                    modules.append(pyramid_upsample(in_ch=pyramid_ch, out_ch=in_ch))
         | 
| 236 | 
            +
                                    pyramid_ch = in_ch
         | 
| 237 | 
            +
                                else:
         | 
| 238 | 
            +
                                    raise ValueError(f'{progressive} is not a valid name')
         | 
| 239 | 
            +
             | 
| 240 | 
            +
                        if i_level != 0:
         | 
| 241 | 
            +
                            if resblock_type == 'ddpm':
         | 
| 242 | 
            +
                                modules.append(Upsample(in_ch=in_ch))
         | 
| 243 | 
            +
                            else:
         | 
| 244 | 
            +
                                modules.append(ResnetBlock(in_ch=in_ch, up=True))
         | 
| 245 | 
            +
             | 
| 246 | 
            +
                    assert not hs_c
         | 
| 247 | 
            +
             | 
| 248 | 
            +
                    if progressive != 'output_skip':
         | 
| 249 | 
            +
                        modules.append(nn.GroupNorm(num_groups=min(in_ch // 4, 32),
         | 
| 250 | 
            +
                                                                                num_channels=in_ch, eps=1e-6))
         | 
| 251 | 
            +
                        modules.append(conv3x3(in_ch, channels, init_scale=init_scale))
         | 
| 252 | 
            +
             | 
| 253 | 
            +
                    self.all_modules = nn.ModuleList(modules)
         | 
| 254 | 
            +
                    
         | 
| 255 | 
            +
             | 
| 256 | 
            +
                def forward(self, x, time_cond):
         | 
| 257 | 
            +
                    # timestep/noise_level embedding; only for continuous training
         | 
| 258 | 
            +
                    modules = self.all_modules
         | 
| 259 | 
            +
                    m_idx = 0
         | 
| 260 | 
            +
             | 
| 261 | 
            +
                    # Convert real and imaginary parts of (x,y) into four channel dimensions
         | 
| 262 | 
            +
                    x = torch.cat((x[:,[0],:,:].real, x[:,[0],:,:].imag,
         | 
| 263 | 
            +
                            x[:,[1],:,:].real, x[:,[1],:,:].imag), dim=1)
         | 
| 264 | 
            +
             | 
| 265 | 
            +
                    if self.embedding_type == 'fourier':
         | 
| 266 | 
            +
                        # Gaussian Fourier features embeddings.
         | 
| 267 | 
            +
                        used_sigmas = time_cond
         | 
| 268 | 
            +
                        temb = modules[m_idx](torch.log(used_sigmas))
         | 
| 269 | 
            +
                        m_idx += 1
         | 
| 270 | 
            +
             | 
| 271 | 
            +
                    elif self.embedding_type == 'positional':
         | 
| 272 | 
            +
                        # Sinusoidal positional embeddings.
         | 
| 273 | 
            +
                        timesteps = time_cond
         | 
| 274 | 
            +
                        used_sigmas = self.sigmas[time_cond.long()]
         | 
| 275 | 
            +
                        temb = layers.get_timestep_embedding(timesteps, self.nf)
         | 
| 276 | 
            +
             | 
| 277 | 
            +
                    else:
         | 
| 278 | 
            +
                        raise ValueError(f'embedding type {self.embedding_type} unknown.')
         | 
| 279 | 
            +
             | 
| 280 | 
            +
                    if self.conditional:
         | 
| 281 | 
            +
                        temb = modules[m_idx](temb)
         | 
| 282 | 
            +
                        m_idx += 1
         | 
| 283 | 
            +
                        temb = modules[m_idx](self.act(temb))
         | 
| 284 | 
            +
                        m_idx += 1
         | 
| 285 | 
            +
                    else:
         | 
| 286 | 
            +
                        temb = None
         | 
| 287 | 
            +
             | 
| 288 | 
            +
                    if not self.centered:
         | 
| 289 | 
            +
                        # If input data is in [0, 1]
         | 
| 290 | 
            +
                        x = 2 * x - 1.
         | 
| 291 | 
            +
             | 
| 292 | 
            +
                    # Downsampling block
         | 
| 293 | 
            +
                    input_pyramid = None
         | 
| 294 | 
            +
                    if self.progressive_input != 'none':
         | 
| 295 | 
            +
                        input_pyramid = x
         | 
| 296 | 
            +
             | 
| 297 | 
            +
                    # Input layer: Conv2d: 4ch -> 128ch
         | 
| 298 | 
            +
                    hs = [modules[m_idx](x)]
         | 
| 299 | 
            +
                    m_idx += 1
         | 
| 300 | 
            +
             | 
| 301 | 
            +
                    # Down path in U-Net
         | 
| 302 | 
            +
                    for i_level in range(self.num_resolutions):
         | 
| 303 | 
            +
                        # Residual blocks for this resolution
         | 
| 304 | 
            +
                        for i_block in range(self.num_res_blocks):
         | 
| 305 | 
            +
                            h = modules[m_idx](hs[-1], temb)
         | 
| 306 | 
            +
                            m_idx += 1
         | 
| 307 | 
            +
                            # Attention layer (optional)
         | 
| 308 | 
            +
                            if h.shape[-2] in self.attn_resolutions: # edit: check H dim (-2) not W dim (-1)
         | 
| 309 | 
            +
                                h = modules[m_idx](h)
         | 
| 310 | 
            +
                                m_idx += 1
         | 
| 311 | 
            +
                            hs.append(h)
         | 
| 312 | 
            +
             | 
| 313 | 
            +
                        # Downsampling
         | 
| 314 | 
            +
                        if i_level != self.num_resolutions - 1:
         | 
| 315 | 
            +
                            if self.resblock_type == 'ddpm':
         | 
| 316 | 
            +
                                h = modules[m_idx](hs[-1])
         | 
| 317 | 
            +
                                m_idx += 1
         | 
| 318 | 
            +
                            else:
         | 
| 319 | 
            +
                                h = modules[m_idx](hs[-1], temb)
         | 
| 320 | 
            +
                                m_idx += 1
         | 
| 321 | 
            +
             | 
| 322 | 
            +
                            if self.progressive_input == 'input_skip':   # Combine h with x
         | 
| 323 | 
            +
                                input_pyramid = self.pyramid_downsample(input_pyramid)
         | 
| 324 | 
            +
                                h = modules[m_idx](input_pyramid, h)
         | 
| 325 | 
            +
                                m_idx += 1
         | 
| 326 | 
            +
             | 
| 327 | 
            +
                            elif self.progressive_input == 'residual':
         | 
| 328 | 
            +
                                input_pyramid = modules[m_idx](input_pyramid)
         | 
| 329 | 
            +
                                m_idx += 1
         | 
| 330 | 
            +
                                if self.skip_rescale:
         | 
| 331 | 
            +
                                    input_pyramid = (input_pyramid + h) / np.sqrt(2.)
         | 
| 332 | 
            +
                                else:
         | 
| 333 | 
            +
                                    input_pyramid = input_pyramid + h
         | 
| 334 | 
            +
                                h = input_pyramid
         | 
| 335 | 
            +
                            hs.append(h)
         | 
| 336 | 
            +
             | 
| 337 | 
            +
                    h = hs[-1] # actualy equal to: h = h
         | 
| 338 | 
            +
                    h = modules[m_idx](h, temb)  # ResNet block
         | 
| 339 | 
            +
                    m_idx += 1
         | 
| 340 | 
            +
                    h = modules[m_idx](h)  # Attention block
         | 
| 341 | 
            +
                    m_idx += 1
         | 
| 342 | 
            +
                    h = modules[m_idx](h, temb)  # ResNet block
         | 
| 343 | 
            +
                    m_idx += 1
         | 
| 344 | 
            +
             | 
| 345 | 
            +
                    pyramid = None
         | 
| 346 | 
            +
             | 
| 347 | 
            +
                    # Upsampling block
         | 
| 348 | 
            +
                    for i_level in reversed(range(self.num_resolutions)):
         | 
| 349 | 
            +
                        for i_block in range(self.num_res_blocks + 1):
         | 
| 350 | 
            +
                            h = modules[m_idx](torch.cat([h, hs.pop()], dim=1), temb)
         | 
| 351 | 
            +
                            m_idx += 1
         | 
| 352 | 
            +
             | 
| 353 | 
            +
                        # edit: from -1 to -2
         | 
| 354 | 
            +
                        if h.shape[-2] in self.attn_resolutions:
         | 
| 355 | 
            +
                            h = modules[m_idx](h)
         | 
| 356 | 
            +
                            m_idx += 1
         | 
| 357 | 
            +
             | 
| 358 | 
            +
                        if self.progressive != 'none':
         | 
| 359 | 
            +
                            if i_level == self.num_resolutions - 1:
         | 
| 360 | 
            +
                                if self.progressive == 'output_skip':
         | 
| 361 | 
            +
                                    pyramid = self.act(modules[m_idx](h))  # GroupNorm
         | 
| 362 | 
            +
                                    m_idx += 1
         | 
| 363 | 
            +
                                    pyramid = modules[m_idx](pyramid)  # Conv2D: 256 -> 4
         | 
| 364 | 
            +
                                    m_idx += 1
         | 
| 365 | 
            +
                                elif self.progressive == 'residual':
         | 
| 366 | 
            +
                                    pyramid = self.act(modules[m_idx](h))
         | 
| 367 | 
            +
                                    m_idx += 1
         | 
| 368 | 
            +
                                    pyramid = modules[m_idx](pyramid)
         | 
| 369 | 
            +
                                    m_idx += 1
         | 
| 370 | 
            +
                                else:
         | 
| 371 | 
            +
                                    raise ValueError(f'{self.progressive} is not a valid name.')
         | 
| 372 | 
            +
                            else:
         | 
| 373 | 
            +
                                if self.progressive == 'output_skip':
         | 
| 374 | 
            +
                                    pyramid = self.pyramid_upsample(pyramid)  # Upsample
         | 
| 375 | 
            +
                                    pyramid_h = self.act(modules[m_idx](h))  # GroupNorm
         | 
| 376 | 
            +
                                    m_idx += 1
         | 
| 377 | 
            +
                                    pyramid_h = modules[m_idx](pyramid_h)
         | 
| 378 | 
            +
                                    m_idx += 1
         | 
| 379 | 
            +
                                    pyramid = pyramid + pyramid_h
         | 
| 380 | 
            +
                                elif self.progressive == 'residual':
         | 
| 381 | 
            +
                                    pyramid = modules[m_idx](pyramid)
         | 
| 382 | 
            +
                                    m_idx += 1
         | 
| 383 | 
            +
                                    if self.skip_rescale:
         | 
| 384 | 
            +
                                        pyramid = (pyramid + h) / np.sqrt(2.)
         | 
| 385 | 
            +
                                    else:
         | 
| 386 | 
            +
                                        pyramid = pyramid + h
         | 
| 387 | 
            +
                                    h = pyramid
         | 
| 388 | 
            +
                                else:
         | 
| 389 | 
            +
                                    raise ValueError(f'{self.progressive} is not a valid name')
         | 
| 390 | 
            +
             | 
| 391 | 
            +
                        # Upsampling Layer
         | 
| 392 | 
            +
                        if i_level != 0:
         | 
| 393 | 
            +
                            if self.resblock_type == 'ddpm':
         | 
| 394 | 
            +
                                h = modules[m_idx](h)
         | 
| 395 | 
            +
                                m_idx += 1
         | 
| 396 | 
            +
                            else:
         | 
| 397 | 
            +
                                h = modules[m_idx](h, temb)  # Upspampling
         | 
| 398 | 
            +
                                m_idx += 1
         | 
| 399 | 
            +
             | 
| 400 | 
            +
                    assert not hs
         | 
| 401 | 
            +
             | 
| 402 | 
            +
                    if self.progressive == 'output_skip':
         | 
| 403 | 
            +
                        h = pyramid
         | 
| 404 | 
            +
                    else:
         | 
| 405 | 
            +
                        h = self.act(modules[m_idx](h))
         | 
| 406 | 
            +
                        m_idx += 1
         | 
| 407 | 
            +
                        h = modules[m_idx](h)
         | 
| 408 | 
            +
                        m_idx += 1
         | 
| 409 | 
            +
             | 
| 410 | 
            +
                    assert m_idx == len(modules), "Implementation error"
         | 
| 411 | 
            +
                    if self.scale_by_sigma:
         | 
| 412 | 
            +
                        used_sigmas = used_sigmas.reshape((x.shape[0], *([1] * len(x.shape[1:]))))
         | 
| 413 | 
            +
                        h = h / used_sigmas
         | 
| 414 | 
            +
             | 
| 415 | 
            +
                    # Convert back to complex number
         | 
| 416 | 
            +
                    h = self.output_layer(h)
         | 
| 417 | 
            +
                    h = torch.permute(h, (0, 2, 3, 1)).contiguous()
         | 
| 418 | 
            +
                    h = torch.view_as_complex(h)[:,None, :, :]
         | 
| 419 | 
            +
                    return h
         | 
    	
        sgmse/backbones/ncsnpp_48k.py
    ADDED
    
    | @@ -0,0 +1,424 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            # coding=utf-8
         | 
| 2 | 
            +
            # Copyright 2020 The Google Research Authors.
         | 
| 3 | 
            +
            #
         | 
| 4 | 
            +
            # Licensed under the Apache License, Version 2.0 (the "License");
         | 
| 5 | 
            +
            # you may not use this file except in compliance with the License.
         | 
| 6 | 
            +
            # You may obtain a copy of the License at
         | 
| 7 | 
            +
            #
         | 
| 8 | 
            +
            #     http://www.apache.org/licenses/LICENSE-2.0
         | 
| 9 | 
            +
            #
         | 
| 10 | 
            +
            # Unless required by applicable law or agreed to in writing, software
         | 
| 11 | 
            +
            # distributed under the License is distributed on an "AS IS" BASIS,
         | 
| 12 | 
            +
            # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
         | 
| 13 | 
            +
            # See the License for the specific language governing permissions and
         | 
| 14 | 
            +
            # limitations under the License.
         | 
| 15 | 
            +
             | 
| 16 | 
            +
            # pylint: skip-file
         | 
| 17 | 
            +
             | 
| 18 | 
            +
            from .ncsnpp_utils import layers, layerspp, normalization
         | 
| 19 | 
            +
            import torch.nn as nn
         | 
| 20 | 
            +
            import functools
         | 
| 21 | 
            +
            import torch
         | 
| 22 | 
            +
            import numpy as np
         | 
| 23 | 
            +
             | 
| 24 | 
            +
            from .shared import BackboneRegistry
         | 
| 25 | 
            +
             | 
| 26 | 
            +
            ResnetBlockDDPM = layerspp.ResnetBlockDDPMpp
         | 
| 27 | 
            +
            ResnetBlockBigGAN = layerspp.ResnetBlockBigGANpp
         | 
| 28 | 
            +
            Combine = layerspp.Combine
         | 
| 29 | 
            +
            conv3x3 = layerspp.conv3x3
         | 
| 30 | 
            +
            conv1x1 = layerspp.conv1x1
         | 
| 31 | 
            +
            get_act = layers.get_act
         | 
| 32 | 
            +
            get_normalization = normalization.get_normalization
         | 
| 33 | 
            +
            default_initializer = layers.default_init
         | 
| 34 | 
            +
             | 
| 35 | 
            +
             | 
| 36 | 
            +
            @BackboneRegistry.register("ncsnpp_48k")
         | 
| 37 | 
            +
            class NCSNpp_48k(nn.Module):
         | 
| 38 | 
            +
                """NCSN++ model, adapted from https://github.com/yang-song/score_sde repository"""
         | 
| 39 | 
            +
             | 
| 40 | 
            +
                @staticmethod
         | 
| 41 | 
            +
                def add_argparse_args(parser):
         | 
| 42 | 
            +
                    parser.add_argument("--ch_mult",type=int, nargs='+', default=[1,1,2,2,2,2,2])
         | 
| 43 | 
            +
                    parser.add_argument("--num_res_blocks", type=int, default=2)
         | 
| 44 | 
            +
                    parser.add_argument("--attn_resolutions", type=int, nargs='+', default=[])
         | 
| 45 | 
            +
                    parser.add_argument("--nf", type=int, default=128, help="Number of channels to use in the model")
         | 
| 46 | 
            +
                    parser.add_argument("--no-centered", dest="centered", action="store_false", help="The data is not centered [-1, 1]")
         | 
| 47 | 
            +
                    parser.add_argument("--centered", dest="centered", action="store_true", help="The data is centered [-1, 1]")
         | 
| 48 | 
            +
                    parser.add_argument("--progressive", type=str, default='none', help="Progressive downsampling method")
         | 
| 49 | 
            +
                    parser.add_argument("--progressive_input", type=str, default='none', help="Progressive upsampling method")
         | 
| 50 | 
            +
                    parser.set_defaults(centered=True)
         | 
| 51 | 
            +
                    return parser
         | 
| 52 | 
            +
             | 
| 53 | 
            +
                def __init__(self,
         | 
| 54 | 
            +
                    scale_by_sigma = True,
         | 
| 55 | 
            +
                    nonlinearity = 'swish',
         | 
| 56 | 
            +
                    nf = 128,
         | 
| 57 | 
            +
                    ch_mult = (1, 1, 2, 2, 2, 2, 2),
         | 
| 58 | 
            +
                    num_res_blocks = 2,
         | 
| 59 | 
            +
                    attn_resolutions = (),
         | 
| 60 | 
            +
                    resamp_with_conv = True,
         | 
| 61 | 
            +
                    conditional = True,
         | 
| 62 | 
            +
                    fir = True,
         | 
| 63 | 
            +
                    fir_kernel = [1, 3, 3, 1],
         | 
| 64 | 
            +
                    skip_rescale = True,
         | 
| 65 | 
            +
                    resblock_type = 'biggan',
         | 
| 66 | 
            +
                    progressive = 'none',
         | 
| 67 | 
            +
                    progressive_input = 'none',
         | 
| 68 | 
            +
                    progressive_combine = 'sum',
         | 
| 69 | 
            +
                    init_scale = 0.,
         | 
| 70 | 
            +
                    fourier_scale = 16,
         | 
| 71 | 
            +
                    image_size = 256,
         | 
| 72 | 
            +
                    embedding_type = 'fourier',
         | 
| 73 | 
            +
                    dropout = .0,
         | 
| 74 | 
            +
                    centered = True,
         | 
| 75 | 
            +
                    **unused_kwargs
         | 
| 76 | 
            +
                ):
         | 
| 77 | 
            +
                    super().__init__()
         | 
| 78 | 
            +
                    self.act = act = get_act(nonlinearity)
         | 
| 79 | 
            +
             | 
| 80 | 
            +
                    self.nf = nf = nf
         | 
| 81 | 
            +
                    ch_mult = ch_mult
         | 
| 82 | 
            +
                    self.num_res_blocks = num_res_blocks = num_res_blocks
         | 
| 83 | 
            +
                    self.attn_resolutions = attn_resolutions
         | 
| 84 | 
            +
                    dropout = dropout
         | 
| 85 | 
            +
                    resamp_with_conv = resamp_with_conv
         | 
| 86 | 
            +
                    self.num_resolutions = num_resolutions = len(ch_mult)
         | 
| 87 | 
            +
                    self.all_resolutions = all_resolutions = [image_size // (2 ** i) for i in range(num_resolutions)]
         | 
| 88 | 
            +
             | 
| 89 | 
            +
                    self.conditional = conditional = conditional  # noise-conditional
         | 
| 90 | 
            +
                    self.centered = centered
         | 
| 91 | 
            +
                    self.scale_by_sigma = scale_by_sigma
         | 
| 92 | 
            +
             | 
| 93 | 
            +
                    fir = fir
         | 
| 94 | 
            +
                    fir_kernel = fir_kernel
         | 
| 95 | 
            +
                    self.skip_rescale = skip_rescale = skip_rescale
         | 
| 96 | 
            +
                    self.resblock_type = resblock_type = resblock_type.lower()
         | 
| 97 | 
            +
                    self.progressive = progressive = progressive.lower()
         | 
| 98 | 
            +
                    self.progressive_input = progressive_input = progressive_input.lower()
         | 
| 99 | 
            +
                    self.embedding_type = embedding_type = embedding_type.lower()
         | 
| 100 | 
            +
                    init_scale = init_scale
         | 
| 101 | 
            +
                    assert progressive in ['none', 'output_skip', 'residual']
         | 
| 102 | 
            +
                    assert progressive_input in ['none', 'input_skip', 'residual']
         | 
| 103 | 
            +
                    assert embedding_type in ['fourier', 'positional']
         | 
| 104 | 
            +
                    combine_method = progressive_combine.lower()
         | 
| 105 | 
            +
                    combiner = functools.partial(Combine, method=combine_method)
         | 
| 106 | 
            +
             | 
| 107 | 
            +
                    num_channels = 4  # x.real, x.imag, y.real, y.imag
         | 
| 108 | 
            +
                    self.output_layer = nn.Conv2d(num_channels, 2, 1)
         | 
| 109 | 
            +
             | 
| 110 | 
            +
                    modules = []
         | 
| 111 | 
            +
                    # timestep/noise_level embedding
         | 
| 112 | 
            +
                    if embedding_type == 'fourier':
         | 
| 113 | 
            +
                        # Gaussian Fourier features embeddings.
         | 
| 114 | 
            +
                        modules.append(layerspp.GaussianFourierProjection(
         | 
| 115 | 
            +
                            embedding_size=nf, scale=fourier_scale
         | 
| 116 | 
            +
                        ))
         | 
| 117 | 
            +
                        embed_dim = 2 * nf
         | 
| 118 | 
            +
                    elif embedding_type == 'positional':
         | 
| 119 | 
            +
                        embed_dim = nf
         | 
| 120 | 
            +
                    else:
         | 
| 121 | 
            +
                        raise ValueError(f'embedding type {embedding_type} unknown.')
         | 
| 122 | 
            +
             | 
| 123 | 
            +
                    if conditional:
         | 
| 124 | 
            +
                        modules.append(nn.Linear(embed_dim, nf * 4))
         | 
| 125 | 
            +
                        modules[-1].weight.data = default_initializer()(modules[-1].weight.shape)
         | 
| 126 | 
            +
                        nn.init.zeros_(modules[-1].bias)
         | 
| 127 | 
            +
                        modules.append(nn.Linear(nf * 4, nf * 4))
         | 
| 128 | 
            +
                        modules[-1].weight.data = default_initializer()(modules[-1].weight.shape)
         | 
| 129 | 
            +
                        nn.init.zeros_(modules[-1].bias)
         | 
| 130 | 
            +
             | 
| 131 | 
            +
                    AttnBlock = functools.partial(layerspp.AttnBlockpp,
         | 
| 132 | 
            +
                        init_scale=init_scale, skip_rescale=skip_rescale)
         | 
| 133 | 
            +
             | 
| 134 | 
            +
                    Upsample = functools.partial(layerspp.Upsample,
         | 
| 135 | 
            +
                        with_conv=resamp_with_conv, fir=fir, fir_kernel=fir_kernel)
         | 
| 136 | 
            +
             | 
| 137 | 
            +
                    if progressive == 'output_skip':
         | 
| 138 | 
            +
                        self.pyramid_upsample = layerspp.Upsample(fir=fir, fir_kernel=fir_kernel, with_conv=False)
         | 
| 139 | 
            +
                    elif progressive == 'residual':
         | 
| 140 | 
            +
                        pyramid_upsample = functools.partial(layerspp.Upsample, fir=fir,
         | 
| 141 | 
            +
                            fir_kernel=fir_kernel, with_conv=True)
         | 
| 142 | 
            +
             | 
| 143 | 
            +
                    Downsample = functools.partial(layerspp.Downsample, with_conv=resamp_with_conv, fir=fir, fir_kernel=fir_kernel)
         | 
| 144 | 
            +
             | 
| 145 | 
            +
                    if progressive_input == 'input_skip':
         | 
| 146 | 
            +
                        self.pyramid_downsample = layerspp.Downsample(fir=fir, fir_kernel=fir_kernel, with_conv=False)
         | 
| 147 | 
            +
                    elif progressive_input == 'residual':
         | 
| 148 | 
            +
                        pyramid_downsample = functools.partial(layerspp.Downsample,
         | 
| 149 | 
            +
                            fir=fir, fir_kernel=fir_kernel, with_conv=True)
         | 
| 150 | 
            +
             | 
| 151 | 
            +
                    if resblock_type == 'ddpm':
         | 
| 152 | 
            +
                        ResnetBlock = functools.partial(ResnetBlockDDPM, act=act,
         | 
| 153 | 
            +
                            dropout=dropout, init_scale=init_scale,
         | 
| 154 | 
            +
                            skip_rescale=skip_rescale, temb_dim=nf * 4)
         | 
| 155 | 
            +
             | 
| 156 | 
            +
                    elif resblock_type == 'biggan':
         | 
| 157 | 
            +
                        ResnetBlock = functools.partial(ResnetBlockBigGAN, act=act,
         | 
| 158 | 
            +
                            dropout=dropout, fir=fir, fir_kernel=fir_kernel,
         | 
| 159 | 
            +
                            init_scale=init_scale, skip_rescale=skip_rescale, temb_dim=nf * 4)
         | 
| 160 | 
            +
             | 
| 161 | 
            +
                    else:
         | 
| 162 | 
            +
                        raise ValueError(f'resblock type {resblock_type} unrecognized.')
         | 
| 163 | 
            +
             | 
| 164 | 
            +
                    # Downsampling block
         | 
| 165 | 
            +
             | 
| 166 | 
            +
                    channels = num_channels
         | 
| 167 | 
            +
                    if progressive_input != 'none':
         | 
| 168 | 
            +
                        input_pyramid_ch = channels
         | 
| 169 | 
            +
             | 
| 170 | 
            +
                    modules.append(conv3x3(channels, nf))
         | 
| 171 | 
            +
                    hs_c = [nf]
         | 
| 172 | 
            +
             | 
| 173 | 
            +
                    in_ch = nf
         | 
| 174 | 
            +
                    for i_level in range(num_resolutions):
         | 
| 175 | 
            +
                        # Residual blocks for this resolution
         | 
| 176 | 
            +
                        for i_block in range(num_res_blocks):
         | 
| 177 | 
            +
                            out_ch = nf * ch_mult[i_level]
         | 
| 178 | 
            +
                            modules.append(ResnetBlock(in_ch=in_ch, out_ch=out_ch))
         | 
| 179 | 
            +
                            in_ch = out_ch
         | 
| 180 | 
            +
             | 
| 181 | 
            +
                            if all_resolutions[i_level] in attn_resolutions:
         | 
| 182 | 
            +
                                modules.append(AttnBlock(channels=in_ch))
         | 
| 183 | 
            +
                            hs_c.append(in_ch)
         | 
| 184 | 
            +
             | 
| 185 | 
            +
                        if i_level != num_resolutions - 1:
         | 
| 186 | 
            +
                            if resblock_type == 'ddpm':
         | 
| 187 | 
            +
                                modules.append(Downsample(in_ch=in_ch))
         | 
| 188 | 
            +
                            else:
         | 
| 189 | 
            +
                                modules.append(ResnetBlock(down=True, in_ch=in_ch))
         | 
| 190 | 
            +
             | 
| 191 | 
            +
                            if progressive_input == 'input_skip':
         | 
| 192 | 
            +
                                modules.append(combiner(dim1=input_pyramid_ch, dim2=in_ch))
         | 
| 193 | 
            +
                                if combine_method == 'cat':
         | 
| 194 | 
            +
                                    in_ch *= 2
         | 
| 195 | 
            +
             | 
| 196 | 
            +
                            elif progressive_input == 'residual':
         | 
| 197 | 
            +
                                modules.append(pyramid_downsample(in_ch=input_pyramid_ch, out_ch=in_ch))
         | 
| 198 | 
            +
                                input_pyramid_ch = in_ch
         | 
| 199 | 
            +
             | 
| 200 | 
            +
                            hs_c.append(in_ch)
         | 
| 201 | 
            +
             | 
| 202 | 
            +
                    in_ch = hs_c[-1]
         | 
| 203 | 
            +
                    modules.append(ResnetBlock(in_ch=in_ch))
         | 
| 204 | 
            +
                    modules.append(AttnBlock(channels=in_ch))
         | 
| 205 | 
            +
                    modules.append(ResnetBlock(in_ch=in_ch))
         | 
| 206 | 
            +
             | 
| 207 | 
            +
                    pyramid_ch = 0
         | 
| 208 | 
            +
                    # Upsampling block
         | 
| 209 | 
            +
                    for i_level in reversed(range(num_resolutions)):
         | 
| 210 | 
            +
                        for i_block in range(num_res_blocks + 1):  # +1 blocks in upsampling because of skip connection from combiner (after downsampling)
         | 
| 211 | 
            +
                            out_ch = nf * ch_mult[i_level]
         | 
| 212 | 
            +
                            modules.append(ResnetBlock(in_ch=in_ch + hs_c.pop(), out_ch=out_ch))
         | 
| 213 | 
            +
                            in_ch = out_ch
         | 
| 214 | 
            +
             | 
| 215 | 
            +
                        if all_resolutions[i_level] in attn_resolutions:
         | 
| 216 | 
            +
                            modules.append(AttnBlock(channels=in_ch))
         | 
| 217 | 
            +
             | 
| 218 | 
            +
                        if progressive != 'none':
         | 
| 219 | 
            +
                            if i_level == num_resolutions - 1:
         | 
| 220 | 
            +
                                if progressive == 'output_skip':
         | 
| 221 | 
            +
                                    modules.append(nn.GroupNorm(num_groups=min(in_ch // 4, 32),
         | 
| 222 | 
            +
                                        num_channels=in_ch, eps=1e-6))
         | 
| 223 | 
            +
                                    modules.append(conv3x3(in_ch, channels, init_scale=init_scale))
         | 
| 224 | 
            +
                                    pyramid_ch = channels
         | 
| 225 | 
            +
                                elif progressive == 'residual':
         | 
| 226 | 
            +
                                    modules.append(nn.GroupNorm(num_groups=min(in_ch // 4, 32), num_channels=in_ch, eps=1e-6))
         | 
| 227 | 
            +
                                    modules.append(conv3x3(in_ch, in_ch, bias=True))
         | 
| 228 | 
            +
                                    pyramid_ch = in_ch
         | 
| 229 | 
            +
                                else:
         | 
| 230 | 
            +
                                    raise ValueError(f'{progressive} is not a valid name.')
         | 
| 231 | 
            +
                            else:
         | 
| 232 | 
            +
                                if progressive == 'output_skip':
         | 
| 233 | 
            +
                                    modules.append(nn.GroupNorm(num_groups=min(in_ch // 4, 32),
         | 
| 234 | 
            +
                                        num_channels=in_ch, eps=1e-6))
         | 
| 235 | 
            +
                                    modules.append(conv3x3(in_ch, channels, bias=True, init_scale=init_scale))
         | 
| 236 | 
            +
                                    pyramid_ch = channels
         | 
| 237 | 
            +
                                elif progressive == 'residual':
         | 
| 238 | 
            +
                                    modules.append(pyramid_upsample(in_ch=pyramid_ch, out_ch=in_ch))
         | 
| 239 | 
            +
                                    pyramid_ch = in_ch
         | 
| 240 | 
            +
                                else:
         | 
| 241 | 
            +
                                    raise ValueError(f'{progressive} is not a valid name')
         | 
| 242 | 
            +
             | 
| 243 | 
            +
                        if i_level != 0:
         | 
| 244 | 
            +
                            if resblock_type == 'ddpm':
         | 
| 245 | 
            +
                                modules.append(Upsample(in_ch=in_ch))
         | 
| 246 | 
            +
                            else:
         | 
| 247 | 
            +
                                modules.append(ResnetBlock(in_ch=in_ch, up=True))
         | 
| 248 | 
            +
             | 
| 249 | 
            +
                    assert not hs_c
         | 
| 250 | 
            +
             | 
| 251 | 
            +
                    if progressive != 'output_skip':
         | 
| 252 | 
            +
                        modules.append(nn.GroupNorm(num_groups=min(in_ch // 4, 32),
         | 
| 253 | 
            +
                                                                                num_channels=in_ch, eps=1e-6))
         | 
| 254 | 
            +
                        modules.append(conv3x3(in_ch, channels, init_scale=init_scale))
         | 
| 255 | 
            +
             | 
| 256 | 
            +
                    self.all_modules = nn.ModuleList(modules)
         | 
| 257 | 
            +
                    
         | 
| 258 | 
            +
             | 
| 259 | 
            +
                def forward(self, x, time_cond):
         | 
| 260 | 
            +
                    # timestep/noise_level embedding; only for continuous training
         | 
| 261 | 
            +
                    modules = self.all_modules
         | 
| 262 | 
            +
                    m_idx = 0
         | 
| 263 | 
            +
             | 
| 264 | 
            +
                    # Convert real and imaginary parts of (x,y) into four channel dimensions
         | 
| 265 | 
            +
                    x = torch.cat((x[:,[0],:,:].real, x[:,[0],:,:].imag,
         | 
| 266 | 
            +
                            x[:,[1],:,:].real, x[:,[1],:,:].imag), dim=1)
         | 
| 267 | 
            +
             | 
| 268 | 
            +
                    if self.embedding_type == 'fourier':
         | 
| 269 | 
            +
                        # Gaussian Fourier features embeddings.
         | 
| 270 | 
            +
                        used_sigmas = time_cond
         | 
| 271 | 
            +
                        temb = modules[m_idx](torch.log(used_sigmas))
         | 
| 272 | 
            +
                        m_idx += 1
         | 
| 273 | 
            +
             | 
| 274 | 
            +
                    elif self.embedding_type == 'positional':
         | 
| 275 | 
            +
                        # Sinusoidal positional embeddings.
         | 
| 276 | 
            +
                        timesteps = time_cond
         | 
| 277 | 
            +
                        used_sigmas = self.sigmas[time_cond.long()]
         | 
| 278 | 
            +
                        temb = layers.get_timestep_embedding(timesteps, self.nf)
         | 
| 279 | 
            +
             | 
| 280 | 
            +
                    else:
         | 
| 281 | 
            +
                        raise ValueError(f'embedding type {self.embedding_type} unknown.')
         | 
| 282 | 
            +
             | 
| 283 | 
            +
                    if self.conditional:
         | 
| 284 | 
            +
                        temb = modules[m_idx](temb)
         | 
| 285 | 
            +
                        m_idx += 1
         | 
| 286 | 
            +
                        temb = modules[m_idx](self.act(temb))
         | 
| 287 | 
            +
                        m_idx += 1
         | 
| 288 | 
            +
                    else:
         | 
| 289 | 
            +
                        temb = None
         | 
| 290 | 
            +
             | 
| 291 | 
            +
                    if not self.centered:
         | 
| 292 | 
            +
                        # If input data is in [0, 1]
         | 
| 293 | 
            +
                        x = 2 * x - 1.
         | 
| 294 | 
            +
             | 
| 295 | 
            +
                    # Downsampling block
         | 
| 296 | 
            +
                    input_pyramid = None
         | 
| 297 | 
            +
                    if self.progressive_input != 'none':
         | 
| 298 | 
            +
                        input_pyramid = x
         | 
| 299 | 
            +
             | 
| 300 | 
            +
                    # Input layer: Conv2d: 4ch -> 128ch
         | 
| 301 | 
            +
                    hs = [modules[m_idx](x)]
         | 
| 302 | 
            +
                    m_idx += 1
         | 
| 303 | 
            +
             | 
| 304 | 
            +
                    # Down path in U-Net
         | 
| 305 | 
            +
                    for i_level in range(self.num_resolutions):
         | 
| 306 | 
            +
                        # Residual blocks for this resolution
         | 
| 307 | 
            +
                        for i_block in range(self.num_res_blocks):
         | 
| 308 | 
            +
                            h = modules[m_idx](hs[-1], temb)
         | 
| 309 | 
            +
                            m_idx += 1
         | 
| 310 | 
            +
                            # Attention layer (optional)
         | 
| 311 | 
            +
                            if h.shape[-2] in self.attn_resolutions: # edit: check H dim (-2) not W dim (-1)
         | 
| 312 | 
            +
                                h = modules[m_idx](h)
         | 
| 313 | 
            +
                                m_idx += 1
         | 
| 314 | 
            +
                            hs.append(h)
         | 
| 315 | 
            +
             | 
| 316 | 
            +
                        # Downsampling
         | 
| 317 | 
            +
                        if i_level != self.num_resolutions - 1:
         | 
| 318 | 
            +
                            if self.resblock_type == 'ddpm':
         | 
| 319 | 
            +
                                h = modules[m_idx](hs[-1])
         | 
| 320 | 
            +
                                m_idx += 1
         | 
| 321 | 
            +
                            else:
         | 
| 322 | 
            +
                                h = modules[m_idx](hs[-1], temb)
         | 
| 323 | 
            +
                                m_idx += 1
         | 
| 324 | 
            +
             | 
| 325 | 
            +
                            if self.progressive_input == 'input_skip':   # Combine h with x
         | 
| 326 | 
            +
                                input_pyramid = self.pyramid_downsample(input_pyramid)
         | 
| 327 | 
            +
                                h = modules[m_idx](input_pyramid, h)
         | 
| 328 | 
            +
                                m_idx += 1
         | 
| 329 | 
            +
             | 
| 330 | 
            +
                            elif self.progressive_input == 'residual':
         | 
| 331 | 
            +
                                input_pyramid = modules[m_idx](input_pyramid)
         | 
| 332 | 
            +
                                m_idx += 1
         | 
| 333 | 
            +
                                if self.skip_rescale:
         | 
| 334 | 
            +
                                    input_pyramid = (input_pyramid + h) / np.sqrt(2.)
         | 
| 335 | 
            +
                                else:
         | 
| 336 | 
            +
                                    input_pyramid = input_pyramid + h
         | 
| 337 | 
            +
                                h = input_pyramid
         | 
| 338 | 
            +
                            hs.append(h)
         | 
| 339 | 
            +
             | 
| 340 | 
            +
                    h = hs[-1] # actualy equal to: h = h
         | 
| 341 | 
            +
                    h = modules[m_idx](h, temb)  # ResNet block
         | 
| 342 | 
            +
                    m_idx += 1
         | 
| 343 | 
            +
                    h = modules[m_idx](h)  # Attention block
         | 
| 344 | 
            +
                    m_idx += 1
         | 
| 345 | 
            +
                    h = modules[m_idx](h, temb)  # ResNet block
         | 
| 346 | 
            +
                    m_idx += 1
         | 
| 347 | 
            +
             | 
| 348 | 
            +
                    pyramid = None
         | 
| 349 | 
            +
             | 
| 350 | 
            +
                    # Upsampling block
         | 
| 351 | 
            +
                    for i_level in reversed(range(self.num_resolutions)):
         | 
| 352 | 
            +
                        for i_block in range(self.num_res_blocks + 1):
         | 
| 353 | 
            +
                            h = modules[m_idx](torch.cat([h, hs.pop()], dim=1), temb)
         | 
| 354 | 
            +
                            m_idx += 1
         | 
| 355 | 
            +
             | 
| 356 | 
            +
                        # edit: from -1 to -2
         | 
| 357 | 
            +
                        if h.shape[-2] in self.attn_resolutions:
         | 
| 358 | 
            +
                            h = modules[m_idx](h)
         | 
| 359 | 
            +
                            m_idx += 1
         | 
| 360 | 
            +
             | 
| 361 | 
            +
                        if self.progressive != 'none':
         | 
| 362 | 
            +
                            if i_level == self.num_resolutions - 1:
         | 
| 363 | 
            +
                                if self.progressive == 'output_skip':
         | 
| 364 | 
            +
                                    pyramid = self.act(modules[m_idx](h))  # GroupNorm
         | 
| 365 | 
            +
                                    m_idx += 1
         | 
| 366 | 
            +
                                    pyramid = modules[m_idx](pyramid)  # Conv2D: 256 -> 4
         | 
| 367 | 
            +
                                    m_idx += 1
         | 
| 368 | 
            +
                                elif self.progressive == 'residual':
         | 
| 369 | 
            +
                                    pyramid = self.act(modules[m_idx](h))
         | 
| 370 | 
            +
                                    m_idx += 1
         | 
| 371 | 
            +
                                    pyramid = modules[m_idx](pyramid)
         | 
| 372 | 
            +
                                    m_idx += 1
         | 
| 373 | 
            +
                                else:
         | 
| 374 | 
            +
                                    raise ValueError(f'{self.progressive} is not a valid name.')
         | 
| 375 | 
            +
                            else:
         | 
| 376 | 
            +
                                if self.progressive == 'output_skip':
         | 
| 377 | 
            +
                                    pyramid = self.pyramid_upsample(pyramid)  # Upsample
         | 
| 378 | 
            +
                                    pyramid_h = self.act(modules[m_idx](h))  # GroupNorm
         | 
| 379 | 
            +
                                    m_idx += 1
         | 
| 380 | 
            +
                                    pyramid_h = modules[m_idx](pyramid_h)
         | 
| 381 | 
            +
                                    m_idx += 1
         | 
| 382 | 
            +
                                    pyramid = pyramid + pyramid_h
         | 
| 383 | 
            +
                                elif self.progressive == 'residual':
         | 
| 384 | 
            +
                                    pyramid = modules[m_idx](pyramid)
         | 
| 385 | 
            +
                                    m_idx += 1
         | 
| 386 | 
            +
                                    if self.skip_rescale:
         | 
| 387 | 
            +
                                        pyramid = (pyramid + h) / np.sqrt(2.)
         | 
| 388 | 
            +
                                    else:
         | 
| 389 | 
            +
                                        pyramid = pyramid + h
         | 
| 390 | 
            +
                                    h = pyramid
         | 
| 391 | 
            +
                                else:
         | 
| 392 | 
            +
                                    raise ValueError(f'{self.progressive} is not a valid name')
         | 
| 393 | 
            +
             | 
| 394 | 
            +
                        # Upsampling Layer
         | 
| 395 | 
            +
                        if i_level != 0:
         | 
| 396 | 
            +
                            if self.resblock_type == 'ddpm':
         | 
| 397 | 
            +
                                h = modules[m_idx](h)
         | 
| 398 | 
            +
                                m_idx += 1
         | 
| 399 | 
            +
                            else:
         | 
| 400 | 
            +
                                h = modules[m_idx](h, temb)  # Upspampling
         | 
| 401 | 
            +
                                m_idx += 1
         | 
| 402 | 
            +
             | 
| 403 | 
            +
                    assert not hs
         | 
| 404 | 
            +
             | 
| 405 | 
            +
                    if self.progressive == 'output_skip':
         | 
| 406 | 
            +
                        h = pyramid
         | 
| 407 | 
            +
                    else:
         | 
| 408 | 
            +
                        h = self.act(modules[m_idx](h))
         | 
| 409 | 
            +
                        m_idx += 1
         | 
| 410 | 
            +
                        h = modules[m_idx](h)
         | 
| 411 | 
            +
                        m_idx += 1
         | 
| 412 | 
            +
             | 
| 413 | 
            +
                    assert m_idx == len(modules), "Implementation error"
         | 
| 414 | 
            +
                    
         | 
| 415 | 
            +
                    # Convert back to complex number
         | 
| 416 | 
            +
                    h = self.output_layer(h)
         | 
| 417 | 
            +
             | 
| 418 | 
            +
                    if self.scale_by_sigma:
         | 
| 419 | 
            +
                        used_sigmas = used_sigmas.reshape((x.shape[0], *([1] * len(x.shape[1:]))))
         | 
| 420 | 
            +
                        h = h / used_sigmas
         | 
| 421 | 
            +
             | 
| 422 | 
            +
                    h = torch.permute(h, (0, 2, 3, 1)).contiguous()
         | 
| 423 | 
            +
                    h = torch.view_as_complex(h)[:,None, :, :]
         | 
| 424 | 
            +
                    return h
         | 
    	
        sgmse/backbones/ncsnpp_utils/layers.py
    ADDED
    
    | @@ -0,0 +1,662 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            # coding=utf-8
         | 
| 2 | 
            +
            # Copyright 2020 The Google Research Authors.
         | 
| 3 | 
            +
            #
         | 
| 4 | 
            +
            # Licensed under the Apache License, Version 2.0 (the "License");
         | 
| 5 | 
            +
            # you may not use this file except in compliance with the License.
         | 
| 6 | 
            +
            # You may obtain a copy of the License at
         | 
| 7 | 
            +
            #
         | 
| 8 | 
            +
            #     http://www.apache.org/licenses/LICENSE-2.0
         | 
| 9 | 
            +
            #
         | 
| 10 | 
            +
            # Unless required by applicable law or agreed to in writing, software
         | 
| 11 | 
            +
            # distributed under the License is distributed on an "AS IS" BASIS,
         | 
| 12 | 
            +
            # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
         | 
| 13 | 
            +
            # See the License for the specific language governing permissions and
         | 
| 14 | 
            +
            # limitations under the License.
         | 
| 15 | 
            +
             | 
| 16 | 
            +
            # pylint: skip-file
         | 
| 17 | 
            +
            """Common layers for defining score networks.
         | 
| 18 | 
            +
            """
         | 
| 19 | 
            +
            import math
         | 
| 20 | 
            +
            import string
         | 
| 21 | 
            +
            from functools import partial
         | 
| 22 | 
            +
            import torch.nn as nn
         | 
| 23 | 
            +
            import torch
         | 
| 24 | 
            +
            import torch.nn.functional as F
         | 
| 25 | 
            +
            import numpy as np
         | 
| 26 | 
            +
            from .normalization import ConditionalInstanceNorm2dPlus
         | 
| 27 | 
            +
             | 
| 28 | 
            +
             | 
| 29 | 
            +
            def get_act(config):
         | 
| 30 | 
            +
              """Get activation functions from the config file."""
         | 
| 31 | 
            +
             | 
| 32 | 
            +
              if config == 'elu':
         | 
| 33 | 
            +
                return nn.ELU()
         | 
| 34 | 
            +
              elif config == 'relu':
         | 
| 35 | 
            +
                return nn.ReLU()
         | 
| 36 | 
            +
              elif config == 'lrelu':
         | 
| 37 | 
            +
                return nn.LeakyReLU(negative_slope=0.2)
         | 
| 38 | 
            +
              elif config == 'swish':
         | 
| 39 | 
            +
                return nn.SiLU()
         | 
| 40 | 
            +
              else:
         | 
| 41 | 
            +
                raise NotImplementedError('activation function does not exist!')
         | 
| 42 | 
            +
             | 
| 43 | 
            +
             | 
| 44 | 
            +
            def ncsn_conv1x1(in_planes, out_planes, stride=1, bias=True, dilation=1, init_scale=1., padding=0):
         | 
| 45 | 
            +
              """1x1 convolution. Same as NCSNv1/v2."""
         | 
| 46 | 
            +
              conv = nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=bias, dilation=dilation,
         | 
| 47 | 
            +
                               padding=padding)
         | 
| 48 | 
            +
              init_scale = 1e-10 if init_scale == 0 else init_scale
         | 
| 49 | 
            +
              conv.weight.data *= init_scale
         | 
| 50 | 
            +
              conv.bias.data *= init_scale
         | 
| 51 | 
            +
              return conv
         | 
| 52 | 
            +
             | 
| 53 | 
            +
             | 
| 54 | 
            +
            def variance_scaling(scale, mode, distribution,
         | 
| 55 | 
            +
                                 in_axis=1, out_axis=0,
         | 
| 56 | 
            +
                                 dtype=torch.float32,
         | 
| 57 | 
            +
                                 device='cpu'):
         | 
| 58 | 
            +
              """Ported from JAX. """
         | 
| 59 | 
            +
             | 
| 60 | 
            +
              def _compute_fans(shape, in_axis=1, out_axis=0):
         | 
| 61 | 
            +
                receptive_field_size = np.prod(shape) / shape[in_axis] / shape[out_axis]
         | 
| 62 | 
            +
                fan_in = shape[in_axis] * receptive_field_size
         | 
| 63 | 
            +
                fan_out = shape[out_axis] * receptive_field_size
         | 
| 64 | 
            +
                return fan_in, fan_out
         | 
| 65 | 
            +
             | 
| 66 | 
            +
              def init(shape, dtype=dtype, device=device):
         | 
| 67 | 
            +
                fan_in, fan_out = _compute_fans(shape, in_axis, out_axis)
         | 
| 68 | 
            +
                if mode == "fan_in":
         | 
| 69 | 
            +
                  denominator = fan_in
         | 
| 70 | 
            +
                elif mode == "fan_out":
         | 
| 71 | 
            +
                  denominator = fan_out
         | 
| 72 | 
            +
                elif mode == "fan_avg":
         | 
| 73 | 
            +
                  denominator = (fan_in + fan_out) / 2
         | 
| 74 | 
            +
                else:
         | 
| 75 | 
            +
                  raise ValueError(
         | 
| 76 | 
            +
                    "invalid mode for variance scaling initializer: {}".format(mode))
         | 
| 77 | 
            +
                variance = scale / denominator
         | 
| 78 | 
            +
                if distribution == "normal":
         | 
| 79 | 
            +
                  return torch.randn(*shape, dtype=dtype, device=device) * np.sqrt(variance)
         | 
| 80 | 
            +
                elif distribution == "uniform":
         | 
| 81 | 
            +
                  return (torch.rand(*shape, dtype=dtype, device=device) * 2. - 1.) * np.sqrt(3 * variance)
         | 
| 82 | 
            +
                else:
         | 
| 83 | 
            +
                  raise ValueError("invalid distribution for variance scaling initializer")
         | 
| 84 | 
            +
             | 
| 85 | 
            +
              return init
         | 
| 86 | 
            +
             | 
| 87 | 
            +
             | 
| 88 | 
            +
            def default_init(scale=1.):
         | 
| 89 | 
            +
              """The same initialization used in DDPM."""
         | 
| 90 | 
            +
              scale = 1e-10 if scale == 0 else scale
         | 
| 91 | 
            +
              return variance_scaling(scale, 'fan_avg', 'uniform')
         | 
| 92 | 
            +
             | 
| 93 | 
            +
             | 
| 94 | 
            +
            class Dense(nn.Module):
         | 
| 95 | 
            +
              """Linear layer with `default_init`."""
         | 
| 96 | 
            +
              def __init__(self):
         | 
| 97 | 
            +
                super().__init__()
         | 
| 98 | 
            +
             | 
| 99 | 
            +
             | 
| 100 | 
            +
            def ddpm_conv1x1(in_planes, out_planes, stride=1, bias=True, init_scale=1., padding=0):
         | 
| 101 | 
            +
              """1x1 convolution with DDPM initialization."""
         | 
| 102 | 
            +
              conv = nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, padding=padding, bias=bias)
         | 
| 103 | 
            +
              conv.weight.data = default_init(init_scale)(conv.weight.data.shape)
         | 
| 104 | 
            +
              nn.init.zeros_(conv.bias)
         | 
| 105 | 
            +
              return conv
         | 
| 106 | 
            +
             | 
| 107 | 
            +
             | 
| 108 | 
            +
            def ncsn_conv3x3(in_planes, out_planes, stride=1, bias=True, dilation=1, init_scale=1., padding=1):
         | 
| 109 | 
            +
              """3x3 convolution with PyTorch initialization. Same as NCSNv1/NCSNv2."""
         | 
| 110 | 
            +
              init_scale = 1e-10 if init_scale == 0 else init_scale
         | 
| 111 | 
            +
              conv = nn.Conv2d(in_planes, out_planes, stride=stride, bias=bias,
         | 
| 112 | 
            +
                               dilation=dilation, padding=padding, kernel_size=3)
         | 
| 113 | 
            +
              conv.weight.data *= init_scale
         | 
| 114 | 
            +
              conv.bias.data *= init_scale
         | 
| 115 | 
            +
              return conv
         | 
| 116 | 
            +
             | 
| 117 | 
            +
             | 
| 118 | 
            +
            def ddpm_conv3x3(in_planes, out_planes, stride=1, bias=True, dilation=1, init_scale=1., padding=1):
         | 
| 119 | 
            +
              """3x3 convolution with DDPM initialization."""
         | 
| 120 | 
            +
              conv = nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=padding,
         | 
| 121 | 
            +
                               dilation=dilation, bias=bias)
         | 
| 122 | 
            +
              conv.weight.data = default_init(init_scale)(conv.weight.data.shape)
         | 
| 123 | 
            +
              nn.init.zeros_(conv.bias)
         | 
| 124 | 
            +
              return conv
         | 
| 125 | 
            +
             | 
| 126 | 
            +
              ###########################################################################
         | 
| 127 | 
            +
              # Functions below are ported over from the NCSNv1/NCSNv2 codebase:
         | 
| 128 | 
            +
              # https://github.com/ermongroup/ncsn
         | 
| 129 | 
            +
              # https://github.com/ermongroup/ncsnv2
         | 
| 130 | 
            +
              ###########################################################################
         | 
| 131 | 
            +
             | 
| 132 | 
            +
             | 
| 133 | 
            +
            class CRPBlock(nn.Module):
         | 
| 134 | 
            +
              def __init__(self, features, n_stages, act=nn.ReLU(), maxpool=True):
         | 
| 135 | 
            +
                super().__init__()
         | 
| 136 | 
            +
                self.convs = nn.ModuleList()
         | 
| 137 | 
            +
                for i in range(n_stages):
         | 
| 138 | 
            +
                  self.convs.append(ncsn_conv3x3(features, features, stride=1, bias=False))
         | 
| 139 | 
            +
                self.n_stages = n_stages
         | 
| 140 | 
            +
                if maxpool:
         | 
| 141 | 
            +
                  self.pool = nn.MaxPool2d(kernel_size=5, stride=1, padding=2)
         | 
| 142 | 
            +
                else:
         | 
| 143 | 
            +
                  self.pool = nn.AvgPool2d(kernel_size=5, stride=1, padding=2)
         | 
| 144 | 
            +
             | 
| 145 | 
            +
                self.act = act
         | 
| 146 | 
            +
             | 
| 147 | 
            +
              def forward(self, x):
         | 
| 148 | 
            +
                x = self.act(x)
         | 
| 149 | 
            +
                path = x
         | 
| 150 | 
            +
                for i in range(self.n_stages):
         | 
| 151 | 
            +
                  path = self.pool(path)
         | 
| 152 | 
            +
                  path = self.convs[i](path)
         | 
| 153 | 
            +
                  x = path + x
         | 
| 154 | 
            +
                return x
         | 
| 155 | 
            +
             | 
| 156 | 
            +
             | 
| 157 | 
            +
            class CondCRPBlock(nn.Module):
         | 
| 158 | 
            +
              def __init__(self, features, n_stages, num_classes, normalizer, act=nn.ReLU()):
         | 
| 159 | 
            +
                super().__init__()
         | 
| 160 | 
            +
                self.convs = nn.ModuleList()
         | 
| 161 | 
            +
                self.norms = nn.ModuleList()
         | 
| 162 | 
            +
                self.normalizer = normalizer
         | 
| 163 | 
            +
                for i in range(n_stages):
         | 
| 164 | 
            +
                  self.norms.append(normalizer(features, num_classes, bias=True))
         | 
| 165 | 
            +
                  self.convs.append(ncsn_conv3x3(features, features, stride=1, bias=False))
         | 
| 166 | 
            +
             | 
| 167 | 
            +
                self.n_stages = n_stages
         | 
| 168 | 
            +
                self.pool = nn.AvgPool2d(kernel_size=5, stride=1, padding=2)
         | 
| 169 | 
            +
                self.act = act
         | 
| 170 | 
            +
             | 
| 171 | 
            +
              def forward(self, x, y):
         | 
| 172 | 
            +
                x = self.act(x)
         | 
| 173 | 
            +
                path = x
         | 
| 174 | 
            +
                for i in range(self.n_stages):
         | 
| 175 | 
            +
                  path = self.norms[i](path, y)
         | 
| 176 | 
            +
                  path = self.pool(path)
         | 
| 177 | 
            +
                  path = self.convs[i](path)
         | 
| 178 | 
            +
             | 
| 179 | 
            +
                  x = path + x
         | 
| 180 | 
            +
                return x
         | 
| 181 | 
            +
             | 
| 182 | 
            +
             | 
| 183 | 
            +
            class RCUBlock(nn.Module):
         | 
| 184 | 
            +
              def __init__(self, features, n_blocks, n_stages, act=nn.ReLU()):
         | 
| 185 | 
            +
                super().__init__()
         | 
| 186 | 
            +
             | 
| 187 | 
            +
                for i in range(n_blocks):
         | 
| 188 | 
            +
                  for j in range(n_stages):
         | 
| 189 | 
            +
                    setattr(self, '{}_{}_conv'.format(i + 1, j + 1), ncsn_conv3x3(features, features, stride=1, bias=False))
         | 
| 190 | 
            +
             | 
| 191 | 
            +
                self.stride = 1
         | 
| 192 | 
            +
                self.n_blocks = n_blocks
         | 
| 193 | 
            +
                self.n_stages = n_stages
         | 
| 194 | 
            +
                self.act = act
         | 
| 195 | 
            +
             | 
| 196 | 
            +
              def forward(self, x):
         | 
| 197 | 
            +
                for i in range(self.n_blocks):
         | 
| 198 | 
            +
                  residual = x
         | 
| 199 | 
            +
                  for j in range(self.n_stages):
         | 
| 200 | 
            +
                    x = self.act(x)
         | 
| 201 | 
            +
                    x = getattr(self, '{}_{}_conv'.format(i + 1, j + 1))(x)
         | 
| 202 | 
            +
             | 
| 203 | 
            +
                  x += residual
         | 
| 204 | 
            +
                return x
         | 
| 205 | 
            +
             | 
| 206 | 
            +
             | 
| 207 | 
            +
            class CondRCUBlock(nn.Module):
         | 
| 208 | 
            +
              def __init__(self, features, n_blocks, n_stages, num_classes, normalizer, act=nn.ReLU()):
         | 
| 209 | 
            +
                super().__init__()
         | 
| 210 | 
            +
             | 
| 211 | 
            +
                for i in range(n_blocks):
         | 
| 212 | 
            +
                  for j in range(n_stages):
         | 
| 213 | 
            +
                    setattr(self, '{}_{}_norm'.format(i + 1, j + 1), normalizer(features, num_classes, bias=True))
         | 
| 214 | 
            +
                    setattr(self, '{}_{}_conv'.format(i + 1, j + 1), ncsn_conv3x3(features, features, stride=1, bias=False))
         | 
| 215 | 
            +
             | 
| 216 | 
            +
                self.stride = 1
         | 
| 217 | 
            +
                self.n_blocks = n_blocks
         | 
| 218 | 
            +
                self.n_stages = n_stages
         | 
| 219 | 
            +
                self.act = act
         | 
| 220 | 
            +
                self.normalizer = normalizer
         | 
| 221 | 
            +
             | 
| 222 | 
            +
              def forward(self, x, y):
         | 
| 223 | 
            +
                for i in range(self.n_blocks):
         | 
| 224 | 
            +
                  residual = x
         | 
| 225 | 
            +
                  for j in range(self.n_stages):
         | 
| 226 | 
            +
                    x = getattr(self, '{}_{}_norm'.format(i + 1, j + 1))(x, y)
         | 
| 227 | 
            +
                    x = self.act(x)
         | 
| 228 | 
            +
                    x = getattr(self, '{}_{}_conv'.format(i + 1, j + 1))(x)
         | 
| 229 | 
            +
             | 
| 230 | 
            +
                  x += residual
         | 
| 231 | 
            +
                return x
         | 
| 232 | 
            +
             | 
| 233 | 
            +
             | 
| 234 | 
            +
            class MSFBlock(nn.Module):
         | 
| 235 | 
            +
              def __init__(self, in_planes, features):
         | 
| 236 | 
            +
                super().__init__()
         | 
| 237 | 
            +
                assert isinstance(in_planes, list) or isinstance(in_planes, tuple)
         | 
| 238 | 
            +
                self.convs = nn.ModuleList()
         | 
| 239 | 
            +
                self.features = features
         | 
| 240 | 
            +
             | 
| 241 | 
            +
                for i in range(len(in_planes)):
         | 
| 242 | 
            +
                  self.convs.append(ncsn_conv3x3(in_planes[i], features, stride=1, bias=True))
         | 
| 243 | 
            +
             | 
| 244 | 
            +
              def forward(self, xs, shape):
         | 
| 245 | 
            +
                sums = torch.zeros(xs[0].shape[0], self.features, *shape, device=xs[0].device)
         | 
| 246 | 
            +
                for i in range(len(self.convs)):
         | 
| 247 | 
            +
                  h = self.convs[i](xs[i])
         | 
| 248 | 
            +
                  h = F.interpolate(h, size=shape, mode='bilinear', align_corners=True)
         | 
| 249 | 
            +
                  sums += h
         | 
| 250 | 
            +
                return sums
         | 
| 251 | 
            +
             | 
| 252 | 
            +
             | 
| 253 | 
            +
            class CondMSFBlock(nn.Module):
         | 
| 254 | 
            +
              def __init__(self, in_planes, features, num_classes, normalizer):
         | 
| 255 | 
            +
                super().__init__()
         | 
| 256 | 
            +
                assert isinstance(in_planes, list) or isinstance(in_planes, tuple)
         | 
| 257 | 
            +
             | 
| 258 | 
            +
                self.convs = nn.ModuleList()
         | 
| 259 | 
            +
                self.norms = nn.ModuleList()
         | 
| 260 | 
            +
                self.features = features
         | 
| 261 | 
            +
                self.normalizer = normalizer
         | 
| 262 | 
            +
             | 
| 263 | 
            +
                for i in range(len(in_planes)):
         | 
| 264 | 
            +
                  self.convs.append(ncsn_conv3x3(in_planes[i], features, stride=1, bias=True))
         | 
| 265 | 
            +
                  self.norms.append(normalizer(in_planes[i], num_classes, bias=True))
         | 
| 266 | 
            +
             | 
| 267 | 
            +
              def forward(self, xs, y, shape):
         | 
| 268 | 
            +
                sums = torch.zeros(xs[0].shape[0], self.features, *shape, device=xs[0].device)
         | 
| 269 | 
            +
                for i in range(len(self.convs)):
         | 
| 270 | 
            +
                  h = self.norms[i](xs[i], y)
         | 
| 271 | 
            +
                  h = self.convs[i](h)
         | 
| 272 | 
            +
                  h = F.interpolate(h, size=shape, mode='bilinear', align_corners=True)
         | 
| 273 | 
            +
                  sums += h
         | 
| 274 | 
            +
                return sums
         | 
| 275 | 
            +
             | 
| 276 | 
            +
             | 
| 277 | 
            +
            class RefineBlock(nn.Module):
         | 
| 278 | 
            +
              def __init__(self, in_planes, features, act=nn.ReLU(), start=False, end=False, maxpool=True):
         | 
| 279 | 
            +
                super().__init__()
         | 
| 280 | 
            +
             | 
| 281 | 
            +
                assert isinstance(in_planes, tuple) or isinstance(in_planes, list)
         | 
| 282 | 
            +
                self.n_blocks = n_blocks = len(in_planes)
         | 
| 283 | 
            +
             | 
| 284 | 
            +
                self.adapt_convs = nn.ModuleList()
         | 
| 285 | 
            +
                for i in range(n_blocks):
         | 
| 286 | 
            +
                  self.adapt_convs.append(RCUBlock(in_planes[i], 2, 2, act))
         | 
| 287 | 
            +
             | 
| 288 | 
            +
                self.output_convs = RCUBlock(features, 3 if end else 1, 2, act)
         | 
| 289 | 
            +
             | 
| 290 | 
            +
                if not start:
         | 
| 291 | 
            +
                  self.msf = MSFBlock(in_planes, features)
         | 
| 292 | 
            +
             | 
| 293 | 
            +
                self.crp = CRPBlock(features, 2, act, maxpool=maxpool)
         | 
| 294 | 
            +
             | 
| 295 | 
            +
              def forward(self, xs, output_shape):
         | 
| 296 | 
            +
                assert isinstance(xs, tuple) or isinstance(xs, list)
         | 
| 297 | 
            +
                hs = []
         | 
| 298 | 
            +
                for i in range(len(xs)):
         | 
| 299 | 
            +
                  h = self.adapt_convs[i](xs[i])
         | 
| 300 | 
            +
                  hs.append(h)
         | 
| 301 | 
            +
             | 
| 302 | 
            +
                if self.n_blocks > 1:
         | 
| 303 | 
            +
                  h = self.msf(hs, output_shape)
         | 
| 304 | 
            +
                else:
         | 
| 305 | 
            +
                  h = hs[0]
         | 
| 306 | 
            +
             | 
| 307 | 
            +
                h = self.crp(h)
         | 
| 308 | 
            +
                h = self.output_convs(h)
         | 
| 309 | 
            +
             | 
| 310 | 
            +
                return h
         | 
| 311 | 
            +
             | 
| 312 | 
            +
             | 
| 313 | 
            +
            class CondRefineBlock(nn.Module):
         | 
| 314 | 
            +
              def __init__(self, in_planes, features, num_classes, normalizer, act=nn.ReLU(), start=False, end=False):
         | 
| 315 | 
            +
                super().__init__()
         | 
| 316 | 
            +
             | 
| 317 | 
            +
                assert isinstance(in_planes, tuple) or isinstance(in_planes, list)
         | 
| 318 | 
            +
                self.n_blocks = n_blocks = len(in_planes)
         | 
| 319 | 
            +
             | 
| 320 | 
            +
                self.adapt_convs = nn.ModuleList()
         | 
| 321 | 
            +
                for i in range(n_blocks):
         | 
| 322 | 
            +
                  self.adapt_convs.append(
         | 
| 323 | 
            +
                    CondRCUBlock(in_planes[i], 2, 2, num_classes, normalizer, act)
         | 
| 324 | 
            +
                  )
         | 
| 325 | 
            +
             | 
| 326 | 
            +
                self.output_convs = CondRCUBlock(features, 3 if end else 1, 2, num_classes, normalizer, act)
         | 
| 327 | 
            +
             | 
| 328 | 
            +
                if not start:
         | 
| 329 | 
            +
                  self.msf = CondMSFBlock(in_planes, features, num_classes, normalizer)
         | 
| 330 | 
            +
             | 
| 331 | 
            +
                self.crp = CondCRPBlock(features, 2, num_classes, normalizer, act)
         | 
| 332 | 
            +
             | 
| 333 | 
            +
              def forward(self, xs, y, output_shape):
         | 
| 334 | 
            +
                assert isinstance(xs, tuple) or isinstance(xs, list)
         | 
| 335 | 
            +
                hs = []
         | 
| 336 | 
            +
                for i in range(len(xs)):
         | 
| 337 | 
            +
                  h = self.adapt_convs[i](xs[i], y)
         | 
| 338 | 
            +
                  hs.append(h)
         | 
| 339 | 
            +
             | 
| 340 | 
            +
                if self.n_blocks > 1:
         | 
| 341 | 
            +
                  h = self.msf(hs, y, output_shape)
         | 
| 342 | 
            +
                else:
         | 
| 343 | 
            +
                  h = hs[0]
         | 
| 344 | 
            +
             | 
| 345 | 
            +
                h = self.crp(h, y)
         | 
| 346 | 
            +
                h = self.output_convs(h, y)
         | 
| 347 | 
            +
             | 
| 348 | 
            +
                return h
         | 
| 349 | 
            +
             | 
| 350 | 
            +
             | 
| 351 | 
            +
            class ConvMeanPool(nn.Module):
         | 
| 352 | 
            +
              def __init__(self, input_dim, output_dim, kernel_size=3, biases=True, adjust_padding=False):
         | 
| 353 | 
            +
                super().__init__()
         | 
| 354 | 
            +
                if not adjust_padding:
         | 
| 355 | 
            +
                  conv = nn.Conv2d(input_dim, output_dim, kernel_size, stride=1, padding=kernel_size // 2, bias=biases)
         | 
| 356 | 
            +
                  self.conv = conv
         | 
| 357 | 
            +
                else:
         | 
| 358 | 
            +
                  conv = nn.Conv2d(input_dim, output_dim, kernel_size, stride=1, padding=kernel_size // 2, bias=biases)
         | 
| 359 | 
            +
             | 
| 360 | 
            +
                  self.conv = nn.Sequential(
         | 
| 361 | 
            +
                    nn.ZeroPad2d((1, 0, 1, 0)),
         | 
| 362 | 
            +
                    conv
         | 
| 363 | 
            +
                  )
         | 
| 364 | 
            +
             | 
| 365 | 
            +
              def forward(self, inputs):
         | 
| 366 | 
            +
                output = self.conv(inputs)
         | 
| 367 | 
            +
                output = sum([output[:, :, ::2, ::2], output[:, :, 1::2, ::2],
         | 
| 368 | 
            +
                              output[:, :, ::2, 1::2], output[:, :, 1::2, 1::2]]) / 4.
         | 
| 369 | 
            +
                return output
         | 
| 370 | 
            +
             | 
| 371 | 
            +
             | 
| 372 | 
            +
            class MeanPoolConv(nn.Module):
         | 
| 373 | 
            +
              def __init__(self, input_dim, output_dim, kernel_size=3, biases=True):
         | 
| 374 | 
            +
                super().__init__()
         | 
| 375 | 
            +
                self.conv = nn.Conv2d(input_dim, output_dim, kernel_size, stride=1, padding=kernel_size // 2, bias=biases)
         | 
| 376 | 
            +
             | 
| 377 | 
            +
              def forward(self, inputs):
         | 
| 378 | 
            +
                output = inputs
         | 
| 379 | 
            +
                output = sum([output[:, :, ::2, ::2], output[:, :, 1::2, ::2],
         | 
| 380 | 
            +
                              output[:, :, ::2, 1::2], output[:, :, 1::2, 1::2]]) / 4.
         | 
| 381 | 
            +
                return self.conv(output)
         | 
| 382 | 
            +
             | 
| 383 | 
            +
             | 
| 384 | 
            +
            class UpsampleConv(nn.Module):
         | 
| 385 | 
            +
              def __init__(self, input_dim, output_dim, kernel_size=3, biases=True):
         | 
| 386 | 
            +
                super().__init__()
         | 
| 387 | 
            +
                self.conv = nn.Conv2d(input_dim, output_dim, kernel_size, stride=1, padding=kernel_size // 2, bias=biases)
         | 
| 388 | 
            +
                self.pixelshuffle = nn.PixelShuffle(upscale_factor=2)
         | 
| 389 | 
            +
             | 
| 390 | 
            +
              def forward(self, inputs):
         | 
| 391 | 
            +
                output = inputs
         | 
| 392 | 
            +
                output = torch.cat([output, output, output, output], dim=1)
         | 
| 393 | 
            +
                output = self.pixelshuffle(output)
         | 
| 394 | 
            +
                return self.conv(output)
         | 
| 395 | 
            +
             | 
| 396 | 
            +
             | 
| 397 | 
            +
            class ConditionalResidualBlock(nn.Module):
         | 
| 398 | 
            +
              def __init__(self, input_dim, output_dim, num_classes, resample=1, act=nn.ELU(),
         | 
| 399 | 
            +
                           normalization=ConditionalInstanceNorm2dPlus, adjust_padding=False, dilation=None):
         | 
| 400 | 
            +
                super().__init__()
         | 
| 401 | 
            +
                self.non_linearity = act
         | 
| 402 | 
            +
                self.input_dim = input_dim
         | 
| 403 | 
            +
                self.output_dim = output_dim
         | 
| 404 | 
            +
                self.resample = resample
         | 
| 405 | 
            +
                self.normalization = normalization
         | 
| 406 | 
            +
                if resample == 'down':
         | 
| 407 | 
            +
                  if dilation > 1:
         | 
| 408 | 
            +
                    self.conv1 = ncsn_conv3x3(input_dim, input_dim, dilation=dilation)
         | 
| 409 | 
            +
                    self.normalize2 = normalization(input_dim, num_classes)
         | 
| 410 | 
            +
                    self.conv2 = ncsn_conv3x3(input_dim, output_dim, dilation=dilation)
         | 
| 411 | 
            +
                    conv_shortcut = partial(ncsn_conv3x3, dilation=dilation)
         | 
| 412 | 
            +
                  else:
         | 
| 413 | 
            +
                    self.conv1 = ncsn_conv3x3(input_dim, input_dim)
         | 
| 414 | 
            +
                    self.normalize2 = normalization(input_dim, num_classes)
         | 
| 415 | 
            +
                    self.conv2 = ConvMeanPool(input_dim, output_dim, 3, adjust_padding=adjust_padding)
         | 
| 416 | 
            +
                    conv_shortcut = partial(ConvMeanPool, kernel_size=1, adjust_padding=adjust_padding)
         | 
| 417 | 
            +
             | 
| 418 | 
            +
                elif resample is None:
         | 
| 419 | 
            +
                  if dilation > 1:
         | 
| 420 | 
            +
                    conv_shortcut = partial(ncsn_conv3x3, dilation=dilation)
         | 
| 421 | 
            +
                    self.conv1 = ncsn_conv3x3(input_dim, output_dim, dilation=dilation)
         | 
| 422 | 
            +
                    self.normalize2 = normalization(output_dim, num_classes)
         | 
| 423 | 
            +
                    self.conv2 = ncsn_conv3x3(output_dim, output_dim, dilation=dilation)
         | 
| 424 | 
            +
                  else:
         | 
| 425 | 
            +
                    conv_shortcut = nn.Conv2d
         | 
| 426 | 
            +
                    self.conv1 = ncsn_conv3x3(input_dim, output_dim)
         | 
| 427 | 
            +
                    self.normalize2 = normalization(output_dim, num_classes)
         | 
| 428 | 
            +
                    self.conv2 = ncsn_conv3x3(output_dim, output_dim)
         | 
| 429 | 
            +
                else:
         | 
| 430 | 
            +
                  raise Exception('invalid resample value')
         | 
| 431 | 
            +
             | 
| 432 | 
            +
                if output_dim != input_dim or resample is not None:
         | 
| 433 | 
            +
                  self.shortcut = conv_shortcut(input_dim, output_dim)
         | 
| 434 | 
            +
             | 
| 435 | 
            +
                self.normalize1 = normalization(input_dim, num_classes)
         | 
| 436 | 
            +
             | 
| 437 | 
            +
              def forward(self, x, y):
         | 
| 438 | 
            +
                output = self.normalize1(x, y)
         | 
| 439 | 
            +
                output = self.non_linearity(output)
         | 
| 440 | 
            +
                output = self.conv1(output)
         | 
| 441 | 
            +
                output = self.normalize2(output, y)
         | 
| 442 | 
            +
                output = self.non_linearity(output)
         | 
| 443 | 
            +
                output = self.conv2(output)
         | 
| 444 | 
            +
             | 
| 445 | 
            +
                if self.output_dim == self.input_dim and self.resample is None:
         | 
| 446 | 
            +
                  shortcut = x
         | 
| 447 | 
            +
                else:
         | 
| 448 | 
            +
                  shortcut = self.shortcut(x)
         | 
| 449 | 
            +
             | 
| 450 | 
            +
                return shortcut + output
         | 
| 451 | 
            +
             | 
| 452 | 
            +
             | 
| 453 | 
            +
            class ResidualBlock(nn.Module):
         | 
| 454 | 
            +
              def __init__(self, input_dim, output_dim, resample=None, act=nn.ELU(),
         | 
| 455 | 
            +
                           normalization=nn.InstanceNorm2d, adjust_padding=False, dilation=1):
         | 
| 456 | 
            +
                super().__init__()
         | 
| 457 | 
            +
                self.non_linearity = act
         | 
| 458 | 
            +
                self.input_dim = input_dim
         | 
| 459 | 
            +
                self.output_dim = output_dim
         | 
| 460 | 
            +
                self.resample = resample
         | 
| 461 | 
            +
                self.normalization = normalization
         | 
| 462 | 
            +
                if resample == 'down':
         | 
| 463 | 
            +
                  if dilation > 1:
         | 
| 464 | 
            +
                    self.conv1 = ncsn_conv3x3(input_dim, input_dim, dilation=dilation)
         | 
| 465 | 
            +
                    self.normalize2 = normalization(input_dim)
         | 
| 466 | 
            +
                    self.conv2 = ncsn_conv3x3(input_dim, output_dim, dilation=dilation)
         | 
| 467 | 
            +
                    conv_shortcut = partial(ncsn_conv3x3, dilation=dilation)
         | 
| 468 | 
            +
                  else:
         | 
| 469 | 
            +
                    self.conv1 = ncsn_conv3x3(input_dim, input_dim)
         | 
| 470 | 
            +
                    self.normalize2 = normalization(input_dim)
         | 
| 471 | 
            +
                    self.conv2 = ConvMeanPool(input_dim, output_dim, 3, adjust_padding=adjust_padding)
         | 
| 472 | 
            +
                    conv_shortcut = partial(ConvMeanPool, kernel_size=1, adjust_padding=adjust_padding)
         | 
| 473 | 
            +
             | 
| 474 | 
            +
                elif resample is None:
         | 
| 475 | 
            +
                  if dilation > 1:
         | 
| 476 | 
            +
                    conv_shortcut = partial(ncsn_conv3x3, dilation=dilation)
         | 
| 477 | 
            +
                    self.conv1 = ncsn_conv3x3(input_dim, output_dim, dilation=dilation)
         | 
| 478 | 
            +
                    self.normalize2 = normalization(output_dim)
         | 
| 479 | 
            +
                    self.conv2 = ncsn_conv3x3(output_dim, output_dim, dilation=dilation)
         | 
| 480 | 
            +
                  else:
         | 
| 481 | 
            +
                    # conv_shortcut = nn.Conv2d ### Something wierd here.
         | 
| 482 | 
            +
                    conv_shortcut = partial(ncsn_conv1x1)
         | 
| 483 | 
            +
                    self.conv1 = ncsn_conv3x3(input_dim, output_dim)
         | 
| 484 | 
            +
                    self.normalize2 = normalization(output_dim)
         | 
| 485 | 
            +
                    self.conv2 = ncsn_conv3x3(output_dim, output_dim)
         | 
| 486 | 
            +
                else:
         | 
| 487 | 
            +
                  raise Exception('invalid resample value')
         | 
| 488 | 
            +
             | 
| 489 | 
            +
                if output_dim != input_dim or resample is not None:
         | 
| 490 | 
            +
                  self.shortcut = conv_shortcut(input_dim, output_dim)
         | 
| 491 | 
            +
             | 
| 492 | 
            +
                self.normalize1 = normalization(input_dim)
         | 
| 493 | 
            +
             | 
| 494 | 
            +
              def forward(self, x):
         | 
| 495 | 
            +
                output = self.normalize1(x)
         | 
| 496 | 
            +
                output = self.non_linearity(output)
         | 
| 497 | 
            +
                output = self.conv1(output)
         | 
| 498 | 
            +
                output = self.normalize2(output)
         | 
| 499 | 
            +
                output = self.non_linearity(output)
         | 
| 500 | 
            +
                output = self.conv2(output)
         | 
| 501 | 
            +
             | 
| 502 | 
            +
                if self.output_dim == self.input_dim and self.resample is None:
         | 
| 503 | 
            +
                  shortcut = x
         | 
| 504 | 
            +
                else:
         | 
| 505 | 
            +
                  shortcut = self.shortcut(x)
         | 
| 506 | 
            +
             | 
| 507 | 
            +
                return shortcut + output
         | 
| 508 | 
            +
             | 
| 509 | 
            +
             | 
| 510 | 
            +
            ###########################################################################
         | 
| 511 | 
            +
            # Functions below are ported over from the DDPM codebase:
         | 
| 512 | 
            +
            #  https://github.com/hojonathanho/diffusion/blob/master/diffusion_tf/nn.py
         | 
| 513 | 
            +
            ###########################################################################
         | 
| 514 | 
            +
             | 
| 515 | 
            +
            def get_timestep_embedding(timesteps, embedding_dim, max_positions=10000):
         | 
| 516 | 
            +
              assert len(timesteps.shape) == 1  # and timesteps.dtype == tf.int32
         | 
| 517 | 
            +
              half_dim = embedding_dim // 2
         | 
| 518 | 
            +
              # magic number 10000 is from transformers
         | 
| 519 | 
            +
              emb = math.log(max_positions) / (half_dim - 1)
         | 
| 520 | 
            +
              # emb = math.log(2.) / (half_dim - 1)
         | 
| 521 | 
            +
              emb = torch.exp(torch.arange(half_dim, dtype=torch.float32, device=timesteps.device) * -emb)
         | 
| 522 | 
            +
              # emb = tf.range(num_embeddings, dtype=jnp.float32)[:, None] * emb[None, :]
         | 
| 523 | 
            +
              # emb = tf.cast(timesteps, dtype=jnp.float32)[:, None] * emb[None, :]
         | 
| 524 | 
            +
              emb = timesteps.float()[:, None] * emb[None, :]
         | 
| 525 | 
            +
              emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
         | 
| 526 | 
            +
              if embedding_dim % 2 == 1:  # zero pad
         | 
| 527 | 
            +
                emb = F.pad(emb, (0, 1), mode='constant')
         | 
| 528 | 
            +
              assert emb.shape == (timesteps.shape[0], embedding_dim)
         | 
| 529 | 
            +
              return emb
         | 
| 530 | 
            +
             | 
| 531 | 
            +
             | 
| 532 | 
            +
            def _einsum(a, b, c, x, y):
         | 
| 533 | 
            +
              einsum_str = '{},{}->{}'.format(''.join(a), ''.join(b), ''.join(c))
         | 
| 534 | 
            +
              return torch.einsum(einsum_str, x, y)
         | 
| 535 | 
            +
             | 
| 536 | 
            +
             | 
| 537 | 
            +
            def contract_inner(x, y):
         | 
| 538 | 
            +
              """tensordot(x, y, 1)."""
         | 
| 539 | 
            +
              x_chars = list(string.ascii_lowercase[:len(x.shape)])
         | 
| 540 | 
            +
              y_chars = list(string.ascii_lowercase[len(x.shape):len(y.shape) + len(x.shape)])
         | 
| 541 | 
            +
              y_chars[0] = x_chars[-1]  # first axis of y and last of x get summed
         | 
| 542 | 
            +
              out_chars = x_chars[:-1] + y_chars[1:]
         | 
| 543 | 
            +
              return _einsum(x_chars, y_chars, out_chars, x, y)
         | 
| 544 | 
            +
             | 
| 545 | 
            +
             | 
| 546 | 
            +
            class NIN(nn.Module):
         | 
| 547 | 
            +
              def __init__(self, in_dim, num_units, init_scale=0.1):
         | 
| 548 | 
            +
                super().__init__()
         | 
| 549 | 
            +
                self.W = nn.Parameter(default_init(scale=init_scale)((in_dim, num_units)), requires_grad=True)
         | 
| 550 | 
            +
                self.b = nn.Parameter(torch.zeros(num_units), requires_grad=True)
         | 
| 551 | 
            +
             | 
| 552 | 
            +
              def forward(self, x):
         | 
| 553 | 
            +
                x = x.permute(0, 2, 3, 1)
         | 
| 554 | 
            +
                y = contract_inner(x, self.W) + self.b
         | 
| 555 | 
            +
                return y.permute(0, 3, 1, 2)
         | 
| 556 | 
            +
             | 
| 557 | 
            +
             | 
| 558 | 
            +
            class AttnBlock(nn.Module):
         | 
| 559 | 
            +
              """Channel-wise self-attention block."""
         | 
| 560 | 
            +
              def __init__(self, channels):
         | 
| 561 | 
            +
                super().__init__()
         | 
| 562 | 
            +
                self.GroupNorm_0 = nn.GroupNorm(num_groups=32, num_channels=channels, eps=1e-6)
         | 
| 563 | 
            +
                self.NIN_0 = NIN(channels, channels)
         | 
| 564 | 
            +
                self.NIN_1 = NIN(channels, channels)
         | 
| 565 | 
            +
                self.NIN_2 = NIN(channels, channels)
         | 
| 566 | 
            +
                self.NIN_3 = NIN(channels, channels, init_scale=0.)
         | 
| 567 | 
            +
             | 
| 568 | 
            +
              def forward(self, x):
         | 
| 569 | 
            +
                B, C, H, W = x.shape
         | 
| 570 | 
            +
                h = self.GroupNorm_0(x)
         | 
| 571 | 
            +
                q = self.NIN_0(h)
         | 
| 572 | 
            +
                k = self.NIN_1(h)
         | 
| 573 | 
            +
                v = self.NIN_2(h)
         | 
| 574 | 
            +
             | 
| 575 | 
            +
                w = torch.einsum('bchw,bcij->bhwij', q, k) * (int(C) ** (-0.5))
         | 
| 576 | 
            +
                w = torch.reshape(w, (B, H, W, H * W))
         | 
| 577 | 
            +
                w = F.softmax(w, dim=-1)
         | 
| 578 | 
            +
                w = torch.reshape(w, (B, H, W, H, W))
         | 
| 579 | 
            +
                h = torch.einsum('bhwij,bcij->bchw', w, v)
         | 
| 580 | 
            +
                h = self.NIN_3(h)
         | 
| 581 | 
            +
                return x + h
         | 
| 582 | 
            +
             | 
| 583 | 
            +
             | 
| 584 | 
            +
            class Upsample(nn.Module):
         | 
| 585 | 
            +
              def __init__(self, channels, with_conv=False):
         | 
| 586 | 
            +
                super().__init__()
         | 
| 587 | 
            +
                if with_conv:
         | 
| 588 | 
            +
                  self.Conv_0 = ddpm_conv3x3(channels, channels)
         | 
| 589 | 
            +
                self.with_conv = with_conv
         | 
| 590 | 
            +
             | 
| 591 | 
            +
              def forward(self, x):
         | 
| 592 | 
            +
                B, C, H, W = x.shape
         | 
| 593 | 
            +
                h = F.interpolate(x, (H * 2, W * 2), mode='nearest')
         | 
| 594 | 
            +
                if self.with_conv:
         | 
| 595 | 
            +
                  h = self.Conv_0(h)
         | 
| 596 | 
            +
                return h
         | 
| 597 | 
            +
             | 
| 598 | 
            +
             | 
| 599 | 
            +
            class Downsample(nn.Module):
         | 
| 600 | 
            +
              def __init__(self, channels, with_conv=False):
         | 
| 601 | 
            +
                super().__init__()
         | 
| 602 | 
            +
                if with_conv:
         | 
| 603 | 
            +
                  self.Conv_0 = ddpm_conv3x3(channels, channels, stride=2, padding=0)
         | 
| 604 | 
            +
                self.with_conv = with_conv
         | 
| 605 | 
            +
             | 
| 606 | 
            +
              def forward(self, x):
         | 
| 607 | 
            +
                B, C, H, W = x.shape
         | 
| 608 | 
            +
                # Emulate 'SAME' padding
         | 
| 609 | 
            +
                if self.with_conv:
         | 
| 610 | 
            +
                  x = F.pad(x, (0, 1, 0, 1))
         | 
| 611 | 
            +
                  x = self.Conv_0(x)
         | 
| 612 | 
            +
                else:
         | 
| 613 | 
            +
                  x = F.avg_pool2d(x, kernel_size=2, stride=2, padding=0)
         | 
| 614 | 
            +
             | 
| 615 | 
            +
                assert x.shape == (B, C, H // 2, W // 2)
         | 
| 616 | 
            +
                return x
         | 
| 617 | 
            +
             | 
| 618 | 
            +
             | 
| 619 | 
            +
            class ResnetBlockDDPM(nn.Module):
         | 
| 620 | 
            +
              """The ResNet Blocks used in DDPM."""
         | 
| 621 | 
            +
              def __init__(self, act, in_ch, out_ch=None, temb_dim=None, conv_shortcut=False, dropout=0.1):
         | 
| 622 | 
            +
                super().__init__()
         | 
| 623 | 
            +
                if out_ch is None:
         | 
| 624 | 
            +
                  out_ch = in_ch
         | 
| 625 | 
            +
                self.GroupNorm_0 = nn.GroupNorm(num_groups=32, num_channels=in_ch, eps=1e-6)
         | 
| 626 | 
            +
                self.act = act
         | 
| 627 | 
            +
                self.Conv_0 = ddpm_conv3x3(in_ch, out_ch)
         | 
| 628 | 
            +
                if temb_dim is not None:
         | 
| 629 | 
            +
                  self.Dense_0 = nn.Linear(temb_dim, out_ch)
         | 
| 630 | 
            +
                  self.Dense_0.weight.data = default_init()(self.Dense_0.weight.data.shape)
         | 
| 631 | 
            +
                  nn.init.zeros_(self.Dense_0.bias)
         | 
| 632 | 
            +
             | 
| 633 | 
            +
                self.GroupNorm_1 = nn.GroupNorm(num_groups=32, num_channels=out_ch, eps=1e-6)
         | 
| 634 | 
            +
                self.Dropout_0 = nn.Dropout(dropout)
         | 
| 635 | 
            +
                self.Conv_1 = ddpm_conv3x3(out_ch, out_ch, init_scale=0.)
         | 
| 636 | 
            +
                if in_ch != out_ch:
         | 
| 637 | 
            +
                  if conv_shortcut:
         | 
| 638 | 
            +
                    self.Conv_2 = ddpm_conv3x3(in_ch, out_ch)
         | 
| 639 | 
            +
                  else:
         | 
| 640 | 
            +
                    self.NIN_0 = NIN(in_ch, out_ch)
         | 
| 641 | 
            +
                self.out_ch = out_ch
         | 
| 642 | 
            +
                self.in_ch = in_ch
         | 
| 643 | 
            +
                self.conv_shortcut = conv_shortcut
         | 
| 644 | 
            +
             | 
| 645 | 
            +
              def forward(self, x, temb=None):
         | 
| 646 | 
            +
                B, C, H, W = x.shape
         | 
| 647 | 
            +
                assert C == self.in_ch
         | 
| 648 | 
            +
                out_ch = self.out_ch if self.out_ch else self.in_ch
         | 
| 649 | 
            +
                h = self.act(self.GroupNorm_0(x))
         | 
| 650 | 
            +
                h = self.Conv_0(h)
         | 
| 651 | 
            +
                # Add bias to each feature map conditioned on the time embedding
         | 
| 652 | 
            +
                if temb is not None:
         | 
| 653 | 
            +
                  h += self.Dense_0(self.act(temb))[:, :, None, None]
         | 
| 654 | 
            +
                h = self.act(self.GroupNorm_1(h))
         | 
| 655 | 
            +
                h = self.Dropout_0(h)
         | 
| 656 | 
            +
                h = self.Conv_1(h)
         | 
| 657 | 
            +
                if C != out_ch:
         | 
| 658 | 
            +
                  if self.conv_shortcut:
         | 
| 659 | 
            +
                    x = self.Conv_2(x)
         | 
| 660 | 
            +
                  else:
         | 
| 661 | 
            +
                    x = self.NIN_0(x)
         | 
| 662 | 
            +
                return x + h
         | 
    	
        sgmse/backbones/ncsnpp_utils/layerspp.py
    ADDED
    
    | @@ -0,0 +1,274 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            # coding=utf-8
         | 
| 2 | 
            +
            # Copyright 2020 The Google Research Authors.
         | 
| 3 | 
            +
            #
         | 
| 4 | 
            +
            # Licensed under the Apache License, Version 2.0 (the "License");
         | 
| 5 | 
            +
            # you may not use this file except in compliance with the License.
         | 
| 6 | 
            +
            # You may obtain a copy of the License at
         | 
| 7 | 
            +
            #
         | 
| 8 | 
            +
            #     http://www.apache.org/licenses/LICENSE-2.0
         | 
| 9 | 
            +
            #
         | 
| 10 | 
            +
            # Unless required by applicable law or agreed to in writing, software
         | 
| 11 | 
            +
            # distributed under the License is distributed on an "AS IS" BASIS,
         | 
| 12 | 
            +
            # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
         | 
| 13 | 
            +
            # See the License for the specific language governing permissions and
         | 
| 14 | 
            +
            # limitations under the License.
         | 
| 15 | 
            +
             | 
| 16 | 
            +
            # pylint: skip-file
         | 
| 17 | 
            +
            """Layers for defining NCSN++.
         | 
| 18 | 
            +
            """
         | 
| 19 | 
            +
            from . import layers
         | 
| 20 | 
            +
            from . import up_or_down_sampling
         | 
| 21 | 
            +
            import torch.nn as nn
         | 
| 22 | 
            +
            import torch
         | 
| 23 | 
            +
            import torch.nn.functional as F
         | 
| 24 | 
            +
            import numpy as np
         | 
| 25 | 
            +
             | 
| 26 | 
            +
            conv1x1 = layers.ddpm_conv1x1
         | 
| 27 | 
            +
            conv3x3 = layers.ddpm_conv3x3
         | 
| 28 | 
            +
            NIN = layers.NIN
         | 
| 29 | 
            +
            default_init = layers.default_init
         | 
| 30 | 
            +
             | 
| 31 | 
            +
             | 
| 32 | 
            +
            class GaussianFourierProjection(nn.Module):
         | 
| 33 | 
            +
              """Gaussian Fourier embeddings for noise levels."""
         | 
| 34 | 
            +
             | 
| 35 | 
            +
              def __init__(self, embedding_size=256, scale=1.0):
         | 
| 36 | 
            +
                super().__init__()
         | 
| 37 | 
            +
                self.W = nn.Parameter(torch.randn(embedding_size) * scale, requires_grad=False)
         | 
| 38 | 
            +
             | 
| 39 | 
            +
              def forward(self, x):
         | 
| 40 | 
            +
                x_proj = x[:, None] * self.W[None, :] * 2 * np.pi
         | 
| 41 | 
            +
                return torch.cat([torch.sin(x_proj), torch.cos(x_proj)], dim=-1)
         | 
| 42 | 
            +
             | 
| 43 | 
            +
             | 
| 44 | 
            +
            class Combine(nn.Module):
         | 
| 45 | 
            +
              """Combine information from skip connections."""
         | 
| 46 | 
            +
             | 
| 47 | 
            +
              def __init__(self, dim1, dim2, method='cat'):
         | 
| 48 | 
            +
                super().__init__()
         | 
| 49 | 
            +
                self.Conv_0 = conv1x1(dim1, dim2)
         | 
| 50 | 
            +
                self.method = method
         | 
| 51 | 
            +
             | 
| 52 | 
            +
              def forward(self, x, y):
         | 
| 53 | 
            +
                h = self.Conv_0(x)
         | 
| 54 | 
            +
                if self.method == 'cat':
         | 
| 55 | 
            +
                  return torch.cat([h, y], dim=1)
         | 
| 56 | 
            +
                elif self.method == 'sum':
         | 
| 57 | 
            +
                  return h + y
         | 
| 58 | 
            +
                else:
         | 
| 59 | 
            +
                  raise ValueError(f'Method {self.method} not recognized.')
         | 
| 60 | 
            +
             | 
| 61 | 
            +
             | 
| 62 | 
            +
            class AttnBlockpp(nn.Module):
         | 
| 63 | 
            +
              """Channel-wise self-attention block. Modified from DDPM."""
         | 
| 64 | 
            +
             | 
| 65 | 
            +
              def __init__(self, channels, skip_rescale=False, init_scale=0.):
         | 
| 66 | 
            +
                super().__init__()
         | 
| 67 | 
            +
                self.GroupNorm_0 = nn.GroupNorm(num_groups=min(channels // 4, 32), num_channels=channels,
         | 
| 68 | 
            +
                                              eps=1e-6)
         | 
| 69 | 
            +
                self.NIN_0 = NIN(channels, channels)
         | 
| 70 | 
            +
                self.NIN_1 = NIN(channels, channels)
         | 
| 71 | 
            +
                self.NIN_2 = NIN(channels, channels)
         | 
| 72 | 
            +
                self.NIN_3 = NIN(channels, channels, init_scale=init_scale)
         | 
| 73 | 
            +
                self.skip_rescale = skip_rescale
         | 
| 74 | 
            +
             | 
| 75 | 
            +
              def forward(self, x):
         | 
| 76 | 
            +
                B, C, H, W = x.shape
         | 
| 77 | 
            +
                h = self.GroupNorm_0(x)
         | 
| 78 | 
            +
                q = self.NIN_0(h)
         | 
| 79 | 
            +
                k = self.NIN_1(h)
         | 
| 80 | 
            +
                v = self.NIN_2(h)
         | 
| 81 | 
            +
             | 
| 82 | 
            +
                w = torch.einsum('bchw,bcij->bhwij', q, k) * (int(C) ** (-0.5))
         | 
| 83 | 
            +
                w = torch.reshape(w, (B, H, W, H * W))
         | 
| 84 | 
            +
                w = F.softmax(w, dim=-1)
         | 
| 85 | 
            +
                w = torch.reshape(w, (B, H, W, H, W))
         | 
| 86 | 
            +
                h = torch.einsum('bhwij,bcij->bchw', w, v)
         | 
| 87 | 
            +
                h = self.NIN_3(h)
         | 
| 88 | 
            +
                if not self.skip_rescale:
         | 
| 89 | 
            +
                  return x + h
         | 
| 90 | 
            +
                else:
         | 
| 91 | 
            +
                  return (x + h) / np.sqrt(2.)
         | 
| 92 | 
            +
             | 
| 93 | 
            +
             | 
| 94 | 
            +
            class Upsample(nn.Module):
         | 
| 95 | 
            +
              def __init__(self, in_ch=None, out_ch=None, with_conv=False, fir=False,
         | 
| 96 | 
            +
                           fir_kernel=(1, 3, 3, 1)):
         | 
| 97 | 
            +
                super().__init__()
         | 
| 98 | 
            +
                out_ch = out_ch if out_ch else in_ch
         | 
| 99 | 
            +
                if not fir:
         | 
| 100 | 
            +
                  if with_conv:
         | 
| 101 | 
            +
                    self.Conv_0 = conv3x3(in_ch, out_ch)
         | 
| 102 | 
            +
                else:
         | 
| 103 | 
            +
                  if with_conv:
         | 
| 104 | 
            +
                    self.Conv2d_0 = up_or_down_sampling.Conv2d(in_ch, out_ch,
         | 
| 105 | 
            +
                                                             kernel=3, up=True,
         | 
| 106 | 
            +
                                                             resample_kernel=fir_kernel,
         | 
| 107 | 
            +
                                                             use_bias=True,
         | 
| 108 | 
            +
                                                             kernel_init=default_init())
         | 
| 109 | 
            +
                self.fir = fir
         | 
| 110 | 
            +
                self.with_conv = with_conv
         | 
| 111 | 
            +
                self.fir_kernel = fir_kernel
         | 
| 112 | 
            +
                self.out_ch = out_ch
         | 
| 113 | 
            +
             | 
| 114 | 
            +
              def forward(self, x):
         | 
| 115 | 
            +
                B, C, H, W = x.shape
         | 
| 116 | 
            +
                if not self.fir:
         | 
| 117 | 
            +
                  h = F.interpolate(x, (H * 2, W * 2), 'nearest')
         | 
| 118 | 
            +
                  if self.with_conv:
         | 
| 119 | 
            +
                    h = self.Conv_0(h)
         | 
| 120 | 
            +
                else:
         | 
| 121 | 
            +
                  if not self.with_conv:
         | 
| 122 | 
            +
                    h = up_or_down_sampling.upsample_2d(x, self.fir_kernel, factor=2)
         | 
| 123 | 
            +
                  else:
         | 
| 124 | 
            +
                    h = self.Conv2d_0(x)
         | 
| 125 | 
            +
             | 
| 126 | 
            +
                return h
         | 
| 127 | 
            +
             | 
| 128 | 
            +
             | 
| 129 | 
            +
            class Downsample(nn.Module):
         | 
| 130 | 
            +
              def __init__(self, in_ch=None, out_ch=None, with_conv=False, fir=False,
         | 
| 131 | 
            +
                           fir_kernel=(1, 3, 3, 1)):
         | 
| 132 | 
            +
                super().__init__()
         | 
| 133 | 
            +
                out_ch = out_ch if out_ch else in_ch
         | 
| 134 | 
            +
                if not fir:
         | 
| 135 | 
            +
                  if with_conv:
         | 
| 136 | 
            +
                    self.Conv_0 = conv3x3(in_ch, out_ch, stride=2, padding=0)
         | 
| 137 | 
            +
                else:
         | 
| 138 | 
            +
                  if with_conv:
         | 
| 139 | 
            +
                    self.Conv2d_0 = up_or_down_sampling.Conv2d(in_ch, out_ch,
         | 
| 140 | 
            +
                                                             kernel=3, down=True,
         | 
| 141 | 
            +
                                                             resample_kernel=fir_kernel,
         | 
| 142 | 
            +
                                                             use_bias=True,
         | 
| 143 | 
            +
                                                             kernel_init=default_init())
         | 
| 144 | 
            +
                self.fir = fir
         | 
| 145 | 
            +
                self.fir_kernel = fir_kernel
         | 
| 146 | 
            +
                self.with_conv = with_conv
         | 
| 147 | 
            +
                self.out_ch = out_ch
         | 
| 148 | 
            +
             | 
| 149 | 
            +
              def forward(self, x):
         | 
| 150 | 
            +
                B, C, H, W = x.shape
         | 
| 151 | 
            +
                if not self.fir:
         | 
| 152 | 
            +
                  if self.with_conv:
         | 
| 153 | 
            +
                    x = F.pad(x, (0, 1, 0, 1))
         | 
| 154 | 
            +
                    x = self.Conv_0(x)
         | 
| 155 | 
            +
                  else:
         | 
| 156 | 
            +
                    x = F.avg_pool2d(x, 2, stride=2)
         | 
| 157 | 
            +
                else:
         | 
| 158 | 
            +
                  if not self.with_conv:
         | 
| 159 | 
            +
                    x = up_or_down_sampling.downsample_2d(x, self.fir_kernel, factor=2)
         | 
| 160 | 
            +
                  else:
         | 
| 161 | 
            +
                    x = self.Conv2d_0(x)
         | 
| 162 | 
            +
             | 
| 163 | 
            +
                return x
         | 
| 164 | 
            +
             | 
| 165 | 
            +
             | 
| 166 | 
            +
            class ResnetBlockDDPMpp(nn.Module):
         | 
| 167 | 
            +
              """ResBlock adapted from DDPM."""
         | 
| 168 | 
            +
             | 
| 169 | 
            +
              def __init__(self, act, in_ch, out_ch=None, temb_dim=None, conv_shortcut=False,
         | 
| 170 | 
            +
                           dropout=0.1, skip_rescale=False, init_scale=0.):
         | 
| 171 | 
            +
                super().__init__()
         | 
| 172 | 
            +
                out_ch = out_ch if out_ch else in_ch
         | 
| 173 | 
            +
                self.GroupNorm_0 = nn.GroupNorm(num_groups=min(in_ch // 4, 32), num_channels=in_ch, eps=1e-6)
         | 
| 174 | 
            +
                self.Conv_0 = conv3x3(in_ch, out_ch)
         | 
| 175 | 
            +
                if temb_dim is not None:
         | 
| 176 | 
            +
                  self.Dense_0 = nn.Linear(temb_dim, out_ch)
         | 
| 177 | 
            +
                  self.Dense_0.weight.data = default_init()(self.Dense_0.weight.data.shape)
         | 
| 178 | 
            +
                  nn.init.zeros_(self.Dense_0.bias)
         | 
| 179 | 
            +
                self.GroupNorm_1 = nn.GroupNorm(num_groups=min(out_ch // 4, 32), num_channels=out_ch, eps=1e-6)
         | 
| 180 | 
            +
                self.Dropout_0 = nn.Dropout(dropout)
         | 
| 181 | 
            +
                self.Conv_1 = conv3x3(out_ch, out_ch, init_scale=init_scale)
         | 
| 182 | 
            +
                if in_ch != out_ch:
         | 
| 183 | 
            +
                  if conv_shortcut:
         | 
| 184 | 
            +
                    self.Conv_2 = conv3x3(in_ch, out_ch)
         | 
| 185 | 
            +
                  else:
         | 
| 186 | 
            +
                    self.NIN_0 = NIN(in_ch, out_ch)
         | 
| 187 | 
            +
             | 
| 188 | 
            +
                self.skip_rescale = skip_rescale
         | 
| 189 | 
            +
                self.act = act
         | 
| 190 | 
            +
                self.out_ch = out_ch
         | 
| 191 | 
            +
                self.conv_shortcut = conv_shortcut
         | 
| 192 | 
            +
             | 
| 193 | 
            +
              def forward(self, x, temb=None):
         | 
| 194 | 
            +
                h = self.act(self.GroupNorm_0(x))
         | 
| 195 | 
            +
                h = self.Conv_0(h)
         | 
| 196 | 
            +
                if temb is not None:
         | 
| 197 | 
            +
                  h += self.Dense_0(self.act(temb))[:, :, None, None]
         | 
| 198 | 
            +
                h = self.act(self.GroupNorm_1(h))
         | 
| 199 | 
            +
                h = self.Dropout_0(h)
         | 
| 200 | 
            +
                h = self.Conv_1(h)
         | 
| 201 | 
            +
                if x.shape[1] != self.out_ch:
         | 
| 202 | 
            +
                  if self.conv_shortcut:
         | 
| 203 | 
            +
                    x = self.Conv_2(x)
         | 
| 204 | 
            +
                  else:
         | 
| 205 | 
            +
                    x = self.NIN_0(x)
         | 
| 206 | 
            +
                if not self.skip_rescale:
         | 
| 207 | 
            +
                  return x + h
         | 
| 208 | 
            +
                else:
         | 
| 209 | 
            +
                  return (x + h) / np.sqrt(2.)
         | 
| 210 | 
            +
             | 
| 211 | 
            +
             | 
| 212 | 
            +
            class ResnetBlockBigGANpp(nn.Module):
         | 
| 213 | 
            +
              def __init__(self, act, in_ch, out_ch=None, temb_dim=None, up=False, down=False,
         | 
| 214 | 
            +
                           dropout=0.1, fir=False, fir_kernel=(1, 3, 3, 1),
         | 
| 215 | 
            +
                           skip_rescale=True, init_scale=0.):
         | 
| 216 | 
            +
                super().__init__()
         | 
| 217 | 
            +
             | 
| 218 | 
            +
                out_ch = out_ch if out_ch else in_ch
         | 
| 219 | 
            +
                self.GroupNorm_0 = nn.GroupNorm(num_groups=min(in_ch // 4, 32), num_channels=in_ch, eps=1e-6)
         | 
| 220 | 
            +
                self.up = up
         | 
| 221 | 
            +
                self.down = down
         | 
| 222 | 
            +
                self.fir = fir
         | 
| 223 | 
            +
                self.fir_kernel = fir_kernel
         | 
| 224 | 
            +
             | 
| 225 | 
            +
                self.Conv_0 = conv3x3(in_ch, out_ch)
         | 
| 226 | 
            +
                if temb_dim is not None:
         | 
| 227 | 
            +
                  self.Dense_0 = nn.Linear(temb_dim, out_ch)
         | 
| 228 | 
            +
                  self.Dense_0.weight.data = default_init()(self.Dense_0.weight.shape)
         | 
| 229 | 
            +
                  nn.init.zeros_(self.Dense_0.bias)
         | 
| 230 | 
            +
             | 
| 231 | 
            +
                self.GroupNorm_1 = nn.GroupNorm(num_groups=min(out_ch // 4, 32), num_channels=out_ch, eps=1e-6)
         | 
| 232 | 
            +
                self.Dropout_0 = nn.Dropout(dropout)
         | 
| 233 | 
            +
                self.Conv_1 = conv3x3(out_ch, out_ch, init_scale=init_scale)
         | 
| 234 | 
            +
                if in_ch != out_ch or up or down:
         | 
| 235 | 
            +
                  self.Conv_2 = conv1x1(in_ch, out_ch)
         | 
| 236 | 
            +
             | 
| 237 | 
            +
                self.skip_rescale = skip_rescale
         | 
| 238 | 
            +
                self.act = act
         | 
| 239 | 
            +
                self.in_ch = in_ch
         | 
| 240 | 
            +
                self.out_ch = out_ch
         | 
| 241 | 
            +
             | 
| 242 | 
            +
              def forward(self, x, temb=None):
         | 
| 243 | 
            +
                h = self.act(self.GroupNorm_0(x))
         | 
| 244 | 
            +
             | 
| 245 | 
            +
                if self.up:
         | 
| 246 | 
            +
                  if self.fir:
         | 
| 247 | 
            +
                    h = up_or_down_sampling.upsample_2d(h, self.fir_kernel, factor=2)
         | 
| 248 | 
            +
                    x = up_or_down_sampling.upsample_2d(x, self.fir_kernel, factor=2)
         | 
| 249 | 
            +
                  else:
         | 
| 250 | 
            +
                    h = up_or_down_sampling.naive_upsample_2d(h, factor=2)
         | 
| 251 | 
            +
                    x = up_or_down_sampling.naive_upsample_2d(x, factor=2)
         | 
| 252 | 
            +
                elif self.down:
         | 
| 253 | 
            +
                  if self.fir:
         | 
| 254 | 
            +
                    h = up_or_down_sampling.downsample_2d(h, self.fir_kernel, factor=2)
         | 
| 255 | 
            +
                    x = up_or_down_sampling.downsample_2d(x, self.fir_kernel, factor=2)
         | 
| 256 | 
            +
                  else:
         | 
| 257 | 
            +
                    h = up_or_down_sampling.naive_downsample_2d(h, factor=2)
         | 
| 258 | 
            +
                    x = up_or_down_sampling.naive_downsample_2d(x, factor=2)
         | 
| 259 | 
            +
             | 
| 260 | 
            +
                h = self.Conv_0(h)
         | 
| 261 | 
            +
                # Add bias to each feature map conditioned on the time embedding
         | 
| 262 | 
            +
                if temb is not None:
         | 
| 263 | 
            +
                  h += self.Dense_0(self.act(temb))[:, :, None, None]
         | 
| 264 | 
            +
                h = self.act(self.GroupNorm_1(h))
         | 
| 265 | 
            +
                h = self.Dropout_0(h)
         | 
| 266 | 
            +
                h = self.Conv_1(h)
         | 
| 267 | 
            +
             | 
| 268 | 
            +
                if self.in_ch != self.out_ch or self.up or self.down:
         | 
| 269 | 
            +
                  x = self.Conv_2(x)
         | 
| 270 | 
            +
             | 
| 271 | 
            +
                if not self.skip_rescale:
         | 
| 272 | 
            +
                  return x + h
         | 
| 273 | 
            +
                else:
         | 
| 274 | 
            +
                  return (x + h) / np.sqrt(2.)
         | 
    	
        sgmse/backbones/ncsnpp_utils/normalization.py
    ADDED
    
    | @@ -0,0 +1,215 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            # coding=utf-8
         | 
| 2 | 
            +
            # Copyright 2020 The Google Research Authors.
         | 
| 3 | 
            +
            #
         | 
| 4 | 
            +
            # Licensed under the Apache License, Version 2.0 (the "License");
         | 
| 5 | 
            +
            # you may not use this file except in compliance with the License.
         | 
| 6 | 
            +
            # You may obtain a copy of the License at
         | 
| 7 | 
            +
            #
         | 
| 8 | 
            +
            #     http://www.apache.org/licenses/LICENSE-2.0
         | 
| 9 | 
            +
            #
         | 
| 10 | 
            +
            # Unless required by applicable law or agreed to in writing, software
         | 
| 11 | 
            +
            # distributed under the License is distributed on an "AS IS" BASIS,
         | 
| 12 | 
            +
            # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
         | 
| 13 | 
            +
            # See the License for the specific language governing permissions and
         | 
| 14 | 
            +
            # limitations under the License.
         | 
| 15 | 
            +
             | 
| 16 | 
            +
            """Normalization layers."""
         | 
| 17 | 
            +
            import torch.nn as nn
         | 
| 18 | 
            +
            import torch
         | 
| 19 | 
            +
            import functools
         | 
| 20 | 
            +
             | 
| 21 | 
            +
             | 
| 22 | 
            +
            def get_normalization(config, conditional=False):
         | 
| 23 | 
            +
              """Obtain normalization modules from the config file."""
         | 
| 24 | 
            +
              norm = config.model.normalization
         | 
| 25 | 
            +
              if conditional:
         | 
| 26 | 
            +
                if norm == 'InstanceNorm++':
         | 
| 27 | 
            +
                  return functools.partial(ConditionalInstanceNorm2dPlus, num_classes=config.model.num_classes)
         | 
| 28 | 
            +
                else:
         | 
| 29 | 
            +
                  raise NotImplementedError(f'{norm} not implemented yet.')
         | 
| 30 | 
            +
              else:
         | 
| 31 | 
            +
                if norm == 'InstanceNorm':
         | 
| 32 | 
            +
                  return nn.InstanceNorm2d
         | 
| 33 | 
            +
                elif norm == 'InstanceNorm++':
         | 
| 34 | 
            +
                  return InstanceNorm2dPlus
         | 
| 35 | 
            +
                elif norm == 'VarianceNorm':
         | 
| 36 | 
            +
                  return VarianceNorm2d
         | 
| 37 | 
            +
                elif norm == 'GroupNorm':
         | 
| 38 | 
            +
                  return nn.GroupNorm
         | 
| 39 | 
            +
                else:
         | 
| 40 | 
            +
                  raise ValueError('Unknown normalization: %s' % norm)
         | 
| 41 | 
            +
             | 
| 42 | 
            +
             | 
| 43 | 
            +
            class ConditionalBatchNorm2d(nn.Module):
         | 
| 44 | 
            +
              def __init__(self, num_features, num_classes, bias=True):
         | 
| 45 | 
            +
                super().__init__()
         | 
| 46 | 
            +
                self.num_features = num_features
         | 
| 47 | 
            +
                self.bias = bias
         | 
| 48 | 
            +
                self.bn = nn.BatchNorm2d(num_features, affine=False)
         | 
| 49 | 
            +
                if self.bias:
         | 
| 50 | 
            +
                  self.embed = nn.Embedding(num_classes, num_features * 2)
         | 
| 51 | 
            +
                  self.embed.weight.data[:, :num_features].uniform_()  # Initialise scale at N(1, 0.02)
         | 
| 52 | 
            +
                  self.embed.weight.data[:, num_features:].zero_()  # Initialise bias at 0
         | 
| 53 | 
            +
                else:
         | 
| 54 | 
            +
                  self.embed = nn.Embedding(num_classes, num_features)
         | 
| 55 | 
            +
                  self.embed.weight.data.uniform_()
         | 
| 56 | 
            +
             | 
| 57 | 
            +
              def forward(self, x, y):
         | 
| 58 | 
            +
                out = self.bn(x)
         | 
| 59 | 
            +
                if self.bias:
         | 
| 60 | 
            +
                  gamma, beta = self.embed(y).chunk(2, dim=1)
         | 
| 61 | 
            +
                  out = gamma.view(-1, self.num_features, 1, 1) * out + beta.view(-1, self.num_features, 1, 1)
         | 
| 62 | 
            +
                else:
         | 
| 63 | 
            +
                  gamma = self.embed(y)
         | 
| 64 | 
            +
                  out = gamma.view(-1, self.num_features, 1, 1) * out
         | 
| 65 | 
            +
                return out
         | 
| 66 | 
            +
             | 
| 67 | 
            +
             | 
| 68 | 
            +
            class ConditionalInstanceNorm2d(nn.Module):
         | 
| 69 | 
            +
              def __init__(self, num_features, num_classes, bias=True):
         | 
| 70 | 
            +
                super().__init__()
         | 
| 71 | 
            +
                self.num_features = num_features
         | 
| 72 | 
            +
                self.bias = bias
         | 
| 73 | 
            +
                self.instance_norm = nn.InstanceNorm2d(num_features, affine=False, track_running_stats=False)
         | 
| 74 | 
            +
                if bias:
         | 
| 75 | 
            +
                  self.embed = nn.Embedding(num_classes, num_features * 2)
         | 
| 76 | 
            +
                  self.embed.weight.data[:, :num_features].uniform_()  # Initialise scale at N(1, 0.02)
         | 
| 77 | 
            +
                  self.embed.weight.data[:, num_features:].zero_()  # Initialise bias at 0
         | 
| 78 | 
            +
                else:
         | 
| 79 | 
            +
                  self.embed = nn.Embedding(num_classes, num_features)
         | 
| 80 | 
            +
                  self.embed.weight.data.uniform_()
         | 
| 81 | 
            +
             | 
| 82 | 
            +
              def forward(self, x, y):
         | 
| 83 | 
            +
                h = self.instance_norm(x)
         | 
| 84 | 
            +
                if self.bias:
         | 
| 85 | 
            +
                  gamma, beta = self.embed(y).chunk(2, dim=-1)
         | 
| 86 | 
            +
                  out = gamma.view(-1, self.num_features, 1, 1) * h + beta.view(-1, self.num_features, 1, 1)
         | 
| 87 | 
            +
                else:
         | 
| 88 | 
            +
                  gamma = self.embed(y)
         | 
| 89 | 
            +
                  out = gamma.view(-1, self.num_features, 1, 1) * h
         | 
| 90 | 
            +
                return out
         | 
| 91 | 
            +
             | 
| 92 | 
            +
             | 
| 93 | 
            +
            class ConditionalVarianceNorm2d(nn.Module):
         | 
| 94 | 
            +
              def __init__(self, num_features, num_classes, bias=False):
         | 
| 95 | 
            +
                super().__init__()
         | 
| 96 | 
            +
                self.num_features = num_features
         | 
| 97 | 
            +
                self.bias = bias
         | 
| 98 | 
            +
                self.embed = nn.Embedding(num_classes, num_features)
         | 
| 99 | 
            +
                self.embed.weight.data.normal_(1, 0.02)
         | 
| 100 | 
            +
             | 
| 101 | 
            +
              def forward(self, x, y):
         | 
| 102 | 
            +
                vars = torch.var(x, dim=(2, 3), keepdim=True)
         | 
| 103 | 
            +
                h = x / torch.sqrt(vars + 1e-5)
         | 
| 104 | 
            +
             | 
| 105 | 
            +
                gamma = self.embed(y)
         | 
| 106 | 
            +
                out = gamma.view(-1, self.num_features, 1, 1) * h
         | 
| 107 | 
            +
                return out
         | 
| 108 | 
            +
             | 
| 109 | 
            +
             | 
| 110 | 
            +
            class VarianceNorm2d(nn.Module):
         | 
| 111 | 
            +
              def __init__(self, num_features, bias=False):
         | 
| 112 | 
            +
                super().__init__()
         | 
| 113 | 
            +
                self.num_features = num_features
         | 
| 114 | 
            +
                self.bias = bias
         | 
| 115 | 
            +
                self.alpha = nn.Parameter(torch.zeros(num_features))
         | 
| 116 | 
            +
                self.alpha.data.normal_(1, 0.02)
         | 
| 117 | 
            +
             | 
| 118 | 
            +
              def forward(self, x):
         | 
| 119 | 
            +
                vars = torch.var(x, dim=(2, 3), keepdim=True)
         | 
| 120 | 
            +
                h = x / torch.sqrt(vars + 1e-5)
         | 
| 121 | 
            +
             | 
| 122 | 
            +
                out = self.alpha.view(-1, self.num_features, 1, 1) * h
         | 
| 123 | 
            +
                return out
         | 
| 124 | 
            +
             | 
| 125 | 
            +
             | 
| 126 | 
            +
            class ConditionalNoneNorm2d(nn.Module):
         | 
| 127 | 
            +
              def __init__(self, num_features, num_classes, bias=True):
         | 
| 128 | 
            +
                super().__init__()
         | 
| 129 | 
            +
                self.num_features = num_features
         | 
| 130 | 
            +
                self.bias = bias
         | 
| 131 | 
            +
                if bias:
         | 
| 132 | 
            +
                  self.embed = nn.Embedding(num_classes, num_features * 2)
         | 
| 133 | 
            +
                  self.embed.weight.data[:, :num_features].uniform_()  # Initialise scale at N(1, 0.02)
         | 
| 134 | 
            +
                  self.embed.weight.data[:, num_features:].zero_()  # Initialise bias at 0
         | 
| 135 | 
            +
                else:
         | 
| 136 | 
            +
                  self.embed = nn.Embedding(num_classes, num_features)
         | 
| 137 | 
            +
                  self.embed.weight.data.uniform_()
         | 
| 138 | 
            +
             | 
| 139 | 
            +
              def forward(self, x, y):
         | 
| 140 | 
            +
                if self.bias:
         | 
| 141 | 
            +
                  gamma, beta = self.embed(y).chunk(2, dim=-1)
         | 
| 142 | 
            +
                  out = gamma.view(-1, self.num_features, 1, 1) * x + beta.view(-1, self.num_features, 1, 1)
         | 
| 143 | 
            +
                else:
         | 
| 144 | 
            +
                  gamma = self.embed(y)
         | 
| 145 | 
            +
                  out = gamma.view(-1, self.num_features, 1, 1) * x
         | 
| 146 | 
            +
                return out
         | 
| 147 | 
            +
             | 
| 148 | 
            +
             | 
| 149 | 
            +
            class NoneNorm2d(nn.Module):
         | 
| 150 | 
            +
              def __init__(self, num_features, bias=True):
         | 
| 151 | 
            +
                super().__init__()
         | 
| 152 | 
            +
             | 
| 153 | 
            +
              def forward(self, x):
         | 
| 154 | 
            +
                return x
         | 
| 155 | 
            +
             | 
| 156 | 
            +
             | 
| 157 | 
            +
            class InstanceNorm2dPlus(nn.Module):
         | 
| 158 | 
            +
              def __init__(self, num_features, bias=True):
         | 
| 159 | 
            +
                super().__init__()
         | 
| 160 | 
            +
                self.num_features = num_features
         | 
| 161 | 
            +
                self.bias = bias
         | 
| 162 | 
            +
                self.instance_norm = nn.InstanceNorm2d(num_features, affine=False, track_running_stats=False)
         | 
| 163 | 
            +
                self.alpha = nn.Parameter(torch.zeros(num_features))
         | 
| 164 | 
            +
                self.gamma = nn.Parameter(torch.zeros(num_features))
         | 
| 165 | 
            +
                self.alpha.data.normal_(1, 0.02)
         | 
| 166 | 
            +
                self.gamma.data.normal_(1, 0.02)
         | 
| 167 | 
            +
                if bias:
         | 
| 168 | 
            +
                  self.beta = nn.Parameter(torch.zeros(num_features))
         | 
| 169 | 
            +
             | 
| 170 | 
            +
              def forward(self, x):
         | 
| 171 | 
            +
                means = torch.mean(x, dim=(2, 3))
         | 
| 172 | 
            +
                m = torch.mean(means, dim=-1, keepdim=True)
         | 
| 173 | 
            +
                v = torch.var(means, dim=-1, keepdim=True)
         | 
| 174 | 
            +
                means = (means - m) / (torch.sqrt(v + 1e-5))
         | 
| 175 | 
            +
                h = self.instance_norm(x)
         | 
| 176 | 
            +
             | 
| 177 | 
            +
                if self.bias:
         | 
| 178 | 
            +
                  h = h + means[..., None, None] * self.alpha[..., None, None]
         | 
| 179 | 
            +
                  out = self.gamma.view(-1, self.num_features, 1, 1) * h + self.beta.view(-1, self.num_features, 1, 1)
         | 
| 180 | 
            +
                else:
         | 
| 181 | 
            +
                  h = h + means[..., None, None] * self.alpha[..., None, None]
         | 
| 182 | 
            +
                  out = self.gamma.view(-1, self.num_features, 1, 1) * h
         | 
| 183 | 
            +
                return out
         | 
| 184 | 
            +
             | 
| 185 | 
            +
             | 
| 186 | 
            +
            class ConditionalInstanceNorm2dPlus(nn.Module):
         | 
| 187 | 
            +
              def __init__(self, num_features, num_classes, bias=True):
         | 
| 188 | 
            +
                super().__init__()
         | 
| 189 | 
            +
                self.num_features = num_features
         | 
| 190 | 
            +
                self.bias = bias
         | 
| 191 | 
            +
                self.instance_norm = nn.InstanceNorm2d(num_features, affine=False, track_running_stats=False)
         | 
| 192 | 
            +
                if bias:
         | 
| 193 | 
            +
                  self.embed = nn.Embedding(num_classes, num_features * 3)
         | 
| 194 | 
            +
                  self.embed.weight.data[:, :2 * num_features].normal_(1, 0.02)  # Initialise scale at N(1, 0.02)
         | 
| 195 | 
            +
                  self.embed.weight.data[:, 2 * num_features:].zero_()  # Initialise bias at 0
         | 
| 196 | 
            +
                else:
         | 
| 197 | 
            +
                  self.embed = nn.Embedding(num_classes, 2 * num_features)
         | 
| 198 | 
            +
                  self.embed.weight.data.normal_(1, 0.02)
         | 
| 199 | 
            +
             | 
| 200 | 
            +
              def forward(self, x, y):
         | 
| 201 | 
            +
                means = torch.mean(x, dim=(2, 3))
         | 
| 202 | 
            +
                m = torch.mean(means, dim=-1, keepdim=True)
         | 
| 203 | 
            +
                v = torch.var(means, dim=-1, keepdim=True)
         | 
| 204 | 
            +
                means = (means - m) / (torch.sqrt(v + 1e-5))
         | 
| 205 | 
            +
                h = self.instance_norm(x)
         | 
| 206 | 
            +
             | 
| 207 | 
            +
                if self.bias:
         | 
| 208 | 
            +
                  gamma, alpha, beta = self.embed(y).chunk(3, dim=-1)
         | 
| 209 | 
            +
                  h = h + means[..., None, None] * alpha[..., None, None]
         | 
| 210 | 
            +
                  out = gamma.view(-1, self.num_features, 1, 1) * h + beta.view(-1, self.num_features, 1, 1)
         | 
| 211 | 
            +
                else:
         | 
| 212 | 
            +
                  gamma, alpha = self.embed(y).chunk(2, dim=-1)
         | 
| 213 | 
            +
                  h = h + means[..., None, None] * alpha[..., None, None]
         | 
| 214 | 
            +
                  out = gamma.view(-1, self.num_features, 1, 1) * h
         | 
| 215 | 
            +
                return out
         | 
    	
        sgmse/backbones/ncsnpp_utils/op/__init__.py
    ADDED
    
    | @@ -0,0 +1 @@ | |
|  | 
|  | |
| 1 | 
            +
            from .upfirdn2d import upfirdn2d
         | 
    	
        sgmse/backbones/ncsnpp_utils/op/fused_act.py
    ADDED
    
    | @@ -0,0 +1,97 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import os
         | 
| 2 | 
            +
             | 
| 3 | 
            +
            import torch
         | 
| 4 | 
            +
            from torch import nn
         | 
| 5 | 
            +
            from torch.nn import functional as F
         | 
| 6 | 
            +
            from torch.autograd import Function
         | 
| 7 | 
            +
            from torch.utils.cpp_extension import load
         | 
| 8 | 
            +
             | 
| 9 | 
            +
             | 
| 10 | 
            +
            module_path = os.path.dirname(__file__)
         | 
| 11 | 
            +
            fused = load(
         | 
| 12 | 
            +
                "fused",
         | 
| 13 | 
            +
                sources=[
         | 
| 14 | 
            +
                    os.path.join(module_path, "fused_bias_act.cpp"),
         | 
| 15 | 
            +
                    os.path.join(module_path, "fused_bias_act_kernel.cu"),
         | 
| 16 | 
            +
                ],
         | 
| 17 | 
            +
            )
         | 
| 18 | 
            +
             | 
| 19 | 
            +
             | 
| 20 | 
            +
            class FusedLeakyReLUFunctionBackward(Function):
         | 
| 21 | 
            +
                @staticmethod
         | 
| 22 | 
            +
                def forward(ctx, grad_output, out, negative_slope, scale):
         | 
| 23 | 
            +
                    ctx.save_for_backward(out)
         | 
| 24 | 
            +
                    ctx.negative_slope = negative_slope
         | 
| 25 | 
            +
                    ctx.scale = scale
         | 
| 26 | 
            +
             | 
| 27 | 
            +
                    empty = grad_output.new_empty(0)
         | 
| 28 | 
            +
             | 
| 29 | 
            +
                    grad_input = fused.fused_bias_act(
         | 
| 30 | 
            +
                        grad_output, empty, out, 3, 1, negative_slope, scale
         | 
| 31 | 
            +
                    )
         | 
| 32 | 
            +
             | 
| 33 | 
            +
                    dim = [0]
         | 
| 34 | 
            +
             | 
| 35 | 
            +
                    if grad_input.ndim > 2:
         | 
| 36 | 
            +
                        dim += list(range(2, grad_input.ndim))
         | 
| 37 | 
            +
             | 
| 38 | 
            +
                    grad_bias = grad_input.sum(dim).detach()
         | 
| 39 | 
            +
             | 
| 40 | 
            +
                    return grad_input, grad_bias
         | 
| 41 | 
            +
             | 
| 42 | 
            +
                @staticmethod
         | 
| 43 | 
            +
                def backward(ctx, gradgrad_input, gradgrad_bias):
         | 
| 44 | 
            +
                    out, = ctx.saved_tensors
         | 
| 45 | 
            +
                    gradgrad_out = fused.fused_bias_act(
         | 
| 46 | 
            +
                        gradgrad_input, gradgrad_bias, out, 3, 1, ctx.negative_slope, ctx.scale
         | 
| 47 | 
            +
                    )
         | 
| 48 | 
            +
             | 
| 49 | 
            +
                    return gradgrad_out, None, None, None
         | 
| 50 | 
            +
             | 
| 51 | 
            +
             | 
| 52 | 
            +
            class FusedLeakyReLUFunction(Function):
         | 
| 53 | 
            +
                @staticmethod
         | 
| 54 | 
            +
                def forward(ctx, input, bias, negative_slope, scale):
         | 
| 55 | 
            +
                    empty = input.new_empty(0)
         | 
| 56 | 
            +
                    out = fused.fused_bias_act(input, bias, empty, 3, 0, negative_slope, scale)
         | 
| 57 | 
            +
                    ctx.save_for_backward(out)
         | 
| 58 | 
            +
                    ctx.negative_slope = negative_slope
         | 
| 59 | 
            +
                    ctx.scale = scale
         | 
| 60 | 
            +
             | 
| 61 | 
            +
                    return out
         | 
| 62 | 
            +
             | 
| 63 | 
            +
                @staticmethod
         | 
| 64 | 
            +
                def backward(ctx, grad_output):
         | 
| 65 | 
            +
                    out, = ctx.saved_tensors
         | 
| 66 | 
            +
             | 
| 67 | 
            +
                    grad_input, grad_bias = FusedLeakyReLUFunctionBackward.apply(
         | 
| 68 | 
            +
                        grad_output, out, ctx.negative_slope, ctx.scale
         | 
| 69 | 
            +
                    )
         | 
| 70 | 
            +
             | 
| 71 | 
            +
                    return grad_input, grad_bias, None, None
         | 
| 72 | 
            +
             | 
| 73 | 
            +
             | 
| 74 | 
            +
            class FusedLeakyReLU(nn.Module):
         | 
| 75 | 
            +
                def __init__(self, channel, negative_slope=0.2, scale=2 ** 0.5):
         | 
| 76 | 
            +
                    super().__init__()
         | 
| 77 | 
            +
             | 
| 78 | 
            +
                    self.bias = nn.Parameter(torch.zeros(channel))
         | 
| 79 | 
            +
                    self.negative_slope = negative_slope
         | 
| 80 | 
            +
                    self.scale = scale
         | 
| 81 | 
            +
             | 
| 82 | 
            +
                def forward(self, input):
         | 
| 83 | 
            +
                    return fused_leaky_relu(input, self.bias, self.negative_slope, self.scale)
         | 
| 84 | 
            +
             | 
| 85 | 
            +
             | 
| 86 | 
            +
            def fused_leaky_relu(input, bias, negative_slope=0.2, scale=2 ** 0.5):
         | 
| 87 | 
            +
                if input.device.type == "cpu":
         | 
| 88 | 
            +
                    rest_dim = [1] * (input.ndim - bias.ndim - 1)
         | 
| 89 | 
            +
                    return (
         | 
| 90 | 
            +
                        F.leaky_relu(
         | 
| 91 | 
            +
                            input + bias.view(1, bias.shape[0], *rest_dim), negative_slope=0.2
         | 
| 92 | 
            +
                        )
         | 
| 93 | 
            +
                        * scale
         | 
| 94 | 
            +
                    )
         | 
| 95 | 
            +
             | 
| 96 | 
            +
                else:
         | 
| 97 | 
            +
                    return FusedLeakyReLUFunction.apply(input, bias, negative_slope, scale)
         | 
    	
        sgmse/backbones/ncsnpp_utils/op/fused_bias_act.cpp
    ADDED
    
    | @@ -0,0 +1,21 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            #include <torch/extension.h>
         | 
| 2 | 
            +
             | 
| 3 | 
            +
             | 
| 4 | 
            +
            torch::Tensor fused_bias_act_op(const torch::Tensor& input, const torch::Tensor& bias, const torch::Tensor& refer,
         | 
| 5 | 
            +
                int act, int grad, float alpha, float scale);
         | 
| 6 | 
            +
             | 
| 7 | 
            +
            #define CHECK_CUDA(x) TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor")
         | 
| 8 | 
            +
            #define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
         | 
| 9 | 
            +
            #define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)
         | 
| 10 | 
            +
             | 
| 11 | 
            +
            torch::Tensor fused_bias_act(const torch::Tensor& input, const torch::Tensor& bias, const torch::Tensor& refer,
         | 
| 12 | 
            +
                int act, int grad, float alpha, float scale) {
         | 
| 13 | 
            +
                CHECK_CUDA(input);
         | 
| 14 | 
            +
                CHECK_CUDA(bias);
         | 
| 15 | 
            +
             | 
| 16 | 
            +
                return fused_bias_act_op(input, bias, refer, act, grad, alpha, scale);
         | 
| 17 | 
            +
            }
         | 
| 18 | 
            +
             | 
| 19 | 
            +
            PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
         | 
| 20 | 
            +
                m.def("fused_bias_act", &fused_bias_act, "fused bias act (CUDA)");
         | 
| 21 | 
            +
            }
         | 
    	
        sgmse/backbones/ncsnpp_utils/op/fused_bias_act_kernel.cu
    ADDED
    
    | @@ -0,0 +1,99 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            // Copyright (c) 2019, NVIDIA Corporation. All rights reserved.
         | 
| 2 | 
            +
            //
         | 
| 3 | 
            +
            // This work is made available under the Nvidia Source Code License-NC.
         | 
| 4 | 
            +
            // To view a copy of this license, visit
         | 
| 5 | 
            +
            // https://nvlabs.github.io/stylegan2/license.html
         | 
| 6 | 
            +
             | 
| 7 | 
            +
            #include <torch/types.h>
         | 
| 8 | 
            +
             | 
| 9 | 
            +
            #include <ATen/ATen.h>
         | 
| 10 | 
            +
            #include <ATen/AccumulateType.h>
         | 
| 11 | 
            +
            #include <ATen/cuda/CUDAContext.h>
         | 
| 12 | 
            +
            #include <ATen/cuda/CUDAApplyUtils.cuh>
         | 
| 13 | 
            +
             | 
| 14 | 
            +
            #include <cuda.h>
         | 
| 15 | 
            +
            #include <cuda_runtime.h>
         | 
| 16 | 
            +
             | 
| 17 | 
            +
             | 
| 18 | 
            +
            template <typename scalar_t>
         | 
| 19 | 
            +
            static __global__ void fused_bias_act_kernel(scalar_t* out, const scalar_t* p_x, const scalar_t* p_b, const scalar_t* p_ref,
         | 
| 20 | 
            +
                int act, int grad, scalar_t alpha, scalar_t scale, int loop_x, int size_x, int step_b, int size_b, int use_bias, int use_ref) {
         | 
| 21 | 
            +
                int xi = blockIdx.x * loop_x * blockDim.x + threadIdx.x;
         | 
| 22 | 
            +
             | 
| 23 | 
            +
                scalar_t zero = 0.0;
         | 
| 24 | 
            +
             | 
| 25 | 
            +
                for (int loop_idx = 0; loop_idx < loop_x && xi < size_x; loop_idx++, xi += blockDim.x) {
         | 
| 26 | 
            +
                    scalar_t x = p_x[xi];
         | 
| 27 | 
            +
             | 
| 28 | 
            +
                    if (use_bias) {
         | 
| 29 | 
            +
                        x += p_b[(xi / step_b) % size_b];
         | 
| 30 | 
            +
                    }
         | 
| 31 | 
            +
             | 
| 32 | 
            +
                    scalar_t ref = use_ref ? p_ref[xi] : zero;
         | 
| 33 | 
            +
             | 
| 34 | 
            +
                    scalar_t y;
         | 
| 35 | 
            +
             | 
| 36 | 
            +
                    switch (act * 10 + grad) {
         | 
| 37 | 
            +
                        default:
         | 
| 38 | 
            +
                        case 10: y = x; break;
         | 
| 39 | 
            +
                        case 11: y = x; break;
         | 
| 40 | 
            +
                        case 12: y = 0.0; break;
         | 
| 41 | 
            +
             | 
| 42 | 
            +
                        case 30: y = (x > 0.0) ? x : x * alpha; break;
         | 
| 43 | 
            +
                        case 31: y = (ref > 0.0) ? x : x * alpha; break;
         | 
| 44 | 
            +
                        case 32: y = 0.0; break;
         | 
| 45 | 
            +
                    }
         | 
| 46 | 
            +
             | 
| 47 | 
            +
                    out[xi] = y * scale;
         | 
| 48 | 
            +
                }
         | 
| 49 | 
            +
            }
         | 
| 50 | 
            +
             | 
| 51 | 
            +
             | 
| 52 | 
            +
            torch::Tensor fused_bias_act_op(const torch::Tensor& input, const torch::Tensor& bias, const torch::Tensor& refer,
         | 
| 53 | 
            +
                int act, int grad, float alpha, float scale) {
         | 
| 54 | 
            +
                int curDevice = -1;
         | 
| 55 | 
            +
                cudaGetDevice(&curDevice);
         | 
| 56 | 
            +
                cudaStream_t stream = at::cuda::getCurrentCUDAStream(curDevice);
         | 
| 57 | 
            +
             | 
| 58 | 
            +
                auto x = input.contiguous();
         | 
| 59 | 
            +
                auto b = bias.contiguous();
         | 
| 60 | 
            +
                auto ref = refer.contiguous();
         | 
| 61 | 
            +
             | 
| 62 | 
            +
                int use_bias = b.numel() ? 1 : 0;
         | 
| 63 | 
            +
                int use_ref = ref.numel() ? 1 : 0;
         | 
| 64 | 
            +
             | 
| 65 | 
            +
                int size_x = x.numel();
         | 
| 66 | 
            +
                int size_b = b.numel();
         | 
| 67 | 
            +
                int step_b = 1;
         | 
| 68 | 
            +
             | 
| 69 | 
            +
                for (int i = 1 + 1; i < x.dim(); i++) {
         | 
| 70 | 
            +
                    step_b *= x.size(i);
         | 
| 71 | 
            +
                }
         | 
| 72 | 
            +
             | 
| 73 | 
            +
                int loop_x = 4;
         | 
| 74 | 
            +
                int block_size = 4 * 32;
         | 
| 75 | 
            +
                int grid_size = (size_x - 1) / (loop_x * block_size) + 1;
         | 
| 76 | 
            +
             | 
| 77 | 
            +
                auto y = torch::empty_like(x);
         | 
| 78 | 
            +
             | 
| 79 | 
            +
                AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "fused_bias_act_kernel", [&] {
         | 
| 80 | 
            +
                    fused_bias_act_kernel<scalar_t><<<grid_size, block_size, 0, stream>>>(
         | 
| 81 | 
            +
                        y.data_ptr<scalar_t>(),
         | 
| 82 | 
            +
                        x.data_ptr<scalar_t>(),
         | 
| 83 | 
            +
                        b.data_ptr<scalar_t>(),
         | 
| 84 | 
            +
                        ref.data_ptr<scalar_t>(),
         | 
| 85 | 
            +
                        act,
         | 
| 86 | 
            +
                        grad,
         | 
| 87 | 
            +
                        alpha,
         | 
| 88 | 
            +
                        scale,
         | 
| 89 | 
            +
                        loop_x,
         | 
| 90 | 
            +
                        size_x,
         | 
| 91 | 
            +
                        step_b,
         | 
| 92 | 
            +
                        size_b,
         | 
| 93 | 
            +
                        use_bias,
         | 
| 94 | 
            +
                        use_ref
         | 
| 95 | 
            +
                    );
         | 
| 96 | 
            +
                });
         | 
| 97 | 
            +
             | 
| 98 | 
            +
                return y;
         | 
| 99 | 
            +
            }
         | 
    	
        sgmse/backbones/ncsnpp_utils/op/upfirdn2d.cpp
    ADDED
    
    | @@ -0,0 +1,23 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            #include <torch/extension.h>
         | 
| 2 | 
            +
             | 
| 3 | 
            +
             | 
| 4 | 
            +
            torch::Tensor upfirdn2d_op(const torch::Tensor& input, const torch::Tensor& kernel,
         | 
| 5 | 
            +
                                        int up_x, int up_y, int down_x, int down_y,
         | 
| 6 | 
            +
                                        int pad_x0, int pad_x1, int pad_y0, int pad_y1);
         | 
| 7 | 
            +
             | 
| 8 | 
            +
            #define CHECK_CUDA(x) TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor")
         | 
| 9 | 
            +
            #define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
         | 
| 10 | 
            +
            #define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)
         | 
| 11 | 
            +
             | 
| 12 | 
            +
            torch::Tensor upfirdn2d(const torch::Tensor& input, const torch::Tensor& kernel,
         | 
| 13 | 
            +
                                    int up_x, int up_y, int down_x, int down_y,
         | 
| 14 | 
            +
                                    int pad_x0, int pad_x1, int pad_y0, int pad_y1) {
         | 
| 15 | 
            +
                CHECK_CUDA(input);
         | 
| 16 | 
            +
                CHECK_CUDA(kernel);
         | 
| 17 | 
            +
             | 
| 18 | 
            +
                return upfirdn2d_op(input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1);
         | 
| 19 | 
            +
            }
         | 
| 20 | 
            +
             | 
| 21 | 
            +
            PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
         | 
| 22 | 
            +
                m.def("upfirdn2d", &upfirdn2d, "upfirdn2d (CUDA)");
         | 
| 23 | 
            +
            }
         | 
    	
        sgmse/backbones/ncsnpp_utils/op/upfirdn2d.py
    ADDED
    
    | @@ -0,0 +1,203 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import os
         | 
| 2 | 
            +
             | 
| 3 | 
            +
            import torch
         | 
| 4 | 
            +
            from torch.nn import functional as F
         | 
| 5 | 
            +
            from torch.autograd import Function
         | 
| 6 | 
            +
            from torch.utils.cpp_extension import load
         | 
| 7 | 
            +
             | 
| 8 | 
            +
             | 
| 9 | 
            +
            module_path = os.path.dirname(__file__)
         | 
| 10 | 
            +
             | 
| 11 | 
            +
            if torch.cuda.is_available():
         | 
| 12 | 
            +
                upfirdn2d_op = load(
         | 
| 13 | 
            +
                    "upfirdn2d",
         | 
| 14 | 
            +
                    sources=[
         | 
| 15 | 
            +
                        os.path.join(module_path, "upfirdn2d.cpp"),
         | 
| 16 | 
            +
                        os.path.join(module_path, "upfirdn2d_kernel.cu"),
         | 
| 17 | 
            +
                    ],
         | 
| 18 | 
            +
                )
         | 
| 19 | 
            +
            else:
         | 
| 20 | 
            +
                upfirdn2d_op = None
         | 
| 21 | 
            +
             | 
| 22 | 
            +
            class UpFirDn2dBackward(Function):
         | 
| 23 | 
            +
                @staticmethod
         | 
| 24 | 
            +
                def forward(
         | 
| 25 | 
            +
                    ctx, grad_output, kernel, grad_kernel, up, down, pad, g_pad, in_size, out_size
         | 
| 26 | 
            +
                ):
         | 
| 27 | 
            +
             | 
| 28 | 
            +
                    up_x, up_y = up
         | 
| 29 | 
            +
                    down_x, down_y = down
         | 
| 30 | 
            +
                    g_pad_x0, g_pad_x1, g_pad_y0, g_pad_y1 = g_pad
         | 
| 31 | 
            +
             | 
| 32 | 
            +
                    grad_output = grad_output.reshape(-1, out_size[0], out_size[1], 1)
         | 
| 33 | 
            +
             | 
| 34 | 
            +
                    grad_input = upfirdn2d_op.upfirdn2d(
         | 
| 35 | 
            +
                        grad_output,
         | 
| 36 | 
            +
                        grad_kernel,
         | 
| 37 | 
            +
                        down_x,
         | 
| 38 | 
            +
                        down_y,
         | 
| 39 | 
            +
                        up_x,
         | 
| 40 | 
            +
                        up_y,
         | 
| 41 | 
            +
                        g_pad_x0,
         | 
| 42 | 
            +
                        g_pad_x1,
         | 
| 43 | 
            +
                        g_pad_y0,
         | 
| 44 | 
            +
                        g_pad_y1,
         | 
| 45 | 
            +
                    )
         | 
| 46 | 
            +
                    grad_input = grad_input.view(in_size[0], in_size[1], in_size[2], in_size[3])
         | 
| 47 | 
            +
             | 
| 48 | 
            +
                    ctx.save_for_backward(kernel)
         | 
| 49 | 
            +
             | 
| 50 | 
            +
                    pad_x0, pad_x1, pad_y0, pad_y1 = pad
         | 
| 51 | 
            +
             | 
| 52 | 
            +
                    ctx.up_x = up_x
         | 
| 53 | 
            +
                    ctx.up_y = up_y
         | 
| 54 | 
            +
                    ctx.down_x = down_x
         | 
| 55 | 
            +
                    ctx.down_y = down_y
         | 
| 56 | 
            +
                    ctx.pad_x0 = pad_x0
         | 
| 57 | 
            +
                    ctx.pad_x1 = pad_x1
         | 
| 58 | 
            +
                    ctx.pad_y0 = pad_y0
         | 
| 59 | 
            +
                    ctx.pad_y1 = pad_y1
         | 
| 60 | 
            +
                    ctx.in_size = in_size
         | 
| 61 | 
            +
                    ctx.out_size = out_size
         | 
| 62 | 
            +
             | 
| 63 | 
            +
                    return grad_input
         | 
| 64 | 
            +
             | 
| 65 | 
            +
                @staticmethod
         | 
| 66 | 
            +
                def backward(ctx, gradgrad_input):
         | 
| 67 | 
            +
                    kernel, = ctx.saved_tensors
         | 
| 68 | 
            +
             | 
| 69 | 
            +
                    gradgrad_input = gradgrad_input.reshape(-1, ctx.in_size[2], ctx.in_size[3], 1)
         | 
| 70 | 
            +
             | 
| 71 | 
            +
                    gradgrad_out = upfirdn2d_op.upfirdn2d(
         | 
| 72 | 
            +
                        gradgrad_input,
         | 
| 73 | 
            +
                        kernel,
         | 
| 74 | 
            +
                        ctx.up_x,
         | 
| 75 | 
            +
                        ctx.up_y,
         | 
| 76 | 
            +
                        ctx.down_x,
         | 
| 77 | 
            +
                        ctx.down_y,
         | 
| 78 | 
            +
                        ctx.pad_x0,
         | 
| 79 | 
            +
                        ctx.pad_x1,
         | 
| 80 | 
            +
                        ctx.pad_y0,
         | 
| 81 | 
            +
                        ctx.pad_y1,
         | 
| 82 | 
            +
                    )
         | 
| 83 | 
            +
                    # gradgrad_out = gradgrad_out.view(ctx.in_size[0], ctx.out_size[0], ctx.out_size[1], ctx.in_size[3])
         | 
| 84 | 
            +
                    gradgrad_out = gradgrad_out.view(
         | 
| 85 | 
            +
                        ctx.in_size[0], ctx.in_size[1], ctx.out_size[0], ctx.out_size[1]
         | 
| 86 | 
            +
                    )
         | 
| 87 | 
            +
             | 
| 88 | 
            +
                    return gradgrad_out, None, None, None, None, None, None, None, None
         | 
| 89 | 
            +
             | 
| 90 | 
            +
             | 
| 91 | 
            +
            class UpFirDn2d(Function):
         | 
| 92 | 
            +
                @staticmethod
         | 
| 93 | 
            +
                def forward(ctx, input, kernel, up, down, pad):
         | 
| 94 | 
            +
                    up_x, up_y = up
         | 
| 95 | 
            +
                    down_x, down_y = down
         | 
| 96 | 
            +
                    pad_x0, pad_x1, pad_y0, pad_y1 = pad
         | 
| 97 | 
            +
             | 
| 98 | 
            +
                    kernel_h, kernel_w = kernel.shape
         | 
| 99 | 
            +
                    batch, channel, in_h, in_w = input.shape
         | 
| 100 | 
            +
                    ctx.in_size = input.shape
         | 
| 101 | 
            +
             | 
| 102 | 
            +
                    input = input.reshape(-1, in_h, in_w, 1)
         | 
| 103 | 
            +
             | 
| 104 | 
            +
                    ctx.save_for_backward(kernel, torch.flip(kernel, [0, 1]))
         | 
| 105 | 
            +
             | 
| 106 | 
            +
                    out_h = (in_h * up_y + pad_y0 + pad_y1 - kernel_h) // down_y + 1
         | 
| 107 | 
            +
                    out_w = (in_w * up_x + pad_x0 + pad_x1 - kernel_w) // down_x + 1
         | 
| 108 | 
            +
                    ctx.out_size = (out_h, out_w)
         | 
| 109 | 
            +
             | 
| 110 | 
            +
                    ctx.up = (up_x, up_y)
         | 
| 111 | 
            +
                    ctx.down = (down_x, down_y)
         | 
| 112 | 
            +
                    ctx.pad = (pad_x0, pad_x1, pad_y0, pad_y1)
         | 
| 113 | 
            +
             | 
| 114 | 
            +
                    g_pad_x0 = kernel_w - pad_x0 - 1
         | 
| 115 | 
            +
                    g_pad_y0 = kernel_h - pad_y0 - 1
         | 
| 116 | 
            +
                    g_pad_x1 = in_w * up_x - out_w * down_x + pad_x0 - up_x + 1
         | 
| 117 | 
            +
                    g_pad_y1 = in_h * up_y - out_h * down_y + pad_y0 - up_y + 1
         | 
| 118 | 
            +
             | 
| 119 | 
            +
                    ctx.g_pad = (g_pad_x0, g_pad_x1, g_pad_y0, g_pad_y1)
         | 
| 120 | 
            +
             | 
| 121 | 
            +
                    out = upfirdn2d_op.upfirdn2d(
         | 
| 122 | 
            +
                        input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1
         | 
| 123 | 
            +
                    )
         | 
| 124 | 
            +
                    # out = out.view(major, out_h, out_w, minor)
         | 
| 125 | 
            +
                    out = out.view(-1, channel, out_h, out_w)
         | 
| 126 | 
            +
             | 
| 127 | 
            +
                    return out
         | 
| 128 | 
            +
             | 
| 129 | 
            +
                @staticmethod
         | 
| 130 | 
            +
                def backward(ctx, grad_output):
         | 
| 131 | 
            +
                    kernel, grad_kernel = ctx.saved_tensors
         | 
| 132 | 
            +
             | 
| 133 | 
            +
                    grad_input = UpFirDn2dBackward.apply(
         | 
| 134 | 
            +
                        grad_output,
         | 
| 135 | 
            +
                        kernel,
         | 
| 136 | 
            +
                        grad_kernel,
         | 
| 137 | 
            +
                        ctx.up,
         | 
| 138 | 
            +
                        ctx.down,
         | 
| 139 | 
            +
                        ctx.pad,
         | 
| 140 | 
            +
                        ctx.g_pad,
         | 
| 141 | 
            +
                        ctx.in_size,
         | 
| 142 | 
            +
                        ctx.out_size,
         | 
| 143 | 
            +
                    )
         | 
| 144 | 
            +
             | 
| 145 | 
            +
                    return grad_input, None, None, None, None
         | 
| 146 | 
            +
             | 
| 147 | 
            +
             | 
| 148 | 
            +
            def upfirdn2d(input, kernel, up=1, down=1, pad=(0, 0)):
         | 
| 149 | 
            +
                if input.device.type == "cpu":
         | 
| 150 | 
            +
                    out = upfirdn2d_native(
         | 
| 151 | 
            +
                        input, kernel, up, up, down, down, pad[0], pad[1], pad[0], pad[1]
         | 
| 152 | 
            +
                    )
         | 
| 153 | 
            +
             | 
| 154 | 
            +
                else:
         | 
| 155 | 
            +
                    out = UpFirDn2d.apply(
         | 
| 156 | 
            +
                        input, kernel, (up, up), (down, down), (pad[0], pad[1], pad[0], pad[1])
         | 
| 157 | 
            +
                    )
         | 
| 158 | 
            +
             | 
| 159 | 
            +
                return out
         | 
| 160 | 
            +
             | 
| 161 | 
            +
             | 
| 162 | 
            +
            def upfirdn2d_native(
         | 
| 163 | 
            +
                input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1
         | 
| 164 | 
            +
            ):
         | 
| 165 | 
            +
                _, channel, in_h, in_w = input.shape
         | 
| 166 | 
            +
                input = input.reshape(-1, in_h, in_w, 1)
         | 
| 167 | 
            +
             | 
| 168 | 
            +
                _, in_h, in_w, minor = input.shape
         | 
| 169 | 
            +
                kernel_h, kernel_w = kernel.shape
         | 
| 170 | 
            +
             | 
| 171 | 
            +
                out = input.view(-1, in_h, 1, in_w, 1, minor)
         | 
| 172 | 
            +
                out = F.pad(out, [0, 0, 0, up_x - 1, 0, 0, 0, up_y - 1])
         | 
| 173 | 
            +
                out = out.view(-1, in_h * up_y, in_w * up_x, minor)
         | 
| 174 | 
            +
             | 
| 175 | 
            +
                out = F.pad(
         | 
| 176 | 
            +
                    out, [0, 0, max(pad_x0, 0), max(pad_x1, 0), max(pad_y0, 0), max(pad_y1, 0)]
         | 
| 177 | 
            +
                )
         | 
| 178 | 
            +
                out = out[
         | 
| 179 | 
            +
                    :,
         | 
| 180 | 
            +
                    max(-pad_y0, 0) : out.shape[1] - max(-pad_y1, 0),
         | 
| 181 | 
            +
                    max(-pad_x0, 0) : out.shape[2] - max(-pad_x1, 0),
         | 
| 182 | 
            +
                    :,
         | 
| 183 | 
            +
                ]
         | 
| 184 | 
            +
             | 
| 185 | 
            +
                out = out.permute(0, 3, 1, 2)
         | 
| 186 | 
            +
                out = out.reshape(
         | 
| 187 | 
            +
                    [-1, 1, in_h * up_y + pad_y0 + pad_y1, in_w * up_x + pad_x0 + pad_x1]
         | 
| 188 | 
            +
                )
         | 
| 189 | 
            +
                w = torch.flip(kernel, [0, 1]).view(1, 1, kernel_h, kernel_w)
         | 
| 190 | 
            +
                out = F.conv2d(out, w)
         | 
| 191 | 
            +
                out = out.reshape(
         | 
| 192 | 
            +
                    -1,
         | 
| 193 | 
            +
                    minor,
         | 
| 194 | 
            +
                    in_h * up_y + pad_y0 + pad_y1 - kernel_h + 1,
         | 
| 195 | 
            +
                    in_w * up_x + pad_x0 + pad_x1 - kernel_w + 1,
         | 
| 196 | 
            +
                )
         | 
| 197 | 
            +
                out = out.permute(0, 2, 3, 1)
         | 
| 198 | 
            +
                out = out[:, ::down_y, ::down_x, :]
         | 
| 199 | 
            +
             | 
| 200 | 
            +
                out_h = (in_h * up_y + pad_y0 + pad_y1 - kernel_h) // down_y + 1
         | 
| 201 | 
            +
                out_w = (in_w * up_x + pad_x0 + pad_x1 - kernel_w) // down_x + 1
         | 
| 202 | 
            +
             | 
| 203 | 
            +
                return out.view(-1, channel, out_h, out_w)
         | 
    	
        sgmse/backbones/ncsnpp_utils/op/upfirdn2d_kernel.cu
    ADDED
    
    | @@ -0,0 +1,369 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            // Copyright (c) 2019, NVIDIA Corporation. All rights reserved.
         | 
| 2 | 
            +
            //
         | 
| 3 | 
            +
            // This work is made available under the Nvidia Source Code License-NC.
         | 
| 4 | 
            +
            // To view a copy of this license, visit
         | 
| 5 | 
            +
            // https://nvlabs.github.io/stylegan2/license.html
         | 
| 6 | 
            +
             | 
| 7 | 
            +
            #include <torch/types.h>
         | 
| 8 | 
            +
             | 
| 9 | 
            +
            #include <ATen/ATen.h>
         | 
| 10 | 
            +
            #include <ATen/AccumulateType.h>
         | 
| 11 | 
            +
            #include <ATen/cuda/CUDAApplyUtils.cuh>
         | 
| 12 | 
            +
            #include <ATen/cuda/CUDAContext.h>
         | 
| 13 | 
            +
             | 
| 14 | 
            +
            #include <cuda.h>
         | 
| 15 | 
            +
            #include <cuda_runtime.h>
         | 
| 16 | 
            +
             | 
| 17 | 
            +
            static __host__ __device__ __forceinline__ int floor_div(int a, int b) {
         | 
| 18 | 
            +
              int c = a / b;
         | 
| 19 | 
            +
             | 
| 20 | 
            +
              if (c * b > a) {
         | 
| 21 | 
            +
                c--;
         | 
| 22 | 
            +
              }
         | 
| 23 | 
            +
             | 
| 24 | 
            +
              return c;
         | 
| 25 | 
            +
            }
         | 
| 26 | 
            +
             | 
| 27 | 
            +
            struct UpFirDn2DKernelParams {
         | 
| 28 | 
            +
              int up_x;
         | 
| 29 | 
            +
              int up_y;
         | 
| 30 | 
            +
              int down_x;
         | 
| 31 | 
            +
              int down_y;
         | 
| 32 | 
            +
              int pad_x0;
         | 
| 33 | 
            +
              int pad_x1;
         | 
| 34 | 
            +
              int pad_y0;
         | 
| 35 | 
            +
              int pad_y1;
         | 
| 36 | 
            +
             | 
| 37 | 
            +
              int major_dim;
         | 
| 38 | 
            +
              int in_h;
         | 
| 39 | 
            +
              int in_w;
         | 
| 40 | 
            +
              int minor_dim;
         | 
| 41 | 
            +
              int kernel_h;
         | 
| 42 | 
            +
              int kernel_w;
         | 
| 43 | 
            +
              int out_h;
         | 
| 44 | 
            +
              int out_w;
         | 
| 45 | 
            +
              int loop_major;
         | 
| 46 | 
            +
              int loop_x;
         | 
| 47 | 
            +
            };
         | 
| 48 | 
            +
             | 
| 49 | 
            +
            template <typename scalar_t>
         | 
| 50 | 
            +
            __global__ void upfirdn2d_kernel_large(scalar_t *out, const scalar_t *input,
         | 
| 51 | 
            +
                                                   const scalar_t *kernel,
         | 
| 52 | 
            +
                                                   const UpFirDn2DKernelParams p) {
         | 
| 53 | 
            +
              int minor_idx = blockIdx.x * blockDim.x + threadIdx.x;
         | 
| 54 | 
            +
              int out_y = minor_idx / p.minor_dim;
         | 
| 55 | 
            +
              minor_idx -= out_y * p.minor_dim;
         | 
| 56 | 
            +
              int out_x_base = blockIdx.y * p.loop_x * blockDim.y + threadIdx.y;
         | 
| 57 | 
            +
              int major_idx_base = blockIdx.z * p.loop_major;
         | 
| 58 | 
            +
             | 
| 59 | 
            +
              if (out_x_base >= p.out_w || out_y >= p.out_h ||
         | 
| 60 | 
            +
                  major_idx_base >= p.major_dim) {
         | 
| 61 | 
            +
                return;
         | 
| 62 | 
            +
              }
         | 
| 63 | 
            +
             | 
| 64 | 
            +
              int mid_y = out_y * p.down_y + p.up_y - 1 - p.pad_y0;
         | 
| 65 | 
            +
              int in_y = min(max(floor_div(mid_y, p.up_y), 0), p.in_h);
         | 
| 66 | 
            +
              int h = min(max(floor_div(mid_y + p.kernel_h, p.up_y), 0), p.in_h) - in_y;
         | 
| 67 | 
            +
              int kernel_y = mid_y + p.kernel_h - (in_y + 1) * p.up_y;
         | 
| 68 | 
            +
             | 
| 69 | 
            +
              for (int loop_major = 0, major_idx = major_idx_base;
         | 
| 70 | 
            +
                   loop_major < p.loop_major && major_idx < p.major_dim;
         | 
| 71 | 
            +
                   loop_major++, major_idx++) {
         | 
| 72 | 
            +
                for (int loop_x = 0, out_x = out_x_base;
         | 
| 73 | 
            +
                     loop_x < p.loop_x && out_x < p.out_w; loop_x++, out_x += blockDim.y) {
         | 
| 74 | 
            +
                  int mid_x = out_x * p.down_x + p.up_x - 1 - p.pad_x0;
         | 
| 75 | 
            +
                  int in_x = min(max(floor_div(mid_x, p.up_x), 0), p.in_w);
         | 
| 76 | 
            +
                  int w = min(max(floor_div(mid_x + p.kernel_w, p.up_x), 0), p.in_w) - in_x;
         | 
| 77 | 
            +
                  int kernel_x = mid_x + p.kernel_w - (in_x + 1) * p.up_x;
         | 
| 78 | 
            +
             | 
| 79 | 
            +
                  const scalar_t *x_p =
         | 
| 80 | 
            +
                      &input[((major_idx * p.in_h + in_y) * p.in_w + in_x) * p.minor_dim +
         | 
| 81 | 
            +
                             minor_idx];
         | 
| 82 | 
            +
                  const scalar_t *k_p = &kernel[kernel_y * p.kernel_w + kernel_x];
         | 
| 83 | 
            +
                  int x_px = p.minor_dim;
         | 
| 84 | 
            +
                  int k_px = -p.up_x;
         | 
| 85 | 
            +
                  int x_py = p.in_w * p.minor_dim;
         | 
| 86 | 
            +
                  int k_py = -p.up_y * p.kernel_w;
         | 
| 87 | 
            +
             | 
| 88 | 
            +
                  scalar_t v = 0.0f;
         | 
| 89 | 
            +
             | 
| 90 | 
            +
                  for (int y = 0; y < h; y++) {
         | 
| 91 | 
            +
                    for (int x = 0; x < w; x++) {
         | 
| 92 | 
            +
                      v += static_cast<scalar_t>(*x_p) * static_cast<scalar_t>(*k_p);
         | 
| 93 | 
            +
                      x_p += x_px;
         | 
| 94 | 
            +
                      k_p += k_px;
         | 
| 95 | 
            +
                    }
         | 
| 96 | 
            +
             | 
| 97 | 
            +
                    x_p += x_py - w * x_px;
         | 
| 98 | 
            +
                    k_p += k_py - w * k_px;
         | 
| 99 | 
            +
                  }
         | 
| 100 | 
            +
             | 
| 101 | 
            +
                  out[((major_idx * p.out_h + out_y) * p.out_w + out_x) * p.minor_dim +
         | 
| 102 | 
            +
                      minor_idx] = v;
         | 
| 103 | 
            +
                }
         | 
| 104 | 
            +
              }
         | 
| 105 | 
            +
            }
         | 
| 106 | 
            +
             | 
| 107 | 
            +
            template <typename scalar_t, int up_x, int up_y, int down_x, int down_y,
         | 
| 108 | 
            +
                      int kernel_h, int kernel_w, int tile_out_h, int tile_out_w>
         | 
| 109 | 
            +
            __global__ void upfirdn2d_kernel(scalar_t *out, const scalar_t *input,
         | 
| 110 | 
            +
                                             const scalar_t *kernel,
         | 
| 111 | 
            +
                                             const UpFirDn2DKernelParams p) {
         | 
| 112 | 
            +
              const int tile_in_h = ((tile_out_h - 1) * down_y + kernel_h - 1) / up_y + 1;
         | 
| 113 | 
            +
              const int tile_in_w = ((tile_out_w - 1) * down_x + kernel_w - 1) / up_x + 1;
         | 
| 114 | 
            +
             | 
| 115 | 
            +
              __shared__ volatile float sk[kernel_h][kernel_w];
         | 
| 116 | 
            +
              __shared__ volatile float sx[tile_in_h][tile_in_w];
         | 
| 117 | 
            +
             | 
| 118 | 
            +
              int minor_idx = blockIdx.x;
         | 
| 119 | 
            +
              int tile_out_y = minor_idx / p.minor_dim;
         | 
| 120 | 
            +
              minor_idx -= tile_out_y * p.minor_dim;
         | 
| 121 | 
            +
              tile_out_y *= tile_out_h;
         | 
| 122 | 
            +
              int tile_out_x_base = blockIdx.y * p.loop_x * tile_out_w;
         | 
| 123 | 
            +
              int major_idx_base = blockIdx.z * p.loop_major;
         | 
| 124 | 
            +
             | 
| 125 | 
            +
              if (tile_out_x_base >= p.out_w | tile_out_y >= p.out_h |
         | 
| 126 | 
            +
                  major_idx_base >= p.major_dim) {
         | 
| 127 | 
            +
                return;
         | 
| 128 | 
            +
              }
         | 
| 129 | 
            +
             | 
| 130 | 
            +
              for (int tap_idx = threadIdx.x; tap_idx < kernel_h * kernel_w;
         | 
| 131 | 
            +
                   tap_idx += blockDim.x) {
         | 
| 132 | 
            +
                int ky = tap_idx / kernel_w;
         | 
| 133 | 
            +
                int kx = tap_idx - ky * kernel_w;
         | 
| 134 | 
            +
                scalar_t v = 0.0;
         | 
| 135 | 
            +
             | 
| 136 | 
            +
                if (kx < p.kernel_w & ky < p.kernel_h) {
         | 
| 137 | 
            +
                  v = kernel[(p.kernel_h - 1 - ky) * p.kernel_w + (p.kernel_w - 1 - kx)];
         | 
| 138 | 
            +
                }
         | 
| 139 | 
            +
             | 
| 140 | 
            +
                sk[ky][kx] = v;
         | 
| 141 | 
            +
              }
         | 
| 142 | 
            +
             | 
| 143 | 
            +
              for (int loop_major = 0, major_idx = major_idx_base;
         | 
| 144 | 
            +
                   loop_major < p.loop_major & major_idx < p.major_dim;
         | 
| 145 | 
            +
                   loop_major++, major_idx++) {
         | 
| 146 | 
            +
                for (int loop_x = 0, tile_out_x = tile_out_x_base;
         | 
| 147 | 
            +
                     loop_x < p.loop_x & tile_out_x < p.out_w;
         | 
| 148 | 
            +
                     loop_x++, tile_out_x += tile_out_w) {
         | 
| 149 | 
            +
                  int tile_mid_x = tile_out_x * down_x + up_x - 1 - p.pad_x0;
         | 
| 150 | 
            +
                  int tile_mid_y = tile_out_y * down_y + up_y - 1 - p.pad_y0;
         | 
| 151 | 
            +
                  int tile_in_x = floor_div(tile_mid_x, up_x);
         | 
| 152 | 
            +
                  int tile_in_y = floor_div(tile_mid_y, up_y);
         | 
| 153 | 
            +
             | 
| 154 | 
            +
                  __syncthreads();
         | 
| 155 | 
            +
             | 
| 156 | 
            +
                  for (int in_idx = threadIdx.x; in_idx < tile_in_h * tile_in_w;
         | 
| 157 | 
            +
                       in_idx += blockDim.x) {
         | 
| 158 | 
            +
                    int rel_in_y = in_idx / tile_in_w;
         | 
| 159 | 
            +
                    int rel_in_x = in_idx - rel_in_y * tile_in_w;
         | 
| 160 | 
            +
                    int in_x = rel_in_x + tile_in_x;
         | 
| 161 | 
            +
                    int in_y = rel_in_y + tile_in_y;
         | 
| 162 | 
            +
             | 
| 163 | 
            +
                    scalar_t v = 0.0;
         | 
| 164 | 
            +
             | 
| 165 | 
            +
                    if (in_x >= 0 & in_y >= 0 & in_x < p.in_w & in_y < p.in_h) {
         | 
| 166 | 
            +
                      v = input[((major_idx * p.in_h + in_y) * p.in_w + in_x) *
         | 
| 167 | 
            +
                                    p.minor_dim +
         | 
| 168 | 
            +
                                minor_idx];
         | 
| 169 | 
            +
                    }
         | 
| 170 | 
            +
             | 
| 171 | 
            +
                    sx[rel_in_y][rel_in_x] = v;
         | 
| 172 | 
            +
                  }
         | 
| 173 | 
            +
             | 
| 174 | 
            +
                  __syncthreads();
         | 
| 175 | 
            +
                  for (int out_idx = threadIdx.x; out_idx < tile_out_h * tile_out_w;
         | 
| 176 | 
            +
                       out_idx += blockDim.x) {
         | 
| 177 | 
            +
                    int rel_out_y = out_idx / tile_out_w;
         | 
| 178 | 
            +
                    int rel_out_x = out_idx - rel_out_y * tile_out_w;
         | 
| 179 | 
            +
                    int out_x = rel_out_x + tile_out_x;
         | 
| 180 | 
            +
                    int out_y = rel_out_y + tile_out_y;
         | 
| 181 | 
            +
             | 
| 182 | 
            +
                    int mid_x = tile_mid_x + rel_out_x * down_x;
         | 
| 183 | 
            +
                    int mid_y = tile_mid_y + rel_out_y * down_y;
         | 
| 184 | 
            +
                    int in_x = floor_div(mid_x, up_x);
         | 
| 185 | 
            +
                    int in_y = floor_div(mid_y, up_y);
         | 
| 186 | 
            +
                    int rel_in_x = in_x - tile_in_x;
         | 
| 187 | 
            +
                    int rel_in_y = in_y - tile_in_y;
         | 
| 188 | 
            +
                    int kernel_x = (in_x + 1) * up_x - mid_x - 1;
         | 
| 189 | 
            +
                    int kernel_y = (in_y + 1) * up_y - mid_y - 1;
         | 
| 190 | 
            +
             | 
| 191 | 
            +
                    scalar_t v = 0.0;
         | 
| 192 | 
            +
             | 
| 193 | 
            +
            #pragma unroll
         | 
| 194 | 
            +
                    for (int y = 0; y < kernel_h / up_y; y++)
         | 
| 195 | 
            +
            #pragma unroll
         | 
| 196 | 
            +
                      for (int x = 0; x < kernel_w / up_x; x++)
         | 
| 197 | 
            +
                        v += sx[rel_in_y + y][rel_in_x + x] *
         | 
| 198 | 
            +
                             sk[kernel_y + y * up_y][kernel_x + x * up_x];
         | 
| 199 | 
            +
             | 
| 200 | 
            +
                    if (out_x < p.out_w & out_y < p.out_h) {
         | 
| 201 | 
            +
                      out[((major_idx * p.out_h + out_y) * p.out_w + out_x) * p.minor_dim +
         | 
| 202 | 
            +
                          minor_idx] = v;
         | 
| 203 | 
            +
                    }
         | 
| 204 | 
            +
                  }
         | 
| 205 | 
            +
                }
         | 
| 206 | 
            +
              }
         | 
| 207 | 
            +
            }
         | 
| 208 | 
            +
             | 
| 209 | 
            +
            torch::Tensor upfirdn2d_op(const torch::Tensor &input,
         | 
| 210 | 
            +
                                       const torch::Tensor &kernel, int up_x, int up_y,
         | 
| 211 | 
            +
                                       int down_x, int down_y, int pad_x0, int pad_x1,
         | 
| 212 | 
            +
                                       int pad_y0, int pad_y1) {
         | 
| 213 | 
            +
              int curDevice = -1;
         | 
| 214 | 
            +
              cudaGetDevice(&curDevice);
         | 
| 215 | 
            +
              cudaStream_t stream = at::cuda::getCurrentCUDAStream(curDevice);
         | 
| 216 | 
            +
             | 
| 217 | 
            +
              UpFirDn2DKernelParams p;
         | 
| 218 | 
            +
             | 
| 219 | 
            +
              auto x = input.contiguous();
         | 
| 220 | 
            +
              auto k = kernel.contiguous();
         | 
| 221 | 
            +
             | 
| 222 | 
            +
              p.major_dim = x.size(0);
         | 
| 223 | 
            +
              p.in_h = x.size(1);
         | 
| 224 | 
            +
              p.in_w = x.size(2);
         | 
| 225 | 
            +
              p.minor_dim = x.size(3);
         | 
| 226 | 
            +
              p.kernel_h = k.size(0);
         | 
| 227 | 
            +
              p.kernel_w = k.size(1);
         | 
| 228 | 
            +
              p.up_x = up_x;
         | 
| 229 | 
            +
              p.up_y = up_y;
         | 
| 230 | 
            +
              p.down_x = down_x;
         | 
| 231 | 
            +
              p.down_y = down_y;
         | 
| 232 | 
            +
              p.pad_x0 = pad_x0;
         | 
| 233 | 
            +
              p.pad_x1 = pad_x1;
         | 
| 234 | 
            +
              p.pad_y0 = pad_y0;
         | 
| 235 | 
            +
              p.pad_y1 = pad_y1;
         | 
| 236 | 
            +
             | 
| 237 | 
            +
              p.out_h = (p.in_h * p.up_y + p.pad_y0 + p.pad_y1 - p.kernel_h + p.down_y) /
         | 
| 238 | 
            +
                        p.down_y;
         | 
| 239 | 
            +
              p.out_w = (p.in_w * p.up_x + p.pad_x0 + p.pad_x1 - p.kernel_w + p.down_x) /
         | 
| 240 | 
            +
                        p.down_x;
         | 
| 241 | 
            +
             | 
| 242 | 
            +
              auto out =
         | 
| 243 | 
            +
                  at::empty({p.major_dim, p.out_h, p.out_w, p.minor_dim}, x.options());
         | 
| 244 | 
            +
             | 
| 245 | 
            +
              int mode = -1;
         | 
| 246 | 
            +
             | 
| 247 | 
            +
              int tile_out_h = -1;
         | 
| 248 | 
            +
              int tile_out_w = -1;
         | 
| 249 | 
            +
             | 
| 250 | 
            +
              if (p.up_x == 1 && p.up_y == 1 && p.down_x == 1 && p.down_y == 1 &&
         | 
| 251 | 
            +
                  p.kernel_h <= 4 && p.kernel_w <= 4) {
         | 
| 252 | 
            +
                mode = 1;
         | 
| 253 | 
            +
                tile_out_h = 16;
         | 
| 254 | 
            +
                tile_out_w = 64;
         | 
| 255 | 
            +
              }
         | 
| 256 | 
            +
             | 
| 257 | 
            +
              if (p.up_x == 1 && p.up_y == 1 && p.down_x == 1 && p.down_y == 1 &&
         | 
| 258 | 
            +
                  p.kernel_h <= 3 && p.kernel_w <= 3) {
         | 
| 259 | 
            +
                mode = 2;
         | 
| 260 | 
            +
                tile_out_h = 16;
         | 
| 261 | 
            +
                tile_out_w = 64;
         | 
| 262 | 
            +
              }
         | 
| 263 | 
            +
             | 
| 264 | 
            +
              if (p.up_x == 2 && p.up_y == 2 && p.down_x == 1 && p.down_y == 1 &&
         | 
| 265 | 
            +
                  p.kernel_h <= 4 && p.kernel_w <= 4) {
         | 
| 266 | 
            +
                mode = 3;
         | 
| 267 | 
            +
                tile_out_h = 16;
         | 
| 268 | 
            +
                tile_out_w = 64;
         | 
| 269 | 
            +
              }
         | 
| 270 | 
            +
             | 
| 271 | 
            +
              if (p.up_x == 2 && p.up_y == 2 && p.down_x == 1 && p.down_y == 1 &&
         | 
| 272 | 
            +
                  p.kernel_h <= 2 && p.kernel_w <= 2) {
         | 
| 273 | 
            +
                mode = 4;
         | 
| 274 | 
            +
                tile_out_h = 16;
         | 
| 275 | 
            +
                tile_out_w = 64;
         | 
| 276 | 
            +
              }
         | 
| 277 | 
            +
             | 
| 278 | 
            +
              if (p.up_x == 1 && p.up_y == 1 && p.down_x == 2 && p.down_y == 2 &&
         | 
| 279 | 
            +
                  p.kernel_h <= 4 && p.kernel_w <= 4) {
         | 
| 280 | 
            +
                mode = 5;
         | 
| 281 | 
            +
                tile_out_h = 8;
         | 
| 282 | 
            +
                tile_out_w = 32;
         | 
| 283 | 
            +
              }
         | 
| 284 | 
            +
             | 
| 285 | 
            +
              if (p.up_x == 1 && p.up_y == 1 && p.down_x == 2 && p.down_y == 2 &&
         | 
| 286 | 
            +
                  p.kernel_h <= 2 && p.kernel_w <= 2) {
         | 
| 287 | 
            +
                mode = 6;
         | 
| 288 | 
            +
                tile_out_h = 8;
         | 
| 289 | 
            +
                tile_out_w = 32;
         | 
| 290 | 
            +
              }
         | 
| 291 | 
            +
             | 
| 292 | 
            +
              dim3 block_size;
         | 
| 293 | 
            +
              dim3 grid_size;
         | 
| 294 | 
            +
             | 
| 295 | 
            +
              if (tile_out_h > 0 && tile_out_w > 0) {
         | 
| 296 | 
            +
                p.loop_major = (p.major_dim - 1) / 16384 + 1;
         | 
| 297 | 
            +
                p.loop_x = 1;
         | 
| 298 | 
            +
                block_size = dim3(32 * 8, 1, 1);
         | 
| 299 | 
            +
                grid_size = dim3(((p.out_h - 1) / tile_out_h + 1) * p.minor_dim,
         | 
| 300 | 
            +
                                 (p.out_w - 1) / (p.loop_x * tile_out_w) + 1,
         | 
| 301 | 
            +
                                 (p.major_dim - 1) / p.loop_major + 1);
         | 
| 302 | 
            +
              } else {
         | 
| 303 | 
            +
                p.loop_major = (p.major_dim - 1) / 16384 + 1;
         | 
| 304 | 
            +
                p.loop_x = 4;
         | 
| 305 | 
            +
                block_size = dim3(4, 32, 1);
         | 
| 306 | 
            +
                grid_size = dim3((p.out_h * p.minor_dim - 1) / block_size.x + 1,
         | 
| 307 | 
            +
                                 (p.out_w - 1) / (p.loop_x * block_size.y) + 1,
         | 
| 308 | 
            +
                                 (p.major_dim - 1) / p.loop_major + 1);
         | 
| 309 | 
            +
              }
         | 
| 310 | 
            +
             | 
| 311 | 
            +
              AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "upfirdn2d_cuda", [&] {
         | 
| 312 | 
            +
                switch (mode) {
         | 
| 313 | 
            +
                case 1:
         | 
| 314 | 
            +
                  upfirdn2d_kernel<scalar_t, 1, 1, 1, 1, 4, 4, 16, 64>
         | 
| 315 | 
            +
                      <<<grid_size, block_size, 0, stream>>>(out.data_ptr<scalar_t>(),
         | 
| 316 | 
            +
                                                             x.data_ptr<scalar_t>(),
         | 
| 317 | 
            +
                                                             k.data_ptr<scalar_t>(), p);
         | 
| 318 | 
            +
             | 
| 319 | 
            +
                  break;
         | 
| 320 | 
            +
             | 
| 321 | 
            +
                case 2:
         | 
| 322 | 
            +
                  upfirdn2d_kernel<scalar_t, 1, 1, 1, 1, 3, 3, 16, 64>
         | 
| 323 | 
            +
                      <<<grid_size, block_size, 0, stream>>>(out.data_ptr<scalar_t>(),
         | 
| 324 | 
            +
                                                             x.data_ptr<scalar_t>(),
         | 
| 325 | 
            +
                                                             k.data_ptr<scalar_t>(), p);
         | 
| 326 | 
            +
             | 
| 327 | 
            +
                  break;
         | 
| 328 | 
            +
             | 
| 329 | 
            +
                case 3:
         | 
| 330 | 
            +
                  upfirdn2d_kernel<scalar_t, 2, 2, 1, 1, 4, 4, 16, 64>
         | 
| 331 | 
            +
                      <<<grid_size, block_size, 0, stream>>>(out.data_ptr<scalar_t>(),
         | 
| 332 | 
            +
                                                             x.data_ptr<scalar_t>(),
         | 
| 333 | 
            +
                                                             k.data_ptr<scalar_t>(), p);
         | 
| 334 | 
            +
             | 
| 335 | 
            +
                  break;
         | 
| 336 | 
            +
             | 
| 337 | 
            +
                case 4:
         | 
| 338 | 
            +
                  upfirdn2d_kernel<scalar_t, 2, 2, 1, 1, 2, 2, 16, 64>
         | 
| 339 | 
            +
                      <<<grid_size, block_size, 0, stream>>>(out.data_ptr<scalar_t>(),
         | 
| 340 | 
            +
                                                             x.data_ptr<scalar_t>(),
         | 
| 341 | 
            +
                                                             k.data_ptr<scalar_t>(), p);
         | 
| 342 | 
            +
             | 
| 343 | 
            +
                  break;
         | 
| 344 | 
            +
             | 
| 345 | 
            +
                case 5:
         | 
| 346 | 
            +
                  upfirdn2d_kernel<scalar_t, 1, 1, 2, 2, 4, 4, 8, 32>
         | 
| 347 | 
            +
                      <<<grid_size, block_size, 0, stream>>>(out.data_ptr<scalar_t>(),
         | 
| 348 | 
            +
                                                             x.data_ptr<scalar_t>(),
         | 
| 349 | 
            +
                                                             k.data_ptr<scalar_t>(), p);
         | 
| 350 | 
            +
             | 
| 351 | 
            +
                  break;
         | 
| 352 | 
            +
             | 
| 353 | 
            +
                case 6:
         | 
| 354 | 
            +
                  upfirdn2d_kernel<scalar_t, 1, 1, 2, 2, 4, 4, 8, 32>
         | 
| 355 | 
            +
                      <<<grid_size, block_size, 0, stream>>>(out.data_ptr<scalar_t>(),
         | 
| 356 | 
            +
                                                             x.data_ptr<scalar_t>(),
         | 
| 357 | 
            +
                                                             k.data_ptr<scalar_t>(), p);
         | 
| 358 | 
            +
             | 
| 359 | 
            +
                  break;
         | 
| 360 | 
            +
             | 
| 361 | 
            +
                default:
         | 
| 362 | 
            +
                  upfirdn2d_kernel_large<scalar_t><<<grid_size, block_size, 0, stream>>>(
         | 
| 363 | 
            +
                      out.data_ptr<scalar_t>(), x.data_ptr<scalar_t>(),
         | 
| 364 | 
            +
                      k.data_ptr<scalar_t>(), p);
         | 
| 365 | 
            +
                }
         | 
| 366 | 
            +
              });
         | 
| 367 | 
            +
             | 
| 368 | 
            +
              return out;
         | 
| 369 | 
            +
            }
         | 
    	
        sgmse/backbones/ncsnpp_utils/up_or_down_sampling.py
    ADDED
    
    | @@ -0,0 +1,257 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            """Layers used for up-sampling or down-sampling images.
         | 
| 2 | 
            +
             | 
| 3 | 
            +
            Many functions are ported from https://github.com/NVlabs/stylegan2.
         | 
| 4 | 
            +
            """
         | 
| 5 | 
            +
             | 
| 6 | 
            +
            import torch.nn as nn
         | 
| 7 | 
            +
            import torch
         | 
| 8 | 
            +
            import torch.nn.functional as F
         | 
| 9 | 
            +
            import numpy as np
         | 
| 10 | 
            +
            from .op import upfirdn2d
         | 
| 11 | 
            +
             | 
| 12 | 
            +
             | 
| 13 | 
            +
            # Function ported from StyleGAN2
         | 
| 14 | 
            +
            def get_weight(module,
         | 
| 15 | 
            +
                           shape,
         | 
| 16 | 
            +
                           weight_var='weight',
         | 
| 17 | 
            +
                           kernel_init=None):
         | 
| 18 | 
            +
              """Get/create weight tensor for a convolution or fully-connected layer."""
         | 
| 19 | 
            +
             | 
| 20 | 
            +
              return module.param(weight_var, kernel_init, shape)
         | 
| 21 | 
            +
             | 
| 22 | 
            +
             | 
| 23 | 
            +
            class Conv2d(nn.Module):
         | 
| 24 | 
            +
              """Conv2d layer with optimal upsampling and downsampling (StyleGAN2)."""
         | 
| 25 | 
            +
             | 
| 26 | 
            +
              def __init__(self, in_ch, out_ch, kernel, up=False, down=False,
         | 
| 27 | 
            +
                           resample_kernel=(1, 3, 3, 1),
         | 
| 28 | 
            +
                           use_bias=True,
         | 
| 29 | 
            +
                           kernel_init=None):
         | 
| 30 | 
            +
                super().__init__()
         | 
| 31 | 
            +
                assert not (up and down)
         | 
| 32 | 
            +
                assert kernel >= 1 and kernel % 2 == 1
         | 
| 33 | 
            +
                self.weight = nn.Parameter(torch.zeros(out_ch, in_ch, kernel, kernel))
         | 
| 34 | 
            +
                if kernel_init is not None:
         | 
| 35 | 
            +
                  self.weight.data = kernel_init(self.weight.data.shape)
         | 
| 36 | 
            +
                if use_bias:
         | 
| 37 | 
            +
                  self.bias = nn.Parameter(torch.zeros(out_ch))
         | 
| 38 | 
            +
             | 
| 39 | 
            +
                self.up = up
         | 
| 40 | 
            +
                self.down = down
         | 
| 41 | 
            +
                self.resample_kernel = resample_kernel
         | 
| 42 | 
            +
                self.kernel = kernel
         | 
| 43 | 
            +
                self.use_bias = use_bias
         | 
| 44 | 
            +
             | 
| 45 | 
            +
              def forward(self, x):
         | 
| 46 | 
            +
                if self.up:
         | 
| 47 | 
            +
                  x = upsample_conv_2d(x, self.weight, k=self.resample_kernel)
         | 
| 48 | 
            +
                elif self.down:
         | 
| 49 | 
            +
                  x = conv_downsample_2d(x, self.weight, k=self.resample_kernel)
         | 
| 50 | 
            +
                else:
         | 
| 51 | 
            +
                  x = F.conv2d(x, self.weight, stride=1, padding=self.kernel // 2)
         | 
| 52 | 
            +
             | 
| 53 | 
            +
                if self.use_bias:
         | 
| 54 | 
            +
                  x = x + self.bias.reshape(1, -1, 1, 1)
         | 
| 55 | 
            +
             | 
| 56 | 
            +
                return x
         | 
| 57 | 
            +
             | 
| 58 | 
            +
             | 
| 59 | 
            +
            def naive_upsample_2d(x, factor=2):
         | 
| 60 | 
            +
              _N, C, H, W = x.shape
         | 
| 61 | 
            +
              x = torch.reshape(x, (-1, C, H, 1, W, 1))
         | 
| 62 | 
            +
              x = x.repeat(1, 1, 1, factor, 1, factor)
         | 
| 63 | 
            +
              return torch.reshape(x, (-1, C, H * factor, W * factor))
         | 
| 64 | 
            +
             | 
| 65 | 
            +
             | 
| 66 | 
            +
            def naive_downsample_2d(x, factor=2):
         | 
| 67 | 
            +
              _N, C, H, W = x.shape
         | 
| 68 | 
            +
              x = torch.reshape(x, (-1, C, H // factor, factor, W // factor, factor))
         | 
| 69 | 
            +
              return torch.mean(x, dim=(3, 5))
         | 
| 70 | 
            +
             | 
| 71 | 
            +
             | 
| 72 | 
            +
            def upsample_conv_2d(x, w, k=None, factor=2, gain=1):
         | 
| 73 | 
            +
              """Fused `upsample_2d()` followed by `tf.nn.conv2d()`.
         | 
| 74 | 
            +
             | 
| 75 | 
            +
                 Padding is performed only once at the beginning, not between the
         | 
| 76 | 
            +
                 operations.
         | 
| 77 | 
            +
                 The fused op is considerably more efficient than performing the same
         | 
| 78 | 
            +
                 calculation
         | 
| 79 | 
            +
                 using standard TensorFlow ops. It supports gradients of arbitrary order.
         | 
| 80 | 
            +
                 Args:
         | 
| 81 | 
            +
                   x:            Input tensor of the shape `[N, C, H, W]` or `[N, H, W,
         | 
| 82 | 
            +
                     C]`.
         | 
| 83 | 
            +
                   w:            Weight tensor of the shape `[filterH, filterW, inChannels,
         | 
| 84 | 
            +
                     outChannels]`. Grouped convolution can be performed by `inChannels =
         | 
| 85 | 
            +
                     x.shape[0] // numGroups`.
         | 
| 86 | 
            +
                   k:            FIR filter of the shape `[firH, firW]` or `[firN]`
         | 
| 87 | 
            +
                     (separable). The default is `[1] * factor`, which corresponds to
         | 
| 88 | 
            +
                     nearest-neighbor upsampling.
         | 
| 89 | 
            +
                   factor:       Integer upsampling factor (default: 2).
         | 
| 90 | 
            +
                   gain:         Scaling factor for signal magnitude (default: 1.0).
         | 
| 91 | 
            +
             | 
| 92 | 
            +
                 Returns:
         | 
| 93 | 
            +
                   Tensor of the shape `[N, C, H * factor, W * factor]` or
         | 
| 94 | 
            +
                   `[N, H * factor, W * factor, C]`, and same datatype as `x`.
         | 
| 95 | 
            +
              """
         | 
| 96 | 
            +
             | 
| 97 | 
            +
              assert isinstance(factor, int) and factor >= 1
         | 
| 98 | 
            +
             | 
| 99 | 
            +
              # Check weight shape.
         | 
| 100 | 
            +
              assert len(w.shape) == 4
         | 
| 101 | 
            +
              convH = w.shape[2]
         | 
| 102 | 
            +
              convW = w.shape[3]
         | 
| 103 | 
            +
              inC = w.shape[1]
         | 
| 104 | 
            +
              outC = w.shape[0]
         | 
| 105 | 
            +
             | 
| 106 | 
            +
              assert convW == convH
         | 
| 107 | 
            +
             | 
| 108 | 
            +
              # Setup filter kernel.
         | 
| 109 | 
            +
              if k is None:
         | 
| 110 | 
            +
                k = [1] * factor
         | 
| 111 | 
            +
              k = _setup_kernel(k) * (gain * (factor ** 2))
         | 
| 112 | 
            +
              p = (k.shape[0] - factor) - (convW - 1)
         | 
| 113 | 
            +
             | 
| 114 | 
            +
              stride = (factor, factor)
         | 
| 115 | 
            +
             | 
| 116 | 
            +
              # Determine data dimensions.
         | 
| 117 | 
            +
              stride = [1, 1, factor, factor]
         | 
| 118 | 
            +
              output_shape = ((_shape(x, 2) - 1) * factor + convH, (_shape(x, 3) - 1) * factor + convW)
         | 
| 119 | 
            +
              output_padding = (output_shape[0] - (_shape(x, 2) - 1) * stride[0] - convH,
         | 
| 120 | 
            +
                                output_shape[1] - (_shape(x, 3) - 1) * stride[1] - convW)
         | 
| 121 | 
            +
              assert output_padding[0] >= 0 and output_padding[1] >= 0
         | 
| 122 | 
            +
              num_groups = _shape(x, 1) // inC
         | 
| 123 | 
            +
             | 
| 124 | 
            +
              # Transpose weights.
         | 
| 125 | 
            +
              w = torch.reshape(w, (num_groups, -1, inC, convH, convW))
         | 
| 126 | 
            +
              w = w[..., ::-1, ::-1].permute(0, 2, 1, 3, 4)
         | 
| 127 | 
            +
              w = torch.reshape(w, (num_groups * inC, -1, convH, convW))
         | 
| 128 | 
            +
             | 
| 129 | 
            +
              x = F.conv_transpose2d(x, w, stride=stride, output_padding=output_padding, padding=0)
         | 
| 130 | 
            +
              ## Original TF code.
         | 
| 131 | 
            +
              # x = tf.nn.conv2d_transpose(
         | 
| 132 | 
            +
              #     x,
         | 
| 133 | 
            +
              #     w,
         | 
| 134 | 
            +
              #     output_shape=output_shape,
         | 
| 135 | 
            +
              #     strides=stride,
         | 
| 136 | 
            +
              #     padding='VALID',
         | 
| 137 | 
            +
              #     data_format=data_format)
         | 
| 138 | 
            +
              ## JAX equivalent
         | 
| 139 | 
            +
             | 
| 140 | 
            +
              return upfirdn2d(x, torch.tensor(k, device=x.device),
         | 
| 141 | 
            +
                               pad=((p + 1) // 2 + factor - 1, p // 2 + 1))
         | 
| 142 | 
            +
             | 
| 143 | 
            +
             | 
| 144 | 
            +
            def conv_downsample_2d(x, w, k=None, factor=2, gain=1):
         | 
| 145 | 
            +
              """Fused `tf.nn.conv2d()` followed by `downsample_2d()`.
         | 
| 146 | 
            +
             | 
| 147 | 
            +
                Padding is performed only once at the beginning, not between the operations.
         | 
| 148 | 
            +
                The fused op is considerably more efficient than performing the same
         | 
| 149 | 
            +
                calculation
         | 
| 150 | 
            +
                using standard TensorFlow ops. It supports gradients of arbitrary order.
         | 
| 151 | 
            +
                Args:
         | 
| 152 | 
            +
                    x:            Input tensor of the shape `[N, C, H, W]` or `[N, H, W,
         | 
| 153 | 
            +
                      C]`.
         | 
| 154 | 
            +
                    w:            Weight tensor of the shape `[filterH, filterW, inChannels,
         | 
| 155 | 
            +
                      outChannels]`. Grouped convolution can be performed by `inChannels =
         | 
| 156 | 
            +
                      x.shape[0] // numGroups`.
         | 
| 157 | 
            +
                    k:            FIR filter of the shape `[firH, firW]` or `[firN]`
         | 
| 158 | 
            +
                      (separable). The default is `[1] * factor`, which corresponds to
         | 
| 159 | 
            +
                      average pooling.
         | 
| 160 | 
            +
                    factor:       Integer downsampling factor (default: 2).
         | 
| 161 | 
            +
                    gain:         Scaling factor for signal magnitude (default: 1.0).
         | 
| 162 | 
            +
             | 
| 163 | 
            +
                Returns:
         | 
| 164 | 
            +
                    Tensor of the shape `[N, C, H // factor, W // factor]` or
         | 
| 165 | 
            +
                    `[N, H // factor, W // factor, C]`, and same datatype as `x`.
         | 
| 166 | 
            +
              """
         | 
| 167 | 
            +
             | 
| 168 | 
            +
              assert isinstance(factor, int) and factor >= 1
         | 
| 169 | 
            +
              _outC, _inC, convH, convW = w.shape
         | 
| 170 | 
            +
              assert convW == convH
         | 
| 171 | 
            +
              if k is None:
         | 
| 172 | 
            +
                k = [1] * factor
         | 
| 173 | 
            +
              k = _setup_kernel(k) * gain
         | 
| 174 | 
            +
              p = (k.shape[0] - factor) + (convW - 1)
         | 
| 175 | 
            +
              s = [factor, factor]
         | 
| 176 | 
            +
              x = upfirdn2d(x, torch.tensor(k, device=x.device),
         | 
| 177 | 
            +
                            pad=((p + 1) // 2, p // 2))
         | 
| 178 | 
            +
              return F.conv2d(x, w, stride=s, padding=0)
         | 
| 179 | 
            +
             | 
| 180 | 
            +
             | 
| 181 | 
            +
            def _setup_kernel(k):
         | 
| 182 | 
            +
              k = np.asarray(k, dtype=np.float32)
         | 
| 183 | 
            +
              if k.ndim == 1:
         | 
| 184 | 
            +
                k = np.outer(k, k)
         | 
| 185 | 
            +
              k /= np.sum(k)
         | 
| 186 | 
            +
              assert k.ndim == 2
         | 
| 187 | 
            +
              assert k.shape[0] == k.shape[1]
         | 
| 188 | 
            +
              return k
         | 
| 189 | 
            +
             | 
| 190 | 
            +
             | 
| 191 | 
            +
            def _shape(x, dim):
         | 
| 192 | 
            +
              return x.shape[dim]
         | 
| 193 | 
            +
             | 
| 194 | 
            +
             | 
| 195 | 
            +
            def upsample_2d(x, k=None, factor=2, gain=1):
         | 
| 196 | 
            +
              r"""Upsample a batch of 2D images with the given filter.
         | 
| 197 | 
            +
             | 
| 198 | 
            +
                Accepts a batch of 2D images of the shape `[N, C, H, W]` or `[N, H, W, C]`
         | 
| 199 | 
            +
                and upsamples each image with the given filter. The filter is normalized so
         | 
| 200 | 
            +
                that
         | 
| 201 | 
            +
                if the input pixels are constant, they will be scaled by the specified
         | 
| 202 | 
            +
                `gain`.
         | 
| 203 | 
            +
                Pixels outside the image are assumed to be zero, and the filter is padded
         | 
| 204 | 
            +
                with
         | 
| 205 | 
            +
                zeros so that its shape is a multiple of the upsampling factor.
         | 
| 206 | 
            +
                Args:
         | 
| 207 | 
            +
                    x:            Input tensor of the shape `[N, C, H, W]` or `[N, H, W,
         | 
| 208 | 
            +
                      C]`.
         | 
| 209 | 
            +
                    k:            FIR filter of the shape `[firH, firW]` or `[firN]`
         | 
| 210 | 
            +
                      (separable). The default is `[1] * factor`, which corresponds to
         | 
| 211 | 
            +
                      nearest-neighbor upsampling.
         | 
| 212 | 
            +
                    factor:       Integer upsampling factor (default: 2).
         | 
| 213 | 
            +
                    gain:         Scaling factor for signal magnitude (default: 1.0).
         | 
| 214 | 
            +
             | 
| 215 | 
            +
                Returns:
         | 
| 216 | 
            +
                    Tensor of the shape `[N, C, H * factor, W * factor]`
         | 
| 217 | 
            +
              """
         | 
| 218 | 
            +
              assert isinstance(factor, int) and factor >= 1
         | 
| 219 | 
            +
              if k is None:
         | 
| 220 | 
            +
                k = [1] * factor
         | 
| 221 | 
            +
              k = _setup_kernel(k) * (gain * (factor ** 2))
         | 
| 222 | 
            +
              p = k.shape[0] - factor
         | 
| 223 | 
            +
              return upfirdn2d(x, torch.tensor(k, device=x.device),
         | 
| 224 | 
            +
                               up=factor, pad=((p + 1) // 2 + factor - 1, p // 2))
         | 
| 225 | 
            +
             | 
| 226 | 
            +
             | 
| 227 | 
            +
            def downsample_2d(x, k=None, factor=2, gain=1):
         | 
| 228 | 
            +
              r"""Downsample a batch of 2D images with the given filter.
         | 
| 229 | 
            +
             | 
| 230 | 
            +
                Accepts a batch of 2D images of the shape `[N, C, H, W]` or `[N, H, W, C]`
         | 
| 231 | 
            +
                and downsamples each image with the given filter. The filter is normalized
         | 
| 232 | 
            +
                so that
         | 
| 233 | 
            +
                if the input pixels are constant, they will be scaled by the specified
         | 
| 234 | 
            +
                `gain`.
         | 
| 235 | 
            +
                Pixels outside the image are assumed to be zero, and the filter is padded
         | 
| 236 | 
            +
                with
         | 
| 237 | 
            +
                zeros so that its shape is a multiple of the downsampling factor.
         | 
| 238 | 
            +
                Args:
         | 
| 239 | 
            +
                    x:            Input tensor of the shape `[N, C, H, W]` or `[N, H, W,
         | 
| 240 | 
            +
                      C]`.
         | 
| 241 | 
            +
                    k:            FIR filter of the shape `[firH, firW]` or `[firN]`
         | 
| 242 | 
            +
                      (separable). The default is `[1] * factor`, which corresponds to
         | 
| 243 | 
            +
                      average pooling.
         | 
| 244 | 
            +
                    factor:       Integer downsampling factor (default: 2).
         | 
| 245 | 
            +
                    gain:         Scaling factor for signal magnitude (default: 1.0).
         | 
| 246 | 
            +
             | 
| 247 | 
            +
                Returns:
         | 
| 248 | 
            +
                    Tensor of the shape `[N, C, H // factor, W // factor]`
         | 
| 249 | 
            +
              """
         | 
| 250 | 
            +
             | 
| 251 | 
            +
              assert isinstance(factor, int) and factor >= 1
         | 
| 252 | 
            +
              if k is None:
         | 
| 253 | 
            +
                k = [1] * factor
         | 
| 254 | 
            +
              k = _setup_kernel(k) * gain
         | 
| 255 | 
            +
              p = k.shape[0] - factor
         | 
| 256 | 
            +
              return upfirdn2d(x, torch.tensor(k, device=x.device),
         | 
| 257 | 
            +
                               down=factor, pad=((p + 1) // 2, p // 2))
         | 
    	
        sgmse/backbones/ncsnpp_utils/utils.py
    ADDED
    
    | @@ -0,0 +1,189 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            # coding=utf-8
         | 
| 2 | 
            +
            # Copyright 2020 The Google Research Authors.
         | 
| 3 | 
            +
            #
         | 
| 4 | 
            +
            # Licensed under the Apache License, Version 2.0 (the "License");
         | 
| 5 | 
            +
            # you may not use this file except in compliance with the License.
         | 
| 6 | 
            +
            # You may obtain a copy of the License at
         | 
| 7 | 
            +
            #
         | 
| 8 | 
            +
            #     http://www.apache.org/licenses/LICENSE-2.0
         | 
| 9 | 
            +
            #
         | 
| 10 | 
            +
            # Unless required by applicable law or agreed to in writing, software
         | 
| 11 | 
            +
            # distributed under the License is distributed on an "AS IS" BASIS,
         | 
| 12 | 
            +
            # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
         | 
| 13 | 
            +
            # See the License for the specific language governing permissions and
         | 
| 14 | 
            +
            # limitations under the License.
         | 
| 15 | 
            +
             | 
| 16 | 
            +
            """All functions and modules related to model definition.
         | 
| 17 | 
            +
            """
         | 
| 18 | 
            +
             | 
| 19 | 
            +
            import torch
         | 
| 20 | 
            +
             | 
| 21 | 
            +
            import numpy as np
         | 
| 22 | 
            +
            from ...sdes import OUVESDE, OUVPSDE
         | 
| 23 | 
            +
             | 
| 24 | 
            +
             | 
| 25 | 
            +
            _MODELS = {}
         | 
| 26 | 
            +
             | 
| 27 | 
            +
             | 
| 28 | 
            +
            def register_model(cls=None, *, name=None):
         | 
| 29 | 
            +
              """A decorator for registering model classes."""
         | 
| 30 | 
            +
             | 
| 31 | 
            +
              def _register(cls):
         | 
| 32 | 
            +
                if name is None:
         | 
| 33 | 
            +
                  local_name = cls.__name__
         | 
| 34 | 
            +
                else:
         | 
| 35 | 
            +
                  local_name = name
         | 
| 36 | 
            +
                if local_name in _MODELS:
         | 
| 37 | 
            +
                  raise ValueError(f'Already registered model with name: {local_name}')
         | 
| 38 | 
            +
                _MODELS[local_name] = cls
         | 
| 39 | 
            +
                return cls
         | 
| 40 | 
            +
             | 
| 41 | 
            +
              if cls is None:
         | 
| 42 | 
            +
                return _register
         | 
| 43 | 
            +
              else:
         | 
| 44 | 
            +
                return _register(cls)
         | 
| 45 | 
            +
             | 
| 46 | 
            +
             | 
| 47 | 
            +
            def get_model(name):
         | 
| 48 | 
            +
              return _MODELS[name]
         | 
| 49 | 
            +
             | 
| 50 | 
            +
             | 
| 51 | 
            +
            def get_sigmas(sigma_min, sigma_max, num_scales):
         | 
| 52 | 
            +
              """Get sigmas --- the set of noise levels for SMLD from config files.
         | 
| 53 | 
            +
              Args:
         | 
| 54 | 
            +
                config: A ConfigDict object parsed from the config file
         | 
| 55 | 
            +
              Returns:
         | 
| 56 | 
            +
                sigmas: a jax numpy arrary of noise levels
         | 
| 57 | 
            +
              """
         | 
| 58 | 
            +
              sigmas = np.exp(
         | 
| 59 | 
            +
                np.linspace(np.log(sigma_max), np.log(sigma_min), num_scales))
         | 
| 60 | 
            +
             | 
| 61 | 
            +
              return sigmas
         | 
| 62 | 
            +
             | 
| 63 | 
            +
             | 
| 64 | 
            +
            def get_ddpm_params(config):
         | 
| 65 | 
            +
              """Get betas and alphas --- parameters used in the original DDPM paper."""
         | 
| 66 | 
            +
              num_diffusion_timesteps = 1000
         | 
| 67 | 
            +
              # parameters need to be adapted if number of time steps differs from 1000
         | 
| 68 | 
            +
              beta_start = config.model.beta_min / config.model.num_scales
         | 
| 69 | 
            +
              beta_end = config.model.beta_max / config.model.num_scales
         | 
| 70 | 
            +
              betas = np.linspace(beta_start, beta_end, num_diffusion_timesteps, dtype=np.float64)
         | 
| 71 | 
            +
             | 
| 72 | 
            +
              alphas = 1. - betas
         | 
| 73 | 
            +
              alphas_cumprod = np.cumprod(alphas, axis=0)
         | 
| 74 | 
            +
              sqrt_alphas_cumprod = np.sqrt(alphas_cumprod)
         | 
| 75 | 
            +
              sqrt_1m_alphas_cumprod = np.sqrt(1. - alphas_cumprod)
         | 
| 76 | 
            +
             | 
| 77 | 
            +
              return {
         | 
| 78 | 
            +
                'betas': betas,
         | 
| 79 | 
            +
                'alphas': alphas,
         | 
| 80 | 
            +
                'alphas_cumprod': alphas_cumprod,
         | 
| 81 | 
            +
                'sqrt_alphas_cumprod': sqrt_alphas_cumprod,
         | 
| 82 | 
            +
                'sqrt_1m_alphas_cumprod': sqrt_1m_alphas_cumprod,
         | 
| 83 | 
            +
                'beta_min': beta_start * (num_diffusion_timesteps - 1),
         | 
| 84 | 
            +
                'beta_max': beta_end * (num_diffusion_timesteps - 1),
         | 
| 85 | 
            +
                'num_diffusion_timesteps': num_diffusion_timesteps
         | 
| 86 | 
            +
              }
         | 
| 87 | 
            +
             | 
| 88 | 
            +
             | 
| 89 | 
            +
            def create_model(config):
         | 
| 90 | 
            +
              """Create the score model."""
         | 
| 91 | 
            +
              model_name = config.model.name
         | 
| 92 | 
            +
              score_model = get_model(model_name)(config)
         | 
| 93 | 
            +
              score_model = score_model.to(config.device)
         | 
| 94 | 
            +
              score_model = torch.nn.DataParallel(score_model)
         | 
| 95 | 
            +
              return score_model
         | 
| 96 | 
            +
             | 
| 97 | 
            +
             | 
| 98 | 
            +
            def get_model_fn(model, train=False):
         | 
| 99 | 
            +
              """Create a function to give the output of the score-based model.
         | 
| 100 | 
            +
             | 
| 101 | 
            +
              Args:
         | 
| 102 | 
            +
                model: The score model.
         | 
| 103 | 
            +
                train: `True` for training and `False` for evaluation.
         | 
| 104 | 
            +
             | 
| 105 | 
            +
              Returns:
         | 
| 106 | 
            +
                A model function.
         | 
| 107 | 
            +
              """
         | 
| 108 | 
            +
             | 
| 109 | 
            +
              def model_fn(x, labels):
         | 
| 110 | 
            +
                """Compute the output of the score-based model.
         | 
| 111 | 
            +
             | 
| 112 | 
            +
                Args:
         | 
| 113 | 
            +
                  x: A mini-batch of input data.
         | 
| 114 | 
            +
                  labels: A mini-batch of conditioning variables for time steps. Should be interpreted differently
         | 
| 115 | 
            +
                    for different models.
         | 
| 116 | 
            +
             | 
| 117 | 
            +
                Returns:
         | 
| 118 | 
            +
                  A tuple of (model output, new mutable states)
         | 
| 119 | 
            +
                """
         | 
| 120 | 
            +
                if not train:
         | 
| 121 | 
            +
                  model.eval()
         | 
| 122 | 
            +
                  return model(x, labels)
         | 
| 123 | 
            +
                else:
         | 
| 124 | 
            +
                  model.train()
         | 
| 125 | 
            +
                  return model(x, labels)
         | 
| 126 | 
            +
             | 
| 127 | 
            +
              return model_fn
         | 
| 128 | 
            +
             | 
| 129 | 
            +
             | 
| 130 | 
            +
            def get_score_fn(sde, model, train=False, continuous=False):
         | 
| 131 | 
            +
              """Wraps `score_fn` so that the model output corresponds to a real time-dependent score function.
         | 
| 132 | 
            +
             | 
| 133 | 
            +
              Args:
         | 
| 134 | 
            +
                sde: An `sde_lib.SDE` object that represents the forward SDE.
         | 
| 135 | 
            +
                model: A score model.
         | 
| 136 | 
            +
                train: `True` for training and `False` for evaluation.
         | 
| 137 | 
            +
                continuous: If `True`, the score-based model is expected to directly take continuous time steps.
         | 
| 138 | 
            +
             | 
| 139 | 
            +
              Returns:
         | 
| 140 | 
            +
                A score function.
         | 
| 141 | 
            +
              """
         | 
| 142 | 
            +
              model_fn = get_model_fn(model, train=train)
         | 
| 143 | 
            +
             | 
| 144 | 
            +
              if isinstance(sde, OUVPSDE):
         | 
| 145 | 
            +
                def score_fn(x, t):
         | 
| 146 | 
            +
                  # Scale neural network output by standard deviation and flip sign
         | 
| 147 | 
            +
                  if continuous:
         | 
| 148 | 
            +
                    # For VP-trained models, t=0 corresponds to the lowest noise level
         | 
| 149 | 
            +
                    # The maximum value of time embedding is assumed to 999 for
         | 
| 150 | 
            +
                    # continuously-trained models.
         | 
| 151 | 
            +
                    labels = t * 999
         | 
| 152 | 
            +
                    score = model_fn(x, labels)
         | 
| 153 | 
            +
                    std = sde.marginal_prob(torch.zeros_like(x), t)[1]
         | 
| 154 | 
            +
                  else:
         | 
| 155 | 
            +
                    # For VP-trained models, t=0 corresponds to the lowest noise level
         | 
| 156 | 
            +
                    labels = t * (sde.N - 1)
         | 
| 157 | 
            +
                    score = model_fn(x, labels)
         | 
| 158 | 
            +
                    std = sde.sqrt_1m_alphas_cumprod.to(labels.device)[labels.long()]
         | 
| 159 | 
            +
             | 
| 160 | 
            +
                  score = -score / std[:, None, None, None]
         | 
| 161 | 
            +
                  return score
         | 
| 162 | 
            +
             | 
| 163 | 
            +
              elif isinstance(sde, OUVESDE):
         | 
| 164 | 
            +
                def score_fn(x, t):
         | 
| 165 | 
            +
                  if continuous:
         | 
| 166 | 
            +
                    labels = sde.marginal_prob(torch.zeros_like(x), t)[1]
         | 
| 167 | 
            +
                  else:
         | 
| 168 | 
            +
                    # For VE-trained models, t=0 corresponds to the highest noise level
         | 
| 169 | 
            +
                    labels = sde.T - t
         | 
| 170 | 
            +
                    labels *= sde.N - 1
         | 
| 171 | 
            +
                    labels = torch.round(labels).long()
         | 
| 172 | 
            +
             | 
| 173 | 
            +
                  score = model_fn(x, labels)
         | 
| 174 | 
            +
                  return score
         | 
| 175 | 
            +
             | 
| 176 | 
            +
              else:
         | 
| 177 | 
            +
                raise NotImplementedError(f"SDE class {sde.__class__.__name__} not yet supported.")
         | 
| 178 | 
            +
             | 
| 179 | 
            +
              return score_fn
         | 
| 180 | 
            +
             | 
| 181 | 
            +
             | 
| 182 | 
            +
            def to_flattened_numpy(x):
         | 
| 183 | 
            +
              """Flatten a torch tensor `x` and convert it to numpy."""
         | 
| 184 | 
            +
              return x.detach().cpu().numpy().reshape((-1,))
         | 
| 185 | 
            +
             | 
| 186 | 
            +
             | 
| 187 | 
            +
            def from_flattened_numpy(x, shape):
         | 
| 188 | 
            +
              """Form a torch tensor with the given `shape` from a flattened numpy array `x`."""
         | 
| 189 | 
            +
              return torch.from_numpy(x.reshape(shape))
         | 
    	
        sgmse/backbones/shared.py
    ADDED
    
    | @@ -0,0 +1,123 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import functools
         | 
| 2 | 
            +
            import numpy as np
         | 
| 3 | 
            +
             | 
| 4 | 
            +
            import torch
         | 
| 5 | 
            +
            import torch.nn as nn
         | 
| 6 | 
            +
             | 
| 7 | 
            +
            from sgmse.util.registry import Registry
         | 
| 8 | 
            +
             | 
| 9 | 
            +
             | 
| 10 | 
            +
            BackboneRegistry = Registry("Backbone")
         | 
| 11 | 
            +
             | 
| 12 | 
            +
             | 
| 13 | 
            +
            class GaussianFourierProjection(nn.Module):
         | 
| 14 | 
            +
                """Gaussian random features for encoding time steps."""
         | 
| 15 | 
            +
             | 
| 16 | 
            +
                def __init__(self, embed_dim, scale=16, complex_valued=False):
         | 
| 17 | 
            +
                    super().__init__()
         | 
| 18 | 
            +
                    self.complex_valued = complex_valued
         | 
| 19 | 
            +
                    if not complex_valued:
         | 
| 20 | 
            +
                        # If the output is real-valued, we concatenate sin+cos of the features to avoid ambiguities.
         | 
| 21 | 
            +
                        # Therefore, in this case the effective embed_dim is cut in half. For the complex-valued case,
         | 
| 22 | 
            +
                        # we use complex numbers which each represent sin+cos directly, so the ambiguity is avoided directly,
         | 
| 23 | 
            +
                        # and this halving is not necessary.
         | 
| 24 | 
            +
                        embed_dim = embed_dim // 2
         | 
| 25 | 
            +
                    # Randomly sample weights during initialization. These weights are fixed
         | 
| 26 | 
            +
                    # during optimization and are not trainable.
         | 
| 27 | 
            +
                    self.W = nn.Parameter(torch.randn(embed_dim) * scale, requires_grad=False)
         | 
| 28 | 
            +
             | 
| 29 | 
            +
                def forward(self, t):
         | 
| 30 | 
            +
                    t_proj = t[:, None] * self.W[None, :] * 2*np.pi
         | 
| 31 | 
            +
                    if self.complex_valued:
         | 
| 32 | 
            +
                        return torch.exp(1j * t_proj)
         | 
| 33 | 
            +
                    else:
         | 
| 34 | 
            +
                        return torch.cat([torch.sin(t_proj), torch.cos(t_proj)], dim=-1)
         | 
| 35 | 
            +
             | 
| 36 | 
            +
             | 
| 37 | 
            +
            class DiffusionStepEmbedding(nn.Module):
         | 
| 38 | 
            +
                """Diffusion-Step embedding as in DiffWave / Vaswani et al. 2017."""
         | 
| 39 | 
            +
             | 
| 40 | 
            +
                def __init__(self, embed_dim, complex_valued=False):
         | 
| 41 | 
            +
                    super().__init__()
         | 
| 42 | 
            +
                    self.complex_valued = complex_valued
         | 
| 43 | 
            +
                    if not complex_valued:
         | 
| 44 | 
            +
                        # If the output is real-valued, we concatenate sin+cos of the features to avoid ambiguities.
         | 
| 45 | 
            +
                        # Therefore, in this case the effective embed_dim is cut in half. For the complex-valued case,
         | 
| 46 | 
            +
                        # we use complex numbers which each represent sin+cos directly, so the ambiguity is avoided directly,
         | 
| 47 | 
            +
                        # and this halving is not necessary.
         | 
| 48 | 
            +
                        embed_dim = embed_dim // 2
         | 
| 49 | 
            +
                    self.embed_dim = embed_dim
         | 
| 50 | 
            +
             | 
| 51 | 
            +
                def forward(self, t):
         | 
| 52 | 
            +
                    fac = 10**(4*torch.arange(self.embed_dim, device=t.device) / (self.embed_dim-1))
         | 
| 53 | 
            +
                    inner = t[:, None] * fac[None, :]
         | 
| 54 | 
            +
                    if self.complex_valued:
         | 
| 55 | 
            +
                        return torch.exp(1j * inner)
         | 
| 56 | 
            +
                    else:
         | 
| 57 | 
            +
                        return torch.cat([torch.sin(inner), torch.cos(inner)], dim=-1)
         | 
| 58 | 
            +
             | 
| 59 | 
            +
             | 
| 60 | 
            +
            class ComplexLinear(nn.Module):
         | 
| 61 | 
            +
                """A potentially complex-valued linear layer. Reduces to a regular linear layer if `complex_valued=False`."""
         | 
| 62 | 
            +
                def __init__(self, input_dim, output_dim, complex_valued):
         | 
| 63 | 
            +
                    super().__init__()
         | 
| 64 | 
            +
                    self.complex_valued = complex_valued
         | 
| 65 | 
            +
                    if self.complex_valued:
         | 
| 66 | 
            +
                        self.re = nn.Linear(input_dim, output_dim)
         | 
| 67 | 
            +
                        self.im = nn.Linear(input_dim, output_dim)
         | 
| 68 | 
            +
                    else:
         | 
| 69 | 
            +
                        self.lin = nn.Linear(input_dim, output_dim)
         | 
| 70 | 
            +
             | 
| 71 | 
            +
                def forward(self, x):
         | 
| 72 | 
            +
                    if self.complex_valued:
         | 
| 73 | 
            +
                        return (self.re(x.real) - self.im(x.imag)) + 1j*(self.re(x.imag) + self.im(x.real))
         | 
| 74 | 
            +
                    else:
         | 
| 75 | 
            +
                        return self.lin(x)
         | 
| 76 | 
            +
             | 
| 77 | 
            +
             | 
| 78 | 
            +
            class FeatureMapDense(nn.Module):
         | 
| 79 | 
            +
                """A fully connected layer that reshapes outputs to feature maps."""
         | 
| 80 | 
            +
             | 
| 81 | 
            +
                def __init__(self, input_dim, output_dim, complex_valued=False):
         | 
| 82 | 
            +
                    super().__init__()
         | 
| 83 | 
            +
                    self.complex_valued = complex_valued
         | 
| 84 | 
            +
                    self.dense = ComplexLinear(input_dim, output_dim, complex_valued=complex_valued)
         | 
| 85 | 
            +
             | 
| 86 | 
            +
                def forward(self, x):
         | 
| 87 | 
            +
                    return self.dense(x)[..., None, None]
         | 
| 88 | 
            +
             | 
| 89 | 
            +
             | 
| 90 | 
            +
            def torch_complex_from_reim(re, im):
         | 
| 91 | 
            +
                return torch.view_as_complex(torch.stack([re, im], dim=-1))
         | 
| 92 | 
            +
             | 
| 93 | 
            +
             | 
| 94 | 
            +
            class ArgsComplexMultiplicationWrapper(nn.Module):
         | 
| 95 | 
            +
                """Adapted from `asteroid`'s `complex_nn.py`, allowing args/kwargs to be passed through forward().
         | 
| 96 | 
            +
             | 
| 97 | 
            +
                Make a complex-valued module `F` from a real-valued module `f` by applying
         | 
| 98 | 
            +
                complex multiplication rules:
         | 
| 99 | 
            +
             | 
| 100 | 
            +
                F(a + i b) = f1(a) - f1(b) + i (f2(b) + f2(a))
         | 
| 101 | 
            +
             | 
| 102 | 
            +
                where `f1`, `f2` are instances of `f` that do *not* share weights.
         | 
| 103 | 
            +
             | 
| 104 | 
            +
                Args:
         | 
| 105 | 
            +
                    module_cls (callable): A class or function that returns a Torch module/functional.
         | 
| 106 | 
            +
                        Constructor of `f` in the formula above.  Called 2x with `*args`, `**kwargs`,
         | 
| 107 | 
            +
                        to construct the real and imaginary component modules.
         | 
| 108 | 
            +
                """
         | 
| 109 | 
            +
             | 
| 110 | 
            +
                def __init__(self, module_cls, *args, **kwargs):
         | 
| 111 | 
            +
                    super().__init__()
         | 
| 112 | 
            +
                    self.re_module = module_cls(*args, **kwargs)
         | 
| 113 | 
            +
                    self.im_module = module_cls(*args, **kwargs)
         | 
| 114 | 
            +
             | 
| 115 | 
            +
                def forward(self, x, *args, **kwargs):
         | 
| 116 | 
            +
                    return torch_complex_from_reim(
         | 
| 117 | 
            +
                        self.re_module(x.real, *args, **kwargs) - self.im_module(x.imag, *args, **kwargs),
         | 
| 118 | 
            +
                        self.re_module(x.imag, *args, **kwargs) + self.im_module(x.real, *args, **kwargs),
         | 
| 119 | 
            +
                    )
         | 
| 120 | 
            +
             | 
| 121 | 
            +
             | 
| 122 | 
            +
            ComplexConv2d = functools.partial(ArgsComplexMultiplicationWrapper, nn.Conv2d)
         | 
| 123 | 
            +
            ComplexConvTranspose2d = functools.partial(ArgsComplexMultiplicationWrapper, nn.ConvTranspose2d)
         | 
    	
        sgmse/data_module.py
    ADDED
    
    | @@ -0,0 +1,236 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
             | 
| 2 | 
            +
            from os.path import join
         | 
| 3 | 
            +
            import torch
         | 
| 4 | 
            +
            import pytorch_lightning as pl
         | 
| 5 | 
            +
            from torch.utils.data import Dataset
         | 
| 6 | 
            +
            from torch.utils.data import DataLoader
         | 
| 7 | 
            +
            from glob import glob
         | 
| 8 | 
            +
            from torchaudio import load
         | 
| 9 | 
            +
            import numpy as np
         | 
| 10 | 
            +
            import torch.nn.functional as F
         | 
| 11 | 
            +
             | 
| 12 | 
            +
             | 
| 13 | 
            +
            def get_window(window_type, window_length):
         | 
| 14 | 
            +
                if window_type == 'sqrthann':
         | 
| 15 | 
            +
                    return torch.sqrt(torch.hann_window(window_length, periodic=True))
         | 
| 16 | 
            +
                elif window_type == 'hann':
         | 
| 17 | 
            +
                    return torch.hann_window(window_length, periodic=True)
         | 
| 18 | 
            +
                else:
         | 
| 19 | 
            +
                    raise NotImplementedError(f"Window type {window_type} not implemented!")
         | 
| 20 | 
            +
             | 
| 21 | 
            +
             | 
| 22 | 
            +
            class Specs(Dataset):
         | 
| 23 | 
            +
                def __init__(self, data_dir, subset, dummy, shuffle_spec, num_frames,
         | 
| 24 | 
            +
                        format='default', normalize="noisy", spec_transform=None,
         | 
| 25 | 
            +
                        stft_kwargs=None, **ignored_kwargs):
         | 
| 26 | 
            +
             | 
| 27 | 
            +
                    # Read file paths according to file naming format.
         | 
| 28 | 
            +
                    if format == "default":
         | 
| 29 | 
            +
                        self.clean_files = []
         | 
| 30 | 
            +
                        self.clean_files += sorted(glob(join(data_dir, subset, "clean", "*.wav")))
         | 
| 31 | 
            +
                        self.clean_files += sorted(glob(join(data_dir, subset, "clean", "**", "*.wav")))
         | 
| 32 | 
            +
                        self.noisy_files = []
         | 
| 33 | 
            +
                        self.noisy_files += sorted(glob(join(data_dir, subset, "noisy", "*.wav")))
         | 
| 34 | 
            +
                        self.noisy_files += sorted(glob(join(data_dir, subset, "noisy", "**", "*.wav")))
         | 
| 35 | 
            +
                    elif format == "reverb":
         | 
| 36 | 
            +
                        self.clean_files = []
         | 
| 37 | 
            +
                        self.clean_files += sorted(glob(join(data_dir, subset, "anechoic", "*.wav")))
         | 
| 38 | 
            +
                        self.clean_files += sorted(glob(join(data_dir, subset, "anechoic", "**", "*.wav")))
         | 
| 39 | 
            +
                        self.noisy_files = []
         | 
| 40 | 
            +
                        self.noisy_files += sorted(glob(join(data_dir, subset, "reverb", "*.wav")))
         | 
| 41 | 
            +
                        self.noisy_files += sorted(glob(join(data_dir, subset, "reverb", "**", "*.wav")))
         | 
| 42 | 
            +
                    else:
         | 
| 43 | 
            +
                        # Feel free to add your own directory format
         | 
| 44 | 
            +
                        raise NotImplementedError(f"Directory format {format} unknown!")
         | 
| 45 | 
            +
             | 
| 46 | 
            +
                    self.dummy = dummy
         | 
| 47 | 
            +
                    self.num_frames = num_frames
         | 
| 48 | 
            +
                    self.shuffle_spec = shuffle_spec
         | 
| 49 | 
            +
                    self.normalize = normalize
         | 
| 50 | 
            +
                    self.spec_transform = spec_transform
         | 
| 51 | 
            +
             | 
| 52 | 
            +
                    assert all(k in stft_kwargs.keys() for k in ["n_fft", "hop_length", "center", "window"]), "misconfigured STFT kwargs"
         | 
| 53 | 
            +
                    self.stft_kwargs = stft_kwargs
         | 
| 54 | 
            +
                    self.hop_length = self.stft_kwargs["hop_length"]
         | 
| 55 | 
            +
                    assert self.stft_kwargs.get("center", None) == True, "'center' must be True for current implementation"
         | 
| 56 | 
            +
             | 
| 57 | 
            +
                def __getitem__(self, i):
         | 
| 58 | 
            +
                    x, _ = load(self.clean_files[i])
         | 
| 59 | 
            +
                    y, _ = load(self.noisy_files[i])
         | 
| 60 | 
            +
             | 
| 61 | 
            +
                    # formula applies for center=True
         | 
| 62 | 
            +
                    target_len = (self.num_frames - 1) * self.hop_length
         | 
| 63 | 
            +
                    current_len = x.size(-1)
         | 
| 64 | 
            +
                    pad = max(target_len - current_len, 0)
         | 
| 65 | 
            +
                    if pad == 0:
         | 
| 66 | 
            +
                        # extract random part of the audio file
         | 
| 67 | 
            +
                        if self.shuffle_spec:
         | 
| 68 | 
            +
                            start = int(np.random.uniform(0, current_len-target_len))
         | 
| 69 | 
            +
                        else:
         | 
| 70 | 
            +
                            start = int((current_len-target_len)/2)
         | 
| 71 | 
            +
                        x = x[..., start:start+target_len]
         | 
| 72 | 
            +
                        y = y[..., start:start+target_len]
         | 
| 73 | 
            +
                    else:
         | 
| 74 | 
            +
                        # pad audio if the length T is smaller than num_frames
         | 
| 75 | 
            +
                        x = F.pad(x, (pad//2, pad//2+(pad%2)), mode='constant')
         | 
| 76 | 
            +
                        y = F.pad(y, (pad//2, pad//2+(pad%2)), mode='constant')
         | 
| 77 | 
            +
             | 
| 78 | 
            +
                    # normalize w.r.t to the noisy or the clean signal or not at all
         | 
| 79 | 
            +
                    # to ensure same clean signal power in x and y.
         | 
| 80 | 
            +
                    if self.normalize == "noisy":
         | 
| 81 | 
            +
                        normfac = y.abs().max()
         | 
| 82 | 
            +
                    elif self.normalize == "clean":
         | 
| 83 | 
            +
                        normfac = x.abs().max()
         | 
| 84 | 
            +
                    elif self.normalize == "not":
         | 
| 85 | 
            +
                        normfac = 1.0
         | 
| 86 | 
            +
                    x = x / normfac
         | 
| 87 | 
            +
                    y = y / normfac
         | 
| 88 | 
            +
             | 
| 89 | 
            +
                    X = torch.stft(x, **self.stft_kwargs)
         | 
| 90 | 
            +
                    Y = torch.stft(y, **self.stft_kwargs)
         | 
| 91 | 
            +
             | 
| 92 | 
            +
                    X, Y = self.spec_transform(X), self.spec_transform(Y)
         | 
| 93 | 
            +
                    return X, Y
         | 
| 94 | 
            +
             | 
| 95 | 
            +
                def __len__(self):
         | 
| 96 | 
            +
                    if self.dummy:
         | 
| 97 | 
            +
                        # for debugging shrink the data set size
         | 
| 98 | 
            +
                        return int(len(self.clean_files)/200)
         | 
| 99 | 
            +
                    else:
         | 
| 100 | 
            +
                        return len(self.clean_files)
         | 
| 101 | 
            +
             | 
| 102 | 
            +
             | 
| 103 | 
            +
            class SpecsDataModule(pl.LightningDataModule):
         | 
| 104 | 
            +
                @staticmethod
         | 
| 105 | 
            +
                def add_argparse_args(parser):
         | 
| 106 | 
            +
                    parser.add_argument("--base_dir", type=str, required=True, help="The base directory of the dataset. Should contain `train`, `valid` and `test` subdirectories, each of which contain `clean` and `noisy` subdirectories.")
         | 
| 107 | 
            +
                    parser.add_argument("--format", type=str, choices=("default", "reverb"), default="default", help="Read file paths according to file naming format.")
         | 
| 108 | 
            +
                    parser.add_argument("--batch_size", type=int, default=8, help="The batch size. 8 by default.")
         | 
| 109 | 
            +
                    parser.add_argument("--n_fft", type=int, default=510, help="Number of FFT bins. 510 by default.")   # to assure 256 freq bins
         | 
| 110 | 
            +
                    parser.add_argument("--hop_length", type=int, default=128, help="Window hop length. 128 by default.")
         | 
| 111 | 
            +
                    parser.add_argument("--num_frames", type=int, default=256, help="Number of frames for the dataset. 256 by default.")
         | 
| 112 | 
            +
                    parser.add_argument("--window", type=str, choices=("sqrthann", "hann"), default="hann", help="The window function to use for the STFT. 'hann' by default.")
         | 
| 113 | 
            +
                    parser.add_argument("--num_workers", type=int, default=4, help="Number of workers to use for DataLoaders. 4 by default.")
         | 
| 114 | 
            +
                    parser.add_argument("--dummy", action="store_true", help="Use reduced dummy dataset for prototyping.")
         | 
| 115 | 
            +
                    parser.add_argument("--spec_factor", type=float, default=0.15, help="Factor to multiply complex STFT coefficients by. 0.15 by default.")
         | 
| 116 | 
            +
                    parser.add_argument("--spec_abs_exponent", type=float, default=0.5, help="Exponent e for the transformation abs(z)**e * exp(1j*angle(z)). 0.5 by default.")
         | 
| 117 | 
            +
                    parser.add_argument("--normalize", type=str, choices=("clean", "noisy", "not"), default="noisy", help="Normalize the input waveforms by the clean signal, the noisy signal, or not at all.")
         | 
| 118 | 
            +
                    parser.add_argument("--transform_type", type=str, choices=("exponent", "log", "none"), default="exponent", help="Spectogram transformation for input representation.")
         | 
| 119 | 
            +
                    return parser
         | 
| 120 | 
            +
             | 
| 121 | 
            +
                def __init__(
         | 
| 122 | 
            +
                    self, base_dir, format='default', batch_size=8,
         | 
| 123 | 
            +
                    n_fft=510, hop_length=128, num_frames=256, window='hann',
         | 
| 124 | 
            +
                    num_workers=4, dummy=False, spec_factor=0.15, spec_abs_exponent=0.5,
         | 
| 125 | 
            +
                    gpu=True, normalize='noisy', transform_type="exponent", **kwargs
         | 
| 126 | 
            +
                ):
         | 
| 127 | 
            +
                    super().__init__()
         | 
| 128 | 
            +
                    self.base_dir = base_dir
         | 
| 129 | 
            +
                    self.format = format
         | 
| 130 | 
            +
                    self.batch_size = batch_size
         | 
| 131 | 
            +
                    self.n_fft = n_fft
         | 
| 132 | 
            +
                    self.hop_length = hop_length
         | 
| 133 | 
            +
                    self.num_frames = num_frames
         | 
| 134 | 
            +
                    self.window = get_window(window, self.n_fft)
         | 
| 135 | 
            +
                    self.windows = {}
         | 
| 136 | 
            +
                    self.num_workers = num_workers
         | 
| 137 | 
            +
                    self.dummy = dummy
         | 
| 138 | 
            +
                    self.spec_factor = spec_factor
         | 
| 139 | 
            +
                    self.spec_abs_exponent = spec_abs_exponent
         | 
| 140 | 
            +
                    self.gpu = gpu
         | 
| 141 | 
            +
                    self.normalize = normalize
         | 
| 142 | 
            +
                    self.transform_type = transform_type
         | 
| 143 | 
            +
                    self.kwargs = kwargs
         | 
| 144 | 
            +
             | 
| 145 | 
            +
                def setup(self, stage=None):
         | 
| 146 | 
            +
                    specs_kwargs = dict(
         | 
| 147 | 
            +
                        stft_kwargs=self.stft_kwargs, num_frames=self.num_frames,
         | 
| 148 | 
            +
                        spec_transform=self.spec_fwd, **self.kwargs
         | 
| 149 | 
            +
                    )
         | 
| 150 | 
            +
                    if stage == 'fit' or stage is None:
         | 
| 151 | 
            +
                        self.train_set = Specs(data_dir=self.base_dir, subset='train',
         | 
| 152 | 
            +
                            dummy=self.dummy, shuffle_spec=True, format=self.format,
         | 
| 153 | 
            +
                            normalize=self.normalize, **specs_kwargs)
         | 
| 154 | 
            +
                        self.valid_set = Specs(data_dir=self.base_dir, subset='valid',
         | 
| 155 | 
            +
                            dummy=self.dummy, shuffle_spec=False, format=self.format,
         | 
| 156 | 
            +
                            normalize=self.normalize, **specs_kwargs)
         | 
| 157 | 
            +
                    if stage == 'test' or stage is None:
         | 
| 158 | 
            +
                        self.test_set = Specs(data_dir=self.base_dir, subset='test',
         | 
| 159 | 
            +
                            dummy=self.dummy, shuffle_spec=False, format=self.format,
         | 
| 160 | 
            +
                            normalize=self.normalize, **specs_kwargs)
         | 
| 161 | 
            +
             | 
| 162 | 
            +
                def spec_fwd(self, spec):
         | 
| 163 | 
            +
                    if self.transform_type == "exponent":
         | 
| 164 | 
            +
                        if self.spec_abs_exponent != 1:
         | 
| 165 | 
            +
                            # only do this calculation if spec_exponent != 1, otherwise it's quite a bit of wasted computation
         | 
| 166 | 
            +
                            # and introduced numerical error
         | 
| 167 | 
            +
                            e = self.spec_abs_exponent
         | 
| 168 | 
            +
                            spec = spec.abs()**e * torch.exp(1j * spec.angle())
         | 
| 169 | 
            +
                        spec = spec * self.spec_factor
         | 
| 170 | 
            +
                    elif self.transform_type == "log":
         | 
| 171 | 
            +
                        spec = torch.log(1 + spec.abs()) * torch.exp(1j * spec.angle())
         | 
| 172 | 
            +
                        spec = spec * self.spec_factor
         | 
| 173 | 
            +
                    elif self.transform_type == "none":
         | 
| 174 | 
            +
                        spec = spec
         | 
| 175 | 
            +
                    return spec
         | 
| 176 | 
            +
             | 
| 177 | 
            +
                def spec_back(self, spec):
         | 
| 178 | 
            +
                    if self.transform_type == "exponent":
         | 
| 179 | 
            +
                        spec = spec / self.spec_factor
         | 
| 180 | 
            +
                        if self.spec_abs_exponent != 1:
         | 
| 181 | 
            +
                            e = self.spec_abs_exponent
         | 
| 182 | 
            +
                            spec = spec.abs()**(1/e) * torch.exp(1j * spec.angle())
         | 
| 183 | 
            +
                    elif self.transform_type == "log":
         | 
| 184 | 
            +
                        spec = spec / self.spec_factor
         | 
| 185 | 
            +
                        spec = (torch.exp(spec.abs()) - 1) * torch.exp(1j * spec.angle())
         | 
| 186 | 
            +
                    elif self.transform_type == "none":
         | 
| 187 | 
            +
                        spec = spec
         | 
| 188 | 
            +
                    return spec
         | 
| 189 | 
            +
             | 
| 190 | 
            +
                @property
         | 
| 191 | 
            +
                def stft_kwargs(self):
         | 
| 192 | 
            +
                    return {**self.istft_kwargs, "return_complex": True}
         | 
| 193 | 
            +
             | 
| 194 | 
            +
                @property
         | 
| 195 | 
            +
                def istft_kwargs(self):
         | 
| 196 | 
            +
                    return dict(
         | 
| 197 | 
            +
                        n_fft=self.n_fft, hop_length=self.hop_length,
         | 
| 198 | 
            +
                        window=self.window, center=True
         | 
| 199 | 
            +
                    )
         | 
| 200 | 
            +
             | 
| 201 | 
            +
                def _get_window(self, x):
         | 
| 202 | 
            +
                    """
         | 
| 203 | 
            +
                    Retrieve an appropriate window for the given tensor x, matching the device.
         | 
| 204 | 
            +
                    Caches the retrieved windows so that only one window tensor will be allocated per device.
         | 
| 205 | 
            +
                    """
         | 
| 206 | 
            +
                    window = self.windows.get(x.device, None)
         | 
| 207 | 
            +
                    if window is None:
         | 
| 208 | 
            +
                        window = self.window.to(x.device)
         | 
| 209 | 
            +
                        self.windows[x.device] = window
         | 
| 210 | 
            +
                    return window
         | 
| 211 | 
            +
             | 
| 212 | 
            +
                def stft(self, sig):
         | 
| 213 | 
            +
                    window = self._get_window(sig)
         | 
| 214 | 
            +
                    return torch.stft(sig, **{**self.stft_kwargs, "window": window})
         | 
| 215 | 
            +
             | 
| 216 | 
            +
                def istft(self, spec, length=None):
         | 
| 217 | 
            +
                    window = self._get_window(spec)
         | 
| 218 | 
            +
                    return torch.istft(spec, **{**self.istft_kwargs, "window": window, "length": length})
         | 
| 219 | 
            +
             | 
| 220 | 
            +
                def train_dataloader(self):
         | 
| 221 | 
            +
                    return DataLoader(
         | 
| 222 | 
            +
                        self.train_set, batch_size=self.batch_size,
         | 
| 223 | 
            +
                        num_workers=self.num_workers, pin_memory=self.gpu, shuffle=True
         | 
| 224 | 
            +
                    )
         | 
| 225 | 
            +
             | 
| 226 | 
            +
                def val_dataloader(self):
         | 
| 227 | 
            +
                    return DataLoader(
         | 
| 228 | 
            +
                        self.valid_set, batch_size=self.batch_size,
         | 
| 229 | 
            +
                        num_workers=self.num_workers, pin_memory=self.gpu, shuffle=False
         | 
| 230 | 
            +
                    )
         | 
| 231 | 
            +
             | 
| 232 | 
            +
                def test_dataloader(self):
         | 
| 233 | 
            +
                    return DataLoader(
         | 
| 234 | 
            +
                        self.test_set, batch_size=self.batch_size,
         | 
| 235 | 
            +
                        num_workers=self.num_workers, pin_memory=self.gpu, shuffle=False
         | 
| 236 | 
            +
                    )
         | 
    	
        sgmse/model.py
    ADDED
    
    | @@ -0,0 +1,253 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import time
         | 
| 2 | 
            +
            from math import ceil
         | 
| 3 | 
            +
            import warnings
         | 
| 4 | 
            +
             | 
| 5 | 
            +
            import torch
         | 
| 6 | 
            +
            import pytorch_lightning as pl
         | 
| 7 | 
            +
            from torch_ema import ExponentialMovingAverage
         | 
| 8 | 
            +
             | 
| 9 | 
            +
            from sgmse import sampling
         | 
| 10 | 
            +
            from sgmse.sdes import SDERegistry
         | 
| 11 | 
            +
            from sgmse.backbones import BackboneRegistry
         | 
| 12 | 
            +
            from sgmse.util.inference import evaluate_model
         | 
| 13 | 
            +
            from sgmse.util.other import pad_spec
         | 
| 14 | 
            +
             | 
| 15 | 
            +
             | 
| 16 | 
            +
            class ScoreModel(pl.LightningModule):
         | 
| 17 | 
            +
                @staticmethod
         | 
| 18 | 
            +
                def add_argparse_args(parser):
         | 
| 19 | 
            +
                    parser.add_argument("--lr", type=float, default=1e-4, help="The learning rate (1e-4 by default)")
         | 
| 20 | 
            +
                    parser.add_argument("--ema_decay", type=float, default=0.999, help="The parameter EMA decay constant (0.999 by default)")
         | 
| 21 | 
            +
                    parser.add_argument("--t_eps", type=float, default=0.03, help="The minimum process time (0.03 by default)")
         | 
| 22 | 
            +
                    parser.add_argument("--num_eval_files", type=int, default=20, help="Number of files for speech enhancement performance evaluation during training. Pass 0 to turn off (no checkpoints based on evaluation metrics will be generated).")
         | 
| 23 | 
            +
                    parser.add_argument("--loss_type", type=str, default="mse", choices=("mse", "mae"), help="The type of loss function to use.")
         | 
| 24 | 
            +
                    return parser
         | 
| 25 | 
            +
             | 
| 26 | 
            +
                def __init__(
         | 
| 27 | 
            +
                    self, backbone, sde, lr=1e-4, ema_decay=0.999, t_eps=0.03,
         | 
| 28 | 
            +
                    num_eval_files=20, loss_type='mse', data_module_cls=None, **kwargs
         | 
| 29 | 
            +
                ):
         | 
| 30 | 
            +
                    """
         | 
| 31 | 
            +
                    Create a new ScoreModel.
         | 
| 32 | 
            +
             | 
| 33 | 
            +
                    Args:
         | 
| 34 | 
            +
                        backbone: Backbone DNN that serves as a score-based model.
         | 
| 35 | 
            +
                        sde: The SDE that defines the diffusion process.
         | 
| 36 | 
            +
                        lr: The learning rate of the optimizer. (1e-4 by default).
         | 
| 37 | 
            +
                        ema_decay: The decay constant of the parameter EMA (0.999 by default).
         | 
| 38 | 
            +
                        t_eps: The minimum time to practically run for to avoid issues very close to zero (1e-5 by default).
         | 
| 39 | 
            +
                        loss_type: The type of loss to use (wrt. noise z/std). Options are 'mse' (default), 'mae'
         | 
| 40 | 
            +
                    """
         | 
| 41 | 
            +
                    super().__init__()
         | 
| 42 | 
            +
                    # Initialize Backbone DNN
         | 
| 43 | 
            +
                    self.backbone = backbone
         | 
| 44 | 
            +
                    dnn_cls = BackboneRegistry.get_by_name(backbone)
         | 
| 45 | 
            +
                    self.dnn = dnn_cls(**kwargs)
         | 
| 46 | 
            +
                    # Initialize SDE
         | 
| 47 | 
            +
                    sde_cls = SDERegistry.get_by_name(sde)
         | 
| 48 | 
            +
                    self.sde = sde_cls(**kwargs)
         | 
| 49 | 
            +
                    # Store hyperparams and save them
         | 
| 50 | 
            +
                    self.lr = lr
         | 
| 51 | 
            +
                    self.ema_decay = ema_decay
         | 
| 52 | 
            +
                    self.ema = ExponentialMovingAverage(self.parameters(), decay=self.ema_decay)
         | 
| 53 | 
            +
                    self._error_loading_ema = False
         | 
| 54 | 
            +
                    self.t_eps = t_eps
         | 
| 55 | 
            +
                    self.loss_type = loss_type
         | 
| 56 | 
            +
                    self.num_eval_files = num_eval_files
         | 
| 57 | 
            +
             | 
| 58 | 
            +
                    self.save_hyperparameters(ignore=['no_wandb'])
         | 
| 59 | 
            +
                    self.data_module = data_module_cls(**kwargs, gpu=kwargs.get('gpus', 0) > 0)
         | 
| 60 | 
            +
             | 
| 61 | 
            +
                def configure_optimizers(self):
         | 
| 62 | 
            +
                    optimizer = torch.optim.Adam(self.parameters(), lr=self.lr)
         | 
| 63 | 
            +
                    return optimizer
         | 
| 64 | 
            +
             | 
| 65 | 
            +
                def optimizer_step(self, *args, **kwargs):
         | 
| 66 | 
            +
                    # Method overridden so that the EMA params are updated after each optimizer step
         | 
| 67 | 
            +
                    super().optimizer_step(*args, **kwargs)
         | 
| 68 | 
            +
                    self.ema.update(self.parameters())
         | 
| 69 | 
            +
             | 
| 70 | 
            +
                # on_load_checkpoint / on_save_checkpoint needed for EMA storing/loading
         | 
| 71 | 
            +
                def on_load_checkpoint(self, checkpoint):
         | 
| 72 | 
            +
                    ema = checkpoint.get('ema', None)
         | 
| 73 | 
            +
                    if ema is not None:
         | 
| 74 | 
            +
                        self.ema.load_state_dict(checkpoint['ema'])
         | 
| 75 | 
            +
                    else:
         | 
| 76 | 
            +
                        self._error_loading_ema = True
         | 
| 77 | 
            +
                        warnings.warn("EMA state_dict not found in checkpoint!")
         | 
| 78 | 
            +
             | 
| 79 | 
            +
                def on_save_checkpoint(self, checkpoint):
         | 
| 80 | 
            +
                    checkpoint['ema'] = self.ema.state_dict()
         | 
| 81 | 
            +
             | 
| 82 | 
            +
                def train(self, mode, no_ema=False):
         | 
| 83 | 
            +
                    res = super().train(mode)  # call the standard `train` method with the given mode
         | 
| 84 | 
            +
                    if not self._error_loading_ema:
         | 
| 85 | 
            +
                        if mode == False and not no_ema:
         | 
| 86 | 
            +
                            # eval
         | 
| 87 | 
            +
                            self.ema.store(self.parameters())        # store current params in EMA
         | 
| 88 | 
            +
                            self.ema.copy_to(self.parameters())      # copy EMA parameters over current params for evaluation
         | 
| 89 | 
            +
                        else:
         | 
| 90 | 
            +
                            # train
         | 
| 91 | 
            +
                            if self.ema.collected_params is not None:
         | 
| 92 | 
            +
                                self.ema.restore(self.parameters())  # restore the EMA weights (if stored)
         | 
| 93 | 
            +
                    return res
         | 
| 94 | 
            +
             | 
| 95 | 
            +
                def eval(self, no_ema=False):
         | 
| 96 | 
            +
                    return self.train(False, no_ema=no_ema)
         | 
| 97 | 
            +
             | 
| 98 | 
            +
                def _loss(self, err):
         | 
| 99 | 
            +
                    if self.loss_type == 'mse':
         | 
| 100 | 
            +
                        losses = torch.square(err.abs())
         | 
| 101 | 
            +
                    elif self.loss_type == 'mae':
         | 
| 102 | 
            +
                        losses = err.abs()
         | 
| 103 | 
            +
                    # taken from reduce_op function: sum over channels and position and mean over batch dim
         | 
| 104 | 
            +
                    # presumably only important for absolute loss number, not for gradients
         | 
| 105 | 
            +
                    loss = torch.mean(0.5*torch.sum(losses.reshape(losses.shape[0], -1), dim=-1))
         | 
| 106 | 
            +
                    return loss
         | 
| 107 | 
            +
             | 
| 108 | 
            +
                def _step(self, batch, batch_idx):
         | 
| 109 | 
            +
                    x, y = batch
         | 
| 110 | 
            +
                    t = torch.rand(x.shape[0], device=x.device) * (self.sde.T - self.t_eps) + self.t_eps
         | 
| 111 | 
            +
                    mean, std = self.sde.marginal_prob(x, t, y)
         | 
| 112 | 
            +
                    z = torch.randn_like(x)  # i.i.d. normal distributed with var=0.5
         | 
| 113 | 
            +
                    sigmas = std[:, None, None, None]
         | 
| 114 | 
            +
                    perturbed_data = mean + sigmas * z
         | 
| 115 | 
            +
                    score = self(perturbed_data, t, y)
         | 
| 116 | 
            +
                    err = score * sigmas + z
         | 
| 117 | 
            +
                    loss = self._loss(err)
         | 
| 118 | 
            +
                    return loss
         | 
| 119 | 
            +
             | 
| 120 | 
            +
                def training_step(self, batch, batch_idx):
         | 
| 121 | 
            +
                    loss = self._step(batch, batch_idx)
         | 
| 122 | 
            +
                    self.log('train_loss', loss, on_step=True, on_epoch=True)
         | 
| 123 | 
            +
                    return loss
         | 
| 124 | 
            +
             | 
| 125 | 
            +
                def validation_step(self, batch, batch_idx):
         | 
| 126 | 
            +
                    loss = self._step(batch, batch_idx)
         | 
| 127 | 
            +
                    self.log('valid_loss', loss, on_step=False, on_epoch=True)
         | 
| 128 | 
            +
             | 
| 129 | 
            +
                    # Evaluate speech enhancement performance
         | 
| 130 | 
            +
                    if batch_idx == 0 and self.num_eval_files != 0:
         | 
| 131 | 
            +
                        pesq, si_sdr, estoi = evaluate_model(self, self.num_eval_files)
         | 
| 132 | 
            +
                        self.log('pesq', pesq, on_step=False, on_epoch=True)
         | 
| 133 | 
            +
                        self.log('si_sdr', si_sdr, on_step=False, on_epoch=True)
         | 
| 134 | 
            +
                        self.log('estoi', estoi, on_step=False, on_epoch=True)
         | 
| 135 | 
            +
             | 
| 136 | 
            +
                    return loss
         | 
| 137 | 
            +
             | 
| 138 | 
            +
                def forward(self, x, t, y):
         | 
| 139 | 
            +
                    # Concatenate y as an extra channel
         | 
| 140 | 
            +
                    dnn_input = torch.cat([x, y], dim=1)
         | 
| 141 | 
            +
                    
         | 
| 142 | 
            +
                    # the minus is most likely unimportant here - taken from Song's repo
         | 
| 143 | 
            +
                    score = -self.dnn(dnn_input, t)
         | 
| 144 | 
            +
                    return score
         | 
| 145 | 
            +
             | 
| 146 | 
            +
                def to(self, *args, **kwargs):
         | 
| 147 | 
            +
                    """Override PyTorch .to() to also transfer the EMA of the model weights"""
         | 
| 148 | 
            +
                    self.ema.to(*args, **kwargs)
         | 
| 149 | 
            +
                    return super().to(*args, **kwargs)
         | 
| 150 | 
            +
             | 
| 151 | 
            +
                def get_pc_sampler(self, predictor_name, corrector_name, y, N=None, minibatch=None, **kwargs):
         | 
| 152 | 
            +
                    N = self.sde.N if N is None else N
         | 
| 153 | 
            +
                    sde = self.sde.copy()
         | 
| 154 | 
            +
                    sde.N = N
         | 
| 155 | 
            +
             | 
| 156 | 
            +
                    kwargs = {"eps": self.t_eps, **kwargs}
         | 
| 157 | 
            +
                    if minibatch is None:
         | 
| 158 | 
            +
                        return sampling.get_pc_sampler(predictor_name, corrector_name, sde=sde, score_fn=self, y=y, **kwargs)
         | 
| 159 | 
            +
                    else:
         | 
| 160 | 
            +
                        M = y.shape[0]
         | 
| 161 | 
            +
                        def batched_sampling_fn():
         | 
| 162 | 
            +
                            samples, ns = [], []
         | 
| 163 | 
            +
                            for i in range(int(ceil(M / minibatch))):
         | 
| 164 | 
            +
                                y_mini = y[i*minibatch:(i+1)*minibatch]
         | 
| 165 | 
            +
                                sampler = sampling.get_pc_sampler(predictor_name, corrector_name, sde=sde, score_fn=self, y=y_mini, **kwargs)
         | 
| 166 | 
            +
                                sample, n = sampler()
         | 
| 167 | 
            +
                                samples.append(sample)
         | 
| 168 | 
            +
                                ns.append(n)
         | 
| 169 | 
            +
                            samples = torch.cat(samples, dim=0)
         | 
| 170 | 
            +
                            return samples, ns
         | 
| 171 | 
            +
                        return batched_sampling_fn
         | 
| 172 | 
            +
             | 
| 173 | 
            +
                def get_ode_sampler(self, y, N=None, minibatch=None, **kwargs):
         | 
| 174 | 
            +
                    N = self.sde.N if N is None else N
         | 
| 175 | 
            +
                    sde = self.sde.copy()
         | 
| 176 | 
            +
                    sde.N = N
         | 
| 177 | 
            +
             | 
| 178 | 
            +
                    kwargs = {"eps": self.t_eps, **kwargs}
         | 
| 179 | 
            +
                    if minibatch is None:
         | 
| 180 | 
            +
                        return sampling.get_ode_sampler(sde, self, y=y, **kwargs)
         | 
| 181 | 
            +
                    else:
         | 
| 182 | 
            +
                        M = y.shape[0]
         | 
| 183 | 
            +
                        def batched_sampling_fn():
         | 
| 184 | 
            +
                            samples, ns = [], []
         | 
| 185 | 
            +
                            for i in range(int(ceil(M / minibatch))):
         | 
| 186 | 
            +
                                y_mini = y[i*minibatch:(i+1)*minibatch]
         | 
| 187 | 
            +
                                sampler = sampling.get_ode_sampler(sde, self, y=y_mini, **kwargs)
         | 
| 188 | 
            +
                                sample, n = sampler()
         | 
| 189 | 
            +
                                samples.append(sample)
         | 
| 190 | 
            +
                                ns.append(n)
         | 
| 191 | 
            +
                            samples = torch.cat(samples, dim=0)
         | 
| 192 | 
            +
                            return sample, ns
         | 
| 193 | 
            +
                        return batched_sampling_fn
         | 
| 194 | 
            +
             | 
| 195 | 
            +
                def train_dataloader(self):
         | 
| 196 | 
            +
                    return self.data_module.train_dataloader()
         | 
| 197 | 
            +
             | 
| 198 | 
            +
                def val_dataloader(self):
         | 
| 199 | 
            +
                    return self.data_module.val_dataloader()
         | 
| 200 | 
            +
             | 
| 201 | 
            +
                def test_dataloader(self):
         | 
| 202 | 
            +
                    return self.data_module.test_dataloader()
         | 
| 203 | 
            +
             | 
| 204 | 
            +
                def setup(self, stage=None):
         | 
| 205 | 
            +
                    return self.data_module.setup(stage=stage)
         | 
| 206 | 
            +
             | 
| 207 | 
            +
                def to_audio(self, spec, length=None):
         | 
| 208 | 
            +
                    return self._istft(self._backward_transform(spec), length)
         | 
| 209 | 
            +
             | 
| 210 | 
            +
                def _forward_transform(self, spec):
         | 
| 211 | 
            +
                    return self.data_module.spec_fwd(spec)
         | 
| 212 | 
            +
             | 
| 213 | 
            +
                def _backward_transform(self, spec):
         | 
| 214 | 
            +
                    return self.data_module.spec_back(spec)
         | 
| 215 | 
            +
             | 
| 216 | 
            +
                def _stft(self, sig):
         | 
| 217 | 
            +
                    return self.data_module.stft(sig)
         | 
| 218 | 
            +
             | 
| 219 | 
            +
                def _istft(self, spec, length=None):
         | 
| 220 | 
            +
                    return self.data_module.istft(spec, length)
         | 
| 221 | 
            +
             | 
| 222 | 
            +
                def enhance(self, y, sampler_type="pc", predictor="reverse_diffusion",
         | 
| 223 | 
            +
                    corrector="ald", N=30, corrector_steps=1, snr=0.5, timeit=False,
         | 
| 224 | 
            +
                    **kwargs
         | 
| 225 | 
            +
                ):
         | 
| 226 | 
            +
                    """
         | 
| 227 | 
            +
                    One-call speech enhancement of noisy speech `y`, for convenience.
         | 
| 228 | 
            +
                    """
         | 
| 229 | 
            +
                    sr=16000
         | 
| 230 | 
            +
                    start = time.time()
         | 
| 231 | 
            +
                    T_orig = y.size(1) 
         | 
| 232 | 
            +
                    norm_factor = y.abs().max().item()
         | 
| 233 | 
            +
                    y = y / norm_factor
         | 
| 234 | 
            +
                    Y = torch.unsqueeze(self._forward_transform(self._stft(y.cuda())), 0)
         | 
| 235 | 
            +
                    Y = pad_spec(Y)
         | 
| 236 | 
            +
                    if sampler_type == "pc":
         | 
| 237 | 
            +
                        sampler = self.get_pc_sampler(predictor, corrector, Y.cuda(), N=N, 
         | 
| 238 | 
            +
                            corrector_steps=corrector_steps, snr=snr, intermediate=False,
         | 
| 239 | 
            +
                            **kwargs)
         | 
| 240 | 
            +
                    elif sampler_type == "ode":
         | 
| 241 | 
            +
                        sampler = self.get_ode_sampler(Y.cuda(), N=N, **kwargs)
         | 
| 242 | 
            +
                    else:
         | 
| 243 | 
            +
                        print("{} is not a valid sampler type!".format(sampler_type))
         | 
| 244 | 
            +
                    sample, nfe = sampler()
         | 
| 245 | 
            +
                    x_hat = self.to_audio(sample.squeeze(), T_orig)
         | 
| 246 | 
            +
                    x_hat = x_hat * norm_factor
         | 
| 247 | 
            +
                    x_hat = x_hat.squeeze().cpu().numpy()
         | 
| 248 | 
            +
                    end = time.time()
         | 
| 249 | 
            +
                    if timeit:
         | 
| 250 | 
            +
                        rtf = (end-start)/(len(x_hat)/sr)
         | 
| 251 | 
            +
                        return x_hat, nfe, rtf
         | 
| 252 | 
            +
                    else:
         | 
| 253 | 
            +
                        return x_hat
         | 
    	
        sgmse/sampling/__init__.py
    ADDED
    
    | @@ -0,0 +1,143 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            # Adapted from https://github.com/yang-song/score_sde_pytorch/blob/1618ddea340f3e4a2ed7852a0694a809775cf8d0/sampling.py
         | 
| 2 | 
            +
            """Various sampling methods."""
         | 
| 3 | 
            +
            from scipy import integrate
         | 
| 4 | 
            +
            import torch
         | 
| 5 | 
            +
             | 
| 6 | 
            +
            from .predictors import Predictor, PredictorRegistry, ReverseDiffusionPredictor
         | 
| 7 | 
            +
            from .correctors import Corrector, CorrectorRegistry
         | 
| 8 | 
            +
             | 
| 9 | 
            +
             | 
| 10 | 
            +
            __all__ = [
         | 
| 11 | 
            +
                'PredictorRegistry', 'CorrectorRegistry', 'Predictor', 'Corrector',
         | 
| 12 | 
            +
                'get_sampler'
         | 
| 13 | 
            +
            ]
         | 
| 14 | 
            +
             | 
| 15 | 
            +
             | 
| 16 | 
            +
            def to_flattened_numpy(x):
         | 
| 17 | 
            +
                """Flatten a torch tensor `x` and convert it to numpy."""
         | 
| 18 | 
            +
                return x.detach().cpu().numpy().reshape((-1,))
         | 
| 19 | 
            +
             | 
| 20 | 
            +
             | 
| 21 | 
            +
            def from_flattened_numpy(x, shape):
         | 
| 22 | 
            +
                """Form a torch tensor with the given `shape` from a flattened numpy array `x`."""
         | 
| 23 | 
            +
                return torch.from_numpy(x.reshape(shape))
         | 
| 24 | 
            +
             | 
| 25 | 
            +
             | 
| 26 | 
            +
            def get_pc_sampler(
         | 
| 27 | 
            +
                predictor_name, corrector_name, sde, score_fn, y,
         | 
| 28 | 
            +
                denoise=True, eps=3e-2, snr=0.1, corrector_steps=1, probability_flow: bool = False,
         | 
| 29 | 
            +
                intermediate=False, **kwargs
         | 
| 30 | 
            +
            ):
         | 
| 31 | 
            +
                """Create a Predictor-Corrector (PC) sampler.
         | 
| 32 | 
            +
             | 
| 33 | 
            +
                Args:
         | 
| 34 | 
            +
                    predictor_name: The name of a registered `sampling.Predictor`.
         | 
| 35 | 
            +
                    corrector_name: The name of a registered `sampling.Corrector`.
         | 
| 36 | 
            +
                    sde: An `sdes.SDE` object representing the forward SDE.
         | 
| 37 | 
            +
                    score_fn: A function (typically learned model) that predicts the score.
         | 
| 38 | 
            +
                    y: A `torch.Tensor`, representing the (non-white-)noisy starting point(s) to condition the prior on.
         | 
| 39 | 
            +
                    denoise: If `True`, add one-step denoising to the final samples.
         | 
| 40 | 
            +
                    eps: A `float` number. The reverse-time SDE and ODE are integrated to `epsilon` to avoid numerical issues.
         | 
| 41 | 
            +
                    snr: The SNR to use for the corrector. 0.1 by default, and ignored for `NoneCorrector`.
         | 
| 42 | 
            +
                    N: The number of reverse sampling steps. If `None`, uses the SDE's `N` property by default.
         | 
| 43 | 
            +
             | 
| 44 | 
            +
                Returns:
         | 
| 45 | 
            +
                    A sampling function that returns samples and the number of function evaluations during sampling.
         | 
| 46 | 
            +
                """
         | 
| 47 | 
            +
                predictor_cls = PredictorRegistry.get_by_name(predictor_name)
         | 
| 48 | 
            +
                corrector_cls = CorrectorRegistry.get_by_name(corrector_name)
         | 
| 49 | 
            +
                predictor = predictor_cls(sde, score_fn, probability_flow=probability_flow)
         | 
| 50 | 
            +
                corrector = corrector_cls(sde, score_fn, snr=snr, n_steps=corrector_steps)
         | 
| 51 | 
            +
             | 
| 52 | 
            +
                def pc_sampler():
         | 
| 53 | 
            +
                    """The PC sampler function."""
         | 
| 54 | 
            +
                    with torch.no_grad():
         | 
| 55 | 
            +
                        xt = sde.prior_sampling(y.shape, y).to(y.device)
         | 
| 56 | 
            +
                        timesteps = torch.linspace(sde.T, eps, sde.N, device=y.device)
         | 
| 57 | 
            +
                        for i in range(sde.N):
         | 
| 58 | 
            +
                            t = timesteps[i]
         | 
| 59 | 
            +
                            if i != len(timesteps) - 1:
         | 
| 60 | 
            +
                                stepsize = t - timesteps[i+1]
         | 
| 61 | 
            +
                            else:
         | 
| 62 | 
            +
                                stepsize = timesteps[-1] # from eps to 0
         | 
| 63 | 
            +
                            vec_t = torch.ones(y.shape[0], device=y.device) * t
         | 
| 64 | 
            +
                            xt, xt_mean = corrector.update_fn(xt, vec_t, y)
         | 
| 65 | 
            +
                            xt, xt_mean = predictor.update_fn(xt, vec_t, y, stepsize)
         | 
| 66 | 
            +
                        x_result = xt_mean if denoise else xt
         | 
| 67 | 
            +
                        ns = sde.N * (corrector.n_steps + 1)
         | 
| 68 | 
            +
                        return x_result, ns
         | 
| 69 | 
            +
                
         | 
| 70 | 
            +
                return pc_sampler
         | 
| 71 | 
            +
             | 
| 72 | 
            +
             | 
| 73 | 
            +
            def get_ode_sampler(
         | 
| 74 | 
            +
                sde, score_fn, y, inverse_scaler=None,
         | 
| 75 | 
            +
                denoise=True, rtol=1e-5, atol=1e-5,
         | 
| 76 | 
            +
                method='RK45', eps=3e-2, device='cuda', **kwargs
         | 
| 77 | 
            +
            ):
         | 
| 78 | 
            +
                """Probability flow ODE sampler with the black-box ODE solver.
         | 
| 79 | 
            +
             | 
| 80 | 
            +
                Args:
         | 
| 81 | 
            +
                    sde: An `sdes.SDE` object representing the forward SDE.
         | 
| 82 | 
            +
                    score_fn: A function (typically learned model) that predicts the score.
         | 
| 83 | 
            +
                    y: A `torch.Tensor`, representing the (non-white-)noisy starting point(s) to condition the prior on.
         | 
| 84 | 
            +
                    inverse_scaler: The inverse data normalizer.
         | 
| 85 | 
            +
                    denoise: If `True`, add one-step denoising to final samples.
         | 
| 86 | 
            +
                    rtol: A `float` number. The relative tolerance level of the ODE solver.
         | 
| 87 | 
            +
                    atol: A `float` number. The absolute tolerance level of the ODE solver.
         | 
| 88 | 
            +
                    method: A `str`. The algorithm used for the black-box ODE solver.
         | 
| 89 | 
            +
                        See the documentation of `scipy.integrate.solve_ivp`.
         | 
| 90 | 
            +
                    eps: A `float` number. The reverse-time SDE/ODE will be integrated to `eps` for numerical stability.
         | 
| 91 | 
            +
                    device: PyTorch device.
         | 
| 92 | 
            +
             | 
| 93 | 
            +
                Returns:
         | 
| 94 | 
            +
                    A sampling function that returns samples and the number of function evaluations during sampling.
         | 
| 95 | 
            +
                """
         | 
| 96 | 
            +
                predictor = ReverseDiffusionPredictor(sde, score_fn, probability_flow=False)
         | 
| 97 | 
            +
                rsde = sde.reverse(score_fn, probability_flow=True)
         | 
| 98 | 
            +
             | 
| 99 | 
            +
                def denoise_update_fn(x):
         | 
| 100 | 
            +
                    vec_eps = torch.ones(x.shape[0], device=x.device) * eps
         | 
| 101 | 
            +
                    _, x = predictor.update_fn(x, vec_eps, y)
         | 
| 102 | 
            +
                    return x
         | 
| 103 | 
            +
             | 
| 104 | 
            +
                def drift_fn(x, t, y):
         | 
| 105 | 
            +
                    """Get the drift function of the reverse-time SDE."""
         | 
| 106 | 
            +
                    return rsde.sde(x, t, y)[0]
         | 
| 107 | 
            +
             | 
| 108 | 
            +
                def ode_sampler(z=None, **kwargs):
         | 
| 109 | 
            +
                    """The probability flow ODE sampler with black-box ODE solver.
         | 
| 110 | 
            +
             | 
| 111 | 
            +
                    Args:
         | 
| 112 | 
            +
                        model: A score model.
         | 
| 113 | 
            +
                        z: If present, generate samples from latent code `z`.
         | 
| 114 | 
            +
                    Returns:
         | 
| 115 | 
            +
                        samples, number of function evaluations.
         | 
| 116 | 
            +
                    """
         | 
| 117 | 
            +
                    with torch.no_grad():
         | 
| 118 | 
            +
                        # If not represent, sample the latent code from the prior distibution of the SDE.
         | 
| 119 | 
            +
                        x = sde.prior_sampling(y.shape, y).to(device)
         | 
| 120 | 
            +
             | 
| 121 | 
            +
                        def ode_func(t, x):
         | 
| 122 | 
            +
                            x = from_flattened_numpy(x, y.shape).to(device).type(torch.complex64)
         | 
| 123 | 
            +
                            vec_t = torch.ones(y.shape[0], device=x.device) * t
         | 
| 124 | 
            +
                            drift = drift_fn(x, vec_t, y)
         | 
| 125 | 
            +
                            return to_flattened_numpy(drift)
         | 
| 126 | 
            +
             | 
| 127 | 
            +
                        # Black-box ODE solver for the probability flow ODE
         | 
| 128 | 
            +
                        solution = integrate.solve_ivp(
         | 
| 129 | 
            +
                            ode_func, (sde.T, eps), to_flattened_numpy(x),
         | 
| 130 | 
            +
                            rtol=rtol, atol=atol, method=method, **kwargs
         | 
| 131 | 
            +
                        )
         | 
| 132 | 
            +
                        nfe = solution.nfev
         | 
| 133 | 
            +
                        x = torch.tensor(solution.y[:, -1]).reshape(y.shape).to(device).type(torch.complex64)
         | 
| 134 | 
            +
             | 
| 135 | 
            +
                        # Denoising is equivalent to running one predictor step without adding noise
         | 
| 136 | 
            +
                        if denoise:
         | 
| 137 | 
            +
                            x = denoise_update_fn(x)
         | 
| 138 | 
            +
             | 
| 139 | 
            +
                        if inverse_scaler is not None:
         | 
| 140 | 
            +
                            x = inverse_scaler(x)
         | 
| 141 | 
            +
                        return x, nfe
         | 
| 142 | 
            +
             | 
| 143 | 
            +
                return ode_sampler
         | 
    	
        sgmse/sampling/correctors.py
    ADDED
    
    | @@ -0,0 +1,96 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import abc
         | 
| 2 | 
            +
            import torch
         | 
| 3 | 
            +
             | 
| 4 | 
            +
            from sgmse import sdes
         | 
| 5 | 
            +
            from sgmse.util.registry import Registry
         | 
| 6 | 
            +
             | 
| 7 | 
            +
             | 
| 8 | 
            +
            CorrectorRegistry = Registry("Corrector")
         | 
| 9 | 
            +
             | 
| 10 | 
            +
             | 
| 11 | 
            +
            class Corrector(abc.ABC):
         | 
| 12 | 
            +
                """The abstract class for a corrector algorithm."""
         | 
| 13 | 
            +
             | 
| 14 | 
            +
                def __init__(self, sde, score_fn, snr, n_steps):
         | 
| 15 | 
            +
                    super().__init__()
         | 
| 16 | 
            +
                    self.rsde = sde.reverse(score_fn)
         | 
| 17 | 
            +
                    self.score_fn = score_fn
         | 
| 18 | 
            +
                    self.snr = snr
         | 
| 19 | 
            +
                    self.n_steps = n_steps
         | 
| 20 | 
            +
             | 
| 21 | 
            +
                @abc.abstractmethod
         | 
| 22 | 
            +
                def update_fn(self, x, t, *args):
         | 
| 23 | 
            +
                    """One update of the corrector.
         | 
| 24 | 
            +
             | 
| 25 | 
            +
                    Args:
         | 
| 26 | 
            +
                        x: A PyTorch tensor representing the current state
         | 
| 27 | 
            +
                        t: A PyTorch tensor representing the current time step.
         | 
| 28 | 
            +
                        *args: Possibly additional arguments, in particular `y` for OU processes
         | 
| 29 | 
            +
             | 
| 30 | 
            +
                    Returns:
         | 
| 31 | 
            +
                        x: A PyTorch tensor of the next state.
         | 
| 32 | 
            +
                        x_mean: A PyTorch tensor. The next state without random noise. Useful for denoising.
         | 
| 33 | 
            +
                    """
         | 
| 34 | 
            +
                    pass
         | 
| 35 | 
            +
             | 
| 36 | 
            +
             | 
| 37 | 
            +
            @CorrectorRegistry.register(name='langevin')
         | 
| 38 | 
            +
            class LangevinCorrector(Corrector):
         | 
| 39 | 
            +
                def __init__(self, sde, score_fn, snr, n_steps):
         | 
| 40 | 
            +
                    super().__init__(sde, score_fn, snr, n_steps)
         | 
| 41 | 
            +
                    self.score_fn = score_fn
         | 
| 42 | 
            +
                    self.n_steps = n_steps
         | 
| 43 | 
            +
                    self.snr = snr
         | 
| 44 | 
            +
             | 
| 45 | 
            +
                def update_fn(self, x, t, *args):
         | 
| 46 | 
            +
                    target_snr = self.snr
         | 
| 47 | 
            +
                    for _ in range(self.n_steps):
         | 
| 48 | 
            +
                        grad = self.score_fn(x, t, *args)
         | 
| 49 | 
            +
                        noise = torch.randn_like(x)
         | 
| 50 | 
            +
                        grad_norm = torch.norm(grad.reshape(grad.shape[0], -1), dim=-1).mean()
         | 
| 51 | 
            +
                        noise_norm = torch.norm(noise.reshape(noise.shape[0], -1), dim=-1).mean()
         | 
| 52 | 
            +
                        step_size = ((target_snr * noise_norm / grad_norm) ** 2 * 2).unsqueeze(0)
         | 
| 53 | 
            +
                        x_mean = x + step_size[:, None, None, None] * grad
         | 
| 54 | 
            +
                        x = x_mean + noise * torch.sqrt(step_size * 2)[:, None, None, None]
         | 
| 55 | 
            +
             | 
| 56 | 
            +
                    return x, x_mean
         | 
| 57 | 
            +
             | 
| 58 | 
            +
             | 
| 59 | 
            +
            @CorrectorRegistry.register(name='ald')
         | 
| 60 | 
            +
            class AnnealedLangevinDynamics(Corrector):
         | 
| 61 | 
            +
                """The original annealed Langevin dynamics predictor in NCSN/NCSNv2."""
         | 
| 62 | 
            +
                def __init__(self, sde, score_fn, snr, n_steps):
         | 
| 63 | 
            +
                    super().__init__(sde, score_fn, snr, n_steps)
         | 
| 64 | 
            +
                    if not isinstance(sde, (sdes.OUVESDE,)):
         | 
| 65 | 
            +
                        raise NotImplementedError(f"SDE class {sde.__class__.__name__} not yet supported.")
         | 
| 66 | 
            +
                    self.sde = sde
         | 
| 67 | 
            +
                    self.score_fn = score_fn
         | 
| 68 | 
            +
                    self.snr = snr
         | 
| 69 | 
            +
                    self.n_steps = n_steps
         | 
| 70 | 
            +
             | 
| 71 | 
            +
                def update_fn(self, x, t, *args):
         | 
| 72 | 
            +
                    n_steps = self.n_steps
         | 
| 73 | 
            +
                    target_snr = self.snr
         | 
| 74 | 
            +
                    std = self.sde.marginal_prob(x, t, *args)[1]
         | 
| 75 | 
            +
             | 
| 76 | 
            +
                    for _ in range(n_steps):
         | 
| 77 | 
            +
                        grad = self.score_fn(x, t, *args)
         | 
| 78 | 
            +
                        noise = torch.randn_like(x)
         | 
| 79 | 
            +
                        step_size = (target_snr * std) ** 2 * 2
         | 
| 80 | 
            +
                        x_mean = x + step_size[:, None, None, None] * grad
         | 
| 81 | 
            +
                        x = x_mean + noise * torch.sqrt(step_size * 2)[:, None, None, None]
         | 
| 82 | 
            +
             | 
| 83 | 
            +
                    return x, x_mean
         | 
| 84 | 
            +
             | 
| 85 | 
            +
             | 
| 86 | 
            +
            @CorrectorRegistry.register(name='none')
         | 
| 87 | 
            +
            class NoneCorrector(Corrector):
         | 
| 88 | 
            +
                """An empty corrector that does nothing."""
         | 
| 89 | 
            +
             | 
| 90 | 
            +
                def __init__(self, *args, **kwargs):
         | 
| 91 | 
            +
                    self.snr = 0
         | 
| 92 | 
            +
                    self.n_steps = 0
         | 
| 93 | 
            +
                    pass
         | 
| 94 | 
            +
             | 
| 95 | 
            +
                def update_fn(self, x, t, *args):
         | 
| 96 | 
            +
                    return x, x
         | 
    	
        sgmse/sampling/predictors.py
    ADDED
    
    | @@ -0,0 +1,76 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import abc
         | 
| 2 | 
            +
             | 
| 3 | 
            +
            import torch
         | 
| 4 | 
            +
            import numpy as np
         | 
| 5 | 
            +
             | 
| 6 | 
            +
            from sgmse.util.registry import Registry
         | 
| 7 | 
            +
             | 
| 8 | 
            +
             | 
| 9 | 
            +
            PredictorRegistry = Registry("Predictor")
         | 
| 10 | 
            +
             | 
| 11 | 
            +
             | 
| 12 | 
            +
            class Predictor(abc.ABC):
         | 
| 13 | 
            +
                """The abstract class for a predictor algorithm."""
         | 
| 14 | 
            +
             | 
| 15 | 
            +
                def __init__(self, sde, score_fn, probability_flow=False):
         | 
| 16 | 
            +
                    super().__init__()
         | 
| 17 | 
            +
                    self.sde = sde
         | 
| 18 | 
            +
                    self.rsde = sde.reverse(score_fn)
         | 
| 19 | 
            +
                    self.score_fn = score_fn
         | 
| 20 | 
            +
                    self.probability_flow = probability_flow
         | 
| 21 | 
            +
             | 
| 22 | 
            +
                @abc.abstractmethod
         | 
| 23 | 
            +
                def update_fn(self, x, t, *args):
         | 
| 24 | 
            +
                    """One update of the predictor.
         | 
| 25 | 
            +
             | 
| 26 | 
            +
                    Args:
         | 
| 27 | 
            +
                        x: A PyTorch tensor representing the current state
         | 
| 28 | 
            +
                        t: A Pytorch tensor representing the current time step.
         | 
| 29 | 
            +
                        *args: Possibly additional arguments, in particular `y` for OU processes
         | 
| 30 | 
            +
             | 
| 31 | 
            +
                    Returns:
         | 
| 32 | 
            +
                        x: A PyTorch tensor of the next state.
         | 
| 33 | 
            +
                        x_mean: A PyTorch tensor. The next state without random noise. Useful for denoising.
         | 
| 34 | 
            +
                    """
         | 
| 35 | 
            +
                    pass
         | 
| 36 | 
            +
             | 
| 37 | 
            +
                def debug_update_fn(self, x, t, *args):
         | 
| 38 | 
            +
                    raise NotImplementedError(f"Debug update function not implemented for predictor {self}.")
         | 
| 39 | 
            +
             | 
| 40 | 
            +
             | 
| 41 | 
            +
            @PredictorRegistry.register('euler_maruyama')
         | 
| 42 | 
            +
            class EulerMaruyamaPredictor(Predictor):
         | 
| 43 | 
            +
                def __init__(self, sde, score_fn, probability_flow=False):
         | 
| 44 | 
            +
                    super().__init__(sde, score_fn, probability_flow=probability_flow)
         | 
| 45 | 
            +
             | 
| 46 | 
            +
                def update_fn(self, x, t, *args):
         | 
| 47 | 
            +
                    dt = -1. / self.rsde.N
         | 
| 48 | 
            +
                    z = torch.randn_like(x)
         | 
| 49 | 
            +
                    f, g = self.rsde.sde(x, t, *args)
         | 
| 50 | 
            +
                    x_mean = x + f * dt
         | 
| 51 | 
            +
                    x = x_mean + g[:, None, None, None] * np.sqrt(-dt) * z
         | 
| 52 | 
            +
                    return x, x_mean
         | 
| 53 | 
            +
             | 
| 54 | 
            +
             | 
| 55 | 
            +
            @PredictorRegistry.register('reverse_diffusion')
         | 
| 56 | 
            +
            class ReverseDiffusionPredictor(Predictor):
         | 
| 57 | 
            +
                def __init__(self, sde, score_fn, probability_flow=False):
         | 
| 58 | 
            +
                    super().__init__(sde, score_fn, probability_flow=probability_flow)
         | 
| 59 | 
            +
             | 
| 60 | 
            +
                def update_fn(self, x, t, y, stepsize):
         | 
| 61 | 
            +
                    f, g = self.rsde.discretize(x, t, y, stepsize)
         | 
| 62 | 
            +
                    z = torch.randn_like(x)
         | 
| 63 | 
            +
                    x_mean = x - f
         | 
| 64 | 
            +
                    x = x_mean + g[:, None, None, None] * z
         | 
| 65 | 
            +
                    return x, x_mean
         | 
| 66 | 
            +
             | 
| 67 | 
            +
             | 
| 68 | 
            +
            @PredictorRegistry.register('none')
         | 
| 69 | 
            +
            class NonePredictor(Predictor):
         | 
| 70 | 
            +
                """An empty predictor that does nothing."""
         | 
| 71 | 
            +
             | 
| 72 | 
            +
                def __init__(self, *args, **kwargs):
         | 
| 73 | 
            +
                    pass
         | 
| 74 | 
            +
             | 
| 75 | 
            +
                def update_fn(self, x, t, *args):
         | 
| 76 | 
            +
                    return x, x
         | 
    	
        sgmse/sdes.py
    ADDED
    
    | @@ -0,0 +1,310 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            """
         | 
| 2 | 
            +
            Abstract SDE classes, Reverse SDE, and VE/VP SDEs.
         | 
| 3 | 
            +
             | 
| 4 | 
            +
            Taken and adapted from https://github.com/yang-song/score_sde_pytorch/blob/1618ddea340f3e4a2ed7852a0694a809775cf8d0/sde_lib.py
         | 
| 5 | 
            +
            """
         | 
| 6 | 
            +
            import abc
         | 
| 7 | 
            +
            import warnings
         | 
| 8 | 
            +
             | 
| 9 | 
            +
            import numpy as np
         | 
| 10 | 
            +
            from sgmse.util.tensors import batch_broadcast
         | 
| 11 | 
            +
            import torch
         | 
| 12 | 
            +
             | 
| 13 | 
            +
            from sgmse.util.registry import Registry
         | 
| 14 | 
            +
             | 
| 15 | 
            +
             | 
| 16 | 
            +
            SDERegistry = Registry("SDE")
         | 
| 17 | 
            +
             | 
| 18 | 
            +
             | 
| 19 | 
            +
            class SDE(abc.ABC):
         | 
| 20 | 
            +
                """SDE abstract class. Functions are designed for a mini-batch of inputs."""
         | 
| 21 | 
            +
             | 
| 22 | 
            +
                def __init__(self, N):
         | 
| 23 | 
            +
                    """Construct an SDE.
         | 
| 24 | 
            +
             | 
| 25 | 
            +
                    Args:
         | 
| 26 | 
            +
                        N: number of discretization time steps.
         | 
| 27 | 
            +
                    """
         | 
| 28 | 
            +
                    super().__init__()
         | 
| 29 | 
            +
                    self.N = N
         | 
| 30 | 
            +
             | 
| 31 | 
            +
                @property
         | 
| 32 | 
            +
                @abc.abstractmethod
         | 
| 33 | 
            +
                def T(self):
         | 
| 34 | 
            +
                    """End time of the SDE."""
         | 
| 35 | 
            +
                    pass
         | 
| 36 | 
            +
             | 
| 37 | 
            +
                @abc.abstractmethod
         | 
| 38 | 
            +
                def sde(self, x, t, *args):
         | 
| 39 | 
            +
                    pass
         | 
| 40 | 
            +
             | 
| 41 | 
            +
                @abc.abstractmethod
         | 
| 42 | 
            +
                def marginal_prob(self, x, t, *args):
         | 
| 43 | 
            +
                    """Parameters to determine the marginal distribution of the SDE, $p_t(x|args)$."""
         | 
| 44 | 
            +
                    pass
         | 
| 45 | 
            +
             | 
| 46 | 
            +
                @abc.abstractmethod
         | 
| 47 | 
            +
                def prior_sampling(self, shape, *args):
         | 
| 48 | 
            +
                    """Generate one sample from the prior distribution, $p_T(x|args)$ with shape `shape`."""
         | 
| 49 | 
            +
                    pass
         | 
| 50 | 
            +
             | 
| 51 | 
            +
                @abc.abstractmethod
         | 
| 52 | 
            +
                def prior_logp(self, z):
         | 
| 53 | 
            +
                    """Compute log-density of the prior distribution.
         | 
| 54 | 
            +
             | 
| 55 | 
            +
                    Useful for computing the log-likelihood via probability flow ODE.
         | 
| 56 | 
            +
             | 
| 57 | 
            +
                    Args:
         | 
| 58 | 
            +
                        z: latent code
         | 
| 59 | 
            +
                    Returns:
         | 
| 60 | 
            +
                        log probability density
         | 
| 61 | 
            +
                    """
         | 
| 62 | 
            +
                    pass
         | 
| 63 | 
            +
             | 
| 64 | 
            +
                @staticmethod
         | 
| 65 | 
            +
                @abc.abstractmethod
         | 
| 66 | 
            +
                def add_argparse_args(parent_parser):
         | 
| 67 | 
            +
                    """
         | 
| 68 | 
            +
                    Add the necessary arguments for instantiation of this SDE class to an argparse ArgumentParser.
         | 
| 69 | 
            +
                    """
         | 
| 70 | 
            +
                    pass
         | 
| 71 | 
            +
             | 
| 72 | 
            +
                def discretize(self, x, t, y, stepsize):
         | 
| 73 | 
            +
                    """Discretize the SDE in the form: x_{i+1} = x_i + f_i(x_i) + G_i z_i.
         | 
| 74 | 
            +
             | 
| 75 | 
            +
                    Useful for reverse diffusion sampling and probabiliy flow sampling.
         | 
| 76 | 
            +
                    Defaults to Euler-Maruyama discretization.
         | 
| 77 | 
            +
             | 
| 78 | 
            +
                    Args:
         | 
| 79 | 
            +
                        x: a torch tensor
         | 
| 80 | 
            +
                        t: a torch float representing the time step (from 0 to `self.T`)
         | 
| 81 | 
            +
             | 
| 82 | 
            +
                    Returns:
         | 
| 83 | 
            +
                        f, G
         | 
| 84 | 
            +
                    """
         | 
| 85 | 
            +
                    dt = stepsize
         | 
| 86 | 
            +
                    drift, diffusion = self.sde(x, t, y)
         | 
| 87 | 
            +
                    f = drift * dt
         | 
| 88 | 
            +
                    G = diffusion * torch.sqrt(dt)
         | 
| 89 | 
            +
                    return f, G
         | 
| 90 | 
            +
             | 
| 91 | 
            +
                def reverse(oself, score_model, probability_flow=False):
         | 
| 92 | 
            +
                    """Create the reverse-time SDE/ODE.
         | 
| 93 | 
            +
             | 
| 94 | 
            +
                    Args:
         | 
| 95 | 
            +
                        score_model: A function that takes x, t and y and returns the score.
         | 
| 96 | 
            +
                        probability_flow: If `True`, create the reverse-time ODE used for probability flow sampling.
         | 
| 97 | 
            +
                    """
         | 
| 98 | 
            +
                    N = oself.N
         | 
| 99 | 
            +
                    T = oself.T
         | 
| 100 | 
            +
                    sde_fn = oself.sde
         | 
| 101 | 
            +
                    discretize_fn = oself.discretize
         | 
| 102 | 
            +
             | 
| 103 | 
            +
                    # Build the class for reverse-time SDE.
         | 
| 104 | 
            +
                    class RSDE(oself.__class__):
         | 
| 105 | 
            +
                        def __init__(self):
         | 
| 106 | 
            +
                            self.N = N
         | 
| 107 | 
            +
                            self.probability_flow = probability_flow
         | 
| 108 | 
            +
             | 
| 109 | 
            +
                        @property
         | 
| 110 | 
            +
                        def T(self):
         | 
| 111 | 
            +
                            return T
         | 
| 112 | 
            +
             | 
| 113 | 
            +
                        def sde(self, x, t, *args):
         | 
| 114 | 
            +
                            """Create the drift and diffusion functions for the reverse SDE/ODE."""
         | 
| 115 | 
            +
                            rsde_parts = self.rsde_parts(x, t, *args)
         | 
| 116 | 
            +
                            total_drift, diffusion = rsde_parts["total_drift"], rsde_parts["diffusion"]
         | 
| 117 | 
            +
                            return total_drift, diffusion
         | 
| 118 | 
            +
             | 
| 119 | 
            +
                        def rsde_parts(self, x, t, *args):
         | 
| 120 | 
            +
                            sde_drift, sde_diffusion = sde_fn(x, t, *args)
         | 
| 121 | 
            +
                            score = score_model(x, t, *args)
         | 
| 122 | 
            +
                            score_drift = -sde_diffusion[:, None, None, None]**2 * score * (0.5 if self.probability_flow else 1.)
         | 
| 123 | 
            +
                            diffusion = torch.zeros_like(sde_diffusion) if self.probability_flow else sde_diffusion
         | 
| 124 | 
            +
                            total_drift = sde_drift + score_drift
         | 
| 125 | 
            +
                            return {
         | 
| 126 | 
            +
                                'total_drift': total_drift, 'diffusion': diffusion, 'sde_drift': sde_drift,
         | 
| 127 | 
            +
                                'sde_diffusion': sde_diffusion, 'score_drift': score_drift, 'score': score,
         | 
| 128 | 
            +
                            }
         | 
| 129 | 
            +
             | 
| 130 | 
            +
                        def discretize(self, x, t, y, stepsize):
         | 
| 131 | 
            +
                            """Create discretized iteration rules for the reverse diffusion sampler."""
         | 
| 132 | 
            +
                            f, G = discretize_fn(x, t, y, stepsize)
         | 
| 133 | 
            +
                            rev_f = f - G[:, None, None, None] ** 2 * score_model(x, t, y) * (0.5 if self.probability_flow else 1.)
         | 
| 134 | 
            +
                            rev_G = torch.zeros_like(G) if self.probability_flow else G
         | 
| 135 | 
            +
                            return rev_f, rev_G
         | 
| 136 | 
            +
             | 
| 137 | 
            +
                    return RSDE()
         | 
| 138 | 
            +
             | 
| 139 | 
            +
                @abc.abstractmethod
         | 
| 140 | 
            +
                def copy(self):
         | 
| 141 | 
            +
                    pass
         | 
| 142 | 
            +
             | 
| 143 | 
            +
             | 
| 144 | 
            +
            @SDERegistry.register("ouve")
         | 
| 145 | 
            +
            class OUVESDE(SDE):
         | 
| 146 | 
            +
                @staticmethod
         | 
| 147 | 
            +
                def add_argparse_args(parser):
         | 
| 148 | 
            +
                    parser.add_argument("--sde-n", type=int, default=1000, help="The number of timesteps in the SDE discretization. 30 by default")
         | 
| 149 | 
            +
                    parser.add_argument("--theta", type=float, default=1.5, help="The constant stiffness of the Ornstein-Uhlenbeck process. 1.5 by default.")
         | 
| 150 | 
            +
                    parser.add_argument("--sigma-min", type=float, default=0.05, help="The minimum sigma to use. 0.05 by default.")
         | 
| 151 | 
            +
                    parser.add_argument("--sigma-max", type=float, default=0.5, help="The maximum sigma to use. 0.5 by default.")
         | 
| 152 | 
            +
                    return parser
         | 
| 153 | 
            +
             | 
| 154 | 
            +
                def __init__(self, theta, sigma_min, sigma_max, N=1000, **ignored_kwargs):
         | 
| 155 | 
            +
                    """Construct an Ornstein-Uhlenbeck Variance Exploding SDE.
         | 
| 156 | 
            +
             | 
| 157 | 
            +
                    Note that the "steady-state mean" `y` is not provided at construction, but must rather be given as an argument
         | 
| 158 | 
            +
                    to the methods which require it (e.g., `sde` or `marginal_prob`).
         | 
| 159 | 
            +
             | 
| 160 | 
            +
                    dx = -theta (y-x) dt + sigma(t) dw
         | 
| 161 | 
            +
             | 
| 162 | 
            +
                    with
         | 
| 163 | 
            +
             | 
| 164 | 
            +
                    sigma(t) = sigma_min (sigma_max/sigma_min)^t * sqrt(2 log(sigma_max/sigma_min))
         | 
| 165 | 
            +
             | 
| 166 | 
            +
                    Args:
         | 
| 167 | 
            +
                        theta: stiffness parameter.
         | 
| 168 | 
            +
                        sigma_min: smallest sigma.
         | 
| 169 | 
            +
                        sigma_max: largest sigma.
         | 
| 170 | 
            +
                        N: number of discretization steps
         | 
| 171 | 
            +
                    """
         | 
| 172 | 
            +
                    super().__init__(N)
         | 
| 173 | 
            +
                    self.theta = theta
         | 
| 174 | 
            +
                    self.sigma_min = sigma_min
         | 
| 175 | 
            +
                    self.sigma_max = sigma_max
         | 
| 176 | 
            +
                    self.logsig = np.log(self.sigma_max / self.sigma_min)
         | 
| 177 | 
            +
                    self.N = N
         | 
| 178 | 
            +
             | 
| 179 | 
            +
                def copy(self):
         | 
| 180 | 
            +
                    return OUVESDE(self.theta, self.sigma_min, self.sigma_max, N=self.N)
         | 
| 181 | 
            +
             | 
| 182 | 
            +
                @property
         | 
| 183 | 
            +
                def T(self):
         | 
| 184 | 
            +
                    return 1
         | 
| 185 | 
            +
             | 
| 186 | 
            +
                def sde(self, x, t, y):
         | 
| 187 | 
            +
                    drift = self.theta * (y - x)
         | 
| 188 | 
            +
                    # the sqrt(2*logsig) factor is required here so that logsig does not in the end affect the perturbation kernel
         | 
| 189 | 
            +
                    # standard deviation. this can be understood from solving the integral of [exp(2s) * g(s)^2] from s=0 to t
         | 
| 190 | 
            +
                    # with g(t) = sigma(t) as defined here, and seeing that `logsig` remains in the integral solution
         | 
| 191 | 
            +
                    # unless this sqrt(2*logsig) factor is included.
         | 
| 192 | 
            +
                    sigma = self.sigma_min * (self.sigma_max / self.sigma_min) ** t
         | 
| 193 | 
            +
                    diffusion = sigma * np.sqrt(2 * self.logsig)
         | 
| 194 | 
            +
                    return drift, diffusion
         | 
| 195 | 
            +
             | 
| 196 | 
            +
                def _mean(self, x0, t, y):
         | 
| 197 | 
            +
                    theta = self.theta
         | 
| 198 | 
            +
                    exp_interp = torch.exp(-theta * t)[:, None, None, None]
         | 
| 199 | 
            +
                    return exp_interp * x0 + (1 - exp_interp) * y
         | 
| 200 | 
            +
             | 
| 201 | 
            +
                def alpha(self, t):
         | 
| 202 | 
            +
                    return torch.exp(-self.theta * t)
         | 
| 203 | 
            +
             | 
| 204 | 
            +
                def _std(self, t):
         | 
| 205 | 
            +
                    # This is a full solution to the ODE for P(t) in our derivations, after choosing g(s) as in self.sde()
         | 
| 206 | 
            +
                    sigma_min, theta, logsig = self.sigma_min, self.theta, self.logsig
         | 
| 207 | 
            +
                    # could maybe replace the two torch.exp(... * t) terms here by cached values **t
         | 
| 208 | 
            +
                    return torch.sqrt(
         | 
| 209 | 
            +
                        (
         | 
| 210 | 
            +
                            sigma_min**2
         | 
| 211 | 
            +
                            * torch.exp(-2 * theta * t)
         | 
| 212 | 
            +
                            * (torch.exp(2 * (theta + logsig) * t) - 1)
         | 
| 213 | 
            +
                            * logsig
         | 
| 214 | 
            +
                        )
         | 
| 215 | 
            +
                        /
         | 
| 216 | 
            +
                        (theta + logsig)
         | 
| 217 | 
            +
                    )
         | 
| 218 | 
            +
             | 
| 219 | 
            +
                def marginal_prob(self, x0, t, y):
         | 
| 220 | 
            +
                    return self._mean(x0, t, y), self._std(t)
         | 
| 221 | 
            +
             | 
| 222 | 
            +
                def prior_sampling(self, shape, y):
         | 
| 223 | 
            +
                    if shape != y.shape:
         | 
| 224 | 
            +
                        warnings.warn(f"Target shape {shape} does not match shape of y {y.shape}! Ignoring target shape.")
         | 
| 225 | 
            +
                    std = self._std(torch.ones((y.shape[0],), device=y.device))
         | 
| 226 | 
            +
                    x_T = y + torch.randn_like(y) * std[:, None, None, None]
         | 
| 227 | 
            +
                    return x_T
         | 
| 228 | 
            +
             | 
| 229 | 
            +
                def prior_logp(self, z):
         | 
| 230 | 
            +
                    raise NotImplementedError("prior_logp for OU SDE not yet implemented!")
         | 
| 231 | 
            +
             | 
| 232 | 
            +
             | 
| 233 | 
            +
            @SDERegistry.register("ouvp")
         | 
| 234 | 
            +
            class OUVPSDE(SDE):
         | 
| 235 | 
            +
                # !!! We do not utilize this SDE in our works due to observed instabilities around t=0.2. !!!
         | 
| 236 | 
            +
                @staticmethod
         | 
| 237 | 
            +
                def add_argparse_args(parser):
         | 
| 238 | 
            +
                    parser.add_argument("--sde-n", type=int, default=1000,
         | 
| 239 | 
            +
                        help="The number of timesteps in the SDE discretization. 1000 by default")
         | 
| 240 | 
            +
                    parser.add_argument("--beta-min", type=float, required=True,
         | 
| 241 | 
            +
                        help="The minimum beta to use.")
         | 
| 242 | 
            +
                    parser.add_argument("--beta-max", type=float, required=True,
         | 
| 243 | 
            +
                        help="The maximum beta to use.")
         | 
| 244 | 
            +
                    parser.add_argument("--stiffness", type=float, default=1,
         | 
| 245 | 
            +
                        help="The stiffness factor for the drift, to be multiplied by 0.5*beta(t). 1 by default.")
         | 
| 246 | 
            +
                    return parser
         | 
| 247 | 
            +
             | 
| 248 | 
            +
                def __init__(self, beta_min, beta_max, stiffness=1, N=1000, **ignored_kwargs):
         | 
| 249 | 
            +
                    """
         | 
| 250 | 
            +
                    !!! We do not utilize this SDE in our works due to observed instabilities around t=0.2. !!!
         | 
| 251 | 
            +
             | 
| 252 | 
            +
                    Construct an Ornstein-Uhlenbeck Variance Preserving SDE:
         | 
| 253 | 
            +
             | 
| 254 | 
            +
                    dx = -1/2 * beta(t) * stiffness * (y-x) dt + sqrt(beta(t)) * dw
         | 
| 255 | 
            +
             | 
| 256 | 
            +
                    with
         | 
| 257 | 
            +
             | 
| 258 | 
            +
                    beta(t) = beta_min + t(beta_max - beta_min)
         | 
| 259 | 
            +
             | 
| 260 | 
            +
                    Note that the "steady-state mean" `y` is not provided at construction, but must rather be given as an argument
         | 
| 261 | 
            +
                    to the methods which require it (e.g., `sde` or `marginal_prob`).
         | 
| 262 | 
            +
             | 
| 263 | 
            +
                    Args:
         | 
| 264 | 
            +
                        beta_min: smallest sigma.
         | 
| 265 | 
            +
                        beta_max: largest sigma.
         | 
| 266 | 
            +
                        stiffness: stiffness factor of the drift. 1 by default.
         | 
| 267 | 
            +
                        N: number of discretization steps
         | 
| 268 | 
            +
                    """
         | 
| 269 | 
            +
                    super().__init__(N)
         | 
| 270 | 
            +
                    self.beta_min = beta_min
         | 
| 271 | 
            +
                    self.beta_max = beta_max
         | 
| 272 | 
            +
                    self.stiffness = stiffness
         | 
| 273 | 
            +
                    self.N = N
         | 
| 274 | 
            +
             | 
| 275 | 
            +
                def copy(self):
         | 
| 276 | 
            +
                    return OUVPSDE(self.beta_min, self.beta_max, self.stiffness, N=self.N)
         | 
| 277 | 
            +
             | 
| 278 | 
            +
                @property
         | 
| 279 | 
            +
                def T(self):
         | 
| 280 | 
            +
                    return 1
         | 
| 281 | 
            +
             | 
| 282 | 
            +
                def _beta(self, t):
         | 
| 283 | 
            +
                    return self.beta_min + t * (self.beta_max - self.beta_min)
         | 
| 284 | 
            +
             | 
| 285 | 
            +
                def sde(self, x, t, y):
         | 
| 286 | 
            +
                    drift = 0.5 * self.stiffness * batch_broadcast(self._beta(t), y) * (y - x)
         | 
| 287 | 
            +
                    diffusion = torch.sqrt(self._beta(t))
         | 
| 288 | 
            +
                    return drift, diffusion
         | 
| 289 | 
            +
             | 
| 290 | 
            +
                def _mean(self, x0, t, y):
         | 
| 291 | 
            +
                    b0, b1, s = self.beta_min, self.beta_max, self.stiffness
         | 
| 292 | 
            +
                    x0y_fac = torch.exp(-0.25 * s * t * (t * (b1-b0) + 2 * b0))[:, None, None, None]
         | 
| 293 | 
            +
                    return y + x0y_fac * (x0 - y)
         | 
| 294 | 
            +
             | 
| 295 | 
            +
                def _std(self, t):
         | 
| 296 | 
            +
                    b0, b1, s = self.beta_min, self.beta_max, self.stiffness
         | 
| 297 | 
            +
                    return (1 - torch.exp(-0.5 * s * t * (t * (b1-b0) + 2 * b0))) / s
         | 
| 298 | 
            +
             | 
| 299 | 
            +
                def marginal_prob(self, x0, t, y):
         | 
| 300 | 
            +
                    return self._mean(x0, t, y), self._std(t)
         | 
| 301 | 
            +
             | 
| 302 | 
            +
                def prior_sampling(self, shape, y):
         | 
| 303 | 
            +
                    if shape != y.shape:
         | 
| 304 | 
            +
                        warnings.warn(f"Target shape {shape} does not match shape of y {y.shape}! Ignoring target shape.")
         | 
| 305 | 
            +
                    std = self._std(torch.ones((y.shape[0],), device=y.device))
         | 
| 306 | 
            +
                    x_T = y + torch.randn_like(y) * std[:, None, None, None]
         | 
| 307 | 
            +
                    return x_T
         | 
| 308 | 
            +
             | 
| 309 | 
            +
                def prior_logp(self, z):
         | 
| 310 | 
            +
                    raise NotImplementedError("prior_logp for OU SDE not yet implemented!")
         | 
    	
        sgmse/util/inference.py
    ADDED
    
    | @@ -0,0 +1,64 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import torch
         | 
| 2 | 
            +
            from torchaudio import load
         | 
| 3 | 
            +
             | 
| 4 | 
            +
            from pesq import pesq
         | 
| 5 | 
            +
            from pystoi import stoi
         | 
| 6 | 
            +
             | 
| 7 | 
            +
            from .other import si_sdr, pad_spec
         | 
| 8 | 
            +
             | 
| 9 | 
            +
            # Settings
         | 
| 10 | 
            +
            sr = 16000
         | 
| 11 | 
            +
            snr = 0.5
         | 
| 12 | 
            +
            N = 30
         | 
| 13 | 
            +
            corrector_steps = 1
         | 
| 14 | 
            +
             | 
| 15 | 
            +
             | 
| 16 | 
            +
            def evaluate_model(model, num_eval_files):
         | 
| 17 | 
            +
             | 
| 18 | 
            +
                clean_files = model.data_module.valid_set.clean_files
         | 
| 19 | 
            +
                noisy_files = model.data_module.valid_set.noisy_files
         | 
| 20 | 
            +
                
         | 
| 21 | 
            +
                # Select test files uniformly accros validation files
         | 
| 22 | 
            +
                total_num_files = len(clean_files)
         | 
| 23 | 
            +
                indices = torch.linspace(0, total_num_files-1, num_eval_files, dtype=torch.int)
         | 
| 24 | 
            +
                clean_files = list(clean_files[i] for i in indices)
         | 
| 25 | 
            +
                noisy_files = list(noisy_files[i] for i in indices)
         | 
| 26 | 
            +
             | 
| 27 | 
            +
                _pesq = 0
         | 
| 28 | 
            +
                _si_sdr = 0
         | 
| 29 | 
            +
                _estoi = 0
         | 
| 30 | 
            +
                # iterate over files
         | 
| 31 | 
            +
                for (clean_file, noisy_file) in zip(clean_files, noisy_files):
         | 
| 32 | 
            +
                    # Load wavs
         | 
| 33 | 
            +
                    x, _ = load(clean_file)
         | 
| 34 | 
            +
                    y, _ = load(noisy_file) 
         | 
| 35 | 
            +
                    T_orig = x.size(1)   
         | 
| 36 | 
            +
             | 
| 37 | 
            +
                    # Normalize per utterance
         | 
| 38 | 
            +
                    norm_factor = y.abs().max()
         | 
| 39 | 
            +
                    y = y / norm_factor
         | 
| 40 | 
            +
             | 
| 41 | 
            +
                    # Prepare DNN input
         | 
| 42 | 
            +
                    Y = torch.unsqueeze(model._forward_transform(model._stft(y.cuda())), 0)
         | 
| 43 | 
            +
                    Y = pad_spec(Y)
         | 
| 44 | 
            +
                    y = y * norm_factor
         | 
| 45 | 
            +
             | 
| 46 | 
            +
                    # Reverse sampling
         | 
| 47 | 
            +
                    sampler = model.get_pc_sampler(
         | 
| 48 | 
            +
                        'reverse_diffusion', 'ald', Y.cuda(), N=N, 
         | 
| 49 | 
            +
                        corrector_steps=corrector_steps, snr=snr)
         | 
| 50 | 
            +
                    sample, _ = sampler()
         | 
| 51 | 
            +
             | 
| 52 | 
            +
                    x_hat = model.to_audio(sample.squeeze(), T_orig)
         | 
| 53 | 
            +
                    x_hat = x_hat * norm_factor
         | 
| 54 | 
            +
             | 
| 55 | 
            +
                    x_hat = x_hat.squeeze().cpu().numpy()
         | 
| 56 | 
            +
                    x = x.squeeze().cpu().numpy()
         | 
| 57 | 
            +
                    y = y.squeeze().cpu().numpy()
         | 
| 58 | 
            +
             | 
| 59 | 
            +
                    _si_sdr += si_sdr(x, x_hat)
         | 
| 60 | 
            +
                    _pesq += pesq(sr, x, x_hat, 'wb') 
         | 
| 61 | 
            +
                    _estoi += stoi(x, x_hat, sr, extended=True)
         | 
| 62 | 
            +
                    
         | 
| 63 | 
            +
                return _pesq/num_eval_files, _si_sdr/num_eval_files, _estoi/num_eval_files
         | 
| 64 | 
            +
             | 
    	
        sgmse/util/other.py
    ADDED
    
    | @@ -0,0 +1,141 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import os
         | 
| 2 | 
            +
            import torch
         | 
| 3 | 
            +
            import numpy as np
         | 
| 4 | 
            +
            import scipy.stats
         | 
| 5 | 
            +
            from scipy.signal import butter, sosfilt
         | 
| 6 | 
            +
             | 
| 7 | 
            +
            from pesq import pesq
         | 
| 8 | 
            +
            from pystoi import stoi
         | 
| 9 | 
            +
             | 
| 10 | 
            +
             | 
| 11 | 
            +
            def si_sdr_components(s_hat, s, n):
         | 
| 12 | 
            +
                # s_target
         | 
| 13 | 
            +
                alpha_s = np.dot(s_hat, s) / np.linalg.norm(s)**2
         | 
| 14 | 
            +
                s_target = alpha_s * s
         | 
| 15 | 
            +
             | 
| 16 | 
            +
                # e_noise
         | 
| 17 | 
            +
                alpha_n = np.dot(s_hat, n) / np.linalg.norm(n)**2
         | 
| 18 | 
            +
                e_noise = alpha_n * n
         | 
| 19 | 
            +
             | 
| 20 | 
            +
                # e_art
         | 
| 21 | 
            +
                e_art = s_hat - s_target - e_noise
         | 
| 22 | 
            +
                
         | 
| 23 | 
            +
                return s_target, e_noise, e_art
         | 
| 24 | 
            +
             | 
| 25 | 
            +
            def energy_ratios(s_hat, s, n):
         | 
| 26 | 
            +
                s_target, e_noise, e_art = si_sdr_components(s_hat, s, n)
         | 
| 27 | 
            +
             | 
| 28 | 
            +
                si_sdr = 10*np.log10(np.linalg.norm(s_target)**2 / np.linalg.norm(e_noise + e_art)**2)
         | 
| 29 | 
            +
                si_sir = 10*np.log10(np.linalg.norm(s_target)**2 / np.linalg.norm(e_noise)**2)
         | 
| 30 | 
            +
                si_sar = 10*np.log10(np.linalg.norm(s_target)**2 / np.linalg.norm(e_art)**2)
         | 
| 31 | 
            +
             | 
| 32 | 
            +
                return si_sdr, si_sir, si_sar
         | 
| 33 | 
            +
             | 
| 34 | 
            +
            def mean_conf_int(data, confidence=0.95):
         | 
| 35 | 
            +
                a = 1.0 * np.array(data)
         | 
| 36 | 
            +
                n = len(a)
         | 
| 37 | 
            +
                m, se = np.mean(a), scipy.stats.sem(a)
         | 
| 38 | 
            +
                h = se * scipy.stats.t.ppf((1 + confidence) / 2., n-1)
         | 
| 39 | 
            +
                return m, h
         | 
| 40 | 
            +
             | 
| 41 | 
            +
            class Method():
         | 
| 42 | 
            +
                def __init__(self, name, base_dir, metrics):
         | 
| 43 | 
            +
                    self.name = name
         | 
| 44 | 
            +
                    self.base_dir = base_dir
         | 
| 45 | 
            +
                    self.metrics = {} 
         | 
| 46 | 
            +
                    
         | 
| 47 | 
            +
                    for i in range(len(metrics)):
         | 
| 48 | 
            +
                        metric = metrics[i]
         | 
| 49 | 
            +
                        value = []
         | 
| 50 | 
            +
                        self.metrics[metric] = value 
         | 
| 51 | 
            +
                        
         | 
| 52 | 
            +
                def append(self, matric, value):
         | 
| 53 | 
            +
                    self.metrics[matric].append(value)
         | 
| 54 | 
            +
             | 
| 55 | 
            +
                def get_mean_ci(self, metric):
         | 
| 56 | 
            +
                    return mean_conf_int(np.array(self.metrics[metric]))
         | 
| 57 | 
            +
             | 
| 58 | 
            +
            def hp_filter(signal, cut_off=80, order=10, sr=16000):
         | 
| 59 | 
            +
                factor = cut_off /sr * 2
         | 
| 60 | 
            +
                sos = butter(order, factor, 'hp', output='sos')
         | 
| 61 | 
            +
                filtered = sosfilt(sos, signal)
         | 
| 62 | 
            +
                return filtered
         | 
| 63 | 
            +
             | 
| 64 | 
            +
            def si_sdr(s, s_hat):
         | 
| 65 | 
            +
                alpha = np.dot(s_hat, s)/np.linalg.norm(s)**2   
         | 
| 66 | 
            +
                sdr = 10*np.log10(np.linalg.norm(alpha*s)**2/np.linalg.norm(
         | 
| 67 | 
            +
                    alpha*s - s_hat)**2)
         | 
| 68 | 
            +
                return sdr
         | 
| 69 | 
            +
             | 
| 70 | 
            +
            def snr_dB(s,n):
         | 
| 71 | 
            +
                s_power = 1/len(s)*np.sum(s**2)
         | 
| 72 | 
            +
                n_power = 1/len(n)*np.sum(n**2)
         | 
| 73 | 
            +
                snr_dB = 10*np.log10(s_power/n_power)
         | 
| 74 | 
            +
                return snr_dB
         | 
| 75 | 
            +
             | 
| 76 | 
            +
            def pad_spec(Y, mode="zero_pad"):
         | 
| 77 | 
            +
                T = Y.size(3)
         | 
| 78 | 
            +
                if T%64 !=0:
         | 
| 79 | 
            +
                    num_pad = 64-T%64
         | 
| 80 | 
            +
                else:
         | 
| 81 | 
            +
                    num_pad = 0
         | 
| 82 | 
            +
                if mode == "zero_pad":
         | 
| 83 | 
            +
                    pad2d = torch.nn.ZeroPad2d((0, num_pad, 0,0))
         | 
| 84 | 
            +
                elif mode == "reflection":
         | 
| 85 | 
            +
                    pad2d = torch.nn.ReflectionPad2d((0, num_pad, 0,0))
         | 
| 86 | 
            +
                elif mode == "replication":
         | 
| 87 | 
            +
                    pad2d = torch.nn.ReplicationPad2d((0, num_pad, 0,0))
         | 
| 88 | 
            +
                else:
         | 
| 89 | 
            +
                    raise NotImplementedError("This function hasn't been implemented yet.")
         | 
| 90 | 
            +
                return pad2d(Y)
         | 
| 91 | 
            +
             | 
| 92 | 
            +
            def ensure_dir(file_path):
         | 
| 93 | 
            +
                directory = file_path
         | 
| 94 | 
            +
                if not os.path.exists(directory):
         | 
| 95 | 
            +
                    os.makedirs(directory)
         | 
| 96 | 
            +
             | 
| 97 | 
            +
             | 
| 98 | 
            +
            def print_metrics(x, y, x_hat_list, labels, sr=16000):
         | 
| 99 | 
            +
                _si_sdr_mix = si_sdr(x, y)
         | 
| 100 | 
            +
                _pesq_mix = pesq(sr, x, y, 'wb')
         | 
| 101 | 
            +
                _estoi_mix = stoi(x, y, sr, extended=True)
         | 
| 102 | 
            +
                print(f'Mixture:  PESQ: {_pesq_mix:.2f}, ESTOI: {_estoi_mix:.2f}, SI-SDR: {_si_sdr_mix:.2f}')
         | 
| 103 | 
            +
                for i, x_hat in enumerate(x_hat_list):
         | 
| 104 | 
            +
                    _si_sdr = si_sdr(x, x_hat)
         | 
| 105 | 
            +
                    _pesq = pesq(sr, x, x_hat, 'wb')
         | 
| 106 | 
            +
                    _estoi = stoi(x, x_hat, sr, extended=True)
         | 
| 107 | 
            +
                    print(f'{labels[i]}: {_pesq:.2f}, ESTOI: {_estoi:.2f}, SI-SDR: {_si_sdr:.2f}')
         | 
| 108 | 
            +
             | 
| 109 | 
            +
            def mean_std(data):
         | 
| 110 | 
            +
                data = data[~np.isnan(data)]
         | 
| 111 | 
            +
                mean = np.mean(data)
         | 
| 112 | 
            +
                std = np.std(data)
         | 
| 113 | 
            +
                return mean, std
         | 
| 114 | 
            +
             | 
| 115 | 
            +
            def print_mean_std(data, decimal=2):
         | 
| 116 | 
            +
                data = np.array(data)
         | 
| 117 | 
            +
                data = data[~np.isnan(data)]
         | 
| 118 | 
            +
                mean = np.mean(data)
         | 
| 119 | 
            +
                std = np.std(data)
         | 
| 120 | 
            +
                if decimal == 2:
         | 
| 121 | 
            +
                    string = f'{mean:.2f} ± {std:.2f}'
         | 
| 122 | 
            +
                elif decimal == 1:
         | 
| 123 | 
            +
                    string = f'{mean:.1f} ± {std:.1f}'
         | 
| 124 | 
            +
                return string
         | 
| 125 | 
            +
             | 
| 126 | 
            +
            def set_torch_cuda_arch_list():
         | 
| 127 | 
            +
                if not torch.cuda.is_available():
         | 
| 128 | 
            +
                    print("CUDA is not available. No GPUs found.")
         | 
| 129 | 
            +
                    return
         | 
| 130 | 
            +
                
         | 
| 131 | 
            +
                num_gpus = torch.cuda.device_count()
         | 
| 132 | 
            +
                compute_capabilities = []
         | 
| 133 | 
            +
             | 
| 134 | 
            +
                for i in range(num_gpus):
         | 
| 135 | 
            +
                    cc_major, cc_minor = torch.cuda.get_device_capability(i)
         | 
| 136 | 
            +
                    cc = f"{cc_major}.{cc_minor}"
         | 
| 137 | 
            +
                    compute_capabilities.append(cc)
         | 
| 138 | 
            +
                
         | 
| 139 | 
            +
                cc_string = ";".join(compute_capabilities)
         | 
| 140 | 
            +
                os.environ['TORCH_CUDA_ARCH_LIST'] = cc_string
         | 
| 141 | 
            +
                print(f"Set TORCH_CUDA_ARCH_LIST to: {cc_string}")
         | 
    	
        sgmse/util/registry.py
    ADDED
    
    | @@ -0,0 +1,34 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import warnings
         | 
| 2 | 
            +
            from typing import Callable
         | 
| 3 | 
            +
             | 
| 4 | 
            +
             | 
| 5 | 
            +
            class Registry:
         | 
| 6 | 
            +
                def __init__(self, managed_thing: str):
         | 
| 7 | 
            +
                    """
         | 
| 8 | 
            +
                    Create a new registry.
         | 
| 9 | 
            +
             | 
| 10 | 
            +
                    Args:
         | 
| 11 | 
            +
                        managed_thing: A string describing what type of thing is managed by this registry. Will be used for
         | 
| 12 | 
            +
                            warnings and errors, so it's a good idea to keep this string globally unique and easily understood.
         | 
| 13 | 
            +
                    """
         | 
| 14 | 
            +
                    self.managed_thing = managed_thing
         | 
| 15 | 
            +
                    self._registry = {}
         | 
| 16 | 
            +
             | 
| 17 | 
            +
                def register(self, name: str) -> Callable:
         | 
| 18 | 
            +
                    def inner_wrapper(wrapped_class) -> Callable:
         | 
| 19 | 
            +
                        if name in self._registry:
         | 
| 20 | 
            +
                            warnings.warn(f"{self.managed_thing} with name '{name}' doubly registered, old class will be replaced.")
         | 
| 21 | 
            +
                        self._registry[name] = wrapped_class
         | 
| 22 | 
            +
                        return wrapped_class
         | 
| 23 | 
            +
                    return inner_wrapper
         | 
| 24 | 
            +
             | 
| 25 | 
            +
                def get_by_name(self, name: str):
         | 
| 26 | 
            +
                    """Get a managed thing by name."""
         | 
| 27 | 
            +
                    if name in self._registry:
         | 
| 28 | 
            +
                        return self._registry[name]
         | 
| 29 | 
            +
                    else:
         | 
| 30 | 
            +
                        raise ValueError(f"{self.managed_thing} with name '{name}' unknown.")
         | 
| 31 | 
            +
             | 
| 32 | 
            +
                def get_all_names(self):
         | 
| 33 | 
            +
                    """Get the list of things' names registered to this registry."""
         | 
| 34 | 
            +
                    return list(self._registry.keys())
         | 
    	
        sgmse/util/tensors.py
    ADDED
    
    | @@ -0,0 +1,16 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            def batch_broadcast(a, x):
         | 
| 2 | 
            +
                """Broadcasts a over all dimensions of x, except the batch dimension, which must match."""
         | 
| 3 | 
            +
             | 
| 4 | 
            +
                if len(a.shape) != 1:
         | 
| 5 | 
            +
                    a = a.squeeze()
         | 
| 6 | 
            +
                    if len(a.shape) != 1:
         | 
| 7 | 
            +
                        raise ValueError(
         | 
| 8 | 
            +
                            f"Don't know how to batch-broadcast tensor `a` with more than one effective dimension (shape {a.shape})"
         | 
| 9 | 
            +
                        )
         | 
| 10 | 
            +
             | 
| 11 | 
            +
                if a.shape[0] != x.shape[0] and a.shape[0] != 1:
         | 
| 12 | 
            +
                    raise ValueError(
         | 
| 13 | 
            +
                        f"Don't know how to batch-broadcast shape {a.shape} over {x.shape} as the batch dimension is not matching")
         | 
| 14 | 
            +
             | 
| 15 | 
            +
                out = a.view((x.shape[0], *(1 for _ in range(len(x.shape)-1))))
         | 
| 16 | 
            +
                return out
         | 
