Sulai2005's picture
Initial commit
506a2b4
from typing import List, Tuple
import numpy as np
import librosa
import torch
import torch.nn.functional as F
from s3tokenizer.utils import padding
from s3tokenizer.model_v2 import (
S3TokenizerV2,
ModelConfig,
)
# Sampling rate of the inputs to S3TokenizerV2
S3_SR = 16_000
S3_HOP = 160 # 100 frames/sec
S3_TOKEN_HOP = 640 # 25 tokens/sec
S3_TOKEN_RATE = 25
SPEECH_VOCAB_SIZE = 6561
class S3Tokenizer(S3TokenizerV2):
"""
s3tokenizer.S3TokenizerV2 with the following changes:
- a more integrated `forward`
- compute `log_mel_spectrogram` using `_mel_filters` and `window` in `register_buffers`
"""
ignore_state_dict_missing = ("_mel_filters", "window")
def __init__(
self,
name: str="speech_tokenizer_v2_25hz",
config: ModelConfig = ModelConfig()
):
super().__init__(name)
self.n_fft = 400
_mel_filters = librosa.filters.mel(
sr=S3_SR,
n_fft=self.n_fft,
n_mels=config.n_mels
)
self.register_buffer(
"_mel_filters",
torch.FloatTensor(_mel_filters),
)
self.register_buffer(
"window",
torch.hann_window(self.n_fft),
)
def pad(self, wavs, sr) -> List[torch.Tensor]:
"""
Given a list of wavs with the same `sample_rate`, pad them so that the length is multiple of 40ms (S3 runs at 25 token/sec).
"""
processed_wavs = []
for wav in wavs:
if isinstance(wav, np.ndarray):
wav = torch.from_numpy(wav)
if wav.dim() == 1:
wav = wav.unsqueeze(0)
n_tokens = (wav.shape[1] / sr) * S3_TOKEN_RATE
n_tokens = np.ceil(n_tokens)
intended_wav_len = n_tokens * (sr / S3_TOKEN_RATE)
intended_wav_len = int(intended_wav_len)
wav = torch.nn.functional.pad(
wav,
(0, intended_wav_len - wav.shape[-1]),
mode="constant",
value=0
)
processed_wavs.append(wav)
return processed_wavs
def _prepare_audio(self, wavs):
"""Prepare a list of audios for s3tokenizer processing."""
processed_wavs = []
for wav in wavs:
if isinstance(wav, np.ndarray):
wav = torch.from_numpy(wav)
if wav.dim() == 1:
wav = wav.unsqueeze(0)
processed_wavs.append(wav)
return processed_wavs
@torch.no_grad()
def forward(
self,
wavs: torch.Tensor,
accelerator: 'Accelerator'=None,
max_len: int=None,
) -> Tuple[torch.Tensor, torch.LongTensor]:
"""
NOTE: mel-spec has a hop size of 160 points (100 frame/sec).
FIXME: this class inherits `nn.Module` but doesn't accept `torch.Tensor` and handles a list of wavs one by one, which is unexpected.
Args
----
- `wavs`: 16 kHz speech audio
- `max_len` max length to truncate the output sequence to (25 token/sec).
NOTE: please pad the waveform if longer sequence is needed.
"""
processed_wavs = self._prepare_audio(wavs)
mels, mel_lens = [], []
for wav in processed_wavs:
wav = wav.to(self.device)
mel = self.log_mel_spectrogram(wav) # [B=1, F, T]
if max_len is not None:
mel = mel[..., :max_len * 4] # num_mel_frames = 4 * num_tokens
mels.append(mel.squeeze(0))
mels, mel_lens = padding(mels)
if accelerator is None:
tokenizer = self
else:
tokenizer = accelerator.unwrap_model(self)
speech_tokens, speech_token_lens = tokenizer.quantize(mels, mel_lens.to(self.device))
return (
speech_tokens.long().detach(),
speech_token_lens.long().detach(),
)
def log_mel_spectrogram(
self,
audio: torch.Tensor,
padding: int = 0,
):
"""
Compute the log-Mel spectrogram of
Parameters
----------
audio: torch.Tensor, shape = (*)
The path to audio or either a NumPy array or Tensor containing the
audio waveform in 16 kHz
padding: int
Number of zero samples to pad to the right
Returns
-------
torch.Tensor, shape = (128, n_frames)
A Tensor that contains the Mel spectrogram
"""
if not torch.is_tensor(audio):
audio = torch.from_numpy(audio)
audio = audio.to(self.device)
if padding > 0:
audio = F.pad(audio, (0, padding))
stft = torch.stft(
audio, self.n_fft, S3_HOP,
window=self.window.to(self.device),
return_complex=True
)
magnitudes = stft[..., :-1].abs()**2
mel_spec = self._mel_filters.to(self.device) @ magnitudes
log_spec = torch.clamp(mel_spec, min=1e-10).log10()
log_spec = torch.maximum(log_spec, log_spec.max() - 8.0)
log_spec = (log_spec + 4.0) / 4.0
return log_spec