Spaces:
Runtime error
Runtime error
# Copyright (c) Meta Platforms, Inc. and affiliates. | |
# All rights reserved. | |
# | |
# This source code is licensed under the BSD-style license found in the | |
# LICENSE file in the root directory of this source tree. | |
import gc | |
import io | |
from dataclasses import dataclass | |
from typing import Dict, List | |
import pyarrow as pa | |
import torch | |
import torchaudio | |
import torchaudio.functional as audio_F | |
from stopes.modules.partitioned_data_mapper import BatchMapper | |
from align_utils import ( | |
get_spans, | |
load_model_dict, | |
merge_repeats, | |
time_to_frame, | |
) | |
from audio_reading_tools import wav_to_bytes | |
class AlignmentStruct: | |
segement_tokens: str | |
audio: str | |
segment_audio_bytes: str = "segment_audio_bytes" | |
segment_duration: str = "segment_duration" | |
segment_start_sec: str = "segment_start_sec" | |
class AudioAlignmentConfig: | |
alignment_column: AlignmentStruct | |
model_path_name: str = "" | |
emission_interval: int = 30 | |
sample_rate: int = 16000 | |
audio_format: str = "flac" | |
use_star: bool = False | |
device: str = "cuda" | |
class AudioAlignment(BatchMapper): | |
scale: int = 1000 | |
def __init__(self, config: AudioAlignmentConfig): | |
super().__init__(config) | |
# FIXME: pass model name correctly | |
self.model, self.dictionary = load_model_dict() | |
self.device = torch.device(config.device) | |
self.model.to(self.device) | |
if self.config.use_star: | |
self.dictionary["<star>"] = len(self.dictionary) | |
self.blank = self.dictionary["<blank>"] | |
self.inverse_dictionary = {v: k for k, v in self.dictionary.items()} | |
self._alignment_column = self.config.alignment_column | |
def generate_emissions(self, waveform: torch.Tensor): | |
reading_sr = self.config.sample_rate | |
emission_interval = self.config.emission_interval | |
total_duration = waveform.size(1) / reading_sr | |
emissions_arr = [] | |
i = 0 | |
while i < total_duration: | |
segment_start_time, segment_end_time = (i, i + emission_interval) | |
context = emission_interval * 0.1 | |
input_start_time = max(segment_start_time - context, 0) | |
input_end_time = min(segment_end_time + context, total_duration) | |
waveform_split = waveform[ | |
:, | |
int(reading_sr * input_start_time) : int(reading_sr * (input_end_time)), | |
] | |
model_outs, _ = self.model(waveform_split) | |
emissions_ = model_outs[0] | |
emission_start_frame = time_to_frame(segment_start_time) | |
emission_end_frame = time_to_frame(segment_end_time) | |
offset = time_to_frame(input_start_time) | |
emissions_ = emissions_[ | |
emission_start_frame - offset : emission_end_frame - offset, : | |
] | |
emissions_arr.append(emissions_) | |
i += emission_interval | |
emissions = torch.cat(emissions_arr, dim=0).squeeze() | |
emissions = torch.log_softmax(emissions, dim=-1) | |
stride = float(waveform.size(1) * self.scale / emissions.size(0) / reading_sr) | |
return emissions, stride | |
def get_one_row_alignments( | |
self, | |
audio_arr, | |
tokens: List[str], | |
): | |
reading_sr = self.config.sample_rate | |
buffer = audio_arr.tobytes() | |
waveform, audio_sf = torchaudio.load(io.BytesIO(buffer)) | |
waveform = waveform.to(self.device) | |
assert audio_sf == reading_sr | |
emissions, stride = self.generate_emissions(waveform) | |
waveform = waveform.cpu() | |
if self.config.use_star: | |
T, _ = emissions.size() | |
emissions = torch.cat( | |
[emissions, torch.zeros(T, 1, device=self.device)], dim=1 | |
) | |
if self.config.use_star: | |
tokens = ["<star>"] + tokens | |
token_indices = [ | |
self.dictionary[c] | |
for c in " ".join(tokens).split(" ") | |
if c in self.dictionary | |
] | |
targets = torch.tensor(token_indices, dtype=torch.int32, device=self.device) | |
input_lengths = torch.tensor(emissions.shape[0]).unsqueeze(-1) | |
target_lengths = torch.tensor(targets.shape[0]).unsqueeze(-1) | |
path, _ = audio_F.forced_align( | |
emissions.unsqueeze(0), | |
targets.unsqueeze(0), | |
input_lengths, | |
target_lengths, | |
blank=self.blank, | |
) | |
path = path.squeeze().to("cpu").tolist() | |
segments = merge_repeats(path, self.inverse_dictionary) | |
spans = get_spans(tokens, segments) | |
audio_segments = [] | |
for span in spans: | |
seg_start_idx, seg_end_idx = span[0].start, span[-1].end | |
segment_start_sec = seg_start_idx * stride / self.scale | |
segment_end_sec = seg_end_idx * stride / self.scale | |
start_frame = int(segment_start_sec * reading_sr) | |
end_frame = int(segment_end_sec * reading_sr) | |
trimmed_waveform = waveform[:, start_frame:end_frame] | |
audio_segments.append( | |
{ | |
self._alignment_column.segment_start_sec: segment_start_sec, | |
self._alignment_column.segment_duration: segment_end_sec | |
- segment_start_sec, | |
self._alignment_column.segment_audio_bytes: wav_to_bytes( | |
trimmed_waveform, reading_sr, self.config.audio_format | |
), | |
} | |
) | |
return audio_segments | |
def get_alignments(self, table: pa.Table) -> Dict[str, pa.Array | pa.ChunkedArray]: | |
results = [] | |
for dd in ( | |
table.select( | |
[self._alignment_column.audio, self._alignment_column.segement_tokens] | |
) | |
.to_pandas() | |
.to_dict(orient="records") | |
): | |
struct = self.get_one_row_alignments( | |
dd[self._alignment_column.audio], | |
dd[self._alignment_column.segement_tokens], | |
) | |
results.append(struct) | |
batch = {} | |
segment_audio_bytes = self._alignment_column.segment_audio_bytes | |
batch[segment_audio_bytes] = pa.array( | |
[[seg[segment_audio_bytes] for seg in doc] for doc in results], | |
type=pa.list_(pa.large_list(pa.int8())), | |
) | |
segment_duration = self._alignment_column.segment_duration | |
batch[segment_duration] = pa.array( | |
[[seg[segment_duration] for seg in doc] for doc in results], | |
type=pa.list_(pa.float32()), | |
) | |
segment_start_sec = self._alignment_column.segment_start_sec | |
batch[segment_start_sec] = pa.array( | |
[[seg[segment_start_sec] for seg in doc] for doc in results], | |
type=pa.list_(pa.float32()), | |
) | |
gc.collect() | |
torch.cuda.empty_cache() | |
return batch | |
def __call__(self, table: pa.Table | None) -> pa.Table | None: | |
if table is None: | |
return table | |
batch = self.get_alignments(table) | |
for name, col in batch.items(): | |
table = table.append_column(name, col) # type: ignore | |
return table | |