from typing import List, Tuple import numpy as np import librosa import torch import torch.nn.functional as F from s3tokenizer.utils import padding from s3tokenizer.model_v2 import ( S3TokenizerV2, ModelConfig, ) # Sampling rate of the inputs to S3TokenizerV2 S3_SR = 16_000 S3_HOP = 160 # 100 frames/sec S3_TOKEN_HOP = 640 # 25 tokens/sec S3_TOKEN_RATE = 25 SPEECH_VOCAB_SIZE = 6561 class S3Tokenizer(S3TokenizerV2): """ s3tokenizer.S3TokenizerV2 with the following changes: - a more integrated `forward` - compute `log_mel_spectrogram` using `_mel_filters` and `window` in `register_buffers` """ ignore_state_dict_missing = ("_mel_filters", "window") def __init__( self, name: str="speech_tokenizer_v2_25hz", config: ModelConfig = ModelConfig() ): super().__init__(name) self.n_fft = 400 _mel_filters = librosa.filters.mel( sr=S3_SR, n_fft=self.n_fft, n_mels=config.n_mels ) self.register_buffer( "_mel_filters", torch.FloatTensor(_mel_filters), ) self.register_buffer( "window", torch.hann_window(self.n_fft), ) def pad(self, wavs, sr) -> List[torch.Tensor]: """ Given a list of wavs with the same `sample_rate`, pad them so that the length is multiple of 40ms (S3 runs at 25 token/sec). """ processed_wavs = [] for wav in wavs: if isinstance(wav, np.ndarray): wav = torch.from_numpy(wav) if wav.dim() == 1: wav = wav.unsqueeze(0) n_tokens = (wav.shape[1] / sr) * S3_TOKEN_RATE n_tokens = np.ceil(n_tokens) intended_wav_len = n_tokens * (sr / S3_TOKEN_RATE) intended_wav_len = int(intended_wav_len) wav = torch.nn.functional.pad( wav, (0, intended_wav_len - wav.shape[-1]), mode="constant", value=0 ) processed_wavs.append(wav) return processed_wavs def _prepare_audio(self, wavs): """Prepare a list of audios for s3tokenizer processing.""" processed_wavs = [] for wav in wavs: if isinstance(wav, np.ndarray): wav = torch.from_numpy(wav) if wav.dim() == 1: wav = wav.unsqueeze(0) processed_wavs.append(wav) return processed_wavs @torch.no_grad() def forward( self, wavs: torch.Tensor, accelerator: 'Accelerator'=None, max_len: int=None, ) -> Tuple[torch.Tensor, torch.LongTensor]: """ NOTE: mel-spec has a hop size of 160 points (100 frame/sec). FIXME: this class inherits `nn.Module` but doesn't accept `torch.Tensor` and handles a list of wavs one by one, which is unexpected. Args ---- - `wavs`: 16 kHz speech audio - `max_len` max length to truncate the output sequence to (25 token/sec). NOTE: please pad the waveform if longer sequence is needed. """ processed_wavs = self._prepare_audio(wavs) mels, mel_lens = [], [] for wav in processed_wavs: wav = wav.to(self.device) mel = self.log_mel_spectrogram(wav) # [B=1, F, T] if max_len is not None: mel = mel[..., :max_len * 4] # num_mel_frames = 4 * num_tokens mels.append(mel.squeeze(0)) mels, mel_lens = padding(mels) if accelerator is None: tokenizer = self else: tokenizer = accelerator.unwrap_model(self) speech_tokens, speech_token_lens = tokenizer.quantize(mels, mel_lens.to(self.device)) return ( speech_tokens.long().detach(), speech_token_lens.long().detach(), ) def log_mel_spectrogram( self, audio: torch.Tensor, padding: int = 0, ): """ Compute the log-Mel spectrogram of Parameters ---------- audio: torch.Tensor, shape = (*) The path to audio or either a NumPy array or Tensor containing the audio waveform in 16 kHz padding: int Number of zero samples to pad to the right Returns ------- torch.Tensor, shape = (128, n_frames) A Tensor that contains the Mel spectrogram """ if not torch.is_tensor(audio): audio = torch.from_numpy(audio) audio = audio.to(self.device) if padding > 0: audio = F.pad(audio, (0, padding)) stft = torch.stft( audio, self.n_fft, S3_HOP, window=self.window.to(self.device), return_complex=True ) magnitudes = stft[..., :-1].abs()**2 mel_spec = self._mel_filters.to(self.device) @ magnitudes log_spec = torch.clamp(mel_spec, min=1e-10).log10() log_spec = torch.maximum(log_spec, log_spec.max() - 8.0) log_spec = (log_spec + 4.0) / 4.0 return log_spec