FarmerlineML commited on
Commit
ddf6cde
·
verified ·
1 Parent(s): 7e0bbd2

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +368 -0
app.py ADDED
@@ -0,0 +1,368 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # app.py
2
+
3
+ import os
4
+ import csv
5
+ import json
6
+ import time
7
+ import uuid
8
+ import gradio as gr
9
+ from transformers import pipeline
10
+ import numpy as np
11
+ import librosa # pip install librosa
12
+
13
+ # Optional but recommended for better jiwer performance
14
+ # pip install python-Levenshtein
15
+ try:
16
+ from jiwer import compute_measures, wer as jiwer_wer, cer as jiwer_cer
17
+ HAS_JIWER = True
18
+ except Exception:
19
+ HAS_JIWER = False
20
+
21
+ # -------- CONFIG: storage paths (Space-friendly) --------
22
+ DATA_DIR = "/home/user/data"
23
+ AUDIO_DIR = os.path.join(DATA_DIR, "audio")
24
+ LOG_CSV = os.path.join(DATA_DIR, "logs.csv")
25
+ os.makedirs(DATA_DIR, exist_ok=True)
26
+ os.makedirs(AUDIO_DIR, exist_ok=True)
27
+
28
+ # --- EDIT THIS: map display names to your HF Hub model IDs ---
29
+ language_models = {
30
+ "Akan (Asante Twi)": "FarmerlineML/w2v-bert-2.0_twi_alpha_v1",
31
+ "Ewe": "FarmerlineML/w2v-bert-2.0_ewe_2",
32
+ "Kiswahili": "FarmerlineML/w2v-bert-2.0_swahili_alpha",
33
+ "Luganda": "FarmerlineML/w2v-bert-2.0_luganda",
34
+ "Brazilian Portuguese": "FarmerlineML/w2v-bert-2.0_brazilian_portugese_alpha",
35
+ "Fante": "misterkissi/w2v2-lg-xls-r-300m-fante",
36
+ "Bemba": "DarliAI/kissi-w2v2-lg-xls-r-300m-bemba",
37
+ "Bambara": "DarliAI/kissi-w2v2-lg-xls-r-300m-bambara",
38
+ "Dagaare": "DarliAI/kissi-w2v2-lg-xls-r-300m-dagaare",
39
+ "Kinyarwanda": "DarliAI/kissi-w2v2-lg-xls-r-300m-kinyarwanda",
40
+ "Fula": "DarliAI/kissi-wav2vec2-fula-fleurs-full",
41
+ "Oromo": "DarliAI/kissi-w2v-bert-2.0-oromo",
42
+ "Runynakore": "misterkissi/w2v2-lg-xls-r-300m-runyankore",
43
+ "Ga": "misterkissi/w2v2-lg-xls-r-300m-ga",
44
+ "Vai": "misterkissi/whisper-small-vai",
45
+ "Kasem": "misterkissi/w2v2-lg-xls-r-300m-kasem",
46
+ "Lingala": "misterkissi/w2v2-lg-xls-r-300m-lingala",
47
+ "Fongbe": "misterkissi/whisper-small-fongbe",
48
+ "Amharic": "misterkissi/w2v2-lg-xls-r-1b-amharic",
49
+ "Xhosa": "misterkissi/w2v2-lg-xls-r-300m-xhosa",
50
+ "Tsonga": "misterkissi/w2v2-lg-xls-r-300m-tsonga",
51
+ # "WOLOF": "misterkissi/w2v2-lg-xls-r-1b-wolof",
52
+ # "HAITIAN CREOLE": "misterkissi/whisper-small-haitian-creole",
53
+ # "KABYLE": "misterkissi/w2v2-lg-xls-r-1b-kabyle",
54
+ "Yoruba": "FarmerlineML/w2v-bert-2.0_yoruba_v1",
55
+ "Luganda": "FarmerlineML/luganda_fkd",
56
+ "Luo": "FarmerlineML/w2v-bert-2.0_luo_v2",
57
+ "Somali": "FarmerlineML/w2v-bert-2.0_somali_alpha",
58
+ "Pidgin": "FarmerlineML/pidgin_nigerian",
59
+ "Kikuyu": "FarmerlineML/w2v-bert-2.0_kikuyu",
60
+ "Igbo": "FarmerlineML/w2v-bert-2.0_igbo_v1",
61
+ "Krio": "FarmerlineML/w2v-bert-2.0_krio_v3"
62
+ }
63
+
64
+ # -------- Lazy-load pipeline cache (Space-safe) --------
65
+ # Small LRU-style cache to avoid loading all models into RAM
66
+ _PIPELINE_CACHE = {}
67
+ _CACHE_ORDER = [] # keeps track of usage order
68
+ _CACHE_MAX_SIZE = 3 # adjust if you have more RAM
69
+
70
+ def _touch_cache(key):
71
+ if key in _CACHE_ORDER:
72
+ _CACHE_ORDER.remove(key)
73
+ _CACHE_ORDER.insert(0, key)
74
+
75
+ def _evict_if_needed():
76
+ while len(_PIPELINE_CACHE) > _CACHE_MAX_SIZE:
77
+ oldest = _CACHE_ORDER.pop() # least-recently used
78
+ try:
79
+ del _PIPELINE_CACHE[oldest]
80
+ except KeyError:
81
+ pass
82
+
83
+ def get_asr_pipeline(language_display: str):
84
+ if language_display in _PIPELINE_CACHE:
85
+ _touch_cache(language_display)
86
+ return _PIPELINE_CACHE[language_display]
87
+ model_id = language_models[language_display]
88
+ pipe = pipeline(
89
+ task="automatic-speech-recognition",
90
+ model=model_id,
91
+ device=-1, # force CPU usage on Spaces CPU
92
+ chunk_length_s=30
93
+ )
94
+ _PIPELINE_CACHE[language_display] = pipe
95
+ _touch_cache(language_display)
96
+ _evict_if_needed()
97
+ return pipe
98
+
99
+ # -------- Helpers --------
100
+ def _model_revision_from_pipeline(pipe) -> str:
101
+ # Best-effort capture of revision/hash for reproducibility
102
+ for attr in ("hub_revision", "revision", "_commit_hash"):
103
+ val = getattr(getattr(pipe, "model", None), attr, None)
104
+ if val:
105
+ return str(val)
106
+ # Fallback to config name_or_path or unknown
107
+ try:
108
+ return str(getattr(pipe.model.config, "_name_or_path", "unknown"))
109
+ except Exception:
110
+ return "unknown"
111
+
112
+ def _append_log_row(row: dict):
113
+ field_order = [
114
+ "timestamp", "session_id",
115
+ "language_display", "model_id", "model_revision",
116
+ "audio_duration_s", "sample_rate", "source",
117
+ "decode_params",
118
+ "transcript_hyp",
119
+ "reference_text", "corrected_text",
120
+ "latency_ms", "rtf",
121
+ "wer", "cer",
122
+ "subs", "ins", "dels",
123
+ "score_out_of_10", "feedback_text",
124
+ "tags",
125
+ "store_audio", "audio_path"
126
+ ]
127
+ file_exists = os.path.isfile(LOG_CSV)
128
+ with open(LOG_CSV, "a", newline="", encoding="utf-8") as f:
129
+ writer = csv.DictWriter(f, fieldnames=field_order)
130
+ if not file_exists:
131
+ writer.writeheader()
132
+ # Ensure all fields exist
133
+ for k in field_order:
134
+ row.setdefault(k, "")
135
+ writer.writerow(row)
136
+
137
+ def _compute_metrics(hyp: str, ref_or_corrected: str):
138
+ if not HAS_JIWER or not ref_or_corrected or not hyp:
139
+ return {
140
+ "wer": None, "cer": None,
141
+ "subs": None, "ins": None, "dels": None
142
+ }
143
+ try:
144
+ measures = compute_measures(ref_or_corrected, hyp)
145
+ return {
146
+ "wer": measures.get("wer"),
147
+ "cer": jiwer_cer(ref_or_corrected, hyp),
148
+ "subs": measures.get("substitutions"),
149
+ "ins": measures.get("insertions"),
150
+ "dels": measures.get("deletions"),
151
+ }
152
+ except Exception:
153
+ # Be resilient if jiwer errors on edge cases
154
+ return {
155
+ "wer": None, "cer": None,
156
+ "subs": None, "ins": None, "dels": None
157
+ }
158
+
159
+ # -------- Inference --------
160
+ def transcribe(audio_path: str, language: str):
161
+ """
162
+ Load the audio via librosa (supports mp3, wav, flac, m4a, ogg, etc.),
163
+ convert to mono, then run it through the chosen ASR pipeline.
164
+ Returns only the transcript (to keep existing behavior),
165
+ while metadata is stored in a hidden state for the feedback step.
166
+ """
167
+ if not audio_path:
168
+ return "⚠️ Please upload or record an audio clip.", None
169
+
170
+ # librosa.load returns a 1D np.ndarray (mono) and the sample rate
171
+ speech, sr = librosa.load(audio_path, sr=None, mono=True)
172
+ duration_s = float(librosa.get_duration(y=speech, sr=sr))
173
+
174
+ pipe = get_asr_pipeline(language)
175
+ decode_params = {"chunk_length_s": getattr(pipe, "chunk_length_s", 30)}
176
+
177
+ t0 = time.time()
178
+ result = pipe({"sampling_rate": sr, "raw": speech})
179
+ latency_ms = int((time.time() - t0) * 1000.0)
180
+ hyp_text = result.get("text", "")
181
+
182
+ rtf = (latency_ms / 1000.0) / max(duration_s, 1e-9)
183
+
184
+ # Prepare metadata for the feedback logger
185
+ meta = {
186
+ "timestamp": time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime()),
187
+ "session_id": f"anon-{uuid.uuid4()}",
188
+ "language_display": language,
189
+ "model_id": language_models.get(language, "unknown"),
190
+ "model_revision": _model_revision_from_pipeline(pipe),
191
+ "audio_duration_s": duration_s,
192
+ "sample_rate": sr,
193
+ "source": "upload", # gr.Audio combines both; we don't distinguish here
194
+ "decode_params": json.dumps(decode_params),
195
+ "transcript_hyp": hyp_text,
196
+ "latency_ms": latency_ms,
197
+ "rtf": rtf,
198
+ # Placeholders to be filled on feedback submit
199
+ "reference_text": "",
200
+ "corrected_text": "",
201
+ "wer": "",
202
+ "cer": "",
203
+ "subs": "",
204
+ "ins": "",
205
+ "dels": "",
206
+ "score_out_of_10": "",
207
+ "feedback_text": "",
208
+ "tags": "",
209
+ "store_audio": False,
210
+ "audio_path": ""
211
+ }
212
+
213
+ return hyp_text, meta
214
+
215
+ # -------- Feedback submit --------
216
+ def submit_feedback(meta, reference_text, corrected_text, score, feedback_text,
217
+ tags, store_audio, share_publicly, audio_file_path):
218
+ """
219
+ Compute metrics (if possible), optionally store audio (consented),
220
+ and append a row to CSV. Returns a compact dict for display.
221
+ """
222
+ if not meta:
223
+ return {"status": "No transcription metadata available. Please transcribe first."}
224
+
225
+ # Choose text to compare against hyp: prefer explicit reference, else corrected
226
+ ref_for_metrics = reference_text.strip() if reference_text else ""
227
+ corrected_text = corrected_text.strip() if corrected_text else ""
228
+ if not ref_for_metrics and corrected_text:
229
+ ref_for_metrics = corrected_text
230
+
231
+ metrics = _compute_metrics(meta.get("transcript_hyp", ""), ref_for_metrics)
232
+
233
+ # Handle audio storage (optional, consented)
234
+ stored_path = ""
235
+ if store_audio and audio_file_path:
236
+ try:
237
+ # Copy the original file to AUDIO_DIR with a random name
238
+ ext = os.path.splitext(audio_file_path)[1] or ".wav"
239
+ stored_path = os.path.join(AUDIO_DIR, f"{uuid.uuid4()}{ext}")
240
+ # Simple byte copy
241
+ with open(audio_file_path, "rb") as src, open(stored_path, "wb") as dst:
242
+ dst.write(src.read())
243
+ except Exception:
244
+ stored_path = ""
245
+
246
+ # Build log row
247
+ row = dict(meta) # start from recorded meta
248
+ row.update({
249
+ "reference_text": reference_text or "",
250
+ "corrected_text": corrected_text or "",
251
+ "wer": metrics["wer"] if metrics["wer"] is not None else "",
252
+ "cer": metrics["cer"] if metrics["cer"] is not None else "",
253
+ "subs": metrics["subs"] if metrics["subs"] is not None else "",
254
+ "ins": metrics["ins"] if metrics["ins"] is not None else "",
255
+ "dels": metrics["dels"] if metrics["dels"] is not None else "",
256
+ "score_out_of_10": score if score is not None else "",
257
+ "feedback_text": feedback_text or "",
258
+ "tags": json.dumps({"labels": tags or [], "share_publicly": bool(share_publicly)}),
259
+ "store_audio": bool(store_audio),
260
+ "audio_path": stored_path
261
+ })
262
+
263
+ try:
264
+ _append_log_row(row)
265
+ status = "Feedback saved."
266
+ except Exception as e:
267
+ status = f"Failed to save feedback: {e}"
268
+
269
+ # Compact result to show back to user
270
+ return {
271
+ "status": status,
272
+ "wer": row["wer"] if row["wer"] != "" else None,
273
+ "cer": row["cer"] if row["cer"] != "" else None,
274
+ "subs": row["subs"] if row["subs"] != "" else None,
275
+ "ins": row["ins"] if row["ins"] != "" else None,
276
+ "dels": row["dels"] if row["dels"] != "" else None,
277
+ "latency_ms": row["latency_ms"],
278
+ "rtf": row["rtf"],
279
+ "model_id": row["model_id"],
280
+ "model_revision": row["model_revision"]
281
+ }
282
+
283
+ # -------- UI (original preserved; additions appended) --------
284
+ with gr.Blocks(title="🌐 Multilingual ASR Demo") as demo:
285
+ gr.Markdown(
286
+ """
287
+ ## 🎙️ Multilingual Speech-to-Text
288
+ Upload an audio file (MP3, WAV, FLAC, M4A, OGG,…) or record via your microphone.
289
+ Then choose the language/model and hit **Transcribe**.
290
+ """
291
+ )
292
+
293
+ with gr.Row():
294
+ lang = gr.Dropdown(
295
+ choices=list(language_models.keys()),
296
+ value=list(language_models.keys())[0],
297
+ label="Select Language / Model"
298
+ )
299
+
300
+ with gr.Row():
301
+ audio = gr.Audio(
302
+ sources=["upload", "microphone"],
303
+ type="filepath",
304
+ label="Upload or Record Audio"
305
+ )
306
+
307
+ btn = gr.Button("Transcribe")
308
+ output = gr.Textbox(label="Transcription")
309
+
310
+ # Hidden state to carry metadata from transcribe -> feedback
311
+ meta_state = gr.State(value=None)
312
+
313
+ # Keep original behavior: output shows transcript
314
+ # Also capture meta into the hidden state
315
+ def _transcribe_and_store(audio_path, language):
316
+ hyp, meta = transcribe(audio_path, language)
317
+ # For convenience, populate corrected_text with the hyp by default
318
+ return hyp, meta, hyp
319
+
320
+ # --- Evaluation & Feedback (appended UI, no style/font changes) ---
321
+ with gr.Accordion("Evaluation & Feedback", open=False):
322
+ with gr.Row():
323
+ reference_tb = gr.Textbox(label="Reference text (optional)", lines=4, value="")
324
+ with gr.Row():
325
+ corrected_tb = gr.Textbox(label="Corrected transcript (optional)", lines=4, value="")
326
+ with gr.Row():
327
+ score_slider = gr.Slider(minimum=0, maximum=10, step=1, label="Score out of 10", value=7)
328
+ with gr.Row():
329
+ feedback_tb = gr.Textbox(label="Feedback (what went right/wrong?)", lines=3, value="")
330
+ with gr.Row():
331
+ tags_cb = gr.CheckboxGroup(
332
+ ["noisy", "far-field", "code-switching", "numbers-heavy", "named-entities", "read-speech", "spontaneous", "call-center", "voicenote"],
333
+ label="Slice tags (select any that apply)"
334
+ )
335
+ with gr.Row():
336
+ store_audio_cb = gr.Checkbox(label="Allow storing my audio for research/eval", value=False)
337
+ share_cb = gr.Checkbox(label="Allow sharing this example publicly", value=False)
338
+
339
+ submit_btn = gr.Button("Submit Feedback / Compute Metrics")
340
+ results_json = gr.JSON(label="Metrics & Status")
341
+
342
+ # Wire events
343
+ btn.click(
344
+ fn=_transcribe_and_store,
345
+ inputs=[audio, lang],
346
+ outputs=[output, meta_state, corrected_tb]
347
+ )
348
+
349
+ submit_btn.click(
350
+ fn=submit_feedback,
351
+ inputs=[
352
+ meta_state,
353
+ reference_tb,
354
+ corrected_tb,
355
+ score_slider,
356
+ feedback_tb,
357
+ tags_cb,
358
+ store_audio_cb,
359
+ share_cb,
360
+ audio # raw file path from gr.Audio
361
+ ],
362
+ outputs=results_json
363
+ )
364
+
365
+ # Use a queue to keep Spaces stable under load
366
+ if __name__ == "__main__":
367
+ demo.queue() # enable_queue=True by default in recent Gradio
368
+ demo.launch()