Spaces:
Runtime error
Runtime error
import io | |
import typing as tp | |
from typing import List | |
import numpy as np | |
import torch | |
import torchaudio | |
from fairseq2.data import FileMapper, read_sequence | |
from fairseq2.data.audio import AudioDecoder | |
from numpy.typing import NDArray | |
from stopes.modules.speech import utils as speech_utils | |
def wav_to_bytes( | |
wav: torch.Tensor | NDArray, sample_rate: int = 16_000, format: str = "ogg" | |
) -> NDArray[np.int8]: | |
if isinstance(wav, np.ndarray): | |
wav = torch.from_numpy(wav) | |
if wav.dtype != torch.float32: | |
wav = wav.float() | |
if wav.ndim == 1: | |
wav = wav.unsqueeze(0) | |
buffer = io.BytesIO() | |
torchaudio.save( | |
buffer, | |
wav, | |
sample_rate=sample_rate, | |
format=format, | |
) | |
buffer.seek(0) | |
return np.frombuffer(buffer.getvalue(), dtype=np.int8) | |
def fs2_read_audio( | |
file_seqs: List[str], sample_rate: int = 16_000, nb_threads: int = 10 | |
): | |
audio_decoder = AudioDecoder(dtype=torch.float32) | |
file_mapper = FileMapper(cached_fd_count=200) | |
def post_process(data): | |
sr = data["sample_rate"] | |
wav = data["waveform"] | |
if sr != sample_rate: | |
wav = torchaudio.functional.resample(wav, sr, sample_rate) | |
if len(wav.shape) > 1: | |
wav = wav.mean(dim=0, keepdim=True) | |
data["waveform"] = wav | |
data["sample_rate"] = sample_rate | |
return data | |
builder = read_sequence(file_seqs) | |
builder.map(file_mapper, num_parallel_calls=nb_threads) | |
builder.map(audio_decoder, selector="data", num_parallel_calls=nb_threads) | |
builder.map(post_process, selector="data") | |
builder.map(lambda x: x["data"]) | |
pipe = builder.and_return() | |
return list(iter(pipe)) | |
def read_audio( | |
filepath: str, | |
sample_rate: int | None = 16_000, | |
offset: float = 0, | |
duration: float = -1, | |
) -> tp.Tuple[torch.Tensor, int]: | |
info = torchaudio.info(filepath) | |
sr = info.sample_rate | |
# Convert seconds to frames | |
frame_offset = int(offset * sr) | |
num_frames = int(duration * sr) if duration > 0 else -1 | |
wav, sr = torchaudio.load( | |
filepath, frame_offset=frame_offset, num_frames=num_frames | |
) | |
if sample_rate is not None and sample_rate != sr: | |
wav = torchaudio.functional.resample(wav, sr, sample_rate) | |
if len(wav.shape) > 1: | |
wav = wav.mean(dim=0, keepdim=True) | |
return (wav, sample_rate or sr) | |
def audio_bytes_to_numpy(path: str, byte_offset: int, length: int) -> np.ndarray: | |
with open(path, "rb") as f: | |
f.seek(byte_offset) | |
audio_bytes = f.read(length) | |
return np.frombuffer(audio_bytes, dtype=np.int8) | |
def bytes_to_tensor( | |
audio_arr: np.ndarray, target_sample_rate: int = 16_000 | |
) -> np.ndarray: | |
buffer = audio_arr.tobytes() | |
wav, sample_rate = torchaudio.load(io.BytesIO(buffer)) | |
if sample_rate != target_sample_rate: | |
wav = torchaudio.functional.resample(wav, sample_rate, target_sample_rate) | |
if len(wav.shape) > 1: | |
wav = wav.mean(dim=0, keepdim=True) | |
return wav.cpu().numpy().flatten() | |
def replace_prefix( | |
old_str, old_prefix="/checkpoint/seamless/ust", new_prefix="/checkpoint/mms/data" | |
): | |
return old_str.replace(old_prefix, new_prefix) | |
def load_audio( | |
input_line: str, | |
sampling_factor: int = 16, | |
as_numpy: bool = True, | |
read_audio_func: tp.Callable = speech_utils.read_audio, | |
collapse_channels: bool = False, | |
) -> tp.Union[torch.Tensor, np.ndarray]: | |
audio_meta = speech_utils.parse_audio(input_line, sampling_factor=sampling_factor) | |
if isinstance(audio_meta, speech_utils.Audio): | |
wav = read_audio_func(audio_meta.path, audio_meta.sampling_factor * 1000) | |
if len(wav.shape) > 1: | |
wav = wav[:, audio_meta.start : audio_meta.end] | |
else: | |
wav = wav[audio_meta.start : audio_meta.end] | |
elif isinstance(audio_meta, speech_utils.AudioBytes): | |
wav = audio_meta.load() | |
elif isinstance(audio_meta, speech_utils.Text): | |
wav = read_audio_func(audio_meta.content, sampling_factor * 1000) | |
if collapse_channels and len(wav.shape) > 1: | |
wav = wav.mean(0) | |
if as_numpy: | |
wav = wav.cpu().numpy().flatten() | |
return wav | |