import io import typing as tp from typing import List import numpy as np import torch import torchaudio from fairseq2.data import FileMapper, read_sequence from fairseq2.data.audio import AudioDecoder from numpy.typing import NDArray from stopes.modules.speech import utils as speech_utils def wav_to_bytes( wav: torch.Tensor | NDArray, sample_rate: int = 16_000, format: str = "ogg" ) -> NDArray[np.int8]: if isinstance(wav, np.ndarray): wav = torch.from_numpy(wav) if wav.dtype != torch.float32: wav = wav.float() if wav.ndim == 1: wav = wav.unsqueeze(0) buffer = io.BytesIO() torchaudio.save( buffer, wav, sample_rate=sample_rate, format=format, ) buffer.seek(0) return np.frombuffer(buffer.getvalue(), dtype=np.int8) def fs2_read_audio( file_seqs: List[str], sample_rate: int = 16_000, nb_threads: int = 10 ): audio_decoder = AudioDecoder(dtype=torch.float32) file_mapper = FileMapper(cached_fd_count=200) def post_process(data): sr = data["sample_rate"] wav = data["waveform"] if sr != sample_rate: wav = torchaudio.functional.resample(wav, sr, sample_rate) if len(wav.shape) > 1: wav = wav.mean(dim=0, keepdim=True) data["waveform"] = wav data["sample_rate"] = sample_rate return data builder = read_sequence(file_seqs) builder.map(file_mapper, num_parallel_calls=nb_threads) builder.map(audio_decoder, selector="data", num_parallel_calls=nb_threads) builder.map(post_process, selector="data") builder.map(lambda x: x["data"]) pipe = builder.and_return() return list(iter(pipe)) def read_audio( filepath: str, sample_rate: int | None = 16_000, offset: float = 0, duration: float = -1, ) -> tp.Tuple[torch.Tensor, int]: info = torchaudio.info(filepath) sr = info.sample_rate # Convert seconds to frames frame_offset = int(offset * sr) num_frames = int(duration * sr) if duration > 0 else -1 wav, sr = torchaudio.load( filepath, frame_offset=frame_offset, num_frames=num_frames ) if sample_rate is not None and sample_rate != sr: wav = torchaudio.functional.resample(wav, sr, sample_rate) if len(wav.shape) > 1: wav = wav.mean(dim=0, keepdim=True) return (wav, sample_rate or sr) def audio_bytes_to_numpy(path: str, byte_offset: int, length: int) -> np.ndarray: with open(path, "rb") as f: f.seek(byte_offset) audio_bytes = f.read(length) return np.frombuffer(audio_bytes, dtype=np.int8) def bytes_to_tensor( audio_arr: np.ndarray, target_sample_rate: int = 16_000 ) -> np.ndarray: buffer = audio_arr.tobytes() wav, sample_rate = torchaudio.load(io.BytesIO(buffer)) if sample_rate != target_sample_rate: wav = torchaudio.functional.resample(wav, sample_rate, target_sample_rate) if len(wav.shape) > 1: wav = wav.mean(dim=0, keepdim=True) return wav.cpu().numpy().flatten() def replace_prefix( old_str, old_prefix="/checkpoint/seamless/ust", new_prefix="/checkpoint/mms/data" ): return old_str.replace(old_prefix, new_prefix) def load_audio( input_line: str, sampling_factor: int = 16, as_numpy: bool = True, read_audio_func: tp.Callable = speech_utils.read_audio, collapse_channels: bool = False, ) -> tp.Union[torch.Tensor, np.ndarray]: audio_meta = speech_utils.parse_audio(input_line, sampling_factor=sampling_factor) if isinstance(audio_meta, speech_utils.Audio): wav = read_audio_func(audio_meta.path, audio_meta.sampling_factor * 1000) if len(wav.shape) > 1: wav = wav[:, audio_meta.start : audio_meta.end] else: wav = wav[audio_meta.start : audio_meta.end] elif isinstance(audio_meta, speech_utils.AudioBytes): wav = audio_meta.load() elif isinstance(audio_meta, speech_utils.Text): wav = read_audio_func(audio_meta.content, sampling_factor * 1000) if collapse_channels and len(wav.shape) > 1: wav = wav.mean(0) if as_numpy: wav = wav.cpu().numpy().flatten() return wav