Spaces:
Runtime error
Runtime error
# @title Model code (no change needed) | |
"""Model code""" | |
# Copyright (c) Meta Platforms, Inc. and affiliates. | |
# All rights reserved. | |
from __future__ import annotations | |
import io | |
import logging | |
import zlib | |
from dataclasses import dataclass, field | |
from typing import Dict, final, Final, List, Literal, Tuple | |
import librosa | |
import numpy as np | |
import soundfile as sf | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
import torchaudio | |
import torchaudio.functional as audio_F | |
import uroman | |
from fairseq2.context import RuntimeContext | |
from fairseq2.data import VocabularyInfo | |
from fairseq2.models.asr import AsrModel, AsrModelOutput | |
from fairseq2.models.llama import LLaMAConfig, LLaMAFactory | |
from fairseq2.models.seq2seq import Seq2SeqBatch | |
from fairseq2.models.wav2vec2 import ( | |
StandardWav2Vec2Masker, | |
Wav2Vec2EncoderConfig, | |
Wav2Vec2EncoderFactory, | |
Wav2Vec2Frontend, | |
Wav2Vec2Masker, | |
) | |
from fairseq2.models.wav2vec2.asr import Wav2Vec2AsrConfig | |
from fairseq2.nn import IncrementalStateBag, Linear, StandardEmbedding | |
from fairseq2.nn.padding import PaddingMask | |
from fairseq2.nn.transformer import TransformerDecoder, TransformerEncoder | |
from torch import Tensor | |
from torch.nn import Dropout | |
class Wav2Vec2LlamaModel(AsrModel): | |
"""Represents a wav2vec 2.0 encoder feeding to a Llama decoder for ASR.""" | |
model_dim: int | |
encoder_frontend: Wav2Vec2Frontend | |
encoder: TransformerEncoder | |
encoder_proj: nn.Module | |
text_frontend: StandardEmbedding | |
llama_decoder: TransformerDecoder | |
final_proj: nn.Module | |
masker: Wav2Vec2Masker | None | |
final_dropout: Dropout | None | |
target_vocab_info: VocabularyInfo | |
def __init__( | |
self, | |
encoder_frontend: Wav2Vec2Frontend, | |
encoder: TransformerEncoder, | |
encoder_proj: nn.Module, | |
text_frontend: StandardEmbedding, | |
llama_decoder: TransformerDecoder, | |
final_proj: nn.Module, | |
target_vocab_info: VocabularyInfo, | |
*, | |
masker: Wav2Vec2Masker | None = None, | |
final_dropout_p: float = 0.0, | |
max_generation_length: int = 8192, | |
encoder_stacking: int = 1, | |
frozen_encoder: bool = False, | |
random_context_length: bool = True, | |
) -> None: | |
""" | |
:param encoder_frontend: | |
The encoder frontend. | |
:param encoder: | |
The encoder (i.e. context network). | |
:param encoder_proj: | |
Normally a linear layer projecting the encoder outputs to the decoder's model dim. | |
:text_frontend: | |
The embedding module for text tokens. | |
:param llama_decoder: | |
The decoder-only model. | |
:param final_proj: | |
The last linear layers projecting from the decoder to logits. | |
:param target_vocab_info: | |
The vocabulary information of sequences produced by the model. | |
:param masker: | |
The feature masker. | |
:param final_dropout_p: | |
The dropout probability on context network outputs. | |
:param max_generation_length: | |
The maximum length of generated sequences. | |
:param encoder_stacking: | |
The number audio embeddings frames to stack before the decoder calls. | |
:param frozen_encoder: | |
If ``True``, the encoder is frozen during training. | |
""" | |
super().__init__() | |
self.model_dim = encoder.model_dim | |
self.encoder_frontend = encoder_frontend | |
self.encoder = encoder | |
self.encoder_proj = encoder_proj | |
self.text_frontend = text_frontend | |
self.llama_decoder = llama_decoder | |
self.final_proj = final_proj | |
self.target_vocab_info = target_vocab_info | |
self.max_generation_length = max_generation_length | |
self.encoder_stacking = encoder_stacking | |
self.frozen_encoder = frozen_encoder | |
self.random_context_length = random_context_length | |
self.context_len_rng = np.random.RandomState(42) | |
self.register_module("masker", masker) | |
if final_dropout_p > 0.0: | |
self.final_dropout = Dropout(final_dropout_p) | |
else: | |
self.register_module("final_dropout", None) | |
def forward(self, batch: Seq2SeqBatch) -> Wav2Vec2LlamaOutput: # type: ignore[override] | |
""" | |
:param batch: | |
The batch of sequences to process. | |
""" | |
device = batch.source_seqs.device | |
dtype = batch.source_seqs.dtype | |
batch = self.prepare_batch(batch) | |
inputs = self.create_default_syntax_inference(batch, device) | |
# Embed all modalities | |
embedded = self.embed_inputs(inputs, dtype) | |
# Concat all decoder inputs | |
( | |
decoder_inputs, | |
decoder_inputs_padding_mask, | |
decoder_context_inputs, | |
decoder_context_padding_mask, | |
) = self.concat_inputs(embedded) | |
# Run the decoder | |
dec_out, _ = self.llama_decoder(decoder_inputs, decoder_inputs_padding_mask) | |
logits = self.final_proj(dec_out) | |
assert self.target_vocab_info.pad_idx is not None | |
assert self.target_vocab_info.eos_idx is not None | |
return Wav2Vec2LlamaOutput( | |
logits=logits, | |
logits_padding_mask=decoder_inputs_padding_mask, | |
decoder_context_inputs=decoder_context_inputs, | |
decoder_context_padding_mask=decoder_context_padding_mask, | |
model=self, | |
pad_idx=self.target_vocab_info.pad_idx, | |
eos_idx=self.target_vocab_info.eos_idx, | |
padding_mask=None, | |
) | |
def prepare_batch(self, batch: Seq2SeqBatch) -> Seq2SeqBatch: | |
# Create padding masks if there aren't any | |
if batch.source_padding_mask is None: | |
lengths = torch.full_like( | |
batch.source_seqs[:, 0], | |
fill_value=batch.source_seqs.size(1), | |
dtype=torch.int64, | |
) | |
batch.source_padding_mask = PaddingMask(lengths, int(lengths.max())) | |
if batch.target_padding_mask is None: | |
lengths = torch.full_like( | |
batch.target_seqs[:, 0], | |
fill_value=batch.target_seqs.size(1), | |
dtype=torch.int64, | |
) | |
batch.target_padding_mask = PaddingMask(lengths, int(lengths.max())) | |
# Padding masks for context audio and text | |
if "context_audio" in batch.example: | |
for i in range(len(batch.example["context_audio"])): | |
# For audio | |
seq_lens = batch.example["context_audio"][i]["data"]["waveform"][ | |
"seq_lens" | |
] | |
batch.example["context_audio"][i]["data"]["waveform"][ | |
"padding_mask" | |
] = PaddingMask(seq_lens, int(seq_lens.max())) | |
# For text | |
seq_lens = batch.example["context_text"][i]["seq_lens"] | |
batch.example["context_text"][i]["padding_mask"] = PaddingMask( | |
seq_lens, int(seq_lens.max()) | |
) | |
return batch | |
def create_default_syntax_inference( | |
self, batch: Seq2SeqBatch, device | |
) -> List[Dict[str, object]]: | |
# Create a dict of inputs for the base case. Ths syntax is: | |
# target audio <bos> target text <eos> | |
inputs = [ | |
{ | |
"value": { | |
"seqs": batch.source_seqs, | |
"padding_mask": batch.source_padding_mask, | |
}, | |
"type": "audio", | |
"loss": False, | |
}, | |
{ | |
"value": { | |
"seqs": self.create_single_char( | |
batch, self.target_vocab_info.bos_idx, device | |
) | |
}, | |
"type": "text", | |
"loss": False, | |
}, | |
] | |
return inputs | |
def create_single_char(batch: Seq2SeqBatch, char: int, device) -> Tensor: | |
return torch.full_like( | |
batch.target_seqs[:, :1], fill_value=char, device=device # type: ignore | |
) | |
def embed_inputs( | |
self, inputs: List[Dict[str, object]], dtype: Literal | |
) -> List[Dict[str, object]]: | |
# Embed the different modalities | |
for inp in inputs: | |
if inp["type"] == "audio": | |
inp["value"]["seqs"], inp["value"]["padding_mask"] = self.embed_audio( | |
inp["value"]["seqs"], inp["value"]["padding_mask"] | |
) | |
elif inp["type"] == "text": | |
inp["value"]["seqs"] = self.embed_text(inp["value"]["seqs"], dtype) | |
else: | |
raise ValueError(f"Unknown input type: {inp['type']}") | |
return inputs | |
def embed_audio( | |
self, seqs: Tensor, padding_mask: PaddingMask | |
) -> tuple[Tensor, PaddingMask | None]: | |
# This is somewhat more memory efficient than setting param.requires_grad to False | |
# Since the encoder activations will not be saved in the graph too. | |
with torch.set_grad_enabled(not self.frozen_encoder): | |
# Run the encoder | |
enc_out, enc_padding_mask, _ = self.encoder_frontend.extract_features( | |
seqs, padding_mask | |
) | |
enc_out, enc_padding_mask, _ = self.encoder_frontend.process_features( | |
enc_out, enc_padding_mask, self.masker if self.training else None | |
) | |
enc_out, enc_padding_mask = self.encoder(enc_out, enc_padding_mask) | |
if self.final_dropout is not None: | |
enc_out = self.final_dropout(enc_out) | |
# Stack the encoder outputs | |
if enc_out.size(1) % self.encoder_stacking != 0: | |
n_padding = self.encoder_stacking - ( | |
enc_out.size(1) % self.encoder_stacking | |
) | |
enc_out = F.pad(enc_out, (0, 0, 0, n_padding)) | |
assert enc_out.size(1) % self.encoder_stacking == 0 | |
enc_out = enc_out.view( | |
enc_out.size(0), | |
enc_out.size(1) // self.encoder_stacking, | |
enc_out.size(-1) * self.encoder_stacking, | |
) | |
new_lengths = torch.where( | |
(enc_padding_mask.seq_lens % self.encoder_stacking) == 0, | |
enc_padding_mask.seq_lens // self.encoder_stacking, | |
enc_padding_mask.seq_lens // self.encoder_stacking + 1, | |
) | |
enc_padding_mask = PaddingMask(new_lengths, int(new_lengths.max())) | |
# Project encoder outputs to decoder input dimension | |
enc_out = self.encoder_proj(enc_out) | |
return enc_out, enc_padding_mask | |
def embed_text(self, seqs: Tensor, dtype: Literal) -> Tensor: | |
return self.text_frontend(seqs).to(dtype) | |
def concat_inputs( | |
self, inputs: List[Dict[str, object]] | |
) -> Tuple[Tensor, PaddingMask]: | |
t = inputs[0]["value"]["seqs"] | |
device = t.device | |
dtype = t.dtype | |
B = t.size(0) | |
input_dim = t.size(2) | |
ones = torch.ones(dtype=torch.int64, device=device, size=[B]) | |
# Compute total lengths | |
lengths = [ | |
( | |
inp["value"]["padding_mask"].seq_lens | |
if "padding_mask" in inp["value"] | |
else ones | |
) | |
for inp in inputs | |
] | |
total_lengths = sum(lengths) | |
padding_mask = PaddingMask(total_lengths, int(total_lengths.max())) | |
# Init the matrix with zeros | |
decoder_inputs = torch.zeros( | |
[B, int(total_lengths.max()), input_dim], | |
device=device, | |
dtype=dtype, | |
) | |
# Put everything in the right place | |
for b in range(B): | |
b_inputs = [ | |
inp["value"]["seqs"][b : b + 1, : length[b]] | |
for (inp, length) in zip(inputs, lengths) | |
] | |
b_inputs = torch.cat(b_inputs, dim=1) | |
assert b_inputs.size(1) == padding_mask.seq_lens[b] | |
decoder_inputs[b, : b_inputs.size(1)] = b_inputs | |
# Compute total context length (everything that we don't train the loss for) | |
context_lengths = [ | |
( | |
inp["value"]["padding_mask"].seq_lens | |
if "padding_mask" in inp["value"] | |
else ones | |
) | |
for inp in inputs | |
if inp["loss"] == False | |
] | |
total_context_lengths = sum(context_lengths) | |
context_padding_mask = PaddingMask( | |
total_context_lengths, int(total_context_lengths.max()) | |
) | |
decoder_context_inputs = decoder_inputs[:, : total_context_lengths.max()] | |
return ( | |
decoder_inputs, | |
padding_mask, | |
decoder_context_inputs, | |
context_padding_mask, | |
) | |
class Wav2Vec2LlamaOutput(AsrModelOutput): | |
logits: Tensor | |
"""The logits for next-step prediction. *Shape:* :math:`(N,S_{out}, V)`, | |
where :math:`N` is the batch size, :math:`S_{out}` is the decoder sequence | |
length, :math:`V` is the size | |
of the vocabulary.""" | |
logits_padding_mask: PaddingMask | |
"""The padding mask for the above tensor. *Shape:* :math:`(N,S_{out})`.""" | |
decoder_context_inputs: Tensor | |
""" | |
Inputs to the llama decoder for everything except the final text. *Shape:* :math:`(N,S_{out},D)`. | |
""" | |
decoder_context_padding_mask: PaddingMask | |
"""The padding mask for the above tensor. *Shape:* :math:`(N,S_{out})`, where | |
:math:`N` is the batch size and :math:`S_{out}` a sequence | |
length.""" | |
model: nn.Module | |
"""A reference to the model.""" | |
pad_idx: int | |
"""The index of the padding symbol in the target vocabulary.""" | |
eos_idx: int | |
"""The index of the end-of-sequence symbol in the target vocabulary.""" | |
def add_eos( | |
self, targets: Tensor, target_padding_mask: PaddingMask | |
) -> tuple[Tensor, PaddingMask]: | |
targets = torch.cat( | |
[ | |
targets, | |
torch.full_like(targets[:, :1], fill_value=self.pad_idx), | |
], | |
dim=-1, | |
) | |
targets[torch.arange(targets.size(0)), target_padding_mask.seq_lens] = ( | |
self.eos_idx | |
) | |
target_padding_mask = PaddingMask( | |
target_padding_mask.seq_lens + 1, | |
int(target_padding_mask.seq_lens.max()) + 1, | |
) | |
return targets, target_padding_mask | |
def remove_context_logits( | |
self, | |
targets: Tensor, | |
target_padding_mask: PaddingMask, | |
) -> Tensor: | |
assert self.decoder_context_padding_mask is not None | |
logits_no_context = torch.zeros_like( | |
self.logits[:, : targets.size(1), :], | |
) | |
for i in range(self.logits.size(0)): | |
context_len_i = self.decoder_context_padding_mask.seq_lens[i] | |
tgt_len_i = target_padding_mask.seq_lens[i] | |
total_len_i = self.logits_padding_mask.seq_lens[i] | |
assert context_len_i + tgt_len_i == total_len_i | |
logits_no_context[i, :tgt_len_i] = self.logits[ | |
i, context_len_i - 1 : context_len_i - 1 + tgt_len_i | |
] | |
return logits_no_context | |
def combine_masks(mask1: Tensor, mask2: Tensor) -> Tensor: | |
combined_mask = torch.zeros_like(mask1) | |
combined_mask[mask1] = mask2 | |
return combined_mask | |
def idx_1d_to_2d(idx: Tensor, dim2: int) -> tuple[Tensor, Tensor]: | |
return idx // dim2, idx % dim2 | |
def compression_ratio(text: str) -> float: | |
text_bytes = text.encode("utf-8") | |
return len(text_bytes) / len(zlib.compress(text_bytes)) | |
def generate_hypotheses( | |
self, pad_idx: int, blank_label: int = 0 | |
) -> tuple[Tensor, PaddingMask | None]: | |
# Some init | |
nbest = 5 | |
length_norm = False | |
B = self.decoder_context_inputs.size(0) | |
device = self.decoder_context_inputs.device | |
dtype = self.decoder_context_inputs.dtype | |
ex_separator = torch.arange(B, device=device).unsqueeze(1) * nbest | |
eos_idx = self.model.target_vocab_info.eos_idx | |
# Prepare a decoder input matrix, prefill with context | |
decoder_inputs = torch.zeros( | |
[ | |
B * nbest, | |
self.model.max_generation_length, | |
self.model.llama_decoder.model_dim, | |
], | |
device=device, | |
dtype=dtype, | |
) | |
decoder_inputs[:, : self.decoder_context_inputs.size(1)] = ( | |
self.decoder_context_inputs.repeat_interleave(nbest, dim=0) | |
) | |
context_lengths = self.decoder_context_padding_mask.seq_lens.repeat_interleave( | |
nbest | |
) | |
# Prepare a token output matrix and a scores matrix | |
out_tokens = torch.full_like( | |
decoder_inputs[:, :, 0], | |
fill_value=pad_idx, | |
dtype=torch.int, | |
) | |
scores = torch.zeros_like(decoder_inputs[:, 0, 0], dtype=torch.float) | |
# Prefill with shortest context, keep state | |
state_bag = IncrementalStateBag(max_num_steps=self.model.max_generation_length) | |
min_context_len = int(context_lengths.min()) | |
_, _ = self.model.llama_decoder( | |
seqs=decoder_inputs[:, :min_context_len], | |
padding_mask=None, | |
state_bag=state_bag, | |
) | |
state_bag.increment_step_nr(min_context_len) | |
# Iterative decoding | |
# For each sample, choose either context, or emitted text embedding | |
# If EOS is emitted, the sample is non-active | |
# Stop when there are no active samples | |
eos_mask = torch.zeros_like(context_lengths, dtype=torch.bool) | |
done = False | |
t = context_lengths.min() - 1 | |
while not done: | |
# Run the decoder on mixed context and emitted text embeddings | |
dec_out, _ = self.model.llama_decoder( | |
seqs=decoder_inputs[:, t : t + 1], | |
padding_mask=None, | |
state_bag=state_bag, | |
) | |
state_bag.increment_step_nr(1) | |
logits = self.model.final_proj(dec_out).squeeze(1) # [B * nbest, V] | |
log_probs = F.log_softmax(logits, dim=-1) | |
# Choose nbest | |
if length_norm: | |
n_tokens = torch.logical_and( | |
out_tokens[:, :t] != pad_idx, out_tokens[:, :t] != eos_idx | |
).sum(dim=1, keepdim=True) | |
candidate_scores = (scores.unsqueeze(1) * n_tokens + log_probs) / ( | |
n_tokens + 1 | |
) | |
else: | |
candidate_scores = scores.unsqueeze(1) + log_probs # [B * nbest, V] | |
candidate_scores[eos_mask] = -torch.inf | |
candidate_scores[eos_mask, eos_idx] = scores[ | |
eos_mask | |
] # Don't change scores for ended hypos | |
top_scores, top_idx = candidate_scores.view(B, -1).topk( | |
k=nbest, dim=-1, sorted=True | |
) | |
top_idx_nbest, top_idx_v = self.idx_1d_to_2d( | |
top_idx, candidate_scores.size(-1) | |
) | |
top_idx_b = (top_idx_nbest + ex_separator).view(-1) # Parent hypos indices | |
# Reorder some tensors based on parent hypos | |
out_tokens = out_tokens[top_idx_b] | |
eos_mask = eos_mask[top_idx_b] | |
state_bag.reorder(top_idx_b) | |
scores = torch.where(eos_mask, scores, top_scores.view(-1)) | |
out_tokens[:, t] = top_idx_v.view(-1) | |
# For hypos that still don't emit tokens, set new tokens to pad_idx, score to 0. | |
no_token_mask = t < context_lengths - 1 | |
out_tokens[no_token_mask, t] = pad_idx | |
scores[no_token_mask] = 0.0 | |
# For hypos that had EOS previously, set new tokens to EOS. Scores don't change. | |
# Set new EOS mask. | |
out_tokens[eos_mask, t] = eos_idx | |
new_tokens = out_tokens[:, t : t + 1] | |
eos_mask = (new_tokens == eos_idx).squeeze(1) | |
# Run new tokens through frontend, set in decoder input | |
new_tokens_embedded = self.model.embed_text(new_tokens, dtype=dtype) | |
decoder_inputs[~no_token_mask, t + 1] = ( | |
new_tokens_embedded[~no_token_mask].to(decoder_inputs.dtype).squeeze(1) | |
) # Don't override audio encoder outputs | |
# Early stopping if emitting repeating characters, use compression ratio | |
# only every t, only when started emitting tokens more than T tokens ago | |
compression_window = 100 | |
compression_threshold = 4.0 | |
if t % 250 == 0: | |
cpu_tokens = out_tokens[:, t - compression_window : t].cpu().numpy() | |
ratios_floats = [ | |
self.compression_ratio( | |
np.array_str(cpu_tokens[i]).replace("\n", "") | |
) | |
for i in range(B * nbest) | |
] | |
ratios = torch.tensor(ratios_floats, device=device) | |
early_stopping_mask = torch.logical_and( | |
ratios > compression_threshold, | |
t > context_lengths + compression_window, | |
) | |
eos_mask = torch.logical_or(eos_mask, early_stopping_mask) | |
# Decide if we are done | |
done = bool( | |
torch.logical_or( | |
torch.all(eos_mask), | |
t == self.model.max_generation_length - 4, | |
) | |
) | |
t += 1 | |
# Get final tokens, only use top hypo | |
out_tokens = out_tokens[::nbest] | |
valid_tokens_mask = torch.logical_and( | |
torch.logical_and( | |
out_tokens != pad_idx, | |
out_tokens != self.model.target_vocab_info.bos_idx, | |
), | |
out_tokens != eos_idx, | |
) | |
valid_tokens_count = valid_tokens_mask.sum(dim=1) | |
final_tokens = torch.full( | |
[B, int(valid_tokens_count.max())], | |
fill_value=pad_idx, | |
dtype=torch.int64, | |
device=device, | |
) | |
for i in range(B): | |
final_tokens[i, : valid_tokens_count[i]] = out_tokens[i][ | |
valid_tokens_mask[i] | |
] | |
padding_mask = PaddingMask(valid_tokens_count, int(valid_tokens_count.max())) | |
return final_tokens, padding_mask | |
class Wav2Vec2LlamaFactory: | |
_config: Wav2Vec2LlamaConfig | |
def __init__( | |
self, | |
config: Wav2Vec2LlamaConfig, | |
) -> None: | |
self._config = config | |
def create_encoder(self) -> tuple[Wav2Vec2Frontend, TransformerEncoder]: | |
factory = Wav2Vec2EncoderFactory(self._config.wav2vec_ctc_config.encoder_config) | |
return factory.create_encoder_frontend(), factory.create_encoder() | |
def create_masker(self) -> Wav2Vec2Masker: | |
config = self._config.wav2vec_ctc_config | |
return StandardWav2Vec2Masker( | |
config.mask_codebase, | |
config.encoder_config.model_dim, | |
config.temporal_mask_span_len, | |
config.max_temporal_mask_prob, | |
config.min_num_temporal_mask_spans, | |
config.spatial_mask_span_len, | |
config.max_spatial_mask_prob, | |
config.min_num_spatial_mask_spans, | |
) | |
def create_model(self) -> Wav2Vec2LlamaModel: | |
encoder_frontend, encoder = self.create_encoder() | |
masker = ( | |
self.create_masker() | |
if self._config.wav2vec_ctc_config.use_masking | |
else None | |
) | |
encoder_proj = Linear( | |
self._config.wav2vec_ctc_config.encoder_config.model_dim | |
* self._config.encoder_stacking, | |
self._config.llama_config.model_dim, | |
bias=True, | |
) | |
text_frontend = StandardEmbedding( | |
num_embeddings=self._config.llama_config.vocab_info.size, | |
embedding_dim=self._config.llama_config.model_dim, | |
) | |
llama_decoder = LLaMAFactory(self._config.llama_config).create_decoder() | |
final_proj = Linear( | |
self._config.llama_config.model_dim, | |
self._config.llama_config.vocab_info.size, | |
bias=False, | |
) | |
return Wav2Vec2LlamaModel( | |
encoder_frontend=encoder_frontend, | |
encoder=encoder, | |
encoder_proj=encoder_proj, | |
text_frontend=text_frontend, | |
llama_decoder=llama_decoder, | |
final_proj=final_proj, | |
target_vocab_info=self._config.wav2vec_ctc_config.vocab_info, | |
masker=masker, | |
final_dropout_p=self._config.wav2vec_ctc_config.final_dropout_p, | |
max_generation_length=self._config.llama_config.max_seq_len, | |
encoder_stacking=self._config.encoder_stacking, | |
frozen_encoder=self._config.frozen_encoder, | |
) | |
"""Configs""" | |
from dataclasses import dataclass, field | |
from typing import Final | |
from fairseq2.context import RuntimeContext | |
from fairseq2.data import VocabularyInfo | |
from fairseq2.models.wav2vec2 import Wav2Vec2EncoderConfig | |
WAV2VEC2_ASR_MODEL_FAMILY: Final = "wav2vec2_asr" | |
class Wav2Vec2AsrConfig: | |
"""Holds the configuration of a wav2vec 2.0 ASR model. | |
The default values correspond to the base 10h architecture as described in | |
:cite:t:`https://doi.org/10.48550/arxiv.2006.11477`. | |
""" | |
encoder_config: Wav2Vec2EncoderConfig = field( | |
default_factory=lambda: Wav2Vec2EncoderConfig( | |
feature_gradient_scale=1.0, | |
dropout_p=0.0, | |
attn_dropout_p=0.0, | |
ffn_inner_dropout_p=0.1, | |
) | |
) | |
"""The configuration of the encoder.""" | |
vocab_info: VocabularyInfo = field( | |
default_factory=lambda: VocabularyInfo( | |
size=32, unk_idx=3, bos_idx=0, eos_idx=2, pad_idx=1 | |
) | |
) | |
"""The vocabulary information.""" | |
final_dropout_p: float = 0.0 | |
"""The dropout probability on the output of the encoder.""" | |
# Mask | |
mask_codebase: str = "fairseq2" | |
use_masking: bool = True | |
"""If ``True``, masks features as regularization.""" | |
temporal_mask_span_len: int = 10 | |
"""The length of each temporal mask span that is applied over time steps.""" | |
max_temporal_mask_prob: float = 0.69 | |
"""The maximum probability of masking a time step. Note that, due to mask | |
span overlap, the effective probability will be lower.""" | |
min_num_temporal_mask_spans: int = 2 | |
"""The minimum number of temporal masks sampled per sequence.""" | |
spatial_mask_span_len: int = 64 | |
"""The length of each spatial mask span that is applied over features.""" | |
max_spatial_mask_prob: float = 0.55 | |
"""The maximum probability of masking a feature. Note that, due to mask span | |
overlap, the effective probability will be lower.""" | |
min_num_spatial_mask_spans: int = 2 | |
"""The minimum number of spatial masks sampled per sequence.""" | |
def register_wav2vec2_asr_configs(context: RuntimeContext) -> None: | |
registry = context.get_config_registry(Wav2Vec2AsrConfig) | |
wav2vec2_asr_arch = registry.decorator | |
w2v2_encoder_registry = context.get_config_registry(Wav2Vec2EncoderConfig) | |
def base_10h() -> Wav2Vec2AsrConfig: | |
return Wav2Vec2AsrConfig() | |
def base_100h() -> Wav2Vec2AsrConfig: | |
config = base_10h() | |
config.encoder_config.layer_drop_p = 0.1 | |
return config | |
def large_10h() -> Wav2Vec2AsrConfig: | |
config = base_10h() | |
config.encoder_config = w2v2_encoder_registry.get("large") | |
config.encoder_config.feature_gradient_scale = 1.0 | |
config.encoder_config.dropout_p = 0.0 | |
config.encoder_config.attn_dropout_p = 0.0 | |
config.encoder_config.ffn_inner_dropout_p = 0.1 | |
config.encoder_config.layer_drop_p = 0.1 | |
config.max_temporal_mask_prob = 0.80 | |
config.max_spatial_mask_prob = 0.30 | |
return config | |
def large_100h() -> Wav2Vec2AsrConfig: | |
config = large_10h() | |
config.max_temporal_mask_prob = 0.53 | |
config.max_spatial_mask_prob = 0.55 | |
return config | |
def large_lv60k_10h() -> Wav2Vec2AsrConfig: | |
config = base_10h() | |
config.encoder_config = w2v2_encoder_registry.get("large_lv60k") | |
config.encoder_config.feature_gradient_scale = 1.0 | |
config.encoder_config.dropout_p = 0.0 | |
config.encoder_config.attn_dropout_p = 0.0 | |
config.encoder_config.ffn_inner_dropout_p = 0.1 | |
config.encoder_config.layer_drop_p = 0.1 | |
config.max_temporal_mask_prob = 0.80 | |
config.max_spatial_mask_prob = 0.30 | |
return config | |
def large_lv60k_100h() -> Wav2Vec2AsrConfig: | |
config = large_lv60k_10h() | |
config.max_temporal_mask_prob = 0.53 | |
config.max_spatial_mask_prob = 0.55 | |
return config | |
def bib61_300m() -> Wav2Vec2AsrConfig: | |
config = base_10h() | |
config.encoder_config = w2v2_encoder_registry.get("large_lv60k") | |
config.encoder_config.feature_gradient_scale = 1.0 | |
config.encoder_config.dropout_p = 0.0 | |
config.encoder_config.attn_dropout_p = 0.0 | |
config.encoder_config.ffn_inner_dropout_p = 0.1 | |
config.encoder_config.layer_drop_p = 0.1 | |
config.use_masking = False | |
config.max_temporal_mask_prob = 0.0 | |
config.max_spatial_mask_prob = 0.0 | |
config.vocab_info.size = 2475 | |
return config | |
def bib1143_300m() -> Wav2Vec2AsrConfig: | |
config = bib61_300m() | |
config.vocab_info.size = 3335 | |
return config | |
def bib61_1b() -> Wav2Vec2AsrConfig: | |
config = base_10h() | |
config.encoder_config = w2v2_encoder_registry.get("1b") | |
config.encoder_config.feature_gradient_scale = 1.0 | |
config.encoder_config.dropout_p = 0.0 | |
config.encoder_config.attn_dropout_p = 0.0 | |
config.encoder_config.ffn_inner_dropout_p = 0.1 | |
config.encoder_config.layer_drop_p = 0.1 | |
config.use_masking = False | |
config.max_temporal_mask_prob = 0.0 | |
config.max_spatial_mask_prob = 0.0 | |
config.vocab_info.size = 2475 | |
return config | |
def llama_bib61_1b() -> Wav2Vec2AsrConfig: | |
config = base_10h() | |
config.encoder_config = w2v2_encoder_registry.get("1b_llama") | |
config.encoder_config.feature_gradient_scale = 1.0 | |
config.encoder_config.dropout_p = 0.0 | |
config.encoder_config.attn_dropout_p = 0.0 | |
config.encoder_config.ffn_inner_dropout_p = 0.1 | |
config.encoder_config.layer_drop_p = 0.1 | |
config.use_masking = False | |
config.max_temporal_mask_prob = 0.0 | |
config.max_spatial_mask_prob = 0.0 | |
config.vocab_info.size = 2475 | |
return config | |
def bib61_2b() -> Wav2Vec2AsrConfig: | |
config = base_10h() | |
config.encoder_config = w2v2_encoder_registry.get("2b") | |
config.encoder_config.feature_gradient_scale = 1.0 | |
config.encoder_config.dropout_p = 0.0 | |
config.encoder_config.attn_dropout_p = 0.0 | |
config.encoder_config.ffn_inner_dropout_p = 0.1 | |
config.encoder_config.layer_drop_p = 0.1 | |
config.use_masking = False | |
config.max_temporal_mask_prob = 0.0 | |
config.max_spatial_mask_prob = 0.0 | |
config.vocab_info.size = 2475 | |
return config | |
def bib61_3b() -> Wav2Vec2AsrConfig: | |
config = base_10h() | |
config.encoder_config = w2v2_encoder_registry.get("3b") | |
config.encoder_config.feature_gradient_scale = 1.0 | |
config.encoder_config.dropout_p = 0.0 | |
config.encoder_config.attn_dropout_p = 0.0 | |
config.encoder_config.ffn_inner_dropout_p = 0.1 | |
config.encoder_config.layer_drop_p = 0.1 | |
config.use_masking = False | |
config.max_temporal_mask_prob = 0.0 | |
config.max_spatial_mask_prob = 0.0 | |
config.vocab_info.size = 2475 | |
return config | |
def bib61_5b() -> Wav2Vec2AsrConfig: | |
config = base_10h() | |
config.encoder_config = w2v2_encoder_registry.get("5b") | |
config.encoder_config.feature_gradient_scale = 1.0 | |
config.encoder_config.dropout_p = 0.0 | |
config.encoder_config.attn_dropout_p = 0.0 | |
config.encoder_config.ffn_inner_dropout_p = 0.1 | |
config.encoder_config.layer_drop_p = 0.1 | |
config.use_masking = False | |
config.max_temporal_mask_prob = 0.0 | |
config.max_spatial_mask_prob = 0.0 | |
config.vocab_info.size = 2475 | |
return config | |
def bib61_7b() -> Wav2Vec2AsrConfig: | |
config = base_10h() | |
config.encoder_config = w2v2_encoder_registry.get("7b") | |
config.encoder_config.feature_gradient_scale = 1.0 | |
config.encoder_config.dropout_p = 0.0 | |
config.encoder_config.attn_dropout_p = 0.0 | |
config.encoder_config.ffn_inner_dropout_p = 0.1 | |
config.encoder_config.layer_drop_p = 0.1 | |
config.use_masking = False | |
config.max_temporal_mask_prob = 0.0 | |
config.max_spatial_mask_prob = 0.0 | |
config.vocab_info.size = 2475 | |
return config | |
def higher_bib61_3b() -> Wav2Vec2AsrConfig: | |
config = base_10h() | |
config.encoder_config = w2v2_encoder_registry.get("3.25b") | |
config.encoder_config.feature_gradient_scale = 1.0 | |
config.encoder_config.dropout_p = 0.0 | |
config.encoder_config.attn_dropout_p = 0.0 | |
config.encoder_config.ffn_inner_dropout_p = 0.1 | |
config.encoder_config.layer_drop_p = 0.1 | |
config.use_masking = False | |
config.max_temporal_mask_prob = 0.0 | |
config.max_spatial_mask_prob = 0.0 | |
config.vocab_info.size = 2475 | |
return config | |
def front51_5b() -> Wav2Vec2AsrConfig: | |
config = base_10h() | |
config.encoder_config = w2v2_encoder_registry.get("5b") | |
config.encoder_config.feature_gradient_scale = 1.0 | |
config.encoder_config.dropout_p = 0.0 | |
config.encoder_config.attn_dropout_p = 0.0 | |
config.encoder_config.ffn_inner_dropout_p = 0.1 | |
config.encoder_config.layer_drop_p = 0.1 | |
config.use_masking = False | |
config.max_temporal_mask_prob = 0.0 | |
config.max_spatial_mask_prob = 0.0 | |
config.vocab_info.size = 222 | |
return config | |
def front51_7b() -> Wav2Vec2AsrConfig: | |
config = base_10h() | |
config.encoder_config = w2v2_encoder_registry.get("7b") | |
config.encoder_config.feature_gradient_scale = 1.0 | |
config.encoder_config.dropout_p = 0.0 | |
config.encoder_config.attn_dropout_p = 0.0 | |
config.encoder_config.ffn_inner_dropout_p = 0.1 | |
config.encoder_config.layer_drop_p = 0.1 | |
config.use_masking = False | |
config.max_temporal_mask_prob = 0.0 | |
config.max_spatial_mask_prob = 0.0 | |
config.vocab_info.size = 222 | |
return config | |
def bib1143_1b() -> Wav2Vec2AsrConfig: | |
config = base_10h() | |
config.encoder_config = w2v2_encoder_registry.get("1b") | |
config.encoder_config.feature_gradient_scale = 1.0 | |
config.encoder_config.dropout_p = 0.0 | |
config.encoder_config.attn_dropout_p = 0.0 | |
config.encoder_config.ffn_inner_dropout_p = 0.1 | |
config.encoder_config.layer_drop_p = 0.1 | |
config.use_masking = False | |
config.max_temporal_mask_prob = 0.0 | |
config.max_spatial_mask_prob = 0.0 | |
config.vocab_info.size = 3335 | |
return config | |
def bib1143_3b() -> Wav2Vec2AsrConfig: | |
config = base_10h() | |
config.encoder_config = w2v2_encoder_registry.get("3b") | |
config.encoder_config.feature_gradient_scale = 1.0 | |
config.encoder_config.dropout_p = 0.0 | |
config.encoder_config.attn_dropout_p = 0.0 | |
config.encoder_config.ffn_inner_dropout_p = 0.1 | |
config.encoder_config.layer_drop_p = 0.1 | |
config.use_masking = False | |
config.max_temporal_mask_prob = 0.0 | |
config.max_spatial_mask_prob = 0.0 | |
config.vocab_info.size = 3335 | |
return config | |
def bib1143_5b() -> Wav2Vec2AsrConfig: | |
config = base_10h() | |
config.encoder_config = w2v2_encoder_registry.get("5b") | |
config.encoder_config.feature_gradient_scale = 1.0 | |
config.encoder_config.dropout_p = 0.0 | |
config.encoder_config.attn_dropout_p = 0.0 | |
config.encoder_config.ffn_inner_dropout_p = 0.1 | |
config.encoder_config.layer_drop_p = 0.1 | |
config.use_masking = False | |
config.max_temporal_mask_prob = 0.0 | |
config.max_spatial_mask_prob = 0.0 | |
config.vocab_info.size = 3335 # following bibfront1194's vocab size | |
return config | |
def bib1143_7b() -> Wav2Vec2AsrConfig: | |
config = base_10h() | |
config.encoder_config = w2v2_encoder_registry.get("7b") | |
config.encoder_config.feature_gradient_scale = 1.0 | |
config.encoder_config.dropout_p = 0.0 | |
config.encoder_config.attn_dropout_p = 0.0 | |
config.encoder_config.ffn_inner_dropout_p = 0.1 | |
config.encoder_config.layer_drop_p = 0.1 | |
config.use_masking = False | |
config.max_temporal_mask_prob = 0.0 | |
config.max_spatial_mask_prob = 0.0 | |
config.vocab_info.size = 3335 | |
return config | |
from dataclasses import dataclass, field | |
from typing import Final | |
from fairseq2.context import RuntimeContext | |
from fairseq2.nn.transformer import TransformerNormOrder | |
from fairseq2.utils.validation import ValidationError, ValidationResult | |
WAV2VEC2_MODEL_FAMILY: Final = "wav2vec2" | |
class Wav2Vec2Config: | |
"""Holds the configuration of a wav2vec 2.0 model. | |
The default values correspond to the base architecture as described in | |
:cite:t:`https://doi.org/10.48550/arxiv.2006.11477`. | |
""" | |
encoder_config: Wav2Vec2EncoderConfig = field( | |
default_factory=lambda: Wav2Vec2EncoderConfig() | |
) | |
"""The configuration of the wav2vec 2.0 encoder.""" | |
final_dim: int = 256 | |
"""The dimensionality of the final projection that is applied to context | |
network outputs and quantized targets.""" | |
final_proj_bias: bool = True | |
"""If ``True``, the final projection learns an additive bias.""" | |
quantizer_encoder_grad: bool = True | |
"""If ``True``, gradients are propagated from the quantizer through the convolutional | |
encoder. Otherwise, they are detached and the encoder is only trained with gradients | |
from the transformer. """ | |
# Mask | |
mask_codebase: str = "fairseq2" | |
temporal_mask_span_len: int = 10 | |
"""The length of each temporal mask span that is applied over time steps.""" | |
max_temporal_mask_prob: float = 0.69 | |
"""The maximum probability of masking a time step. Note that, due to mask | |
span overlap, the effective probability will be lower.""" | |
min_num_temporal_mask_spans: int = 2 | |
"""The minimum number of temporal masks sampled per sequence.""" | |
spatial_mask_span_len: int = 10 | |
"""The length of each spatial mask span that is applied over features.""" | |
max_spatial_mask_prob: float = 0.0 | |
"""The maximum probability of masking a feature. Note that, due to mask span | |
overlap, the effective probability will be lower.""" | |
min_num_spatial_mask_spans: int = 2 | |
"""The minimum number of spatial masks sampled per sequence.""" | |
# Quantization | |
quantized_dim: int = 256 | |
"""The output dimensionality of vector quantizer.""" | |
num_codebooks: int = 2 | |
"""The number of codebooks.""" | |
num_codebook_entries: int = 320 | |
"""The number of entries per codebook.""" | |
codebook_sampling_temperature: tuple[float, float, float] = (2.0, 0.5, 0.999995) | |
"""A tuple of start temperature, end temperature, and decay factor for | |
codebook entry sampling.""" | |
# Loss | |
num_distractors: int = 100 | |
"""The number of distractors to use in contrastive prediction.""" | |
logit_temp: float = 0.1 | |
"""The temperature to divide logits by.""" | |
class Wav2Vec2EncoderConfig: | |
"""Holds the configuration of a wav2vec 2.0 encoder. | |
The default values correspond to the base architecture described in | |
:cite:t:`https://doi.org/10.48550/arxiv.2006.11477`. | |
""" | |
model_dim: int = 768 | |
"""The dimensionality of the model.""" | |
max_seq_len: int = 4096 | |
"""The maximum sequence length after feature extraction.""" | |
# Features | |
feature_dim: int = 512 | |
"""The dimensionality of extracted features.""" | |
use_fbank: bool = False | |
"""If ``True``, uses log-mel filterbanks instead of waveforms as input.""" | |
first_pass_dropout_p: float = 0.0 | |
"""The dropout probability on extracted features before masking and | |
positional encoding.""" | |
layer_norm_features: bool = True | |
"""If ``True``, applies Layer Normalization to extracted features.""" | |
# Waveform Feature Extractor | |
feature_extractor_layer_descs: list[tuple[int, int, int]] = field( | |
default_factory=lambda: [(512, 10, 5)] + [(512, 3, 2)] * 4 + [(512, 2, 2)] * 2 | |
) | |
"""A tuple of output dimension, kernel size, and stride for each feature | |
extraction layer.""" | |
feature_extractor_bias: bool = False | |
"""If ``True``, convolutions in feature extraction layers learn an additive | |
bias.""" | |
feature_extractor_layer_norm_convs: bool = False | |
"""If ``True``, applies Layer Normalization to outputs of convolutions in | |
feature extraction layers.""" | |
feature_gradient_scale: float = 0.1 | |
"""The scale factor for gradients of extracted features. Setting to a value | |
less than 1.0 allows the feature extractor to learn at a lower rate than the | |
rest of the model.""" | |
# Filterbank Feature Extractor | |
num_fbank_channels: int = 0 | |
"""The number of source log-mel filterbank channels.""" | |
fbank_stride: int = 0 | |
sample_fbank_every_k: int = 0 | |
# Position Encoder | |
pos_encoder_type: str = "conv" | |
"""The type of position encoder ('conv', 'relative', 'rotary').""" | |
# Convolutional Position Encoder | |
pos_encoder_depth: int = 1 | |
"""The number of stacked position encoder layers.""" | |
pos_conv_kernel_size: int = 128 | |
"""The total kernel size of 1D convolutions in position encoder layers.""" | |
num_pos_conv_groups: int = 16 | |
"""The number of convolution groups in position encoder layers.""" | |
# Encoder (i.e. Context Network) | |
use_conformer: bool = False | |
"""If ``True``, uses Conformer blocks instead of Transformer encoder layers.""" | |
num_encoder_layers: int = 12 | |
"""The number of encoder layers.""" | |
num_encoder_attn_heads: int = 12 | |
"""The number of attention heads in encoder layers.""" | |
ffn_inner_dim: int = 3072 | |
"""The inner dimensionality of feed-forward networks.""" | |
dropout_p: float = 0.1 | |
"""The dropout probability on outputs of Transformer layers.""" | |
attn_dropout_p: float = 0.1 | |
"""The dropout probability on attention weights.""" | |
ffn_inner_dropout_p: float = 0.0 | |
"""The dropout probability on inner activations of feed-forward networks.""" | |
layer_drop_p: float = 0.05 | |
"""If greater than zero, applies LayerDrop to encoder layers as described in | |
:cite:t:`https://doi.org/10.48550/arxiv.1909.11556`.""" | |
norm_order: TransformerNormOrder = TransformerNormOrder.POST | |
"""The Layer Normalization order.""" | |
depthwise_conv_kernel_size: int = 0 | |
"""The kernel size of depthwise convolutions in Conformer blocks.""" | |
def validate(self) -> None: | |
result = ValidationResult() | |
if self.use_conformer and self.norm_order != TransformerNormOrder.POST: | |
result.add_error( | |
f"`norm_order` must be `POST` when `use_conformer` is `True`, but is `{self.norm_order}` instead." | |
) | |
if result.has_error: | |
raise ValidationError( | |
"The wav2vec 2.0 encoder configuration has one or more validation errors:", result # fmt: skip | |
) | |
def register_wav2vec2_configs(context: RuntimeContext) -> None: | |
arch = context.get_config_registry(Wav2Vec2Config).decorator | |
arch_encoder = context.get_config_registry(Wav2Vec2EncoderConfig).decorator | |
def base() -> Wav2Vec2Config: | |
return Wav2Vec2Config() | |
def base_encoder() -> Wav2Vec2EncoderConfig: | |
return base().encoder_config | |
def large() -> Wav2Vec2Config: | |
config = base() | |
config.encoder_config.model_dim = 1024 | |
config.encoder_config.num_encoder_layers = 24 | |
config.encoder_config.num_encoder_attn_heads = 16 | |
config.encoder_config.ffn_inner_dim = 4096 | |
config.encoder_config.dropout_p = 0.0 | |
config.encoder_config.layer_drop_p = 0.2 | |
config.quantized_dim = 768 | |
config.final_dim = 768 | |
return config | |
def large_encoder() -> Wav2Vec2EncoderConfig: | |
return large().encoder_config | |
def large_lv60k() -> Wav2Vec2Config: | |
config = large() | |
config.encoder_config.layer_norm_features = False | |
config.encoder_config.feature_extractor_bias = True | |
config.encoder_config.feature_extractor_layer_norm_convs = True | |
config.encoder_config.layer_drop_p = 0.0 | |
config.encoder_config.norm_order = TransformerNormOrder.PRE | |
config.codebook_sampling_temperature = (2.0, 0.1, 0.999995) | |
return config | |
def large_lv60k_encoder() -> Wav2Vec2EncoderConfig: | |
return large_lv60k().encoder_config | |
def xlsr_base() -> Wav2Vec2Config: | |
config = large_lv60k() | |
config.encoder_config.attn_dropout_p = 0.0 | |
config.encoder_config.feature_gradient_scale = 1.0 | |
return config | |
def xlsr_base_encoder() -> Wav2Vec2EncoderConfig: | |
return xlsr_base().encoder_config | |
def base_conformer() -> Wav2Vec2Config: | |
config = xlsr_base() | |
config.encoder_config.use_conformer = True | |
config.encoder_config.norm_order = TransformerNormOrder.POST | |
config.encoder_config.depthwise_conv_kernel_size = 31 | |
# pos_encoder_type | |
return config | |
def base_conformer_encoder() -> Wav2Vec2EncoderConfig: | |
return base_conformer().encoder_config | |
def tiny() -> Wav2Vec2Config: | |
config = xlsr_base() | |
config.encoder_config.model_dim = 1280 | |
config.encoder_config.num_encoder_layers = 4 | |
config.encoder_config.ffn_inner_dim = 1280 | |
config.encoder_config.dropout_p = 0.0 | |
config.quantized_dim = 512 | |
config.final_dim = 512 | |
config.encoder_config.first_pass_dropout_p = 0.1 | |
return config | |
def tiny_encoder() -> Wav2Vec2EncoderConfig: | |
return tiny().encoder_config | |
def b1() -> Wav2Vec2Config: | |
config = xlsr_base() | |
config.encoder_config.model_dim = 1280 | |
config.encoder_config.num_encoder_layers = 48 | |
config.encoder_config.ffn_inner_dim = 5120 | |
config.encoder_config.dropout_p = 0.0 | |
config.quantized_dim = 1024 | |
config.final_dim = 1024 | |
config.encoder_config.first_pass_dropout_p = 0.1 | |
return config | |
def b1_encoder() -> Wav2Vec2EncoderConfig: | |
return b1().encoder_config | |
def b2() -> Wav2Vec2Config: | |
config = b1() | |
config.encoder_config.model_dim = 1920 | |
config.encoder_config.ffn_inner_dim = 7680 | |
return config | |
def b2_encoder() -> Wav2Vec2EncoderConfig: | |
return b2().encoder_config | |
def b3() -> Wav2Vec2Config: | |
config = b1() | |
config.encoder_config.num_encoder_layers = 60 | |
config.encoder_config.model_dim = 2048 | |
config.encoder_config.ffn_inner_dim = 8192 | |
return config | |
def b3_encoder() -> Wav2Vec2EncoderConfig: | |
return b3().encoder_config | |
def mel_3b() -> Wav2Vec2Config: | |
config = b3() | |
config.encoder_config.use_fbank = True | |
config.encoder_config.num_fbank_channels = 80 | |
config.encoder_config.fbank_stride = 2 | |
config.encoder_config.sample_fbank_every_k = 1 | |
config.encoder_config.feature_dim = 160 | |
return config | |
def mel_3b_encoder() -> Wav2Vec2EncoderConfig: | |
return mel_3b().encoder_config | |
def higher_3b() -> Wav2Vec2Config: | |
config = b1() | |
config.encoder_config.num_encoder_layers = 64 | |
config.encoder_config.model_dim = 2048 | |
config.encoder_config.ffn_inner_dim = 8192 | |
config.encoder_config.num_encoder_attn_heads = 32 | |
config.quantized_dim = 1280 | |
config.final_dim = 1280 | |
return config | |
def higher_3b_encoder() -> Wav2Vec2EncoderConfig: | |
return higher_3b().encoder_config | |
def b4() -> Wav2Vec2Config: | |
config = b2() | |
config.quantized_dim = 1280 | |
config.final_dim = 1280 | |
config.encoder_config.num_encoder_layers = 64 | |
config.encoder_config.model_dim = 2304 | |
config.encoder_config.ffn_inner_dim = 9216 | |
config.encoder_config.num_encoder_attn_heads = 32 | |
return config | |
def b4_encoder() -> Wav2Vec2EncoderConfig: | |
return b4().encoder_config | |
def llama_1b() -> Wav2Vec2Config: | |
config = xlsr_base() | |
config.encoder_config.model_dim = 2048 | |
config.encoder_config.num_encoder_layers = 16 | |
config.encoder_config.ffn_inner_dim = int(2048 * 4 * 1.5) | |
config.encoder_config.num_encoder_attn_heads = 32 | |
config.encoder_config.dropout_p = 0.0 | |
config.quantized_dim = 1024 | |
config.final_dim = 1024 | |
config.encoder_config.first_pass_dropout_p = 0.1 | |
return config | |
def llama_1b_encoder() -> Wav2Vec2EncoderConfig: | |
return llama_1b().encoder_config | |
def llama_3b() -> Wav2Vec2Config: | |
config = llama_1b() | |
config.encoder_config.model_dim = 2560 | |
config.encoder_config.num_encoder_layers = 32 | |
config.encoder_config.ffn_inner_dim = int(2560 * 4 * 1.0) | |
config.quantized_dim = 2048 | |
config.final_dim = 2048 | |
return config | |
def llama_3b_encoder() -> Wav2Vec2EncoderConfig: | |
return llama_3b().encoder_config | |
def b5() -> Wav2Vec2Config: | |
config = b3() | |
config.encoder_config.num_encoder_layers = 96 | |
config.encoder_config.model_dim = 2048 | |
config.encoder_config.ffn_inner_dim = 8192 | |
config.encoder_config.num_encoder_attn_heads = 16 | |
config.quantized_dim = 1024 | |
config.final_dim = 1024 | |
return config | |
def b5_encoder() -> Wav2Vec2EncoderConfig: | |
return b5().encoder_config | |
def b7() -> Wav2Vec2Config: | |
config = b5() | |
config.encoder_config.num_encoder_layers = 128 | |
config.encoder_config.model_dim = 2048 | |
config.encoder_config.ffn_inner_dim = 8192 | |
config.encoder_config.num_encoder_attn_heads = 16 | |
config.quantized_dim = 1024 | |
config.final_dim = 1024 | |
return config | |
def b7_encoder() -> Wav2Vec2EncoderConfig: | |
return b7().encoder_config | |
# @title Create model and load weights | |
"""Create model and load weights""" | |
from dataclasses import field | |
import torch | |
from fairseq2 import setup_fairseq2 | |
from fairseq2.context import get_runtime_context | |
from fairseq2.data.text.tokenizers.sentencepiece import RawSentencePieceTokenizer | |
class Wav2Vec2LlamaConfig: | |
wav2vec_ctc_config: Wav2Vec2AsrConfig = field() | |
llama_config: LLaMAConfig = field() | |
encoder_stacking: int = 1 | |
frozen_encoder: bool = False | |
def load_mms_model(ckpt_path: str, tokenizer_path: str, device=None): | |
""" | |
Load the MMS model and tokenizer from checkpoint files with memory optimization. | |
Args: | |
ckpt_path (str): Path to the model checkpoint file | |
tokenizer_path (str): Path to the tokenizer model file | |
device: Device to load the model on. If None, auto-detects GPU/CPU | |
Returns: | |
tuple: (model, text_decoder, device) where: | |
- model: The loaded and configured MMS model | |
- text_decoder: The tokenizer decoder | |
- device: The device the model is loaded on | |
""" | |
import gc | |
import os | |
import psutil | |
logger = logging.getLogger(__name__) | |
def log_memory_usage(step: str): | |
"""Log current memory usage.""" | |
process = psutil.Process(os.getpid()) | |
memory_info = process.memory_info() | |
virtual_memory = psutil.virtual_memory() | |
logger.info( | |
f"[{step}] Process RSS: {memory_info.rss / (1024**3):.2f} GB, " | |
f"System Available: {virtual_memory.available / (1024**3):.2f} GB" | |
) | |
logger.info(f"Starting MMS model loading process...") | |
logger.info(f"Checkpoint path: {ckpt_path}") | |
logger.info(f"Tokenizer path: {tokenizer_path}") | |
# Check file size | |
if os.path.exists(ckpt_path): | |
ckpt_size_gb = os.path.getsize(ckpt_path) / (1024**3) | |
logger.info(f"Checkpoint file size: {ckpt_size_gb:.2f} GB") | |
log_memory_usage("Initial") | |
# Set device with proper CUDA initialization | |
if device is None: | |
try: | |
# Initialize CUDA context properly | |
logger.info("Checking CUDA availability...") | |
if torch.cuda.is_available(): | |
logger.info( | |
f"CUDA is available. Device count: {torch.cuda.device_count()}" | |
) | |
# Initialize CUDA context | |
torch.cuda.init() | |
# Set device to first available GPU | |
device = torch.device("cuda:0") | |
logger.info(f"CUDA device name: {torch.cuda.get_device_name(0)}") | |
logger.info( | |
f"CUDA device memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB" | |
) | |
else: | |
logger.warning("CUDA is not available, falling back to CPU") | |
device = torch.device("cpu") | |
except Exception as e: | |
logger.warning(f"CUDA initialization failed: {e}, falling back to CPU") | |
device = torch.device("cpu") | |
else: | |
device = torch.device("cpu") # Force CPU for memory efficiency | |
logger.info(f"Using device: {device}") | |
# Load model parameters from checkpoint with memory optimization | |
logger.info("Loading model parameters from checkpoint...") | |
try: | |
# Try memory-mapped loading first (more memory efficient) | |
logger.info("Attempting memory-mapped loading...") | |
model_params = torch.load(ckpt_path, map_location="cpu", mmap=True) | |
logger.info("✓ Model parameters loaded successfully (memory-mapped)") | |
except Exception as e: | |
logger.warning(f"Memory-mapped loading failed: {e}") | |
logger.info("Falling back to regular loading...") | |
# Force garbage collection before loading | |
gc.collect() | |
model_params = torch.load(ckpt_path, map_location="cpu") | |
logger.info("✓ Model parameters loaded successfully (regular)") | |
log_memory_usage("After checkpoint load") | |
# Create context | |
logger.info("Setting up fairseq2 context and registering configs...") | |
setup_fairseq2() | |
context = get_runtime_context() | |
try: | |
register_wav2vec2_configs(context) | |
register_wav2vec2_asr_configs(context) | |
logger.info("✓ Configs registered successfully") | |
except Exception as e: | |
logger.warning(f"Config registration failed (may already be registered): {e}") | |
w2v2_ctc_registry = context.get_config_registry(Wav2Vec2AsrConfig) | |
# Create config | |
logger.info("Creating model configuration...") | |
wav2vec_ctc_config = w2v2_ctc_registry.get("7b_bib1143") | |
logger.info( | |
f"✓ wav2vec config loaded: vocab_size={wav2vec_ctc_config.vocab_info.size}" | |
) | |
llama_config = LLaMAConfig( | |
model_dim=4096, | |
max_seq_len=8192, | |
vocab_info=wav2vec_ctc_config.vocab_info, | |
num_layers=12, | |
num_attn_heads=8, | |
num_key_value_heads=8, | |
ffn_inner_dim=4096, | |
rope_theta=10_000.0, | |
dropout_p=0.1, | |
) | |
logger.info( | |
f"✓ LLaMA config created: model_dim={llama_config.model_dim}, layers={llama_config.num_layers}" | |
) | |
config = Wav2Vec2LlamaConfig() | |
config.wav2vec_ctc_config = wav2vec_ctc_config | |
config.llama_config = llama_config | |
# Instantiate model | |
logger.info("Instantiating model from factory...") | |
factory = Wav2Vec2LlamaFactory(config) | |
model = factory.create_model() | |
logger.info("✓ Model instantiated successfully") | |
# Load state dict from ckpt | |
logger.info("Loading model state dictionary...") | |
model.load_state_dict(model_params["model"]) | |
del model_params | |
logger.info("✓ Model weights loaded successfully") | |
# Move to device and set eval mode | |
logger.info(f"Moving model to device {device} and setting eval mode...") | |
model = model.to(device).eval() | |
logger.info("✓ Model moved to device and set to eval mode") | |
# Create tokenizer | |
logger.info(f"Creating tokenizer from {tokenizer_path}...") | |
tokenizer = RawSentencePieceTokenizer(tokenizer_path) | |
text_decoder_1143 = tokenizer.create_decoder() | |
logger.info("✓ Tokenizer created successfully") | |
logger.info("MMS model loading completed successfully!") | |
return model, text_decoder_1143, device | |
def prepare_audio_batch(wav_path: str, device, max_duration_seconds=2): | |
""" | |
Load a wav file from disk and prepare batch for model inference. | |
Args: | |
wav_path (str): Path to the WAV file | |
device: Device to place the batch on | |
max_duration_seconds (int): Maximum duration to process (for efficiency) | |
Returns: | |
Seq2SeqBatch: Prepared batch for model inference | |
""" | |
logger = logging.getLogger(__name__) | |
logger.info(f"Preparing audio batch from: {wav_path}") | |
logger.info(f"Max duration: {max_duration_seconds}s, target device: {device}") | |
# Load the WAV file, resample the data to 16 kHz | |
logger.info("Loading and resampling audio file...") | |
data, fs = librosa.load(wav_path) | |
logger.info(f"Original sample rate: {fs} Hz, duration: {len(data)/fs:.2f}s") | |
data = librosa.resample(data, orig_sr=fs, target_sr=16000) | |
logger.info("✓ Audio resampled to 16kHz") | |
# If the data is multi-channel, merge all channels | |
if len(data.shape) > 1: | |
logger.info("Multi-channel audio detected, merging channels...") | |
data = np.mean(data, axis=0) | |
else: | |
data = data | |
# Cut to specified duration (for efficiency) | |
if max_duration_seconds > 0: | |
original_length = len(data) | |
data = data[: 16000 * max_duration_seconds] | |
if len(data) < original_length: | |
logger.info( | |
f"Audio truncated from {original_length/16000:.2f}s to {len(data)/16000:.2f}s" | |
) | |
# Convert to tensor and normalize | |
logger.info("Converting to tensor and normalizing...") | |
# Originally data = torch.Tensor(data).to(torch.bfloat16) | |
data = torch.Tensor(data).float() # Use float32 to match model expectations | |
data = F.layer_norm(data, data.shape) | |
# Create batch | |
logger.info("Creating batch for inference...") | |
batch = Seq2SeqBatch( | |
source_seqs=data.unsqueeze(0).to(device), | |
source_padding_mask=None, | |
target_seqs=torch.tensor([1], dtype=torch.long) | |
.unsqueeze(0) | |
.to(device), # Not used for inference | |
target_padding_mask=None, | |
example=[], | |
) | |
logger.info( | |
f"✓ Audio batch prepared successfully, shape: {batch.source_seqs.shape}" | |
) | |
return batch | |
def run_inference(model, batch, text_decoder, config, device): | |
""" | |
Run model inference on a prepared batch. | |
Args: | |
model: The loaded MMS model | |
batch: Prepared audio batch | |
text_decoder: Tokenizer decoder | |
config: Model configuration | |
device: Device for inference | |
Returns: | |
list: Decoded text outputs | |
""" | |
logger = logging.getLogger(__name__) | |
logger.info("Starting model inference...") | |
logger.info(f"Input batch shape: {batch.source_seqs.shape}, device: {device}") | |
with torch.no_grad(): | |
ctx = ( | |
torch.cuda.amp.autocast() | |
if torch.cuda.is_available() | |
else torch.cpu.amp.autocast(dtype=torch.bfloat16) | |
) | |
logger.info( | |
f"Using autocast context: {'CUDA' if torch.cuda.is_available() else 'CPU'}" | |
) | |
with ctx: | |
logger.info("Running forward pass...") | |
output = model(batch) | |
logger.info("✓ Forward pass completed") | |
logger.info("Generating hypotheses...") | |
hyp_seq, hyp_padding_mask = output.generate_hypotheses( | |
pad_idx=config.llama_config.vocab_info.pad_idx | |
) | |
logger.info(f"✓ Generated {len(hyp_seq)} hypotheses") | |
logger.info("Decoding text...") | |
results = [text_decoder(s) for s in hyp_seq] | |
logger.info(f"✓ Inference completed, results: {results}") | |
return results | |
def transcribe_audio( | |
wav_path: str, | |
ckpt_path: str = None, | |
tokenizer_path: str = None, | |
max_duration_seconds=2, | |
): | |
""" | |
Complete pipeline to transcribe audio using MMS model. | |
Uses the singleton model instance from server.py to avoid reloading. | |
Args: | |
wav_path (str): Path to the WAV file | |
ckpt_path (str): Path to the model checkpoint (not used, kept for compatibility) | |
tokenizer_path (str): Path to the tokenizer (not used, kept for compatibility) | |
max_duration_seconds (int): Maximum duration to process | |
Returns: | |
tuple: (transcription_results, audio_data) where: | |
- transcription_results: list of transcribed text | |
- audio_data: processed audio tensor for reuse in alignment | |
""" | |
logger = logging.getLogger(__name__) | |
logger.info("Starting complete audio transcription pipeline...") | |
try: | |
# Get model from singleton (don't reload) | |
logger.info("Getting pre-loaded MMS model from singleton...") | |
from server import get_device, get_model, get_text_decoder | |
model = get_model() | |
text_decoder = get_text_decoder() | |
device = get_device() | |
if model is None or text_decoder is None or device is None: | |
raise RuntimeError("Model not properly loaded in server singleton") | |
logger.info(f"✓ Using pre-loaded model on device: {device}") | |
# Get config (needed for inference) | |
logger.info("Setting up configuration for inference...") | |
setup_fairseq2() | |
context = get_runtime_context() | |
try: | |
register_wav2vec2_configs(context) | |
register_wav2vec2_asr_configs(context) | |
except Exception as e: | |
logger.warning(f"Config registration warning: {e}") | |
w2v2_ctc_registry = context.get_config_registry(Wav2Vec2AsrConfig) | |
wav2vec_ctc_config = w2v2_ctc_registry.get("7b_bib1143") | |
llama_config = LLaMAConfig( | |
model_dim=4096, | |
max_seq_len=8192, | |
vocab_info=wav2vec_ctc_config.vocab_info, | |
num_layers=12, | |
num_attn_heads=8, | |
num_key_value_heads=8, | |
ffn_inner_dim=4096, | |
rope_theta=10_000.0, | |
dropout_p=0.1, | |
) | |
config = Wav2Vec2LlamaConfig() | |
config.wav2vec_ctc_config = wav2vec_ctc_config | |
config.llama_config = llama_config | |
# Prepare batch | |
logger.info("Preparing audio batch...") | |
batch = prepare_audio_batch(wav_path, device, max_duration_seconds) | |
# Extract the processed audio data for return | |
audio_data = batch.source_seqs.squeeze(0) # Remove batch dimension | |
# Run inference | |
logger.info("Running inference...") | |
results = run_inference(model, batch, text_decoder, config, device) | |
logger.info(f"Transcription pipeline completed successfully: {results}") | |
return results, audio_data | |
except Exception as e: | |
logger.error(f"Error in transcription pipeline: {str(e)}", exc_info=True) | |
raise | |
def normalize_text_with_uroman(text: str) -> str: | |
""" | |
Normalize text using uroman for better forced alignment. | |
Args: | |
text (str): Input text to normalize | |
Returns: | |
str: Normalized text | |
""" | |
logger = logging.getLogger(__name__) | |
try: | |
# Use uroman to normalize the text | |
uroman_instance = uroman.Uroman() | |
normalized = uroman_instance.romanize_string(text) | |
logger.info(f"Text normalized: '{text}' -> '{normalized}'") | |
return normalized | |
except Exception as e: | |
logger.warning(f"Failed to normalize text with uroman: {e}") | |
# Fallback to basic normalization | |
return text.lower().strip() | |
def perform_forced_alignment( | |
audio_data: np.ndarray, | |
transcription_tokens: List[str], | |
model, | |
device, | |
sample_rate: int = 16000, | |
) -> List[Dict]: | |
""" | |
Perform forced alignment using the AudioAlignment class from audio_sentence_alignment.py. | |
Uses pre-processed audio data from prepare_audio_batch. | |
Args: | |
audio_data (np.ndarray): Pre-processed audio data from prepare_audio_batch | |
transcription_tokens (List[str]): List of tokens from transcription | |
model: The loaded MMS model (not used directly, AudioAlignment loads its own) | |
device: Device for computation | |
sample_rate (int): Audio sample rate | |
Returns: | |
List[Dict]: List of segments with timestamps and text | |
""" | |
logger = logging.getLogger(__name__) | |
try: | |
logger.info(f"Starting forced alignment with pre-processed audio data") | |
logger.info(f"Audio shape: {audio_data.shape}, sample_rate: {sample_rate}") | |
logger.info(f"Tokens to align: {transcription_tokens}") | |
from audio_reading_tools import wav_to_bytes | |
# Import AudioAlignment and its config classes | |
from audio_sentence_alignment import ( | |
AlignmentStruct, | |
AudioAlignment, | |
AudioAlignmentConfig, | |
) | |
# Use the pre-processed audio data directly | |
# Convert to the format expected by AudioAlignment.get_one_row_alignments | |
if hasattr(audio_data, "cpu"): | |
# If it's a torch tensor, use it directly | |
audio_tensor = audio_data.float() | |
else: | |
# If it's numpy, convert to tensor | |
audio_tensor = torch.from_numpy(audio_data).float() | |
# Ensure it's 1D (flatten if needed) | |
if len(audio_tensor.shape) > 1: | |
audio_tensor = audio_tensor.flatten() | |
# Convert audio tensor to bytes format expected by AudioAlignment | |
# Use wav_to_bytes to create proper audio bytes | |
audio_arr = wav_to_bytes(audio_tensor, sample_rate=sample_rate, format="wav") | |
logger.info( | |
f"Converted audio to bytes: shape={audio_arr.shape}, dtype={audio_arr.dtype}" | |
) | |
# Preprocess tokens for MMS alignment model using the same approach as TextRomanizer | |
# The MMS alignment model expects romanized tokens in the same format as text_sentences_tokens | |
try: | |
# Join tokens back to text for uroman processing | |
transcription_text = " ".join(transcription_tokens) | |
# Import the required functions from TextRomanizer pipeline | |
from align_utils import get_uroman_tokens | |
from text_normalization import text_normalize | |
# Create uroman instance and process the text the same way as TextRomanizer | |
uroman_instance = uroman.Uroman() | |
# Step 1: Normalize the text first using text_normalize function (same as TextRomanizer) | |
normalized_text = text_normalize(transcription_text.strip(), "en") | |
# Step 2: Get uroman tokens using the same function as TextRomanizer | |
# This creates character-level tokens with spaces between characters | |
uroman_tokens_str = get_uroman_tokens( | |
[normalized_text], uroman_instance, "en" | |
)[0] | |
# Step 3: Split by spaces to get individual character tokens (same as real MMS pipeline) | |
alignment_tokens = uroman_tokens_str.split() | |
logger.info(f"Original tokens: {transcription_tokens}") | |
logger.info(f"Original text: '{transcription_text}'") | |
logger.info(f"Normalized text: '{normalized_text}'") | |
logger.info(f"Uroman tokens string: '{uroman_tokens_str}'") | |
logger.info( | |
f"Alignment tokens (count={len(alignment_tokens)}): {alignment_tokens[:20]}..." | |
) | |
# Additional debugging - check for any unusual characters | |
for i, token in enumerate(alignment_tokens[:10]): # Check first 10 tokens | |
logger.info( | |
f"Token {i}: '{token}' (length={len(token)}, chars={[c for c in token]})" | |
) | |
except Exception as e: | |
logger.warning( | |
f"Failed to preprocess tokens with TextRomanizer approach: {e}" | |
) | |
logger.exception("Full error traceback:") | |
# Fallback: use simple character-level tokenization | |
transcription_text = " ".join(transcription_tokens).lower() | |
# Simple character-level tokenization as fallback | |
alignment_tokens = [] | |
for char in transcription_text: | |
if char == " ": | |
alignment_tokens.append(" ") | |
else: | |
alignment_tokens.append(char) | |
logger.info(f"Using fallback character tokens: {alignment_tokens[:20]}...") | |
logger.info( | |
f"Using {len(alignment_tokens)} alignment tokens for forced alignment" | |
) | |
# Create alignment configuration | |
alignment_struct = AlignmentStruct( | |
segement_tokens="tokens", | |
audio="audio", | |
) | |
config = AudioAlignmentConfig( | |
alignment_column=alignment_struct, | |
sample_rate=sample_rate, | |
device=str(device), | |
use_star=False, # Set to False for standard alignment | |
) | |
# Create AudioAlignment instance | |
logger.info("Creating AudioAlignment instance...") | |
alignment = AudioAlignment(config) | |
# Perform alignment using get_one_row_alignments | |
logger.info("Performing alignment...") | |
logger.info(f"About to call get_one_row_alignments with:") | |
logger.info(f" audio_arr type: {type(audio_arr)}, shape: {audio_arr.shape}") | |
logger.info( | |
f" alignment_tokens type: {type(alignment_tokens)}, length: {len(alignment_tokens)}" | |
) | |
logger.info( | |
f" First 10 tokens: {alignment_tokens[:10] if len(alignment_tokens) >= 10 else alignment_tokens}" | |
) | |
# Check for any problematic characters in tokens | |
for i, token in enumerate(alignment_tokens[:5]): | |
token_chars = [ord(c) for c in str(token)] | |
logger.info(f" Token {i} '{token}' char codes: {token_chars}") | |
# Check if tokens contain any RTL characters that might cause the LTR assertion | |
rtl_chars = [] | |
for i, token in enumerate(alignment_tokens): | |
for char in str(token): | |
# Check for Arabic, Hebrew, and other RTL characters | |
if ( | |
"\u0590" <= char <= "\u08ff" | |
or "\ufb1d" <= char <= "\ufdff" | |
or "\ufe70" <= char <= "\ufeff" | |
): | |
rtl_chars.append((i, token, char, ord(char))) | |
if rtl_chars: | |
logger.warning(f"Found RTL characters in tokens: {rtl_chars[:10]}...") | |
try: | |
audio_segments = alignment.get_one_row_alignments( | |
audio_arr, alignment_tokens | |
) | |
except Exception as alignment_error: | |
logger.error(f"Alignment failed with error: {alignment_error}") | |
logger.error(f"Error type: {type(alignment_error)}") | |
# Try to provide more context about the error | |
if "ltr" in str(alignment_error).lower(): | |
logger.error("LTR assertion error detected. This might be due to:") | |
logger.error("1. RTL characters in the input tokens") | |
logger.error( | |
"2. Incorrect token format - tokens should be individual characters" | |
) | |
logger.error("3. Unicode normalization issues") | |
# Try a simple ASCII-only fallback | |
logger.info("Attempting ASCII-only fallback...") | |
ascii_tokens = [] | |
for token in alignment_tokens: | |
# Keep only ASCII characters | |
ascii_token = "".join(c for c in str(token) if ord(c) < 128) | |
if ascii_token: | |
ascii_tokens.append(ascii_token) | |
logger.info( | |
f"ASCII tokens (count={len(ascii_tokens)}): {ascii_tokens[:20]}..." | |
) | |
try: | |
audio_segments = alignment.get_one_row_alignments( | |
audio_arr, ascii_tokens | |
) | |
alignment_tokens = ascii_tokens # Update for later use | |
logger.info("ASCII fallback successful!") | |
except Exception as ascii_error: | |
logger.error(f"ASCII fallback also failed: {ascii_error}") | |
raise alignment_error | |
else: | |
raise | |
logger.info( | |
f"Alignment completed, got {len(audio_segments)} character segments" | |
) | |
# Convert character-level segments back to word-level segments | |
# Map character segments to original word tokens | |
aligned_segments = [] | |
transcription_text = " ".join(transcription_tokens) | |
word_idx = 0 | |
char_idx = 0 | |
for word in transcription_tokens: | |
if word_idx >= len(transcription_tokens): | |
break | |
# Find the start and end character indices for this word | |
word_start_char = char_idx | |
word_end_char = char_idx + len(word) | |
# Find corresponding segments within this character range | |
word_segments = [] | |
for seg_idx, segment in enumerate(audio_segments): | |
if seg_idx >= word_start_char and seg_idx < word_end_char: | |
word_segments.append(segment) | |
if word_segments: | |
# Get timing from first and last character segments of the word | |
start_time = word_segments[0][alignment_struct.segment_start_sec] | |
last_segment = word_segments[-1] | |
end_time = ( | |
last_segment[alignment_struct.segment_start_sec] | |
+ last_segment[alignment_struct.segment_duration] | |
) | |
duration = end_time - start_time | |
else: | |
# Fallback timing if no segments found | |
if word_idx < len(audio_segments): | |
segment = audio_segments[min(word_idx, len(audio_segments) - 1)] | |
start_time = segment[alignment_struct.segment_start_sec] | |
duration = segment[alignment_struct.segment_duration] | |
end_time = start_time + duration | |
else: | |
# Final fallback | |
duration = 0.5 # Default duration | |
start_time = word_idx * duration | |
end_time = start_time + duration | |
aligned_segments.append( | |
{ | |
"text": word, | |
"start": start_time, | |
"end": end_time, | |
"duration": duration, | |
} | |
) | |
logger.info( | |
f"Word '{word}': {start_time:.3f}s - {end_time:.3f}s ({duration:.3f}s)" | |
) | |
# Update indices | |
char_idx += len(word) | |
if ( | |
char_idx < len(transcription_text) | |
and transcription_text[char_idx] == " " | |
): | |
char_idx += 1 # Skip space | |
word_idx += 1 | |
logger.info(f"Forced alignment completed: {len(aligned_segments)} segments") | |
return aligned_segments | |
except Exception as e: | |
logger.error(f"Error in forced alignment: {str(e)}", exc_info=True) | |
# Fallback: create uniform timestamps based on audio data length | |
logger.info("Using fallback uniform timestamps") | |
try: | |
# Calculate duration from the audio data | |
total_duration = ( | |
len(audio_data) / sample_rate | |
if len(audio_data) > 0 | |
else len(transcription_tokens) * 0.5 | |
) | |
except: | |
total_duration = len(transcription_tokens) * 0.5 # Fallback | |
segment_duration = ( | |
total_duration / len(transcription_tokens) if transcription_tokens else 1.0 | |
) | |
fallback_segments = [] | |
for i, token in enumerate(transcription_tokens): | |
start_time = i * segment_duration | |
end_time = (i + 1) * segment_duration | |
fallback_segments.append( | |
{ | |
"text": token, | |
"start": start_time, | |
"end": end_time, | |
"duration": segment_duration, | |
} | |
) | |
logger.info( | |
f"Using fallback uniform timestamps: {len(fallback_segments)} segments" | |
) | |
return fallback_segments | |
def transcribe_audio_with_alignment( | |
wav_path: str, | |
ckpt_path: str = None, | |
tokenizer_path: str = None, | |
max_duration_seconds: int = 2, | |
) -> Dict: | |
""" | |
Complete pipeline to transcribe audio and perform forced alignment. | |
Uses pre-processed audio data from prepare_audio_batch for both steps. | |
Args: | |
wav_path (str): Path to the WAV file | |
ckpt_path (str): Path to the model checkpoint (not used, kept for compatibility) | |
tokenizer_path (str): Path to the tokenizer (not used, kept for compatibility) | |
max_duration_seconds (int): Maximum duration to process | |
Returns: | |
Dict: Transcription results with alignment information | |
""" | |
logger = logging.getLogger(__name__) | |
try: | |
# Get model and device first | |
from server import get_device, get_model | |
model = get_model() | |
device = get_device() | |
if model is None or device is None: | |
logger.warning( | |
"Model not available for alignment, returning transcription only" | |
) | |
# Get the transcription and processed audio data | |
transcription_results, audio_data = transcribe_audio( | |
wav_path, ckpt_path, tokenizer_path, max_duration_seconds | |
) | |
if not transcription_results: | |
return { | |
"transcription": "", | |
"tokens": [], | |
"aligned_segments": [], | |
"total_duration": 0.0, | |
} | |
transcription_text = ( | |
transcription_results[0] | |
if isinstance(transcription_results, list) | |
else str(transcription_results) | |
) | |
# Tokenize the transcription for alignment | |
tokens = transcription_text.split() if transcription_text else [] | |
# Perform forced alignment using the same preprocessed audio data | |
logger.info("Performing forced alignment with preprocessed audio...") | |
aligned_segments = perform_forced_alignment(audio_data, tokens, model, device) | |
# Calculate total duration | |
total_duration = aligned_segments[-1]["end"] if aligned_segments else 0.0 | |
result = { | |
"transcription": transcription_text, | |
"tokens": tokens, | |
"aligned_segments": aligned_segments, | |
"total_duration": total_duration, | |
"num_segments": len(aligned_segments), | |
} | |
logger.info( | |
f"Transcription with alignment completed: {len(aligned_segments)} segments, {total_duration:.2f}s total" | |
) | |
return result | |
except Exception as e: | |
logger.error(f"Error in transcription with alignment: {str(e)}", exc_info=True) | |
# Return basic transcription without alignment | |
try: | |
transcription_results, _ = transcribe_audio( | |
wav_path, ckpt_path, tokenizer_path, max_duration_seconds | |
) | |
transcription_text = ( | |
transcription_results[0] if transcription_results else "" | |
) | |
tokens = transcription_text.split() if transcription_text else [] | |
return { | |
"transcription": transcription_text, | |
"tokens": tokens, | |
"aligned_segments": [], | |
"total_duration": 0.0, | |
"alignment_error": str(e), | |
} | |
except Exception as e2: | |
logger.error(f"Error in fallback transcription: {str(e2)}", exc_info=True) | |
return { | |
"transcription": "", | |
"tokens": [], | |
"aligned_segments": [], | |
"total_duration": 0.0, | |
"error": str(e2), | |
} | |