mms-transcription / server /audio_reading_tools.py
EC2 Default User
Initial Transcription Commit
38818c3
import io
import typing as tp
from typing import List
import numpy as np
import torch
import torchaudio
from fairseq2.data import FileMapper, read_sequence
from fairseq2.data.audio import AudioDecoder
from numpy.typing import NDArray
from stopes.modules.speech import utils as speech_utils
def wav_to_bytes(
wav: torch.Tensor | NDArray, sample_rate: int = 16_000, format: str = "ogg"
) -> NDArray[np.int8]:
if isinstance(wav, np.ndarray):
wav = torch.from_numpy(wav)
if wav.dtype != torch.float32:
wav = wav.float()
if wav.ndim == 1:
wav = wav.unsqueeze(0)
buffer = io.BytesIO()
torchaudio.save(
buffer,
wav,
sample_rate=sample_rate,
format=format,
)
buffer.seek(0)
return np.frombuffer(buffer.getvalue(), dtype=np.int8)
def fs2_read_audio(
file_seqs: List[str], sample_rate: int = 16_000, nb_threads: int = 10
):
audio_decoder = AudioDecoder(dtype=torch.float32)
file_mapper = FileMapper(cached_fd_count=200)
def post_process(data):
sr = data["sample_rate"]
wav = data["waveform"]
if sr != sample_rate:
wav = torchaudio.functional.resample(wav, sr, sample_rate)
if len(wav.shape) > 1:
wav = wav.mean(dim=0, keepdim=True)
data["waveform"] = wav
data["sample_rate"] = sample_rate
return data
builder = read_sequence(file_seqs)
builder.map(file_mapper, num_parallel_calls=nb_threads)
builder.map(audio_decoder, selector="data", num_parallel_calls=nb_threads)
builder.map(post_process, selector="data")
builder.map(lambda x: x["data"])
pipe = builder.and_return()
return list(iter(pipe))
def read_audio(
filepath: str,
sample_rate: int | None = 16_000,
offset: float = 0,
duration: float = -1,
) -> tp.Tuple[torch.Tensor, int]:
info = torchaudio.info(filepath)
sr = info.sample_rate
# Convert seconds to frames
frame_offset = int(offset * sr)
num_frames = int(duration * sr) if duration > 0 else -1
wav, sr = torchaudio.load(
filepath, frame_offset=frame_offset, num_frames=num_frames
)
if sample_rate is not None and sample_rate != sr:
wav = torchaudio.functional.resample(wav, sr, sample_rate)
if len(wav.shape) > 1:
wav = wav.mean(dim=0, keepdim=True)
return (wav, sample_rate or sr)
def audio_bytes_to_numpy(path: str, byte_offset: int, length: int) -> np.ndarray:
with open(path, "rb") as f:
f.seek(byte_offset)
audio_bytes = f.read(length)
return np.frombuffer(audio_bytes, dtype=np.int8)
def bytes_to_tensor(
audio_arr: np.ndarray, target_sample_rate: int = 16_000
) -> np.ndarray:
buffer = audio_arr.tobytes()
wav, sample_rate = torchaudio.load(io.BytesIO(buffer))
if sample_rate != target_sample_rate:
wav = torchaudio.functional.resample(wav, sample_rate, target_sample_rate)
if len(wav.shape) > 1:
wav = wav.mean(dim=0, keepdim=True)
return wav.cpu().numpy().flatten()
def replace_prefix(
old_str, old_prefix="/checkpoint/seamless/ust", new_prefix="/checkpoint/mms/data"
):
return old_str.replace(old_prefix, new_prefix)
def load_audio(
input_line: str,
sampling_factor: int = 16,
as_numpy: bool = True,
read_audio_func: tp.Callable = speech_utils.read_audio,
collapse_channels: bool = False,
) -> tp.Union[torch.Tensor, np.ndarray]:
audio_meta = speech_utils.parse_audio(input_line, sampling_factor=sampling_factor)
if isinstance(audio_meta, speech_utils.Audio):
wav = read_audio_func(audio_meta.path, audio_meta.sampling_factor * 1000)
if len(wav.shape) > 1:
wav = wav[:, audio_meta.start : audio_meta.end]
else:
wav = wav[audio_meta.start : audio_meta.end]
elif isinstance(audio_meta, speech_utils.AudioBytes):
wav = audio_meta.load()
elif isinstance(audio_meta, speech_utils.Text):
wav = read_audio_func(audio_meta.content, sampling_factor * 1000)
if collapse_channels and len(wav.shape) > 1:
wav = wav.mean(0)
if as_numpy:
wav = wav.cpu().numpy().flatten()
return wav