Spaces:
Running
Running
# Modified from CosyVoice https://github.com/FunAudioLLM/CosyVoice | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
import logging | |
import numpy as np | |
import torch | |
import torchaudio as ta | |
from functools import lru_cache | |
from typing import Optional | |
from ..s3tokenizer import S3_SR, SPEECH_VOCAB_SIZE, S3Tokenizer | |
from .const import S3GEN_SR | |
from .flow import CausalMaskedDiffWithXvec | |
from .xvector import CAMPPlus | |
from .utils.mel import mel_spectrogram | |
from .f0_predictor import ConvRNNF0Predictor | |
from .hifigan import HiFTGenerator | |
from .transformer.upsample_encoder import UpsampleConformerEncoder | |
from .flow_matching import CausalConditionalCFM | |
from .decoder import ConditionalDecoder | |
from .configs import CFM_PARAMS | |
def drop_invalid_tokens(x): | |
assert len(x.shape) <= 2 and x.shape[0] == 1, "only batch size of one allowed for now" | |
return x[x < SPEECH_VOCAB_SIZE] | |
# TODO: global resampler cache | |
def get_resampler(src_sr, dst_sr, device): | |
return ta.transforms.Resample(src_sr, dst_sr).to(device) | |
class S3Token2Mel(torch.nn.Module): | |
""" | |
CosyVoice2's CFM decoder maps S3 speech tokens to mel-spectrograms. | |
TODO: make these modules configurable? | |
""" | |
def __init__(self): | |
super().__init__() | |
self.tokenizer = S3Tokenizer("speech_tokenizer_v2_25hz") | |
self.mel_extractor = mel_spectrogram # TODO: make it a torch module? | |
self.speaker_encoder = CAMPPlus() # use default args | |
encoder = UpsampleConformerEncoder( | |
output_size=512, | |
attention_heads=8, | |
linear_units=2048, | |
num_blocks=6, | |
dropout_rate=0.1, | |
positional_dropout_rate=0.1, | |
attention_dropout_rate=0.1, | |
normalize_before=True, | |
input_layer='linear', | |
pos_enc_layer_type='rel_pos_espnet', | |
selfattention_layer_type='rel_selfattn', | |
input_size=512, | |
use_cnn_module=False, | |
macaron_style=False, | |
) | |
estimator = ConditionalDecoder( | |
in_channels=320, | |
out_channels=80, | |
causal=True, | |
channels=[256], | |
dropout=0.0, | |
attention_head_dim=64, | |
n_blocks=4, | |
num_mid_blocks=12, | |
num_heads=8, | |
act_fn='gelu', | |
) | |
cfm_params = CFM_PARAMS | |
decoder = CausalConditionalCFM( | |
spk_emb_dim=80, | |
cfm_params=cfm_params, | |
estimator=estimator, | |
) | |
self.flow = CausalMaskedDiffWithXvec( | |
encoder=encoder, | |
decoder=decoder | |
) | |
self.resamplers = {} | |
def device(self): | |
params = self.tokenizer.parameters() | |
return next(params).device | |
def embed_ref( | |
self, | |
ref_wav: torch.Tensor, | |
ref_sr: int, | |
device="auto", | |
ref_fade_out=True, | |
): | |
device = self.device if device == "auto" else device | |
if isinstance(ref_wav, np.ndarray): | |
ref_wav = torch.from_numpy(ref_wav).float() | |
if ref_wav.device != device: | |
ref_wav = ref_wav.to(device) | |
if len(ref_wav.shape) == 1: | |
ref_wav = ref_wav.unsqueeze(0) # (B, L) | |
if ref_wav.size(1) > 10 * ref_sr: | |
print("WARNING: cosydec received ref longer than 10s") | |
ref_wav_24 = ref_wav | |
if ref_sr != S3GEN_SR: | |
ref_wav_24 = get_resampler(ref_sr, S3GEN_SR, device)(ref_wav) | |
ref_mels_24 = self.mel_extractor(ref_wav_24).transpose(1, 2).to(device) | |
ref_mels_24_len = None | |
# Resample to 16kHz | |
ref_wav_16 = get_resampler(ref_sr, S3_SR, device)(ref_wav).to(device) | |
# Speaker embedding | |
ref_x_vector = self.speaker_encoder.inference(ref_wav_16) | |
# Tokenize 16khz reference | |
ref_speech_tokens, ref_speech_token_lens = self.tokenizer(ref_wav_16) | |
# Make sure mel_len = 2 * stoken_len (happens when the input is not padded to multiple of 40ms) | |
if ref_mels_24.shape[1] != 2 * ref_speech_tokens.shape[1]: | |
logging.warning( | |
"Reference mel length is not equal to 2 * reference token length.\n" | |
) | |
ref_speech_tokens = ref_speech_tokens[:, :ref_mels_24.shape[1] // 2] | |
ref_speech_token_lens[0] = ref_speech_tokens.shape[1] | |
return dict( | |
prompt_token=ref_speech_tokens.to(device), | |
prompt_token_len=ref_speech_token_lens, | |
prompt_feat=ref_mels_24, | |
prompt_feat_len=ref_mels_24_len, | |
embedding=ref_x_vector, | |
) | |
def forward( | |
self, | |
speech_tokens: torch.LongTensor, | |
# locally-computed ref embedding (mutex with ref_dict) | |
ref_wav: Optional[torch.Tensor], | |
ref_sr: Optional[int], | |
# pre-computed ref embedding (prod API) | |
ref_dict: Optional[dict] = None, | |
finalize: bool = False, | |
): | |
""" | |
Generate waveforms from S3 speech tokens and a reference waveform, which the speaker timbre is inferred from. | |
NOTE: | |
- The speaker encoder accepts 16 kHz waveform. | |
- S3TokenizerV2 accepts 16 kHz waveform. | |
- The mel-spectrogram for the reference assumes 24 kHz input signal. | |
- This function is designed for batch_size=1 only. | |
Args | |
---- | |
- `speech_tokens`: S3 speech tokens [B=1, T] | |
- `ref_wav`: reference waveform (`torch.Tensor` with shape=[B=1, T]) | |
- `ref_sr`: reference sample rate | |
- `finalize`: whether streaming is finished or not. Note that if False, the last 3 tokens will be ignored. | |
""" | |
assert (ref_wav is None) ^ (ref_dict is None), f"Must provide exactly one of ref_wav or ref_dict (got {ref_wav} and {ref_dict})" | |
if ref_dict is None: | |
ref_dict = self.embed_ref(ref_wav, ref_sr) | |
else: | |
# type/device casting (all values will be numpy if it's from a prod API call) | |
for rk in list(ref_dict): | |
if isinstance(ref_dict[rk], np.ndarray): | |
ref_dict[rk] = torch.from_numpy(ref_dict[rk]) | |
if torch.is_tensor(ref_dict[rk]): | |
ref_dict[rk] = ref_dict[rk].to(self.device) | |
if len(speech_tokens.shape) == 1: | |
speech_tokens = speech_tokens.unsqueeze(0) | |
# assert speech_tokens.shape[0] == 1, "only batch size of one allowed for now" | |
speech_token_lens = torch.LongTensor([speech_tokens.size(1)]).to(self.device) | |
output_mels, _ = self.flow.inference( | |
token=speech_tokens, | |
token_len=speech_token_lens, | |
finalize=finalize, | |
**ref_dict, | |
) | |
return output_mels | |
class S3Token2Wav(S3Token2Mel): | |
""" | |
The decoder of CosyVoice2 is a concat of token-to-mel (CFM) and a mel-to-waveform (HiFiGAN) modules. | |
TODO: make these modules configurable? | |
""" | |
def __init__(self): | |
super().__init__() | |
f0_predictor = ConvRNNF0Predictor() | |
self.mel2wav = HiFTGenerator( | |
sampling_rate=S3GEN_SR, | |
upsample_rates=[8, 5, 3], | |
upsample_kernel_sizes=[16, 11, 7], | |
source_resblock_kernel_sizes=[7, 7, 11], | |
source_resblock_dilation_sizes=[[1, 3, 5], [1, 3, 5], [1, 3, 5]], | |
f0_predictor=f0_predictor, | |
) | |
# silence out a few ms and fade audio in to reduce artifacts | |
n_trim = S3GEN_SR // 50 # 20ms = half of a frame | |
trim_fade = torch.zeros(2 * n_trim) | |
trim_fade[n_trim:] = (torch.cos(torch.linspace(torch.pi, 0, n_trim)) + 1) / 2 | |
self.register_buffer("trim_fade", trim_fade, persistent=False) # (buffers get automatic device casting) | |
def forward( | |
self, | |
speech_tokens, | |
# locally-computed ref embedding (mutex with ref_dict) | |
ref_wav: Optional[torch.Tensor], | |
ref_sr: Optional[int], | |
# pre-computed ref embedding (prod API) | |
ref_dict: Optional[dict] = None, | |
finalize: bool = False | |
): | |
output_mels = super().forward(speech_tokens, ref_wav=ref_wav, ref_sr=ref_sr, ref_dict=ref_dict, finalize=finalize) | |
# TODO jrm: ignoring the speed control (mel interpolation) and the HiFTGAN caching mechanisms for now. | |
hift_cache_source = torch.zeros(1, 1, 0).to(self.device) | |
output_wavs, *_ = self.mel2wav.inference(speech_feat=output_mels, cache_source=hift_cache_source) | |
if not self.training: | |
# NOTE: ad-hoc method to reduce "spillover" from the reference clip. | |
output_wavs[:, :len(self.trim_fade)] *= self.trim_fade | |
return output_wavs | |
def flow_inference( | |
self, | |
speech_tokens, | |
# locally-computed ref embedding (mutex with ref_dict) | |
ref_wav: Optional[torch.Tensor] = None, | |
ref_sr: Optional[int] = None, | |
# pre-computed ref embedding (prod API) | |
ref_dict: Optional[dict] = None, | |
finalize: bool = False, | |
): | |
return super().forward(speech_tokens, ref_wav=ref_wav, ref_sr=ref_sr, ref_dict=ref_dict, finalize=finalize) | |
def hift_inference(self, speech_feat, cache_source: torch.Tensor = None): | |
if cache_source is None: | |
cache_source = torch.zeros(1, 1, 0).to(self.device) | |
return self.mel2wav.inference(speech_feat=speech_feat, cache_source=cache_source) | |
def inference( | |
self, | |
speech_tokens, | |
# locally-computed ref embedding (mutex with ref_dict) | |
ref_wav: Optional[torch.Tensor] = None, | |
ref_sr: Optional[int] = None, | |
# pre-computed ref embedding (prod API) | |
ref_dict: Optional[dict] = None, | |
cache_source: torch.Tensor = None, # NOTE: this arg is for streaming, it can probably be removed here | |
finalize: bool = True, | |
): | |
output_mels = self.flow_inference(speech_tokens, ref_wav=ref_wav, ref_sr=ref_sr, ref_dict=ref_dict, finalize=finalize) | |
output_wavs, output_sources = self.hift_inference(output_mels, cache_source) | |
# NOTE: ad-hoc method to reduce "spillover" from the reference clip. | |
output_wavs[:, :len(self.trim_fade)] *= self.trim_fade | |
return output_wavs, output_sources | |