Sulai2005's picture
Initial commit
506a2b4
# Modified from CosyVoice https://github.com/FunAudioLLM/CosyVoice
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import numpy as np
import torch
import torchaudio as ta
from functools import lru_cache
from typing import Optional
from ..s3tokenizer import S3_SR, SPEECH_VOCAB_SIZE, S3Tokenizer
from .const import S3GEN_SR
from .flow import CausalMaskedDiffWithXvec
from .xvector import CAMPPlus
from .utils.mel import mel_spectrogram
from .f0_predictor import ConvRNNF0Predictor
from .hifigan import HiFTGenerator
from .transformer.upsample_encoder import UpsampleConformerEncoder
from .flow_matching import CausalConditionalCFM
from .decoder import ConditionalDecoder
from .configs import CFM_PARAMS
def drop_invalid_tokens(x):
assert len(x.shape) <= 2 and x.shape[0] == 1, "only batch size of one allowed for now"
return x[x < SPEECH_VOCAB_SIZE]
# TODO: global resampler cache
@lru_cache(100)
def get_resampler(src_sr, dst_sr, device):
return ta.transforms.Resample(src_sr, dst_sr).to(device)
class S3Token2Mel(torch.nn.Module):
"""
CosyVoice2's CFM decoder maps S3 speech tokens to mel-spectrograms.
TODO: make these modules configurable?
"""
def __init__(self):
super().__init__()
self.tokenizer = S3Tokenizer("speech_tokenizer_v2_25hz")
self.mel_extractor = mel_spectrogram # TODO: make it a torch module?
self.speaker_encoder = CAMPPlus() # use default args
encoder = UpsampleConformerEncoder(
output_size=512,
attention_heads=8,
linear_units=2048,
num_blocks=6,
dropout_rate=0.1,
positional_dropout_rate=0.1,
attention_dropout_rate=0.1,
normalize_before=True,
input_layer='linear',
pos_enc_layer_type='rel_pos_espnet',
selfattention_layer_type='rel_selfattn',
input_size=512,
use_cnn_module=False,
macaron_style=False,
)
estimator = ConditionalDecoder(
in_channels=320,
out_channels=80,
causal=True,
channels=[256],
dropout=0.0,
attention_head_dim=64,
n_blocks=4,
num_mid_blocks=12,
num_heads=8,
act_fn='gelu',
)
cfm_params = CFM_PARAMS
decoder = CausalConditionalCFM(
spk_emb_dim=80,
cfm_params=cfm_params,
estimator=estimator,
)
self.flow = CausalMaskedDiffWithXvec(
encoder=encoder,
decoder=decoder
)
self.resamplers = {}
@property
def device(self):
params = self.tokenizer.parameters()
return next(params).device
def embed_ref(
self,
ref_wav: torch.Tensor,
ref_sr: int,
device="auto",
ref_fade_out=True,
):
device = self.device if device == "auto" else device
if isinstance(ref_wav, np.ndarray):
ref_wav = torch.from_numpy(ref_wav).float()
if ref_wav.device != device:
ref_wav = ref_wav.to(device)
if len(ref_wav.shape) == 1:
ref_wav = ref_wav.unsqueeze(0) # (B, L)
if ref_wav.size(1) > 10 * ref_sr:
print("WARNING: cosydec received ref longer than 10s")
ref_wav_24 = ref_wav
if ref_sr != S3GEN_SR:
ref_wav_24 = get_resampler(ref_sr, S3GEN_SR, device)(ref_wav)
ref_mels_24 = self.mel_extractor(ref_wav_24).transpose(1, 2).to(device)
ref_mels_24_len = None
# Resample to 16kHz
ref_wav_16 = get_resampler(ref_sr, S3_SR, device)(ref_wav).to(device)
# Speaker embedding
ref_x_vector = self.speaker_encoder.inference(ref_wav_16)
# Tokenize 16khz reference
ref_speech_tokens, ref_speech_token_lens = self.tokenizer(ref_wav_16)
# Make sure mel_len = 2 * stoken_len (happens when the input is not padded to multiple of 40ms)
if ref_mels_24.shape[1] != 2 * ref_speech_tokens.shape[1]:
logging.warning(
"Reference mel length is not equal to 2 * reference token length.\n"
)
ref_speech_tokens = ref_speech_tokens[:, :ref_mels_24.shape[1] // 2]
ref_speech_token_lens[0] = ref_speech_tokens.shape[1]
return dict(
prompt_token=ref_speech_tokens.to(device),
prompt_token_len=ref_speech_token_lens,
prompt_feat=ref_mels_24,
prompt_feat_len=ref_mels_24_len,
embedding=ref_x_vector,
)
def forward(
self,
speech_tokens: torch.LongTensor,
# locally-computed ref embedding (mutex with ref_dict)
ref_wav: Optional[torch.Tensor],
ref_sr: Optional[int],
# pre-computed ref embedding (prod API)
ref_dict: Optional[dict] = None,
finalize: bool = False,
):
"""
Generate waveforms from S3 speech tokens and a reference waveform, which the speaker timbre is inferred from.
NOTE:
- The speaker encoder accepts 16 kHz waveform.
- S3TokenizerV2 accepts 16 kHz waveform.
- The mel-spectrogram for the reference assumes 24 kHz input signal.
- This function is designed for batch_size=1 only.
Args
----
- `speech_tokens`: S3 speech tokens [B=1, T]
- `ref_wav`: reference waveform (`torch.Tensor` with shape=[B=1, T])
- `ref_sr`: reference sample rate
- `finalize`: whether streaming is finished or not. Note that if False, the last 3 tokens will be ignored.
"""
assert (ref_wav is None) ^ (ref_dict is None), f"Must provide exactly one of ref_wav or ref_dict (got {ref_wav} and {ref_dict})"
if ref_dict is None:
ref_dict = self.embed_ref(ref_wav, ref_sr)
else:
# type/device casting (all values will be numpy if it's from a prod API call)
for rk in list(ref_dict):
if isinstance(ref_dict[rk], np.ndarray):
ref_dict[rk] = torch.from_numpy(ref_dict[rk])
if torch.is_tensor(ref_dict[rk]):
ref_dict[rk] = ref_dict[rk].to(self.device)
if len(speech_tokens.shape) == 1:
speech_tokens = speech_tokens.unsqueeze(0)
# assert speech_tokens.shape[0] == 1, "only batch size of one allowed for now"
speech_token_lens = torch.LongTensor([speech_tokens.size(1)]).to(self.device)
output_mels, _ = self.flow.inference(
token=speech_tokens,
token_len=speech_token_lens,
finalize=finalize,
**ref_dict,
)
return output_mels
class S3Token2Wav(S3Token2Mel):
"""
The decoder of CosyVoice2 is a concat of token-to-mel (CFM) and a mel-to-waveform (HiFiGAN) modules.
TODO: make these modules configurable?
"""
def __init__(self):
super().__init__()
f0_predictor = ConvRNNF0Predictor()
self.mel2wav = HiFTGenerator(
sampling_rate=S3GEN_SR,
upsample_rates=[8, 5, 3],
upsample_kernel_sizes=[16, 11, 7],
source_resblock_kernel_sizes=[7, 7, 11],
source_resblock_dilation_sizes=[[1, 3, 5], [1, 3, 5], [1, 3, 5]],
f0_predictor=f0_predictor,
)
# silence out a few ms and fade audio in to reduce artifacts
n_trim = S3GEN_SR // 50 # 20ms = half of a frame
trim_fade = torch.zeros(2 * n_trim)
trim_fade[n_trim:] = (torch.cos(torch.linspace(torch.pi, 0, n_trim)) + 1) / 2
self.register_buffer("trim_fade", trim_fade, persistent=False) # (buffers get automatic device casting)
def forward(
self,
speech_tokens,
# locally-computed ref embedding (mutex with ref_dict)
ref_wav: Optional[torch.Tensor],
ref_sr: Optional[int],
# pre-computed ref embedding (prod API)
ref_dict: Optional[dict] = None,
finalize: bool = False
):
output_mels = super().forward(speech_tokens, ref_wav=ref_wav, ref_sr=ref_sr, ref_dict=ref_dict, finalize=finalize)
# TODO jrm: ignoring the speed control (mel interpolation) and the HiFTGAN caching mechanisms for now.
hift_cache_source = torch.zeros(1, 1, 0).to(self.device)
output_wavs, *_ = self.mel2wav.inference(speech_feat=output_mels, cache_source=hift_cache_source)
if not self.training:
# NOTE: ad-hoc method to reduce "spillover" from the reference clip.
output_wavs[:, :len(self.trim_fade)] *= self.trim_fade
return output_wavs
@torch.inference_mode()
def flow_inference(
self,
speech_tokens,
# locally-computed ref embedding (mutex with ref_dict)
ref_wav: Optional[torch.Tensor] = None,
ref_sr: Optional[int] = None,
# pre-computed ref embedding (prod API)
ref_dict: Optional[dict] = None,
finalize: bool = False,
):
return super().forward(speech_tokens, ref_wav=ref_wav, ref_sr=ref_sr, ref_dict=ref_dict, finalize=finalize)
@torch.inference_mode()
def hift_inference(self, speech_feat, cache_source: torch.Tensor = None):
if cache_source is None:
cache_source = torch.zeros(1, 1, 0).to(self.device)
return self.mel2wav.inference(speech_feat=speech_feat, cache_source=cache_source)
@torch.inference_mode()
def inference(
self,
speech_tokens,
# locally-computed ref embedding (mutex with ref_dict)
ref_wav: Optional[torch.Tensor] = None,
ref_sr: Optional[int] = None,
# pre-computed ref embedding (prod API)
ref_dict: Optional[dict] = None,
cache_source: torch.Tensor = None, # NOTE: this arg is for streaming, it can probably be removed here
finalize: bool = True,
):
output_mels = self.flow_inference(speech_tokens, ref_wav=ref_wav, ref_sr=ref_sr, ref_dict=ref_dict, finalize=finalize)
output_wavs, output_sources = self.hift_inference(output_mels, cache_source)
# NOTE: ad-hoc method to reduce "spillover" from the reference clip.
output_wavs[:, :len(self.trim_fade)] *= self.trim_fade
return output_wavs, output_sources