DarliAI_ASR / app.py
FarmerlineML's picture
Update app.py
0a9945e verified
raw
history blame
9.98 kB
# app.py
import os
import time
import datetime as dt
import pandas as pd
import gradio as gr
from transformers import pipeline
import numpy as np
import librosa # pip install librosa
from jiwer import wer # pip install jiwer
LOG_PATH = "feedback_logs.csv"
# --- EDIT THIS: map display names to your HF Hub model IDs ---
language_models = {
"Akan (Asante Twi)": "FarmerlineML/w2v-bert-2.0_twi_alpha_v1",
"Ewe": "FarmerlineML/w2v-bert-2.0_ewe_2",
"Kiswahili": "FarmerlineML/w2v-bert-2.0_swahili_alpha",
"Luganda": "FarmerlineML/w2v-bert-2.0_luganda",
"Brazilian Portuguese": "FarmerlineML/w2v-bert-2.0_brazilian_portugese_alpha",
"Fante": "misterkissi/w2v2-lg-xls-r-300m-fante",
"Bemba": "DarliAI/kissi-w2v2-lg-xls-r-300m-bemba",
"Bambara": "DarliAI/kissi-w2v2-lg-xls-r-300m-bambara",
"Dagaare": "DarliAI/kissi-w2v2-lg-xls-r-300m-dagaare",
"Kinyarwanda": "DarliAI/kissi-w2v2-lg-xls-r-300m-kinyarwanda",
"Fula": "DarliAI/kissi-wav2vec2-fula-fleurs-full",
"Oromo": "DarliAI/kissi-w2v-bert-2.0-oromo",
"Runynakore": "misterkissi/w2v2-lg-xls-r-300m-runyankore",
"Ga": "misterkissi/w2v2-lg-xls-r-300m-ga",
"Vai": "misterkissi/whisper-small-vai",
"Kasem": "misterkissi/w2v2-lg-xls-r-300m-kasem",
"Lingala": "misterkissi/w2v2-lg-xls-r-300m-lingala",
"Fongbe": "misterkissi/whisper-small-fongbe",
"Amharic": "misterkissi/w2v2-lg-xls-r-1b-amharic",
"Xhosa": "misterkissi/w2v2-lg-xls-r-300m-xhosa",
"Tsonga": "misterkissi/w2v2-lg-xls-r-300m-tsonga",
"Yoruba": "FarmerlineML/w2v-bert-2.0_yoruba_v1",
"Luganda (FKD)": "FarmerlineML/luganda_fkd",
"Luo": "FarmerlineML/w2v-bert-2.0_luo_v2",
"Somali": "FarmerlineML/w2v-bert-2.0_somali_alpha",
"Pidgin": "FarmerlineML/pidgin_nigerian",
"Kikuyu": "FarmerlineML/w2v-bert-2.0_kikuyu",
"Igbo": "FarmerlineML/w2v-bert-2.0_igbo_v1",
#"Krio": "FarmerlineML/w2v-bert-2.0_krio_v3"
}
# Pre-load pipelines for each language on CPU (device=-1)
asr_pipelines = {
lang: pipeline(
task="automatic-speech-recognition",
model=model_id,
device=-1, # force CPU usage
chunk_length_s=30
)
for lang, model_id in language_models.items()
}
def transcribe(audio_path: str, language: str):
"""
Load the audio via librosa (supports mp3, wav, flac, m4a, ogg, etc.),
convert to mono, then run it through the chosen ASR pipeline.
Returns (transcript, runtime_seconds, duration_seconds).
"""
if not audio_path:
return "⚠️ Please upload or record an audio clip.", 0.0, 0.0
# librosa.load returns a 1D np.ndarray (mono) and the sample rate
speech, sr = librosa.load(audio_path, sr=None, mono=True)
duration_s = librosa.get_duration(y=speech, sr=sr)
t0 = time.time()
result = asr_pipelines[language]({
"sampling_rate": sr,
"raw": speech
})
runtime_s = time.time() - t0
text = result.get("text", "")
return text, round(runtime_s, 3), round(duration_s, 3)
def compute_wer(pred: str, ref: str) -> float:
if not ref or not pred:
return None
try:
return float(wer(ref, pred))
except Exception:
return None
def ensure_logfile():
if not os.path.exists(LOG_PATH):
pd.DataFrame(columns=[
"timestamp", "language", "model_id", "audio_filename",
"duration_s", "runtime_s", "transcript", "reference",
"wer", "score_10", "feedback",
"domain", "environment", "accent_locale"
]).to_csv(LOG_PATH, index=False)
def save_feedback(language: str,
transcript: str,
reference: str,
score_10: int,
feedback: str,
audio_file: str,
duration_s: float,
runtime_s: float,
domain: str,
environment: str,
accent_locale: str):
ensure_logfile()
model_id = language_models.get(language, "")
audio_filename = os.path.basename(audio_file) if audio_file else ""
w = compute_wer(transcript, reference)
row = {
"timestamp": dt.datetime.utcnow().isoformat(),
"language": language,
"model_id": model_id,
"audio_filename": audio_filename,
"duration_s": duration_s,
"runtime_s": runtime_s,
"transcript": transcript,
"reference": reference,
"wer": w,
"score_10": score_10,
"feedback": feedback,
"domain": domain,
"environment": environment,
"accent_locale": accent_locale
}
try:
df = pd.read_csv(LOG_PATH)
df = pd.concat([df, pd.DataFrame([row])], ignore_index=True)
df.to_csv(LOG_PATH, index=False)
msg = "βœ… Feedback saved."
if w is not None:
msg += f" WER: {w:.3f}"
return msg
except Exception as e:
return f"❌ Could not save feedback: {e}"
def load_metrics():
ensure_logfile()
df = pd.read_csv(LOG_PATH)
if df.empty:
return "No feedback yet.", None, None, df
# Aggregates
# Per-language means:
per_lang = df.groupby("language").agg(
n=("wer", "count"),
mean_WER=("wer", "mean"),
mean_score=("score_10", "mean"),
mean_runtime_s=("runtime_s", "mean"),
mean_duration_s=("duration_s", "mean")
).reset_index().sort_values(by="mean_WER", ascending=True)
# Per-domain (optional):
per_domain = df.groupby("domain").agg(
n=("wer", "count"),
mean_WER=("wer", "mean"),
mean_score=("score_10", "mean")
).reset_index().sort_values(by="mean_WER", ascending=True)
return "πŸ“Š Metrics updated.", per_lang, per_domain, df
with gr.Blocks(title="🌐 Multilingual ASR Demo", theme=gr.themes.Soft()) as demo:
gr.Markdown(
"""
## πŸŽ™οΈ Multilingual Speech-to-Text + Feedback & Benchmarking
Upload an audio file (MP3, WAV, FLAC, M4A, OGG,…) or record via your microphone.
Choose the language/model and hit **Transcribe**.
Optionally provide a **reference transcript** to compute WER, then leave a score & feedback.
"""
)
with gr.Tabs():
with gr.Tab("ASR"):
with gr.Row():
lang = gr.Dropdown(
choices=list(language_models.keys()),
value=list(language_models.keys())[0],
label="Select Language / Model"
)
with gr.Row():
audio = gr.Audio(
sources=["upload", "microphone"],
type="filepath",
label="Upload or Record Audio"
)
btn = gr.Button("Transcribe", variant="primary")
output = gr.Textbox(label="Transcription", lines=6)
runtime = gr.Number(label="Model runtime (s)", precision=3, interactive=False)
duration = gr.Number(label="Audio duration (s)", precision=3, interactive=False)
# Feedback / Benchmark block
gr.Markdown("### πŸ“ Feedback & WER (optional)")
with gr.Row():
reference = gr.Textbox(label="Reference transcript (optional, for WER)", lines=4, placeholder="Paste the ground-truth text here to compute WER")
with gr.Row():
score = gr.Slider(0, 10, step=1, value=8, label="Overall quality score (0–10)")
with gr.Row():
domain = gr.Dropdown(
["General", "Conversational", "News", "Agriculture", "Healthcare", "Education", "Customer support", "Finance", "Legal", "Entertainment", "Other"],
value="General",
label="Domain/topic"
)
environment = gr.Dropdown(
["Quiet", "Office", "Outdoor", "Vehicle", "Crowd/Market", "Radio/Phone", "Other"],
value="Quiet",
label="Recording environment"
)
accent_locale = gr.Textbox(label="Accent / Locale (e.g., Accra, Nairobi, Lagos)", placeholder="Optional")
feedback = gr.Textbox(label="Free-text feedback", lines=4, placeholder="What worked well? What failed? Any specific words or sounds?")
save_btn = gr.Button("Save Feedback", variant="secondary")
save_msg = gr.Markdown("")
# Wire up
btn.click(
fn=transcribe,
inputs=[audio, lang],
outputs=[output, runtime, duration]
)
save_btn.click(
fn=save_feedback,
inputs=[lang, output, reference, score, feedback, audio, duration, runtime, domain, environment, accent_locale],
outputs=save_msg
)
with gr.Tab("Metrics"):
refresh = gr.Button("Refresh metrics", variant="primary")
metrics_msg = gr.Markdown()
per_lang_df = gr.Dataframe(interactive=False, label="Per-language summary (lower WER is better)")
per_domain_df = gr.Dataframe(interactive=False, label="Per-domain summary")
logs_df = gr.Dataframe(interactive=False, label="Raw feedback log")
refresh.click(
fn=load_metrics,
inputs=[],
outputs=[metrics_msg, per_lang_df, per_domain_df, logs_df]
)
if __name__ == "__main__":
demo.launch()