File size: 10,856 Bytes
25dc731
05566a8
 
 
 
25dc731
bb3f271
 
 
3940c6b
89f17cd
25dc731
 
3a8ecbf
bb3f271
 
 
 
 
 
25dc731
05566a8
 
 
bb3f271
05566a8
25dc731
bb3f271
 
 
 
 
 
 
25dc731
 
 
 
bb3f271
05566a8
 
 
25dc731
fa0e345
05566a8
bb3f271
05566a8
bb3f271
05566a8
fa0e345
 
 
05566a8
 
bb3f271
 
fa0e345
 
 
05566a8
 
bb3f271
 
fa0e345
25dc731
 
3940c6b
bb3f271
 
 
3940c6b
 
 
 
 
 
 
 
b7a8eef
bb3f271
25dc731
bb3f271
 
 
 
 
 
 
 
 
 
 
 
 
 
3940c6b
bb3f271
3940c6b
 
 
bb3f271
 
 
 
3a8ecbf
bb3f271
 
 
05566a8
bb3f271
 
05566a8
bb3f271
 
05566a8
bb3f271
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3940c6b
bb3f271
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
05566a8
bb3f271
 
 
 
05566a8
bb3f271
 
 
05566a8
bb3f271
fa0e345
bb3f271
 
ca298ac
bb3f271
 
 
 
 
 
 
 
 
 
 
3a8ecbf
bb3f271
 
 
 
 
 
 
 
 
 
 
 
 
 
c2ad75f
bb3f271
c2ad75f
bb3f271
 
 
25dc731
 
bb3f271
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
import gradio as gr
import random
import difflib
import re
import jiwer
import torch
from transformers import WhisperForConditionalGeneration, WhisperProcessor
from indic_transliteration import sanscript
from indic_transliteration.sanscript import transliterate
import spaces

# ---------------- CONFIG ---------------- #
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

# Updated model configurations for each language
MODEL_CONFIGS = {
    "English": "openai/whisper-large-v2",
    "Tamil": "vasista22/whisper-tamil-large-v2", 
    "Malayalam": "thennal/whisper-medium-ml"
}

LANG_CODES = {
    "English": "en",
    "Tamil": "ta", 
    "Malayalam": "ml"
}

LANG_PRIMERS = {
    "English": ("The transcript should be in English only.",
                "Write only in English without translation. Example: This is an English sentence."),
    "Tamil": ("நகல் தமிழ் எழுத்துக்களில் மட்டும் இருக்க வேண்டும்.",
              "தமிழ் எழுத்துக்களில் மட்டும் எழுதவும், மொழிபெயர்ப்பு செய்யக்கூடாது. உதாரணம்: இது ஒரு தமிழ் வாக்கியம்."),
    "Malayalam": ("ട്രാൻസ്ഖ്രിപ്റ്റ് മലയാള ലിപിയിൽ ആയിരിക്കണം.",
                  "മലയാള ലിപിയിൽ മാത്രം എഴുതുക, വിവർത്തനം ചെയ്യരുത്. ഉദാഹരണം: ഇതൊരു മലയാള വാക്യമാണ്. എനിക്ക് മലയാളം അറിയാം.")
}

SCRIPT_PATTERNS = {
    "Tamil": re.compile(r"[஀-௿]"),
    "Malayalam": re.compile(r"[ഀ-ൿ]"), 
    "English": re.compile(r"[A-Za-z]")
}

SENTENCE_BANK = {
    "English": [
        "The sun sets over the horizon.",
        "Learning languages is fun.",
        "I like to drink coffee in the morning.",
        "Technology helps us communicate better.",
        "Reading books expands our knowledge."
    ],
    "Tamil": [
        "இன்று நல்ல வானிலை உள்ளது.",
        "நான் தமிழ் கற்றுக்கொண்டு இருக்கிறேன்.",
        "எனக்கு புத்தகம் படிக்க விருப்பம்.",
        "தமிழ் மொழி மிகவும் அழகானது.",
        "நான் தினமும் பள்ளிக்கு செல்கிறேன்."
    ],
    "Malayalam": [
        "എനിക്ക് മലയാളം വളരെ ഇഷ്ടമാണ്.",
        "ഇന്ന് മഴപെയ്യുന്നു.",
        "ഞാൻ പുസ്തകം വായിക്കുന്നു.",
        "കേരളം എന്റെ സ്വന്തം നാടാണ്.",
        "ഞാൻ മലയാളം പഠിക്കുന്നു."
    ]
}

# Global variables for models (will be loaded lazily)
whisper_models = {}
whisper_processors = {}

def load_model(language_choice):
    """Load model for specific language if not already loaded"""
    if language_choice not in whisper_models:
        model_id = MODEL_CONFIGS[language_choice]
        print(f"Loading {language_choice} model: {model_id}")
        whisper_models[language_choice] = WhisperForConditionalGeneration.from_pretrained(model_id).to(DEVICE)
        whisper_processors[language_choice] = WhisperProcessor.from_pretrained(model_id)
        print(f"{language_choice} model loaded successfully!")

# ---------------- HELPERS ---------------- #
def get_random_sentence(language_choice):
    return random.choice(SENTENCE_BANK[language_choice])

def is_script(text, lang_name):
    pattern = SCRIPT_PATTERNS.get(lang_name)
    return bool(pattern.search(text)) if pattern else True

def transliterate_to_hk(text, lang_choice):
    mapping = {
        "Tamil": sanscript.TAMIL,
        "Malayalam": sanscript.MALAYALAM,
        "English": None
    }
    return transliterate(text, mapping[lang_choice], sanscript.HK) if mapping[lang_choice] else text

@spaces.GPU
def transcribe_once(audio_path, language_choice, initial_prompt, beam_size, temperature, condition_on_previous_text):
    # Load model if not already loaded
    load_model(language_choice)
    
    # Get the appropriate model and processor for the language
    model = whisper_models[language_choice]
    processor = whisper_processors[language_choice]
    lang_code = LANG_CODES[language_choice]
    
    # Load and process audio
    import librosa
    audio, sr = librosa.load(audio_path, sr=16000)
    
    # Process audio with the specific model's processor
    input_features = processor(audio, sampling_rate=16000, return_tensors="pt").input_features.to(DEVICE)
    
    # Generate forced decoder ids for the language
    forced_decoder_ids = processor.get_decoder_prompt_ids(language=lang_code, task="transcribe")
    
    # Generate transcription
    with torch.no_grad():
        predicted_ids = model.generate(
            input_features,
            forced_decoder_ids=forced_decoder_ids,
            max_length=448,
            num_beams=beam_size,
            temperature=temperature if temperature > 0 else None,
            do_sample=temperature > 0,
        )
    
    # Decode the transcription
    transcription = processor.batch_decode(predicted_ids, skip_special_tokens=True)[0]
    return transcription.strip()

def highlight_differences(ref, hyp):
    ref_words, hyp_words = ref.strip().split(), hyp.strip().split()
    sm = difflib.SequenceMatcher(None, ref_words, hyp_words)
    out_html = []
    for tag, i1, i2, j1, j2 in sm.get_opcodes():
        if tag == 'equal':
            out_html.extend([f"<span style='color:green'>{w}</span>" for w in ref_words[i1:i2]])
        elif tag == 'replace':
            out_html.extend([f"<span style='color:red'>{w}</span>" for w in ref_words[i1:i2]])
            out_html.extend([f"<span style='color:orange'>{w}</span>" for w in hyp_words[j1:j2]])
        elif tag == 'delete':
            out_html.extend([f"<span style='color:red;text-decoration:line-through'>{w}</span>" for w in ref_words[i1:i2]])
        elif tag == 'insert':
            out_html.extend([f"<span style='color:orange'>{w}</span>" for w in hyp_words[j1:j2]])
    return " ".join(out_html)

def char_level_highlight(ref, hyp):
    sm = difflib.SequenceMatcher(None, list(ref), list(hyp))
    out = []
    for tag, i1, i2, j1, j2 in sm.get_opcodes():
        if tag == 'equal':
            out.extend([f"<span style='color:green'>{c}</span>" for c in ref[i1:i2]])
        elif tag in ('replace', 'delete'):
            out.extend([f"<span style='color:red;text-decoration:underline'>{c}</span>" for c in ref[i1:i2]])
        elif tag == 'insert':
            out.extend([f"<span style='color:orange'>{c}</span>" for c in hyp[j1:j2]])
    return "".join(out)

# ---------------- MAIN ---------------- #
@spaces.GPU
def compare_pronunciation(audio, language_choice, intended_sentence,
                          pass1_beam, pass1_temp, pass1_condition):
    if audio is None or not intended_sentence.strip():
        return ("No audio or intended sentence.", "", "", "", "", "", "", "")

    primer_weak, primer_strong = LANG_PRIMERS[language_choice]

    # Pass 1: raw transcription with user-configured decoding parameters
    actual_text = transcribe_once(audio, language_choice, primer_weak,
                                  pass1_beam, pass1_temp, pass1_condition)

    # Pass 2: strict transcription biased by intended sentence (fixed decoding params)
    strict_prompt = f"{primer_strong}\nTarget: {intended_sentence}"
    corrected_text = transcribe_once(audio, language_choice, strict_prompt,
                                     beam_size=5, temperature=0.0, condition_on_previous_text=False)

    # Compute WER and CER
    wer_val = jiwer.wer(intended_sentence, actual_text)
    cer_val = jiwer.cer(intended_sentence, actual_text)

    # Transliteration of Pass 1 output
    hk_translit = transliterate_to_hk(actual_text, language_choice) if is_script(actual_text, language_choice) else f"[Script mismatch: expected {language_choice}]"

    # Highlight word-level and character-level differences
    diff_html = highlight_differences(intended_sentence, actual_text)
    char_html = char_level_highlight(intended_sentence, actual_text)

    return (actual_text, corrected_text, hk_translit, f"{wer_val:.2f}", f"{cer_val:.2f}",
            diff_html, char_html, intended_sentence)

# ---------------- UI ---------------- #
with gr.Blocks(title="Pronunciation Comparator") as demo:
    gr.Markdown("## 🎙 Pronunciation Comparator - English, Tamil & Malayalam")
    gr.Markdown("Practice pronunciation with specialized Whisper models for each language!")

    with gr.Row():
        lang_choice = gr.Dropdown(choices=list(LANG_CODES.keys()), value="Malayalam", label="Language")
        gen_btn = gr.Button("🎲 Generate Sentence")

    intended_display = gr.Textbox(label="Generated Sentence (Read aloud)", interactive=False)

    with gr.Row():
        audio_input = gr.Audio(sources=["microphone", "upload"], type="filepath", label="Record your pronunciation")
        
    with gr.Column():
        gr.Markdown("### Transcription Parameters")
        pass1_beam = gr.Slider(1, 10, value=8, step=1, label="Pass 1 Beam Size")
        pass1_temp = gr.Slider(0.0, 1.0, value=0.4, step=0.1, label="Pass 1 Temperature")
        pass1_condition = gr.Checkbox(value=True, label="Pass 1: Condition on previous text")

    submit_btn = gr.Button("🔍 Analyze Pronunciation", variant="primary")

    with gr.Row():
        pass1_out = gr.Textbox(label="Pass 1: What You Actually Said")
        pass2_out = gr.Textbox(label="Pass 2: Target-Biased Output")
        
    with gr.Row():
        hk_out = gr.Textbox(label="Harvard-Kyoto Transliteration (Pass 1)")
        wer_out = gr.Textbox(label="Word Error Rate")
        cer_out = gr.Textbox(label="Character Error Rate")

    gr.Markdown("### Visual Feedback")
    diff_html_box = gr.HTML(label="Word Differences Highlighted")
    char_html_box = gr.HTML(label="Character-Level Highlighting (mispronounced = red underline)")

    # Event handlers
    gen_btn.click(fn=get_random_sentence, inputs=[lang_choice], outputs=[intended_display])

    submit_btn.click(
        fn=compare_pronunciation,
        inputs=[audio_input, lang_choice, intended_display, pass1_beam, pass1_temp, pass1_condition],
        outputs=[
            pass1_out, pass2_out, hk_out, wer_out, cer_out,
            diff_html_box, char_html_box, intended_display
        ]
    )

if __name__ == "__main__":
    demo.launch()