Spaces:
Sleeping
Sleeping
# app.py | |
import os | |
import csv | |
import json | |
import time | |
import uuid | |
import gradio as gr | |
from transformers import pipeline | |
import numpy as np | |
import librosa # pip install librosa | |
# Optional but recommended for better jiwer performance | |
# pip install python-Levenshtein | |
try: | |
from jiwer import compute_measures, wer as jiwer_wer, cer as jiwer_cer | |
HAS_JIWER = True | |
except Exception: | |
HAS_JIWER = False | |
# -------- CONFIG: storage paths (Space-friendly) -------- | |
DATA_DIR = "/home/user/data" | |
AUDIO_DIR = os.path.join(DATA_DIR, "audio") | |
LOG_CSV = os.path.join(DATA_DIR, "logs.csv") | |
os.makedirs(DATA_DIR, exist_ok=True) | |
os.makedirs(AUDIO_DIR, exist_ok=True) | |
# --- EDIT THIS: map display names to your HF Hub model IDs --- | |
language_models = { | |
"Akan (Asante Twi)": "FarmerlineML/w2v-bert-2.0_twi_alpha_v1", | |
"Ewe": "FarmerlineML/w2v-bert-2.0_ewe_2", | |
"Kiswahili": "FarmerlineML/w2v-bert-2.0_swahili_alpha", | |
"Luganda": "FarmerlineML/w2v-bert-2.0_luganda", | |
"Brazilian Portuguese": "FarmerlineML/w2v-bert-2.0_brazilian_portugese_alpha", | |
"Fante": "misterkissi/w2v2-lg-xls-r-300m-fante", | |
"Bemba": "DarliAI/kissi-w2v2-lg-xls-r-300m-bemba", | |
"Bambara": "DarliAI/kissi-w2v2-lg-xls-r-300m-bambara", | |
"Dagaare": "DarliAI/kissi-w2v2-lg-xls-r-300m-dagaare", | |
"Kinyarwanda": "DarliAI/kissi-w2v2-lg-xls-r-300m-kinyarwanda", | |
"Fula": "DarliAI/kissi-wav2vec2-fula-fleurs-full", | |
"Oromo": "DarliAI/kissi-w2v-bert-2.0-oromo", | |
"Runynakore": "misterkissi/w2v2-lg-xls-r-300m-runyankore", | |
"Ga": "misterkissi/w2v2-lg-xls-r-300m-ga", | |
"Vai": "misterkissi/whisper-small-vai", | |
"Kasem": "misterkissi/w2v2-lg-xls-r-300m-kasem", | |
"Lingala": "misterkissi/w2v2-lg-xls-r-300m-lingala", | |
"Fongbe": "misterkissi/whisper-small-fongbe", | |
"Amharic": "misterkissi/w2v2-lg-xls-r-1b-amharic", | |
"Xhosa": "misterkissi/w2v2-lg-xls-r-300m-xhosa", | |
"Tsonga": "misterkissi/w2v2-lg-xls-r-300m-tsonga", | |
# "WOLOF": "misterkissi/w2v2-lg-xls-r-1b-wolof", | |
# "HAITIAN CREOLE": "misterkissi/whisper-small-haitian-creole", | |
# "KABYLE": "misterkissi/w2v2-lg-xls-r-1b-kabyle", | |
"Yoruba": "FarmerlineML/w2v-bert-2.0_yoruba_v1", | |
"Luganda": "FarmerlineML/luganda_fkd", | |
"Luo": "FarmerlineML/w2v-bert-2.0_luo_v2", | |
"Somali": "FarmerlineML/w2v-bert-2.0_somali_alpha", | |
"Pidgin": "FarmerlineML/pidgin_nigerian", | |
"Kikuyu": "FarmerlineML/w2v-bert-2.0_kikuyu", | |
"Igbo": "FarmerlineML/w2v-bert-2.0_igbo_v1", | |
"Krio": "FarmerlineML/w2v-bert-2.0_krio_v3" | |
} | |
# -------- Lazy-load pipeline cache (Space-safe) -------- | |
# Small LRU-style cache to avoid loading all models into RAM | |
_PIPELINE_CACHE = {} | |
_CACHE_ORDER = [] # keeps track of usage order | |
_CACHE_MAX_SIZE = 3 # adjust if you have more RAM | |
def _touch_cache(key): | |
if key in _CACHE_ORDER: | |
_CACHE_ORDER.remove(key) | |
_CACHE_ORDER.insert(0, key) | |
def _evict_if_needed(): | |
while len(_PIPELINE_CACHE) > _CACHE_MAX_SIZE: | |
oldest = _CACHE_ORDER.pop() # least-recently used | |
try: | |
del _PIPELINE_CACHE[oldest] | |
except KeyError: | |
pass | |
def get_asr_pipeline(language_display: str): | |
if language_display in _PIPELINE_CACHE: | |
_touch_cache(language_display) | |
return _PIPELINE_CACHE[language_display] | |
model_id = language_models[language_display] | |
pipe = pipeline( | |
task="automatic-speech-recognition", | |
model=model_id, | |
device=-1, # force CPU usage on Spaces CPU | |
chunk_length_s=30 | |
) | |
_PIPELINE_CACHE[language_display] = pipe | |
_touch_cache(language_display) | |
_evict_if_needed() | |
return pipe | |
# -------- Helpers -------- | |
def _model_revision_from_pipeline(pipe) -> str: | |
# Best-effort capture of revision/hash for reproducibility | |
for attr in ("hub_revision", "revision", "_commit_hash"): | |
val = getattr(getattr(pipe, "model", None), attr, None) | |
if val: | |
return str(val) | |
# Fallback to config name_or_path or unknown | |
try: | |
return str(getattr(pipe.model.config, "_name_or_path", "unknown")) | |
except Exception: | |
return "unknown" | |
def _append_log_row(row: dict): | |
field_order = [ | |
"timestamp", "session_id", | |
"language_display", "model_id", "model_revision", | |
"audio_duration_s", "sample_rate", "source", | |
"decode_params", | |
"transcript_hyp", | |
"reference_text", "corrected_text", | |
"latency_ms", "rtf", | |
"wer", "cer", | |
"subs", "ins", "dels", | |
"score_out_of_10", "feedback_text", | |
"tags", | |
"store_audio", "audio_path" | |
] | |
file_exists = os.path.isfile(LOG_CSV) | |
with open(LOG_CSV, "a", newline="", encoding="utf-8") as f: | |
writer = csv.DictWriter(f, fieldnames=field_order) | |
if not file_exists: | |
writer.writeheader() | |
# Ensure all fields exist | |
for k in field_order: | |
row.setdefault(k, "") | |
writer.writerow(row) | |
def _compute_metrics(hyp: str, ref_or_corrected: str): | |
if not HAS_JIWER or not ref_or_corrected or not hyp: | |
return { | |
"wer": None, "cer": None, | |
"subs": None, "ins": None, "dels": None | |
} | |
try: | |
measures = compute_measures(ref_or_corrected, hyp) | |
return { | |
"wer": measures.get("wer"), | |
"cer": jiwer_cer(ref_or_corrected, hyp), | |
"subs": measures.get("substitutions"), | |
"ins": measures.get("insertions"), | |
"dels": measures.get("deletions"), | |
} | |
except Exception: | |
# Be resilient if jiwer errors on edge cases | |
return { | |
"wer": None, "cer": None, | |
"subs": None, "ins": None, "dels": None | |
} | |
# -------- Inference -------- | |
def transcribe(audio_path: str, language: str): | |
""" | |
Load the audio via librosa (supports mp3, wav, flac, m4a, ogg, etc.), | |
convert to mono, then run it through the chosen ASR pipeline. | |
Returns only the transcript (to keep existing behavior), | |
while metadata is stored in a hidden state for the feedback step. | |
""" | |
if not audio_path: | |
return "⚠️ Please upload or record an audio clip.", None | |
# librosa.load returns a 1D np.ndarray (mono) and the sample rate | |
speech, sr = librosa.load(audio_path, sr=None, mono=True) | |
duration_s = float(librosa.get_duration(y=speech, sr=sr)) | |
pipe = get_asr_pipeline(language) | |
decode_params = {"chunk_length_s": getattr(pipe, "chunk_length_s", 30)} | |
t0 = time.time() | |
result = pipe({"sampling_rate": sr, "raw": speech}) | |
latency_ms = int((time.time() - t0) * 1000.0) | |
hyp_text = result.get("text", "") | |
rtf = (latency_ms / 1000.0) / max(duration_s, 1e-9) | |
# Prepare metadata for the feedback logger | |
meta = { | |
"timestamp": time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime()), | |
"session_id": f"anon-{uuid.uuid4()}", | |
"language_display": language, | |
"model_id": language_models.get(language, "unknown"), | |
"model_revision": _model_revision_from_pipeline(pipe), | |
"audio_duration_s": duration_s, | |
"sample_rate": sr, | |
"source": "upload", # gr.Audio combines both; we don't distinguish here | |
"decode_params": json.dumps(decode_params), | |
"transcript_hyp": hyp_text, | |
"latency_ms": latency_ms, | |
"rtf": rtf, | |
# Placeholders to be filled on feedback submit | |
"reference_text": "", | |
"corrected_text": "", | |
"wer": "", | |
"cer": "", | |
"subs": "", | |
"ins": "", | |
"dels": "", | |
"score_out_of_10": "", | |
"feedback_text": "", | |
"tags": "", | |
"store_audio": False, | |
"audio_path": "" | |
} | |
return hyp_text, meta | |
# -------- Feedback submit -------- | |
def submit_feedback(meta, reference_text, corrected_text, score, feedback_text, | |
tags, store_audio, share_publicly, audio_file_path): | |
""" | |
Compute metrics (if possible), optionally store audio (consented), | |
and append a row to CSV. Returns a compact dict for display. | |
""" | |
if not meta: | |
return {"status": "No transcription metadata available. Please transcribe first."} | |
# Choose text to compare against hyp: prefer explicit reference, else corrected | |
ref_for_metrics = reference_text.strip() if reference_text else "" | |
corrected_text = corrected_text.strip() if corrected_text else "" | |
if not ref_for_metrics and corrected_text: | |
ref_for_metrics = corrected_text | |
metrics = _compute_metrics(meta.get("transcript_hyp", ""), ref_for_metrics) | |
# Handle audio storage (optional, consented) | |
stored_path = "" | |
if store_audio and audio_file_path: | |
try: | |
# Copy the original file to AUDIO_DIR with a random name | |
ext = os.path.splitext(audio_file_path)[1] or ".wav" | |
stored_path = os.path.join(AUDIO_DIR, f"{uuid.uuid4()}{ext}") | |
# Simple byte copy | |
with open(audio_file_path, "rb") as src, open(stored_path, "wb") as dst: | |
dst.write(src.read()) | |
except Exception: | |
stored_path = "" | |
# Build log row | |
row = dict(meta) # start from recorded meta | |
row.update({ | |
"reference_text": reference_text or "", | |
"corrected_text": corrected_text or "", | |
"wer": metrics["wer"] if metrics["wer"] is not None else "", | |
"cer": metrics["cer"] if metrics["cer"] is not None else "", | |
"subs": metrics["subs"] if metrics["subs"] is not None else "", | |
"ins": metrics["ins"] if metrics["ins"] is not None else "", | |
"dels": metrics["dels"] if metrics["dels"] is not None else "", | |
"score_out_of_10": score if score is not None else "", | |
"feedback_text": feedback_text or "", | |
"tags": json.dumps({"labels": tags or [], "share_publicly": bool(share_publicly)}), | |
"store_audio": bool(store_audio), | |
"audio_path": stored_path | |
}) | |
try: | |
_append_log_row(row) | |
status = "Feedback saved." | |
except Exception as e: | |
status = f"Failed to save feedback: {e}" | |
# Compact result to show back to user | |
return { | |
"status": status, | |
"wer": row["wer"] if row["wer"] != "" else None, | |
"cer": row["cer"] if row["cer"] != "" else None, | |
"subs": row["subs"] if row["subs"] != "" else None, | |
"ins": row["ins"] if row["ins"] != "" else None, | |
"dels": row["dels"] if row["dels"] != "" else None, | |
"latency_ms": row["latency_ms"], | |
"rtf": row["rtf"], | |
"model_id": row["model_id"], | |
"model_revision": row["model_revision"] | |
} | |
# -------- UI (original preserved; additions appended) -------- | |
with gr.Blocks(title="🌐 Multilingual ASR Demo") as demo: | |
gr.Markdown( | |
""" | |
## 🎙️ Multilingual Speech-to-Text | |
Upload an audio file (MP3, WAV, FLAC, M4A, OGG,…) or record via your microphone. | |
Then choose the language/model and hit **Transcribe**. | |
""" | |
) | |
with gr.Row(): | |
lang = gr.Dropdown( | |
choices=list(language_models.keys()), | |
value=list(language_models.keys())[0], | |
label="Select Language / Model" | |
) | |
with gr.Row(): | |
audio = gr.Audio( | |
sources=["upload", "microphone"], | |
type="filepath", | |
label="Upload or Record Audio" | |
) | |
btn = gr.Button("Transcribe") | |
output = gr.Textbox(label="Transcription") | |
# Hidden state to carry metadata from transcribe -> feedback | |
meta_state = gr.State(value=None) | |
# Keep original behavior: output shows transcript | |
# Also capture meta into the hidden state | |
def _transcribe_and_store(audio_path, language): | |
hyp, meta = transcribe(audio_path, language) | |
# For convenience, populate corrected_text with the hyp by default | |
return hyp, meta, hyp | |
# --- Evaluation & Feedback (appended UI, no style/font changes) --- | |
with gr.Accordion("Evaluation & Feedback", open=False): | |
with gr.Row(): | |
reference_tb = gr.Textbox(label="Reference text (optional)", lines=4, value="") | |
with gr.Row(): | |
corrected_tb = gr.Textbox(label="Corrected transcript (optional)", lines=4, value="") | |
with gr.Row(): | |
score_slider = gr.Slider(minimum=0, maximum=10, step=1, label="Score out of 10", value=7) | |
with gr.Row(): | |
feedback_tb = gr.Textbox(label="Feedback (what went right/wrong?)", lines=3, value="") | |
with gr.Row(): | |
tags_cb = gr.CheckboxGroup( | |
["noisy", "far-field", "code-switching", "numbers-heavy", "named-entities", "read-speech", "spontaneous", "call-center", "voicenote"], | |
label="Slice tags (select any that apply)" | |
) | |
with gr.Row(): | |
store_audio_cb = gr.Checkbox(label="Allow storing my audio for research/eval", value=False) | |
share_cb = gr.Checkbox(label="Allow sharing this example publicly", value=False) | |
submit_btn = gr.Button("Submit Feedback / Compute Metrics") | |
results_json = gr.JSON(label="Metrics & Status") | |
# Wire events | |
btn.click( | |
fn=_transcribe_and_store, | |
inputs=[audio, lang], | |
outputs=[output, meta_state, corrected_tb] | |
) | |
submit_btn.click( | |
fn=submit_feedback, | |
inputs=[ | |
meta_state, | |
reference_tb, | |
corrected_tb, | |
score_slider, | |
feedback_tb, | |
tags_cb, | |
store_audio_cb, | |
share_cb, | |
audio # raw file path from gr.Audio | |
], | |
outputs=results_json | |
) | |
# Use a queue to keep Spaces stable under load | |
if __name__ == "__main__": | |
demo.queue() # enable_queue=True by default in recent Gradio | |
demo.launch() | |