Spaces:
Sleeping
Sleeping
# app.py | |
import os | |
import time | |
import datetime as dt | |
import pandas as pd | |
import gradio as gr | |
from transformers import pipeline | |
import numpy as np | |
import librosa # pip install librosa | |
from jiwer import wer # pip install jiwer | |
LOG_PATH = "feedback_logs.csv" | |
# --- EDIT THIS: map display names to your HF Hub model IDs --- | |
language_models = { | |
"Akan (Asante Twi)": "FarmerlineML/w2v-bert-2.0_twi_alpha_v1", | |
"Ewe": "FarmerlineML/w2v-bert-2.0_ewe_2", | |
"Kiswahili": "FarmerlineML/w2v-bert-2.0_swahili_alpha", | |
"Luganda": "FarmerlineML/w2v-bert-2.0_luganda", | |
"Brazilian Portuguese": "FarmerlineML/w2v-bert-2.0_brazilian_portugese_alpha", | |
"Fante": "misterkissi/w2v2-lg-xls-r-300m-fante", | |
"Bemba": "DarliAI/kissi-w2v2-lg-xls-r-300m-bemba", | |
"Bambara": "DarliAI/kissi-w2v2-lg-xls-r-300m-bambara", | |
"Dagaare": "DarliAI/kissi-w2v2-lg-xls-r-300m-dagaare", | |
"Kinyarwanda": "DarliAI/kissi-w2v2-lg-xls-r-300m-kinyarwanda", | |
"Fula": "DarliAI/kissi-wav2vec2-fula-fleurs-full", | |
"Oromo": "DarliAI/kissi-w2v-bert-2.0-oromo", | |
"Runynakore": "misterkissi/w2v2-lg-xls-r-300m-runyankore", | |
"Ga": "misterkissi/w2v2-lg-xls-r-300m-ga", | |
"Vai": "misterkissi/whisper-small-vai", | |
"Kasem": "misterkissi/w2v2-lg-xls-r-300m-kasem", | |
"Lingala": "misterkissi/w2v2-lg-xls-r-300m-lingala", | |
"Fongbe": "misterkissi/whisper-small-fongbe", | |
"Amharic": "misterkissi/w2v2-lg-xls-r-1b-amharic", | |
"Xhosa": "misterkissi/w2v2-lg-xls-r-300m-xhosa", | |
"Tsonga": "misterkissi/w2v2-lg-xls-r-300m-tsonga", | |
"Yoruba": "FarmerlineML/w2v-bert-2.0_yoruba_v1", | |
"Luganda (FKD)": "FarmerlineML/luganda_fkd", | |
"Luo": "FarmerlineML/w2v-bert-2.0_luo_v2", | |
"Somali": "FarmerlineML/w2v-bert-2.0_somali_alpha", | |
"Pidgin": "FarmerlineML/pidgin_nigerian", | |
"Kikuyu": "FarmerlineML/w2v-bert-2.0_kikuyu", | |
"Igbo": "FarmerlineML/w2v-bert-2.0_igbo_v1", | |
"Krio": "FarmerlineML/w2v-bert-2.0_krio_v3" | |
} | |
# Pre-load pipelines for each language on CPU (device=-1) | |
asr_pipelines = { | |
lang: pipeline( | |
task="automatic-speech-recognition", | |
model=model_id, | |
device=-1, # force CPU usage | |
chunk_length_s=30 | |
) | |
for lang, model_id in language_models.items() | |
} | |
def transcribe(audio_path: str, language: str): | |
""" | |
Load the audio via librosa (supports mp3, wav, flac, m4a, ogg, etc.), | |
convert to mono, then run it through the chosen ASR pipeline. | |
Returns (transcript, runtime_seconds, duration_seconds). | |
""" | |
if not audio_path: | |
return "β οΈ Please upload or record an audio clip.", 0.0, 0.0 | |
# librosa.load returns a 1D np.ndarray (mono) and the sample rate | |
speech, sr = librosa.load(audio_path, sr=None, mono=True) | |
duration_s = librosa.get_duration(y=speech, sr=sr) | |
t0 = time.time() | |
result = asr_pipelines[language]({ | |
"sampling_rate": sr, | |
"raw": speech | |
}) | |
runtime_s = time.time() - t0 | |
text = result.get("text", "") | |
return text, round(runtime_s, 3), round(duration_s, 3) | |
def compute_wer(pred: str, ref: str) -> float: | |
if not ref or not pred: | |
return None | |
try: | |
return float(wer(ref, pred)) | |
except Exception: | |
return None | |
def ensure_logfile(): | |
if not os.path.exists(LOG_PATH): | |
pd.DataFrame(columns=[ | |
"timestamp", "language", "model_id", "audio_filename", | |
"duration_s", "runtime_s", "transcript", "reference", | |
"wer", "score_10", "feedback", | |
"domain", "environment", "accent_locale" | |
]).to_csv(LOG_PATH, index=False) | |
def save_feedback(language: str, | |
transcript: str, | |
reference: str, | |
score_10: int, | |
feedback: str, | |
audio_file: str, | |
duration_s: float, | |
runtime_s: float, | |
domain: str, | |
environment: str, | |
accent_locale: str): | |
ensure_logfile() | |
model_id = language_models.get(language, "") | |
audio_filename = os.path.basename(audio_file) if audio_file else "" | |
w = compute_wer(transcript, reference) | |
row = { | |
"timestamp": dt.datetime.utcnow().isoformat(), | |
"language": language, | |
"model_id": model_id, | |
"audio_filename": audio_filename, | |
"duration_s": duration_s, | |
"runtime_s": runtime_s, | |
"transcript": transcript, | |
"reference": reference, | |
"wer": w, | |
"score_10": score_10, | |
"feedback": feedback, | |
"domain": domain, | |
"environment": environment, | |
"accent_locale": accent_locale | |
} | |
try: | |
df = pd.read_csv(LOG_PATH) | |
df = pd.concat([df, pd.DataFrame([row])], ignore_index=True) | |
df.to_csv(LOG_PATH, index=False) | |
msg = "β Feedback saved." | |
if w is not None: | |
msg += f" WER: {w:.3f}" | |
return msg | |
except Exception as e: | |
return f"β Could not save feedback: {e}" | |
def load_metrics(): | |
ensure_logfile() | |
df = pd.read_csv(LOG_PATH) | |
if df.empty: | |
return "No feedback yet.", None, None, df | |
# Aggregates | |
# Per-language means: | |
per_lang = df.groupby("language").agg( | |
n=("wer", "count"), | |
mean_WER=("wer", "mean"), | |
mean_score=("score_10", "mean"), | |
mean_runtime_s=("runtime_s", "mean"), | |
mean_duration_s=("duration_s", "mean") | |
).reset_index().sort_values(by="mean_WER", ascending=True) | |
# Per-domain (optional): | |
per_domain = df.groupby("domain").agg( | |
n=("wer", "count"), | |
mean_WER=("wer", "mean"), | |
mean_score=("score_10", "mean") | |
).reset_index().sort_values(by="mean_WER", ascending=True) | |
return "π Metrics updated.", per_lang, per_domain, df | |
with gr.Blocks(title="π Multilingual ASR Demo", theme=gr.themes.Soft()) as demo: | |
gr.Markdown( | |
""" | |
## ποΈ Multilingual Speech-to-Text + Feedback & Benchmarking | |
Upload an audio file (MP3, WAV, FLAC, M4A, OGG,β¦) or record via your microphone. | |
Choose the language/model and hit **Transcribe**. | |
Optionally provide a **reference transcript** to compute WER, then leave a score & feedback. | |
""" | |
) | |
with gr.Tabs(): | |
with gr.Tab("ASR"): | |
with gr.Row(): | |
lang = gr.Dropdown( | |
choices=list(language_models.keys()), | |
value=list(language_models.keys())[0], | |
label="Select Language / Model" | |
) | |
with gr.Row(): | |
audio = gr.Audio( | |
sources=["upload", "microphone"], | |
type="filepath", | |
label="Upload or Record Audio" | |
) | |
btn = gr.Button("Transcribe", variant="primary") | |
output = gr.Textbox(label="Transcription", lines=6) | |
runtime = gr.Number(label="Model runtime (s)", precision=3, interactive=False) | |
duration = gr.Number(label="Audio duration (s)", precision=3, interactive=False) | |
# Feedback / Benchmark block | |
gr.Markdown("### π Feedback & WER (optional)") | |
with gr.Row(): | |
reference = gr.Textbox(label="Reference transcript (optional, for WER)", lines=4, placeholder="Paste the ground-truth text here to compute WER") | |
with gr.Row(): | |
score = gr.Slider(0, 10, step=1, value=8, label="Overall quality score (0β10)") | |
with gr.Row(): | |
domain = gr.Dropdown( | |
["General", "Conversational", "News", "Agriculture", "Healthcare", "Education", "Customer support", "Finance", "Legal", "Entertainment", "Other"], | |
value="General", | |
label="Domain/topic" | |
) | |
environment = gr.Dropdown( | |
["Quiet", "Office", "Outdoor", "Vehicle", "Crowd/Market", "Radio/Phone", "Other"], | |
value="Quiet", | |
label="Recording environment" | |
) | |
accent_locale = gr.Textbox(label="Accent / Locale (e.g., Accra, Nairobi, Lagos)", placeholder="Optional") | |
feedback = gr.Textbox(label="Free-text feedback", lines=4, placeholder="What worked well? What failed? Any specific words or sounds?") | |
save_btn = gr.Button("Save Feedback", variant="secondary") | |
save_msg = gr.Markdown("") | |
# Wire up | |
btn.click( | |
fn=transcribe, | |
inputs=[audio, lang], | |
outputs=[output, runtime, duration] | |
) | |
save_btn.click( | |
fn=save_feedback, | |
inputs=[lang, output, reference, score, feedback, audio, duration, runtime, domain, environment, accent_locale], | |
outputs=save_msg | |
) | |
with gr.Tab("Metrics"): | |
refresh = gr.Button("Refresh metrics", variant="primary") | |
metrics_msg = gr.Markdown() | |
per_lang_df = gr.Dataframe(interactive=False, label="Per-language summary (lower WER is better)") | |
per_domain_df = gr.Dataframe(interactive=False, label="Per-domain summary") | |
logs_df = gr.Dataframe(interactive=False, label="Raw feedback log") | |
refresh.click( | |
fn=load_metrics, | |
inputs=[], | |
outputs=[metrics_msg, per_lang_df, per_domain_df, logs_df] | |
) | |
if __name__ == "__main__": | |
demo.launch() | |