Spaces:
Runtime error
Runtime error
| import logging as logger | |
| from .architecture.DAT import DAT | |
| from .architecture.face.codeformer import CodeFormer | |
| from .architecture.face.gfpganv1_clean_arch import GFPGANv1Clean | |
| from .architecture.face.restoreformer_arch import RestoreFormer | |
| from .architecture.HAT import HAT | |
| from .architecture.LaMa import LaMa | |
| from .architecture.OmniSR.OmniSR import OmniSR | |
| from .architecture.RRDB import RRDBNet as ESRGAN | |
| from .architecture.SCUNet import SCUNet | |
| from .architecture.SPSR import SPSRNet as SPSR | |
| from .architecture.SRVGG import SRVGGNetCompact as RealESRGANv2 | |
| from .architecture.SwiftSRGAN import Generator as SwiftSRGAN | |
| from .architecture.Swin2SR import Swin2SR | |
| from .architecture.SwinIR import SwinIR | |
| from .types import PyTorchModel | |
| class UnsupportedModel(Exception): | |
| pass | |
| def load_state_dict(state_dict) -> PyTorchModel: | |
| logger.debug(f"Loading state dict into pytorch model arch") | |
| state_dict_keys = list(state_dict.keys()) | |
| if "params_ema" in state_dict_keys: | |
| state_dict = state_dict["params_ema"] | |
| elif "params-ema" in state_dict_keys: | |
| state_dict = state_dict["params-ema"] | |
| elif "params" in state_dict_keys: | |
| state_dict = state_dict["params"] | |
| state_dict_keys = list(state_dict.keys()) | |
| # SRVGGNet Real-ESRGAN (v2) | |
| if "body.0.weight" in state_dict_keys and "body.1.weight" in state_dict_keys: | |
| model = RealESRGANv2(state_dict) | |
| # SPSR (ESRGAN with lots of extra layers) | |
| elif "f_HR_conv1.0.weight" in state_dict: | |
| model = SPSR(state_dict) | |
| # Swift-SRGAN | |
| elif ( | |
| "model" in state_dict_keys | |
| and "initial.cnn.depthwise.weight" in state_dict["model"].keys() | |
| ): | |
| model = SwiftSRGAN(state_dict) | |
| # SwinIR, Swin2SR, HAT | |
| elif "layers.0.residual_group.blocks.0.norm1.weight" in state_dict_keys: | |
| if ( | |
| "layers.0.residual_group.blocks.0.conv_block.cab.0.weight" | |
| in state_dict_keys | |
| ): | |
| model = HAT(state_dict) | |
| elif "patch_embed.proj.weight" in state_dict_keys: | |
| model = Swin2SR(state_dict) | |
| else: | |
| model = SwinIR(state_dict) | |
| # GFPGAN | |
| elif ( | |
| "toRGB.0.weight" in state_dict_keys | |
| and "stylegan_decoder.style_mlp.1.weight" in state_dict_keys | |
| ): | |
| model = GFPGANv1Clean(state_dict) | |
| # RestoreFormer | |
| elif ( | |
| "encoder.conv_in.weight" in state_dict_keys | |
| and "encoder.down.0.block.0.norm1.weight" in state_dict_keys | |
| ): | |
| model = RestoreFormer(state_dict) | |
| elif ( | |
| "encoder.blocks.0.weight" in state_dict_keys | |
| and "quantize.embedding.weight" in state_dict_keys | |
| ): | |
| model = CodeFormer(state_dict) | |
| # LaMa | |
| elif ( | |
| "model.model.1.bn_l.running_mean" in state_dict_keys | |
| or "generator.model.1.bn_l.running_mean" in state_dict_keys | |
| ): | |
| model = LaMa(state_dict) | |
| # Omni-SR | |
| elif "residual_layer.0.residual_layer.0.layer.0.fn.0.weight" in state_dict_keys: | |
| model = OmniSR(state_dict) | |
| # SCUNet | |
| elif "m_head.0.weight" in state_dict_keys and "m_tail.0.weight" in state_dict_keys: | |
| model = SCUNet(state_dict) | |
| # DAT | |
| elif "layers.0.blocks.2.attn.attn_mask_0" in state_dict_keys: | |
| model = DAT(state_dict) | |
| # Regular ESRGAN, "new-arch" ESRGAN, Real-ESRGAN v1 | |
| else: | |
| try: | |
| model = ESRGAN(state_dict) | |
| except: | |
| # pylint: disable=raise-missing-from | |
| raise UnsupportedModel | |
| return model | |