File size: 4,353 Bytes
dd8edb5 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 |
import gradio as gr
from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor
import librosa
import torch
import epitran
import re
import difflib
import editdistance
from jiwer import wer
import json
# Load model once at startup
model_name = "jonatasgrosman/wav2vec2-large-xlsr-53-arabic"
processor = Wav2Vec2Processor.from_pretrained(model_name)
model = Wav2Vec2ForCTC.from_pretrained(model_name)
epi = epitran.Epitran('ara-Arab')
def clean_phonemes(ipa):
"""Remove diacritics and length markers from phonemes"""
return re.sub(r'[\u064B-\u0652\u02D0]', '', ipa)
def analyze_phonemes(language, reference_text, audio_file):
# Convert reference text to phonemes
ref_phonemes = []
for word in reference_text.split():
ipa = epi.transliterate(word)
ipa_clean = clean_phonemes(ipa)
ref_phonemes.append(list(ipa_clean))
# Process audio file
audio, sr = librosa.load(audio_file.name, sr=16000)
input_values = processor(audio, sampling_rate=16000, return_tensors="pt").input_values
# Get transcription
with torch.no_grad():
logits = model(input_values).logits
pred_ids = torch.argmax(logits, dim=-1)
transcription = processor.batch_decode(pred_ids)[0].strip()
# Convert transcription to phonemes
obs_phonemes = []
for word in transcription.split():
ipa = epi.transliterate(word)
ipa_clean = clean_phonemes(ipa)
obs_phonemes.append(list(ipa_clean))
# Prepare results in JSON format
results = {
"reference_text": reference_text,
"transcription": transcription,
"word_alignment": [],
"metrics": {}
}
# Calculate metrics
total_phoneme_errors = 0
total_phoneme_length = 0
correct_words = 0
total_word_length = len(ref_phonemes)
# Word-by-word alignment
for i, (ref, obs) in enumerate(zip(ref_phonemes, obs_phonemes)):
ref_str = ''.join(ref)
obs_str = ''.join(obs)
edits = editdistance.eval(ref, obs)
acc = round((1 - edits / max(1, len(ref))) * 100, 2)
# Get error details
matcher = difflib.SequenceMatcher(None, ref, obs)
ops = matcher.get_opcodes()
error_details = []
for tag, i1, i2, j1, j2 in ops:
ref_seg = ''.join(ref[i1:i2]) or '-'
obs_seg = ''.join(obs[j1:j2]) or '-'
if tag != 'equal':
error_details.append({
"type": tag.upper(),
"reference": ref_seg,
"observed": obs_seg
})
results["word_alignment"].append({
"word_index": i,
"reference_phonemes": ref_str,
"observed_phonemes": obs_str,
"edit_distance": edits,
"accuracy": acc,
"is_correct": edits == 0,
"errors": error_details
})
total_phoneme_errors += edits
total_phoneme_length += len(ref)
correct_words += 1 if edits == 0 else 0
# Calculate metrics
phoneme_acc = round((1 - total_phoneme_errors / max(1, total_phoneme_length)) * 100, 2)
phoneme_er = round((total_phoneme_errors / max(1, total_phoneme_length)) * 100, 2)
word_acc = round((correct_words / max(1, total_word_length)) * 100, 2)
word_er = round(((total_word_length - correct_words) / max(1, total_word_length)) * 100, 2)
text_wer = round(wer(reference_text, transcription) * 100, 2)
results["metrics"] = {
"word_accuracy": word_acc,
"word_error_rate": word_er,
"phoneme_accuracy": phoneme_acc,
"phoneme_error_rate": phoneme_er,
"asr_word_error_rate": text_wer
}
return json.dumps(results, indent=2, ensure_ascii=False)
# Create Gradio interface
demo = gr.Interface(
fn=analyze_phonemes,
inputs=[
gr.Dropdown(["Arabic"], label="Language", value="Arabic"),
gr.Textbox(label="Reference Text", value="ููุจูุฃูููู ุขููุงุกู ุฑูุจููููู
ูุง ุชูููุฐููุจูุงูู"),
gr.File(label="Upload Audio File", type="file")
],
outputs=gr.JSON(label="Phoneme Alignment Results"),
title="Arabic Phoneme Alignment Analysis",
description="Compare audio pronunciation with reference text at phoneme level"
)
demo.launch() |