Spaces:
Runtime error
Runtime error
| # Copyright (c) ByteDance, Inc. and its affiliates. | |
| # Copyright (c) Chutong Meng | |
| # | |
| # This source code is licensed under the CC BY-NC license found in the | |
| # LICENSE file in the root directory of this source tree. | |
| # Based on AudioDec (https://github.com/facebookresearch/AudioDec) | |
| import torch.nn as nn | |
| from repcodec.modules.decoder import Decoder | |
| from repcodec.modules.encoder import Encoder | |
| from repcodec.modules.projector import Projector | |
| from repcodec.modules.quantizer import Quantizer | |
| class RepCodec(nn.Module): | |
| def __init__( | |
| self, | |
| input_channels=768, | |
| output_channels=768, | |
| encode_channels=768, | |
| decode_channels=768, | |
| code_dim=768, | |
| codebook_num=1, | |
| codebook_size=1024, | |
| bias=True, | |
| enc_ratios=(1, 1), | |
| dec_ratios=(1, 1), | |
| enc_strides=(1, 1), | |
| dec_strides=(1, 1), | |
| enc_kernel_size=3, | |
| dec_kernel_size=3, | |
| enc_block_dilations=(1, 1), | |
| enc_block_kernel_size=3, | |
| dec_block_dilations=(1, 1), | |
| dec_block_kernel_size=3 | |
| ): | |
| super().__init__() | |
| self.input_channels = input_channels | |
| self.encoder = Encoder( | |
| input_channels=input_channels, | |
| encode_channels=encode_channels, | |
| channel_ratios=enc_ratios, | |
| strides=enc_strides, | |
| kernel_size=enc_kernel_size, | |
| bias=bias, | |
| block_dilations=enc_block_dilations, | |
| unit_kernel_size=enc_block_kernel_size | |
| ) | |
| self.decoder = Decoder( | |
| code_dim=code_dim, | |
| output_channels=output_channels, | |
| decode_channels=decode_channels, | |
| channel_ratios=dec_ratios, | |
| strides=dec_strides, | |
| kernel_size=dec_kernel_size, | |
| bias=bias, | |
| block_dilations=dec_block_dilations, | |
| unit_kernel_size=dec_block_kernel_size | |
| ) | |
| self.projector = Projector( | |
| input_channels=self.encoder.out_channels, | |
| code_dim=code_dim, | |
| kernel_size=3, | |
| stride=1, | |
| bias=False | |
| ) | |
| self.quantizer = Quantizer( | |
| code_dim=code_dim, | |
| codebook_num=codebook_num, | |
| codebook_size=codebook_size | |
| ) | |
| def forward(self, x): | |
| x = self.encoder(x) | |
| z = self.projector(x) | |
| zq, vqloss, perplexity = self.quantizer(z) | |
| y = self.decoder(zq) | |
| return y, zq, z, vqloss, perplexity | |