File size: 4,635 Bytes
05005db
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
# -*- coding: utf-8 -*-
# Copyright 2024 Yiwei Guo
#   Licensed under the Apache 2.0 license.

"""vec2wav2.0 main architectures"""

import torch
from vec2wav2.models.conformer.decoder import Decoder as ConformerDecoder
from vec2wav2.utils import crop_seq
from vec2wav2.models.bigvgan import BigVGAN
from vec2wav2.models.prompt_prenet import ConvPromptPrenet
import logging


class CTXVEC2WAVFrontend(torch.nn.Module):

    def __init__(self,

                 prompt_net_type,

                 num_mels,

                 vqvec_channels,

                 prompt_channels,

                 conformer_params):

        super(CTXVEC2WAVFrontend, self).__init__()

        if prompt_net_type == "ConvPromptPrenet":
            self.prompt_prenet = ConvPromptPrenet(
                embed=prompt_channels,
                conv_layers=[(128, 3, 1, 1), (256, 5, 1, 2), (512, 5, 1, 2), (conformer_params["attention_dim"], 3, 1, 1)],
                dropout=0.1,
                skip_connections=True,
                residual_scale=0.25,
                non_affine_group_norm=False,
                conv_bias=True,
                activation=torch.nn.ReLU()
            )
        elif prompt_net_type == "Conv1d":
            self.prompt_prenet = torch.nn.Conv1d(prompt_channels, conformer_params["attention_dim"], kernel_size=5, padding=2)
        else:
            raise NotImplementedError

        self.encoder1 = ConformerDecoder(vqvec_channels, input_layer='linear', **conformer_params)

        self.hidden_proj = torch.nn.Linear(conformer_params["attention_dim"], conformer_params["attention_dim"])

        self.encoder2 = ConformerDecoder(0, input_layer=None, **conformer_params)
        self.mel_proj = torch.nn.Linear(conformer_params["attention_dim"], num_mels)

    def forward(self, vqvec, prompt, mask=None, prompt_mask=None):
        """

        params:

            vqvec: sequence of VQ-vectors.

            prompt: sequence of mel-spectrogram prompt (acoustic context)

            mask: mask of the vqvec. True or 1 stands for valid values.

            prompt_mask: mask of the prompt.

        vqvec and prompt are of shape [B, D, T]. All masks are of shape [B, T].

        returns:

            enc_out: the input to the vec2wav2 Generator (BigVGAN);

            mel: the frontend predicted mel spectrogram (for faster convergence);

        """
        prompt = self.prompt_prenet(prompt.transpose(1, 2)).transpose(1, 2)

        if mask is not None:
            mask = mask.unsqueeze(-2)
        if prompt_mask is not None:
            prompt_mask = prompt_mask.unsqueeze(-2)
        enc_out, _ = self.encoder1(vqvec, mask, prompt, prompt_mask)

        h = self.hidden_proj(enc_out)

        enc_out, _ = self.encoder2(h, mask, prompt, prompt_mask)
        mel = self.mel_proj(enc_out)  # (B, L, 80)

        return enc_out, mel, None


class VEC2WAV2Generator(torch.nn.Module):

    def __init__(self, frontend: CTXVEC2WAVFrontend, backend: BigVGAN):

        super(VEC2WAV2Generator, self).__init__()
        self.frontend = frontend
        self.backend = backend

    def forward(self, vqvec, prompt, mask=None, prompt_mask=None, crop_len=0, crop_offsets=None):
        """

        :param vqvec: (torch.Tensor) The shape is (B, L, D). Sequence of VQ-vectors.

        :param prompt: (torch.Tensor) The shape is (B, L', 80). Sequence of mel-spectrogram prompt (acoustic context)

        :param mask: (torch.Tensor) The dtype is torch.bool. The shape is (B, L). True or 1 stands for valid values in `vqvec`.

        :param prompt_mask: (torch.Tensor) The dtype is torch.bool. The shape is (B, L'). True or 1 stands for valid values in `prompt`.

        :return: frontend predicted mel spectrogram; reconstructed waveform.

        """
        h, mel, _ = self.frontend(vqvec, prompt, mask=mask, prompt_mask=prompt_mask)  # (B, L, adim), (B, L, 80)
        if mask is not None:
            h = h.masked_fill(~mask.unsqueeze(-1), 0)
        h = h.transpose(1, 2)
        if crop_len > 0:
            h = crop_seq(h, crop_offsets, crop_len)
        if prompt_mask is not None:
            prompt_avg = prompt.masked_fill(~prompt_mask.unsqueeze(-1), 0).sum(1) / prompt_mask.sum(1).unsqueeze(-1)
        else:
            prompt_avg = prompt.mean(1)
        wav = self.backend(h, prompt_avg)  # (B, C, T)
        return mel, None, wav

    def inference(self, vqvec, prompt):
        h, mel, _ = self.frontend(vqvec, prompt)
        wav = self.backend(h.transpose(1,2), prompt.mean(1))

        return mel, None, wav