Spaces:
Running
Running
from typing import List, Tuple | |
import numpy as np | |
import librosa | |
import torch | |
import torch.nn.functional as F | |
from s3tokenizer.utils import padding | |
from s3tokenizer.model_v2 import ( | |
S3TokenizerV2, | |
ModelConfig, | |
) | |
# Sampling rate of the inputs to S3TokenizerV2 | |
S3_SR = 16_000 | |
S3_HOP = 160 # 100 frames/sec | |
S3_TOKEN_HOP = 640 # 25 tokens/sec | |
S3_TOKEN_RATE = 25 | |
SPEECH_VOCAB_SIZE = 6561 | |
class S3Tokenizer(S3TokenizerV2): | |
""" | |
s3tokenizer.S3TokenizerV2 with the following changes: | |
- a more integrated `forward` | |
- compute `log_mel_spectrogram` using `_mel_filters` and `window` in `register_buffers` | |
""" | |
ignore_state_dict_missing = ("_mel_filters", "window") | |
def __init__( | |
self, | |
name: str="speech_tokenizer_v2_25hz", | |
config: ModelConfig = ModelConfig() | |
): | |
super().__init__(name) | |
self.n_fft = 400 | |
_mel_filters = librosa.filters.mel( | |
sr=S3_SR, | |
n_fft=self.n_fft, | |
n_mels=config.n_mels | |
) | |
self.register_buffer( | |
"_mel_filters", | |
torch.FloatTensor(_mel_filters), | |
) | |
self.register_buffer( | |
"window", | |
torch.hann_window(self.n_fft), | |
) | |
def pad(self, wavs, sr) -> List[torch.Tensor]: | |
""" | |
Given a list of wavs with the same `sample_rate`, pad them so that the length is multiple of 40ms (S3 runs at 25 token/sec). | |
""" | |
processed_wavs = [] | |
for wav in wavs: | |
if isinstance(wav, np.ndarray): | |
wav = torch.from_numpy(wav) | |
if wav.dim() == 1: | |
wav = wav.unsqueeze(0) | |
n_tokens = (wav.shape[1] / sr) * S3_TOKEN_RATE | |
n_tokens = np.ceil(n_tokens) | |
intended_wav_len = n_tokens * (sr / S3_TOKEN_RATE) | |
intended_wav_len = int(intended_wav_len) | |
wav = torch.nn.functional.pad( | |
wav, | |
(0, intended_wav_len - wav.shape[-1]), | |
mode="constant", | |
value=0 | |
) | |
processed_wavs.append(wav) | |
return processed_wavs | |
def _prepare_audio(self, wavs): | |
"""Prepare a list of audios for s3tokenizer processing.""" | |
processed_wavs = [] | |
for wav in wavs: | |
if isinstance(wav, np.ndarray): | |
wav = torch.from_numpy(wav) | |
if wav.dim() == 1: | |
wav = wav.unsqueeze(0) | |
processed_wavs.append(wav) | |
return processed_wavs | |
def forward( | |
self, | |
wavs: torch.Tensor, | |
accelerator: 'Accelerator'=None, | |
max_len: int=None, | |
) -> Tuple[torch.Tensor, torch.LongTensor]: | |
""" | |
NOTE: mel-spec has a hop size of 160 points (100 frame/sec). | |
FIXME: this class inherits `nn.Module` but doesn't accept `torch.Tensor` and handles a list of wavs one by one, which is unexpected. | |
Args | |
---- | |
- `wavs`: 16 kHz speech audio | |
- `max_len` max length to truncate the output sequence to (25 token/sec). | |
NOTE: please pad the waveform if longer sequence is needed. | |
""" | |
processed_wavs = self._prepare_audio(wavs) | |
mels, mel_lens = [], [] | |
for wav in processed_wavs: | |
wav = wav.to(self.device) | |
mel = self.log_mel_spectrogram(wav) # [B=1, F, T] | |
if max_len is not None: | |
mel = mel[..., :max_len * 4] # num_mel_frames = 4 * num_tokens | |
mels.append(mel.squeeze(0)) | |
mels, mel_lens = padding(mels) | |
if accelerator is None: | |
tokenizer = self | |
else: | |
tokenizer = accelerator.unwrap_model(self) | |
speech_tokens, speech_token_lens = tokenizer.quantize(mels, mel_lens.to(self.device)) | |
return ( | |
speech_tokens.long().detach(), | |
speech_token_lens.long().detach(), | |
) | |
def log_mel_spectrogram( | |
self, | |
audio: torch.Tensor, | |
padding: int = 0, | |
): | |
""" | |
Compute the log-Mel spectrogram of | |
Parameters | |
---------- | |
audio: torch.Tensor, shape = (*) | |
The path to audio or either a NumPy array or Tensor containing the | |
audio waveform in 16 kHz | |
padding: int | |
Number of zero samples to pad to the right | |
Returns | |
------- | |
torch.Tensor, shape = (128, n_frames) | |
A Tensor that contains the Mel spectrogram | |
""" | |
if not torch.is_tensor(audio): | |
audio = torch.from_numpy(audio) | |
audio = audio.to(self.device) | |
if padding > 0: | |
audio = F.pad(audio, (0, padding)) | |
stft = torch.stft( | |
audio, self.n_fft, S3_HOP, | |
window=self.window.to(self.device), | |
return_complex=True | |
) | |
magnitudes = stft[..., :-1].abs()**2 | |
mel_spec = self._mel_filters.to(self.device) @ magnitudes | |
log_spec = torch.clamp(mel_spec, min=1e-10).log10() | |
log_spec = torch.maximum(log_spec, log_spec.max() - 8.0) | |
log_spec = (log_spec + 4.0) / 4.0 | |
return log_spec | |