Spaces:
Build error
Build error
File size: 4,635 Bytes
05005db |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 |
# -*- coding: utf-8 -*-
# Copyright 2024 Yiwei Guo
# Licensed under the Apache 2.0 license.
"""vec2wav2.0 main architectures"""
import torch
from vec2wav2.models.conformer.decoder import Decoder as ConformerDecoder
from vec2wav2.utils import crop_seq
from vec2wav2.models.bigvgan import BigVGAN
from vec2wav2.models.prompt_prenet import ConvPromptPrenet
import logging
class CTXVEC2WAVFrontend(torch.nn.Module):
def __init__(self,
prompt_net_type,
num_mels,
vqvec_channels,
prompt_channels,
conformer_params):
super(CTXVEC2WAVFrontend, self).__init__()
if prompt_net_type == "ConvPromptPrenet":
self.prompt_prenet = ConvPromptPrenet(
embed=prompt_channels,
conv_layers=[(128, 3, 1, 1), (256, 5, 1, 2), (512, 5, 1, 2), (conformer_params["attention_dim"], 3, 1, 1)],
dropout=0.1,
skip_connections=True,
residual_scale=0.25,
non_affine_group_norm=False,
conv_bias=True,
activation=torch.nn.ReLU()
)
elif prompt_net_type == "Conv1d":
self.prompt_prenet = torch.nn.Conv1d(prompt_channels, conformer_params["attention_dim"], kernel_size=5, padding=2)
else:
raise NotImplementedError
self.encoder1 = ConformerDecoder(vqvec_channels, input_layer='linear', **conformer_params)
self.hidden_proj = torch.nn.Linear(conformer_params["attention_dim"], conformer_params["attention_dim"])
self.encoder2 = ConformerDecoder(0, input_layer=None, **conformer_params)
self.mel_proj = torch.nn.Linear(conformer_params["attention_dim"], num_mels)
def forward(self, vqvec, prompt, mask=None, prompt_mask=None):
"""
params:
vqvec: sequence of VQ-vectors.
prompt: sequence of mel-spectrogram prompt (acoustic context)
mask: mask of the vqvec. True or 1 stands for valid values.
prompt_mask: mask of the prompt.
vqvec and prompt are of shape [B, D, T]. All masks are of shape [B, T].
returns:
enc_out: the input to the vec2wav2 Generator (BigVGAN);
mel: the frontend predicted mel spectrogram (for faster convergence);
"""
prompt = self.prompt_prenet(prompt.transpose(1, 2)).transpose(1, 2)
if mask is not None:
mask = mask.unsqueeze(-2)
if prompt_mask is not None:
prompt_mask = prompt_mask.unsqueeze(-2)
enc_out, _ = self.encoder1(vqvec, mask, prompt, prompt_mask)
h = self.hidden_proj(enc_out)
enc_out, _ = self.encoder2(h, mask, prompt, prompt_mask)
mel = self.mel_proj(enc_out) # (B, L, 80)
return enc_out, mel, None
class VEC2WAV2Generator(torch.nn.Module):
def __init__(self, frontend: CTXVEC2WAVFrontend, backend: BigVGAN):
super(VEC2WAV2Generator, self).__init__()
self.frontend = frontend
self.backend = backend
def forward(self, vqvec, prompt, mask=None, prompt_mask=None, crop_len=0, crop_offsets=None):
"""
:param vqvec: (torch.Tensor) The shape is (B, L, D). Sequence of VQ-vectors.
:param prompt: (torch.Tensor) The shape is (B, L', 80). Sequence of mel-spectrogram prompt (acoustic context)
:param mask: (torch.Tensor) The dtype is torch.bool. The shape is (B, L). True or 1 stands for valid values in `vqvec`.
:param prompt_mask: (torch.Tensor) The dtype is torch.bool. The shape is (B, L'). True or 1 stands for valid values in `prompt`.
:return: frontend predicted mel spectrogram; reconstructed waveform.
"""
h, mel, _ = self.frontend(vqvec, prompt, mask=mask, prompt_mask=prompt_mask) # (B, L, adim), (B, L, 80)
if mask is not None:
h = h.masked_fill(~mask.unsqueeze(-1), 0)
h = h.transpose(1, 2)
if crop_len > 0:
h = crop_seq(h, crop_offsets, crop_len)
if prompt_mask is not None:
prompt_avg = prompt.masked_fill(~prompt_mask.unsqueeze(-1), 0).sum(1) / prompt_mask.sum(1).unsqueeze(-1)
else:
prompt_avg = prompt.mean(1)
wav = self.backend(h, prompt_avg) # (B, C, T)
return mel, None, wav
def inference(self, vqvec, prompt):
h, mel, _ = self.frontend(vqvec, prompt)
wav = self.backend(h.transpose(1,2), prompt.mean(1))
return mel, None, wav
|