mms-transcription / server /align_utils.py
EC2 Default User
Added basic frontend, dockerfile
0f60365
import math
import os
import re
import tempfile
from dataclasses import dataclass
import torch
from torchaudio.models import wav2vec2_model
# iso codes with specialized rules in uroman
special_isos_uroman = "ara, bel, bul, deu, ell, eng, fas, grc, ell, eng, heb, kaz, kir, lav, lit, mkd, mkd2, oss, pnt, pus, rus, srp, srp2, tur, uig, ukr, yid".split(
","
)
special_isos_uroman = [i.strip() for i in special_isos_uroman]
def normalize_uroman(text):
text = text.lower()
text = re.sub("([^a-z' ])", " ", text)
text = re.sub(" +", " ", text)
return text.strip()
def get_uroman_tokens(norm_transcripts, uroman, iso=None):
tf = tempfile.NamedTemporaryFile()
tf2 = tempfile.NamedTemporaryFile()
with open(tf.name, "w") as f:
for t in norm_transcripts:
f.write(t + "\n")
uroman.romanize_file(
input_filename=tf.name,
output_filename=tf2.name,
lcode=iso if iso in special_isos_uroman else None,
)
outtexts = []
with open(tf2.name) as f:
for line in f:
line = " ".join(line.strip())
line = re.sub(r"\s+", " ", line).strip()
outtexts.append(line)
assert len(outtexts) == len(norm_transcripts)
uromans = []
for ot in outtexts:
uromans.append(normalize_uroman(ot))
return uromans
@dataclass
class Segment:
label: str
start: int
end: int
def __repr__(self):
return f"{self.label}: [{self.start:5d}, {self.end:5d})"
@property
def length(self):
return self.end - self.start
def merge_repeats(path, idx_to_token_map):
i1, i2 = 0, 0
segments = []
while i1 < len(path):
while i2 < len(path) and path[i1] == path[i2]:
i2 += 1
segments.append(Segment(idx_to_token_map[path[i1]], i1, i2 - 1))
i1 = i2
return segments
def time_to_frame(time):
stride_msec = 20
frames_per_sec = 1000 / stride_msec
return int(time * frames_per_sec)
def load_model_dict():
# Use models directory from environment variable
models_dir = os.environ.get("MODELS_DIR", "/home/user/app/models")
model_path_name = os.path.join(models_dir, "ctc_alignment_mling_uroman_model.pt")
print("Loading model from models directory...")
if not os.path.exists(model_path_name):
raise FileNotFoundError(f"Model file not found at {model_path_name}")
print(f"Model found at: {model_path_name}")
state_dict = torch.load(model_path_name, map_location="cpu")
model = wav2vec2_model(
extractor_mode="layer_norm",
extractor_conv_layer_config=[
(512, 10, 5),
(512, 3, 2),
(512, 3, 2),
(512, 3, 2),
(512, 3, 2),
(512, 2, 2),
(512, 2, 2),
],
extractor_conv_bias=True,
encoder_embed_dim=1024,
encoder_projection_dropout=0.0,
encoder_pos_conv_kernel=128,
encoder_pos_conv_groups=16,
encoder_num_layers=24,
encoder_num_heads=16,
encoder_attention_dropout=0.0,
encoder_ff_interm_features=4096,
encoder_ff_interm_dropout=0.1,
encoder_dropout=0.0,
encoder_layer_norm_first=True,
encoder_layer_drop=0.1,
aux_num_out=31,
)
model.load_state_dict(state_dict)
model.eval()
# Use models directory from environment variable
models_dir = os.environ.get("MODELS_DIR", "/home/user/app/models")
dict_path_name = os.path.join(
models_dir, "ctc_alignment_mling_uroman_model_dict.txt"
)
if not os.path.exists(dict_path_name):
raise FileNotFoundError(f"Dictionary file not found at {dict_path_name}")
print(f"Dictionary found at: {dict_path_name}")
dictionary = {}
with open(dict_path_name) as f:
dictionary = {l.strip(): i for i, l in enumerate(f.readlines())}
return model, dictionary
def get_spans(tokens, segments):
ltr_idx = 0
tokens_idx = 0
intervals = []
start, end = (0, 0)
sil = "<blank>"
for seg_idx, seg in enumerate(segments):
if tokens_idx == len(tokens):
assert seg_idx == len(segments) - 1
assert seg.label == "<blank>"
continue
cur_token = tokens[tokens_idx].split(" ")
ltr = cur_token[ltr_idx]
if seg.label == "<blank>":
continue
assert seg.label == ltr
if (ltr_idx) == 0:
start = seg_idx
if ltr_idx == len(cur_token) - 1:
ltr_idx = 0
tokens_idx += 1
intervals.append((start, seg_idx))
while tokens_idx < len(tokens) and len(tokens[tokens_idx]) == 0:
intervals.append((seg_idx, seg_idx))
tokens_idx += 1
else:
ltr_idx += 1
spans = []
for idx, (start, end) in enumerate(intervals):
span = segments[start : end + 1]
if start > 0:
prev_seg = segments[start - 1]
if prev_seg.label == sil:
pad_start = (
prev_seg.start
if (idx == 0)
else int((prev_seg.start + prev_seg.end) / 2)
)
span = [Segment(sil, pad_start, span[0].start)] + span
if end + 1 < len(segments):
next_seg = segments[end + 1]
if next_seg.label == sil:
pad_end = (
next_seg.end
if (idx == len(intervals) - 1)
else math.floor((next_seg.start + next_seg.end) / 2)
)
span = span + [Segment(sil, span[-1].end, pad_end)]
spans.append(span)
return spans