File size: 4,353 Bytes
dd8edb5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
import gradio as gr
from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor
import librosa
import torch
import epitran
import re
import difflib
import editdistance
from jiwer import wer
import json

# Load model once at startup
model_name = "jonatasgrosman/wav2vec2-large-xlsr-53-arabic"
processor = Wav2Vec2Processor.from_pretrained(model_name)
model = Wav2Vec2ForCTC.from_pretrained(model_name)
epi = epitran.Epitran('ara-Arab')

def clean_phonemes(ipa):
    """Remove diacritics and length markers from phonemes"""
    return re.sub(r'[\u064B-\u0652\u02D0]', '', ipa)

def analyze_phonemes(language, reference_text, audio_file):
    # Convert reference text to phonemes
    ref_phonemes = []
    for word in reference_text.split():
        ipa = epi.transliterate(word)
        ipa_clean = clean_phonemes(ipa)
        ref_phonemes.append(list(ipa_clean))
    
    # Process audio file
    audio, sr = librosa.load(audio_file.name, sr=16000)
    input_values = processor(audio, sampling_rate=16000, return_tensors="pt").input_values
    
    # Get transcription
    with torch.no_grad():
        logits = model(input_values).logits
        pred_ids = torch.argmax(logits, dim=-1)
        transcription = processor.batch_decode(pred_ids)[0].strip()
    
    # Convert transcription to phonemes
    obs_phonemes = []
    for word in transcription.split():
        ipa = epi.transliterate(word)
        ipa_clean = clean_phonemes(ipa)
        obs_phonemes.append(list(ipa_clean))
    
    # Prepare results in JSON format
    results = {
        "reference_text": reference_text,
        "transcription": transcription,
        "word_alignment": [],
        "metrics": {}
    }
    
    # Calculate metrics
    total_phoneme_errors = 0
    total_phoneme_length = 0
    correct_words = 0
    total_word_length = len(ref_phonemes)
    
    # Word-by-word alignment
    for i, (ref, obs) in enumerate(zip(ref_phonemes, obs_phonemes)):
        ref_str = ''.join(ref)
        obs_str = ''.join(obs)
        edits = editdistance.eval(ref, obs)
        acc = round((1 - edits / max(1, len(ref))) * 100, 2)
        
        # Get error details
        matcher = difflib.SequenceMatcher(None, ref, obs)
        ops = matcher.get_opcodes()
        error_details = []
        for tag, i1, i2, j1, j2 in ops:
            ref_seg = ''.join(ref[i1:i2]) or '-'
            obs_seg = ''.join(obs[j1:j2]) or '-'
            if tag != 'equal':
                error_details.append({
                    "type": tag.upper(),
                    "reference": ref_seg,
                    "observed": obs_seg
                })
        
        results["word_alignment"].append({
            "word_index": i,
            "reference_phonemes": ref_str,
            "observed_phonemes": obs_str,
            "edit_distance": edits,
            "accuracy": acc,
            "is_correct": edits == 0,
            "errors": error_details
        })
        
        total_phoneme_errors += edits
        total_phoneme_length += len(ref)
        correct_words += 1 if edits == 0 else 0
    
    # Calculate metrics
    phoneme_acc = round((1 - total_phoneme_errors / max(1, total_phoneme_length)) * 100, 2)
    phoneme_er = round((total_phoneme_errors / max(1, total_phoneme_length)) * 100, 2)
    word_acc = round((correct_words / max(1, total_word_length)) * 100, 2)
    word_er = round(((total_word_length - correct_words) / max(1, total_word_length)) * 100, 2)
    text_wer = round(wer(reference_text, transcription) * 100, 2)
    
    results["metrics"] = {
        "word_accuracy": word_acc,
        "word_error_rate": word_er,
        "phoneme_accuracy": phoneme_acc,
        "phoneme_error_rate": phoneme_er,
        "asr_word_error_rate": text_wer
    }
    
    return json.dumps(results, indent=2, ensure_ascii=False)

# Create Gradio interface
demo = gr.Interface(
    fn=analyze_phonemes,
    inputs=[
        gr.Dropdown(["Arabic"], label="Language", value="Arabic"),
        gr.Textbox(label="Reference Text", value="ููŽุจูุฃูŽูŠู‘ู ุขู„ูŽุงุกู ุฑูŽุจู‘ููƒูู…ูŽุง ุชููƒูŽุฐู‘ูุจูŽุงู†ู"),
        gr.File(label="Upload Audio File", type="file")
    ],
    outputs=gr.JSON(label="Phoneme Alignment Results"),
    title="Arabic Phoneme Alignment Analysis",
    description="Compare audio pronunciation with reference text at phoneme level"
)

demo.launch()