Spaces:
Runtime error
Runtime error
import math | |
import os | |
import re | |
import tempfile | |
from dataclasses import dataclass | |
import torch | |
from torchaudio.models import wav2vec2_model | |
# iso codes with specialized rules in uroman | |
special_isos_uroman = "ara, bel, bul, deu, ell, eng, fas, grc, ell, eng, heb, kaz, kir, lav, lit, mkd, mkd2, oss, pnt, pus, rus, srp, srp2, tur, uig, ukr, yid".split( | |
"," | |
) | |
special_isos_uroman = [i.strip() for i in special_isos_uroman] | |
def normalize_uroman(text): | |
text = text.lower() | |
text = re.sub("([^a-z' ])", " ", text) | |
text = re.sub(" +", " ", text) | |
return text.strip() | |
def get_uroman_tokens(norm_transcripts, uroman, iso=None): | |
tf = tempfile.NamedTemporaryFile() | |
tf2 = tempfile.NamedTemporaryFile() | |
with open(tf.name, "w") as f: | |
for t in norm_transcripts: | |
f.write(t + "\n") | |
uroman.romanize_file( | |
input_filename=tf.name, | |
output_filename=tf2.name, | |
lcode=iso if iso in special_isos_uroman else None, | |
) | |
outtexts = [] | |
with open(tf2.name) as f: | |
for line in f: | |
line = " ".join(line.strip()) | |
line = re.sub(r"\s+", " ", line).strip() | |
outtexts.append(line) | |
assert len(outtexts) == len(norm_transcripts) | |
uromans = [] | |
for ot in outtexts: | |
uromans.append(normalize_uroman(ot)) | |
return uromans | |
class Segment: | |
label: str | |
start: int | |
end: int | |
def __repr__(self): | |
return f"{self.label}: [{self.start:5d}, {self.end:5d})" | |
def length(self): | |
return self.end - self.start | |
def merge_repeats(path, idx_to_token_map): | |
i1, i2 = 0, 0 | |
segments = [] | |
while i1 < len(path): | |
while i2 < len(path) and path[i1] == path[i2]: | |
i2 += 1 | |
segments.append(Segment(idx_to_token_map[path[i1]], i1, i2 - 1)) | |
i1 = i2 | |
return segments | |
def time_to_frame(time): | |
stride_msec = 20 | |
frames_per_sec = 1000 / stride_msec | |
return int(time * frames_per_sec) | |
def load_model_dict(): | |
# Use models directory from environment variable | |
models_dir = os.environ.get("MODELS_DIR", "/home/user/app/models") | |
model_path_name = os.path.join(models_dir, "ctc_alignment_mling_uroman_model.pt") | |
print("Loading model from models directory...") | |
if not os.path.exists(model_path_name): | |
raise FileNotFoundError(f"Model file not found at {model_path_name}") | |
print(f"Model found at: {model_path_name}") | |
state_dict = torch.load(model_path_name, map_location="cpu") | |
model = wav2vec2_model( | |
extractor_mode="layer_norm", | |
extractor_conv_layer_config=[ | |
(512, 10, 5), | |
(512, 3, 2), | |
(512, 3, 2), | |
(512, 3, 2), | |
(512, 3, 2), | |
(512, 2, 2), | |
(512, 2, 2), | |
], | |
extractor_conv_bias=True, | |
encoder_embed_dim=1024, | |
encoder_projection_dropout=0.0, | |
encoder_pos_conv_kernel=128, | |
encoder_pos_conv_groups=16, | |
encoder_num_layers=24, | |
encoder_num_heads=16, | |
encoder_attention_dropout=0.0, | |
encoder_ff_interm_features=4096, | |
encoder_ff_interm_dropout=0.1, | |
encoder_dropout=0.0, | |
encoder_layer_norm_first=True, | |
encoder_layer_drop=0.1, | |
aux_num_out=31, | |
) | |
model.load_state_dict(state_dict) | |
model.eval() | |
# Use models directory from environment variable | |
models_dir = os.environ.get("MODELS_DIR", "/home/user/app/models") | |
dict_path_name = os.path.join( | |
models_dir, "ctc_alignment_mling_uroman_model_dict.txt" | |
) | |
if not os.path.exists(dict_path_name): | |
raise FileNotFoundError(f"Dictionary file not found at {dict_path_name}") | |
print(f"Dictionary found at: {dict_path_name}") | |
dictionary = {} | |
with open(dict_path_name) as f: | |
dictionary = {l.strip(): i for i, l in enumerate(f.readlines())} | |
return model, dictionary | |
def get_spans(tokens, segments): | |
ltr_idx = 0 | |
tokens_idx = 0 | |
intervals = [] | |
start, end = (0, 0) | |
sil = "<blank>" | |
for seg_idx, seg in enumerate(segments): | |
if tokens_idx == len(tokens): | |
assert seg_idx == len(segments) - 1 | |
assert seg.label == "<blank>" | |
continue | |
cur_token = tokens[tokens_idx].split(" ") | |
ltr = cur_token[ltr_idx] | |
if seg.label == "<blank>": | |
continue | |
assert seg.label == ltr | |
if (ltr_idx) == 0: | |
start = seg_idx | |
if ltr_idx == len(cur_token) - 1: | |
ltr_idx = 0 | |
tokens_idx += 1 | |
intervals.append((start, seg_idx)) | |
while tokens_idx < len(tokens) and len(tokens[tokens_idx]) == 0: | |
intervals.append((seg_idx, seg_idx)) | |
tokens_idx += 1 | |
else: | |
ltr_idx += 1 | |
spans = [] | |
for idx, (start, end) in enumerate(intervals): | |
span = segments[start : end + 1] | |
if start > 0: | |
prev_seg = segments[start - 1] | |
if prev_seg.label == sil: | |
pad_start = ( | |
prev_seg.start | |
if (idx == 0) | |
else int((prev_seg.start + prev_seg.end) / 2) | |
) | |
span = [Segment(sil, pad_start, span[0].start)] + span | |
if end + 1 < len(segments): | |
next_seg = segments[end + 1] | |
if next_seg.label == sil: | |
pad_end = ( | |
next_seg.end | |
if (idx == len(intervals) - 1) | |
else math.floor((next_seg.start + next_seg.end) / 2) | |
) | |
span = span + [Segment(sil, span[-1].end, pad_end)] | |
spans.append(span) | |
return spans | |