Evaluation-2 / app.py
FarmerlineML's picture
Create app.py
ddf6cde verified
raw
history blame
14.3 kB
# app.py
import os
import csv
import json
import time
import uuid
import gradio as gr
from transformers import pipeline
import numpy as np
import librosa # pip install librosa
# Optional but recommended for better jiwer performance
# pip install python-Levenshtein
try:
from jiwer import compute_measures, wer as jiwer_wer, cer as jiwer_cer
HAS_JIWER = True
except Exception:
HAS_JIWER = False
# -------- CONFIG: storage paths (Space-friendly) --------
DATA_DIR = "/home/user/data"
AUDIO_DIR = os.path.join(DATA_DIR, "audio")
LOG_CSV = os.path.join(DATA_DIR, "logs.csv")
os.makedirs(DATA_DIR, exist_ok=True)
os.makedirs(AUDIO_DIR, exist_ok=True)
# --- EDIT THIS: map display names to your HF Hub model IDs ---
language_models = {
"Akan (Asante Twi)": "FarmerlineML/w2v-bert-2.0_twi_alpha_v1",
"Ewe": "FarmerlineML/w2v-bert-2.0_ewe_2",
"Kiswahili": "FarmerlineML/w2v-bert-2.0_swahili_alpha",
"Luganda": "FarmerlineML/w2v-bert-2.0_luganda",
"Brazilian Portuguese": "FarmerlineML/w2v-bert-2.0_brazilian_portugese_alpha",
"Fante": "misterkissi/w2v2-lg-xls-r-300m-fante",
"Bemba": "DarliAI/kissi-w2v2-lg-xls-r-300m-bemba",
"Bambara": "DarliAI/kissi-w2v2-lg-xls-r-300m-bambara",
"Dagaare": "DarliAI/kissi-w2v2-lg-xls-r-300m-dagaare",
"Kinyarwanda": "DarliAI/kissi-w2v2-lg-xls-r-300m-kinyarwanda",
"Fula": "DarliAI/kissi-wav2vec2-fula-fleurs-full",
"Oromo": "DarliAI/kissi-w2v-bert-2.0-oromo",
"Runynakore": "misterkissi/w2v2-lg-xls-r-300m-runyankore",
"Ga": "misterkissi/w2v2-lg-xls-r-300m-ga",
"Vai": "misterkissi/whisper-small-vai",
"Kasem": "misterkissi/w2v2-lg-xls-r-300m-kasem",
"Lingala": "misterkissi/w2v2-lg-xls-r-300m-lingala",
"Fongbe": "misterkissi/whisper-small-fongbe",
"Amharic": "misterkissi/w2v2-lg-xls-r-1b-amharic",
"Xhosa": "misterkissi/w2v2-lg-xls-r-300m-xhosa",
"Tsonga": "misterkissi/w2v2-lg-xls-r-300m-tsonga",
# "WOLOF": "misterkissi/w2v2-lg-xls-r-1b-wolof",
# "HAITIAN CREOLE": "misterkissi/whisper-small-haitian-creole",
# "KABYLE": "misterkissi/w2v2-lg-xls-r-1b-kabyle",
"Yoruba": "FarmerlineML/w2v-bert-2.0_yoruba_v1",
"Luganda": "FarmerlineML/luganda_fkd",
"Luo": "FarmerlineML/w2v-bert-2.0_luo_v2",
"Somali": "FarmerlineML/w2v-bert-2.0_somali_alpha",
"Pidgin": "FarmerlineML/pidgin_nigerian",
"Kikuyu": "FarmerlineML/w2v-bert-2.0_kikuyu",
"Igbo": "FarmerlineML/w2v-bert-2.0_igbo_v1",
"Krio": "FarmerlineML/w2v-bert-2.0_krio_v3"
}
# -------- Lazy-load pipeline cache (Space-safe) --------
# Small LRU-style cache to avoid loading all models into RAM
_PIPELINE_CACHE = {}
_CACHE_ORDER = [] # keeps track of usage order
_CACHE_MAX_SIZE = 3 # adjust if you have more RAM
def _touch_cache(key):
if key in _CACHE_ORDER:
_CACHE_ORDER.remove(key)
_CACHE_ORDER.insert(0, key)
def _evict_if_needed():
while len(_PIPELINE_CACHE) > _CACHE_MAX_SIZE:
oldest = _CACHE_ORDER.pop() # least-recently used
try:
del _PIPELINE_CACHE[oldest]
except KeyError:
pass
def get_asr_pipeline(language_display: str):
if language_display in _PIPELINE_CACHE:
_touch_cache(language_display)
return _PIPELINE_CACHE[language_display]
model_id = language_models[language_display]
pipe = pipeline(
task="automatic-speech-recognition",
model=model_id,
device=-1, # force CPU usage on Spaces CPU
chunk_length_s=30
)
_PIPELINE_CACHE[language_display] = pipe
_touch_cache(language_display)
_evict_if_needed()
return pipe
# -------- Helpers --------
def _model_revision_from_pipeline(pipe) -> str:
# Best-effort capture of revision/hash for reproducibility
for attr in ("hub_revision", "revision", "_commit_hash"):
val = getattr(getattr(pipe, "model", None), attr, None)
if val:
return str(val)
# Fallback to config name_or_path or unknown
try:
return str(getattr(pipe.model.config, "_name_or_path", "unknown"))
except Exception:
return "unknown"
def _append_log_row(row: dict):
field_order = [
"timestamp", "session_id",
"language_display", "model_id", "model_revision",
"audio_duration_s", "sample_rate", "source",
"decode_params",
"transcript_hyp",
"reference_text", "corrected_text",
"latency_ms", "rtf",
"wer", "cer",
"subs", "ins", "dels",
"score_out_of_10", "feedback_text",
"tags",
"store_audio", "audio_path"
]
file_exists = os.path.isfile(LOG_CSV)
with open(LOG_CSV, "a", newline="", encoding="utf-8") as f:
writer = csv.DictWriter(f, fieldnames=field_order)
if not file_exists:
writer.writeheader()
# Ensure all fields exist
for k in field_order:
row.setdefault(k, "")
writer.writerow(row)
def _compute_metrics(hyp: str, ref_or_corrected: str):
if not HAS_JIWER or not ref_or_corrected or not hyp:
return {
"wer": None, "cer": None,
"subs": None, "ins": None, "dels": None
}
try:
measures = compute_measures(ref_or_corrected, hyp)
return {
"wer": measures.get("wer"),
"cer": jiwer_cer(ref_or_corrected, hyp),
"subs": measures.get("substitutions"),
"ins": measures.get("insertions"),
"dels": measures.get("deletions"),
}
except Exception:
# Be resilient if jiwer errors on edge cases
return {
"wer": None, "cer": None,
"subs": None, "ins": None, "dels": None
}
# -------- Inference --------
def transcribe(audio_path: str, language: str):
"""
Load the audio via librosa (supports mp3, wav, flac, m4a, ogg, etc.),
convert to mono, then run it through the chosen ASR pipeline.
Returns only the transcript (to keep existing behavior),
while metadata is stored in a hidden state for the feedback step.
"""
if not audio_path:
return "⚠️ Please upload or record an audio clip.", None
# librosa.load returns a 1D np.ndarray (mono) and the sample rate
speech, sr = librosa.load(audio_path, sr=None, mono=True)
duration_s = float(librosa.get_duration(y=speech, sr=sr))
pipe = get_asr_pipeline(language)
decode_params = {"chunk_length_s": getattr(pipe, "chunk_length_s", 30)}
t0 = time.time()
result = pipe({"sampling_rate": sr, "raw": speech})
latency_ms = int((time.time() - t0) * 1000.0)
hyp_text = result.get("text", "")
rtf = (latency_ms / 1000.0) / max(duration_s, 1e-9)
# Prepare metadata for the feedback logger
meta = {
"timestamp": time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime()),
"session_id": f"anon-{uuid.uuid4()}",
"language_display": language,
"model_id": language_models.get(language, "unknown"),
"model_revision": _model_revision_from_pipeline(pipe),
"audio_duration_s": duration_s,
"sample_rate": sr,
"source": "upload", # gr.Audio combines both; we don't distinguish here
"decode_params": json.dumps(decode_params),
"transcript_hyp": hyp_text,
"latency_ms": latency_ms,
"rtf": rtf,
# Placeholders to be filled on feedback submit
"reference_text": "",
"corrected_text": "",
"wer": "",
"cer": "",
"subs": "",
"ins": "",
"dels": "",
"score_out_of_10": "",
"feedback_text": "",
"tags": "",
"store_audio": False,
"audio_path": ""
}
return hyp_text, meta
# -------- Feedback submit --------
def submit_feedback(meta, reference_text, corrected_text, score, feedback_text,
tags, store_audio, share_publicly, audio_file_path):
"""
Compute metrics (if possible), optionally store audio (consented),
and append a row to CSV. Returns a compact dict for display.
"""
if not meta:
return {"status": "No transcription metadata available. Please transcribe first."}
# Choose text to compare against hyp: prefer explicit reference, else corrected
ref_for_metrics = reference_text.strip() if reference_text else ""
corrected_text = corrected_text.strip() if corrected_text else ""
if not ref_for_metrics and corrected_text:
ref_for_metrics = corrected_text
metrics = _compute_metrics(meta.get("transcript_hyp", ""), ref_for_metrics)
# Handle audio storage (optional, consented)
stored_path = ""
if store_audio and audio_file_path:
try:
# Copy the original file to AUDIO_DIR with a random name
ext = os.path.splitext(audio_file_path)[1] or ".wav"
stored_path = os.path.join(AUDIO_DIR, f"{uuid.uuid4()}{ext}")
# Simple byte copy
with open(audio_file_path, "rb") as src, open(stored_path, "wb") as dst:
dst.write(src.read())
except Exception:
stored_path = ""
# Build log row
row = dict(meta) # start from recorded meta
row.update({
"reference_text": reference_text or "",
"corrected_text": corrected_text or "",
"wer": metrics["wer"] if metrics["wer"] is not None else "",
"cer": metrics["cer"] if metrics["cer"] is not None else "",
"subs": metrics["subs"] if metrics["subs"] is not None else "",
"ins": metrics["ins"] if metrics["ins"] is not None else "",
"dels": metrics["dels"] if metrics["dels"] is not None else "",
"score_out_of_10": score if score is not None else "",
"feedback_text": feedback_text or "",
"tags": json.dumps({"labels": tags or [], "share_publicly": bool(share_publicly)}),
"store_audio": bool(store_audio),
"audio_path": stored_path
})
try:
_append_log_row(row)
status = "Feedback saved."
except Exception as e:
status = f"Failed to save feedback: {e}"
# Compact result to show back to user
return {
"status": status,
"wer": row["wer"] if row["wer"] != "" else None,
"cer": row["cer"] if row["cer"] != "" else None,
"subs": row["subs"] if row["subs"] != "" else None,
"ins": row["ins"] if row["ins"] != "" else None,
"dels": row["dels"] if row["dels"] != "" else None,
"latency_ms": row["latency_ms"],
"rtf": row["rtf"],
"model_id": row["model_id"],
"model_revision": row["model_revision"]
}
# -------- UI (original preserved; additions appended) --------
with gr.Blocks(title="🌐 Multilingual ASR Demo") as demo:
gr.Markdown(
"""
## 🎙️ Multilingual Speech-to-Text
Upload an audio file (MP3, WAV, FLAC, M4A, OGG,…) or record via your microphone.
Then choose the language/model and hit **Transcribe**.
"""
)
with gr.Row():
lang = gr.Dropdown(
choices=list(language_models.keys()),
value=list(language_models.keys())[0],
label="Select Language / Model"
)
with gr.Row():
audio = gr.Audio(
sources=["upload", "microphone"],
type="filepath",
label="Upload or Record Audio"
)
btn = gr.Button("Transcribe")
output = gr.Textbox(label="Transcription")
# Hidden state to carry metadata from transcribe -> feedback
meta_state = gr.State(value=None)
# Keep original behavior: output shows transcript
# Also capture meta into the hidden state
def _transcribe_and_store(audio_path, language):
hyp, meta = transcribe(audio_path, language)
# For convenience, populate corrected_text with the hyp by default
return hyp, meta, hyp
# --- Evaluation & Feedback (appended UI, no style/font changes) ---
with gr.Accordion("Evaluation & Feedback", open=False):
with gr.Row():
reference_tb = gr.Textbox(label="Reference text (optional)", lines=4, value="")
with gr.Row():
corrected_tb = gr.Textbox(label="Corrected transcript (optional)", lines=4, value="")
with gr.Row():
score_slider = gr.Slider(minimum=0, maximum=10, step=1, label="Score out of 10", value=7)
with gr.Row():
feedback_tb = gr.Textbox(label="Feedback (what went right/wrong?)", lines=3, value="")
with gr.Row():
tags_cb = gr.CheckboxGroup(
["noisy", "far-field", "code-switching", "numbers-heavy", "named-entities", "read-speech", "spontaneous", "call-center", "voicenote"],
label="Slice tags (select any that apply)"
)
with gr.Row():
store_audio_cb = gr.Checkbox(label="Allow storing my audio for research/eval", value=False)
share_cb = gr.Checkbox(label="Allow sharing this example publicly", value=False)
submit_btn = gr.Button("Submit Feedback / Compute Metrics")
results_json = gr.JSON(label="Metrics & Status")
# Wire events
btn.click(
fn=_transcribe_and_store,
inputs=[audio, lang],
outputs=[output, meta_state, corrected_tb]
)
submit_btn.click(
fn=submit_feedback,
inputs=[
meta_state,
reference_tb,
corrected_tb,
score_slider,
feedback_tb,
tags_cb,
store_audio_cb,
share_cb,
audio # raw file path from gr.Audio
],
outputs=results_json
)
# Use a queue to keep Spaces stable under load
if __name__ == "__main__":
demo.queue() # enable_queue=True by default in recent Gradio
demo.launch()