Spaces:
Runtime error
Runtime error
""" | |
TODO: | |
+ [x] Load Configuration | |
+ [ ] Multi ASR Engine | |
+ [ ] Batch / Real Time support | |
""" | |
import numpy as np | |
from pathlib import Path | |
import jiwer | |
import pdb | |
import torch.nn as nn | |
import torch | |
import torchaudio | |
import gradio as gr | |
from logging import PlaceHolder | |
from transformers import Wav2Vec2Processor, Wav2Vec2ForCTC | |
import yaml | |
from transformers import pipeline | |
import librosa | |
import librosa.display | |
import matplotlib.pyplot as plt | |
# local import | |
import sys | |
sys.path.append("src") | |
# Load automos | |
config_yaml = "config/samples.yaml" | |
with open(config_yaml, "r") as f: | |
# pdb.set_trace() | |
try: | |
config = yaml.safe_load(f) | |
except FileExistsError: | |
print("Config file Loading Error") | |
exit() | |
# Auto load examples | |
refs = np.loadtxt(config["ref_txt"], delimiter="\n", dtype="str") | |
refs_ids = [x.split()[0] for x in refs] | |
refs_txt = [" ".join(x.split()[1:]) for x in refs] | |
ref_wavs = [str(x) for x in sorted(Path(config["ref_wavs"]).glob("**/*.wav"))] | |
with open("src/description.html", "r", encoding="utf-8") as f: | |
description = f.read() | |
# description | |
reference_id = gr.Textbox( | |
value="ID", placeholder="Utter ID", label="Reference_ID" | |
) | |
reference_textbox = gr.Textbox( | |
value="Input reference here", | |
placeholder="Input reference here", | |
label="Reference", | |
) | |
reference_PPM = gr.Textbox( | |
placeholder="Pneumatic Voice's PPM", label="Ref PPM" | |
) | |
examples = [ | |
[x, y] for x, y in zip(ref_wavs, refs_txt) | |
] | |
# ASR part | |
p = pipeline("automatic-speech-recognition") | |
# WER part | |
transformation = jiwer.Compose( | |
[ | |
jiwer.RemovePunctuation(), | |
jiwer.ToLowerCase(), | |
jiwer.RemoveWhiteSpace(replace_by_space=True), | |
jiwer.RemoveMultipleSpaces(), | |
jiwer.ReduceToListOfListOfWords(word_delimiter=" "), | |
] | |
) | |
class ChangeSampleRate(nn.Module): | |
def __init__(self, input_rate: int, output_rate: int): | |
super().__init__() | |
self.output_rate = output_rate | |
self.input_rate = input_rate | |
def forward(self, wav: torch.tensor) -> torch.tensor: | |
# Only accepts 1-channel waveform input | |
wav = wav.view(wav.size(0), -1) | |
new_length = wav.size(-1) * self.output_rate // self.input_rate | |
indices = torch.arange(new_length) * ( | |
self.input_rate / self.output_rate | |
) | |
round_down = wav[:, indices.long()] | |
round_up = wav[:, (indices.long() + 1).clamp(max=wav.size(-1) - 1)] | |
output = round_down * (1.0 - indices.fmod(1.0)).unsqueeze(0) + ( | |
round_up * indices.fmod(1.0).unsqueeze(0) | |
) | |
return output | |
# Flagging setup | |
def calc_wer(audio_path, ref): | |
wav, sr = torchaudio.load(audio_path) | |
if wav.shape[0] != 1: | |
wav = wav[0, :] | |
print(wav.shape) | |
osr = 16000 | |
batch = wav.unsqueeze(0).repeat(10, 1, 1) | |
csr = ChangeSampleRate(sr, osr) | |
out_wavs = csr(wav) | |
# ASR | |
trans = jiwer.ToLowerCase()(p(audio_path)["text"]) | |
# WER | |
wer = jiwer.wer( | |
ref, | |
trans, | |
truth_transform=transformation, | |
hypothesis_transform=transformation, | |
) | |
return [trans, wer] | |
iface = gr.Interface( | |
fn=calc_wer, | |
inputs=[ | |
gr.Audio( | |
source="microphone", | |
type="filepath", | |
label="Audio_to_evaluate", | |
), | |
reference_textbox | |
], | |
outputs=[ | |
gr.Textbox(placeholder="Hypothesis", label="Hypothesis"), | |
gr.Textbox(placeholder="Word Error Rate", label="WER"), | |
], | |
title="Laronix Automatic Speech Recognition", | |
description=description, | |
examples=examples, | |
css=".body {background-color: green}", | |
) | |
print("Launch examples") | |
iface.launch( | |
share=False, | |
) |