Yoad
First commit with actual logic
2f5cf2f
raw
history blame
13.6 kB
import os
import streamlit as st
import pandas as pd
import jiwer
import requests
from datetime import datetime
from pathlib import Path
from st_fixed_container import st_fixed_container
from visual_eval.visualization import render_visualize_jiwer_result_html
from visual_eval.evaluator import HebrewTextNormalizer
HF_API_TOKEN = None
try:
HF_API_TOKEN = st.secrets["HF_API_TOKEN"]
except FileNotFoundError:
HF_API_TOKEN = os.environ.get("HF_API_TOKEN")
has_api_token = HF_API_TOKEN is not None
known_datasets = [
("ivrit-ai/eval-d1:test:text", None, "ivrit_ai_eval_d1"),
("upai-inc/saspeech:test:text", None, "saspeech"),
("google/fleurs:test:transcription", "he_il", "fleurs"),
("mozilla-foundation/common_voice_17_0:test:sentence", "he", "common_voice_17"),
("imvladikon/hebrew_speech_kan:validation:sentence", None, "hebrew_speech_kan"),
]
# Initialize session state for audio cache if it doesn't exist
if "audio_cache" not in st.session_state:
st.session_state.audio_cache = {}
if "audio_preview_active" not in st.session_state:
st.session_state.audio_preview_active = {}
def on_file_upload():
st.session_state.audio_cache = {}
st.session_state.audio_preview_active = {}
st.session_state.selected_entry_idx = 0
def display_rtl(html):
"""Render an RTL container with the provided HTML string"""
st.markdown(
f"""
<div dir="rtl" lang="he">
{html}
</div>
""",
unsafe_allow_html=True,
)
@st.cache_data
def calculate_final_metrics(uploaded_file, _df):
"""Calculate final metrics for all entries
Args:
uploaded_file: The uploaded file object (For cache hash gen)
_df: The dataframe containing the evaluation results (not included in cache hash)
Returns:
A dictionary containing the final metrics
"""
_df = _df.sort_values(by=["id"])
_df["reference_text"] = _df["reference_text"].fillna("")
_df["predicted_text"] = _df["predicted_text"].fillna("")
# convert to list of dicts
entries_data = _df.to_dict(orient="records")
htn = HebrewTextNormalizer()
# Calculate final metrics
results = jiwer.process_words(
[htn(entry["reference_text"]) for entry in entries_data],
[htn(entry["predicted_text"]) for entry in entries_data],
)
return results
def get_known_dataset_by_output_name(output_name):
for dataset in known_datasets:
if dataset[2] == output_name:
return dataset
return None
def get_dataset_entries_audio_urls(dataset, offset=0, max_entries=100):
if dataset is None or not has_api_token:
return None
dataset_repo_id, dataset_config, _ = dataset
if not dataset_config:
dataset_config = "default"
if ":" in dataset_repo_id:
dataset_repo_id, split, _ = dataset_repo_id.split(":")
else:
split = "test"
headers = {"Authorization": f"Bearer {HF_API_TOKEN}"}
api_query_params = {
"dataset": dataset_repo_id,
"config": dataset_config,
"split": split,
"offset": offset,
"length": max_entries,
}
query_params_str = "&".join([f"{k}={v}" for k, v in api_query_params.items()])
API_URL = f"https://datasets-server.huggingface.co/rows?{query_params_str}"
def query():
response = requests.get(API_URL, headers=headers)
return response.json()
data = query()
def get_audio_url(row):
audio_feature_list = row["row"]["audio"]
first_audio = audio_feature_list[0]
return first_audio["src"]
if "rows" in data and len(data["rows"]) > 0:
return [get_audio_url(row) for row in data["rows"]]
else:
return None
def get_audio_url_for_entry(
dataset, entry_idx, cache_neighbors=True, neighbor_range=20
):
"""
Get audio URL for a specific entry and optionally cache neighbors
Args:
dataset: Dataset tuple (repo_id, config, output_name)
entry_idx: Index of the entry to get audio URL for
cache_neighbors: Whether to cache audio URLs for neighboring entries
neighbor_range: Range of neighboring entries to cache
Returns:
Audio URL for the specified entry
"""
# Calculate the range of entries to load
if cache_neighbors:
start_idx = max(0, entry_idx - neighbor_range)
max_entries = neighbor_range * 2 + 1
else:
start_idx = entry_idx
max_entries = 1
# Get audio URLs for the range of entries
audio_urls = get_dataset_entries_audio_urls(dataset, start_idx, max_entries)
if not audio_urls:
return None
# Cache the audio URLs
for i, url in enumerate(audio_urls):
idx = start_idx + i
# Extract expiration time from URL if available
expires = None
if "expires=" in url:
try:
expires_param = url.split("expires=")[1].split("&")[0]
expires = datetime.fromtimestamp(int(expires_param))
except (ValueError, IndexError):
expires = None
st.session_state.audio_cache[idx] = {"url": url, "expires": expires}
# Return the URL for the requested entry
relative_idx = entry_idx - start_idx
if 0 <= relative_idx < len(audio_urls):
return audio_urls[relative_idx]
return None
def get_cached_audio_url(entry_idx):
"""
Get audio URL from cache if available and not expired
Args:
entry_idx: Index of the entry to get audio URL for
Returns:
Audio URL if available in cache and not expired, None otherwise
"""
if entry_idx not in st.session_state.audio_cache:
return None
cache_entry = st.session_state.audio_cache[entry_idx]
# Check if the URL is expired
if cache_entry["expires"] and datetime.now() > cache_entry["expires"]:
return None
return cache_entry["url"]
def main():
st.set_page_config(
page_title="ASR Evaluation Visualizer", page_icon="🎤", layout="wide"
)
if not has_api_token:
st.warning("No Hugging Face API token found. Audio previews will not work.")
st.title("ASR Evaluation Visualizer")
# File uploader
uploaded_file = st.file_uploader(
"Upload evaluation results CSV", type=["csv"], on_change=on_file_upload
)
if uploaded_file is not None:
# Load the data
try:
eval_results = pd.read_csv(uploaded_file)
st.success("File uploaded successfully!")
with st.sidebar:
# Toggle for calculating total metrics
show_total_metrics = st.toggle("Show total metrics", value=False)
if show_total_metrics:
total_metrics = calculate_final_metrics(uploaded_file, eval_results)
# Display total metrics in a nice format
with st.container():
st.metric("WER", f"{total_metrics.wer * 100:.4f}%")
st.table(
{
"Hits": total_metrics.hits,
"Subs": total_metrics.substitutions,
"Dels": total_metrics.deletions,
"Insrt": total_metrics.insertions,
}
)
# Create sidebar for entry selection
st.sidebar.header("Select Entry")
# Add Next/Prev buttons at the top of the sidebar
col1, col2 = st.sidebar.columns(2)
# Define navigation functions
def go_prev():
if st.session_state.selected_entry_idx > 0:
st.session_state.selected_entry_idx -= 1
def go_next():
if st.session_state.selected_entry_idx < len(eval_results) - 1:
st.session_state.selected_entry_idx += 1
# Add navigation buttons
col1.button("← Prev", on_click=go_prev, use_container_width=True)
col2.button("Next →", on_click=go_next, use_container_width=True)
# Create a data table with entries and their WER
entries_data = []
for i in range(len(eval_results)):
wer_value = eval_results.iloc[i].get("wer", 0)
# Format WER as percentage
wer_formatted = (
f"{wer_value*100:.2f}%"
if isinstance(wer_value, (int, float))
else wer_value
)
entries_data.append({"Entry": f"Entry #{i+1}", "WER": wer_formatted})
# Create a selection mechanism using radio buttons that look like a table
st.sidebar.write("Select an entry:")
# Use a container for better styling
entry_container = st.sidebar.container()
# Create a radio button for each entry, styled to look like a table row
entry_container.radio(
"Select an entry",
options=list(range(len(eval_results))),
format_func=lambda i: f"Entry #{i+1} ({entries_data[i]['WER']})",
label_visibility="collapsed",
key="selected_entry_idx",
)
# Use the selected entry
selected_entry = st.session_state.selected_entry_idx
# Toggle for normalized vs raw text
use_normalized = st.sidebar.toggle("Use normalized text", value=True)
# Get the text columns based on the toggle
if use_normalized:
ref_col, hyp_col = "norm_reference_text", "norm_predicted_text"
else:
ref_col, hyp_col = "reference_text", "predicted_text"
# Get the reference and hypothesis texts
ref, hyp = eval_results.iloc[selected_entry][[ref_col, hyp_col]].values
st.header("Visualization")
# Check if the CSV file is from a known dataset
dataset_name = None
# If no dataset column, try to infer from filename
if uploaded_file is not None:
filename_stem = Path(uploaded_file.name).stem
dataset_name = filename_stem
if not dataset_name and "dataset" in eval_results.columns:
dataset_name = eval_results.iloc[selected_entry]["dataset"]
# Get the known dataset if available
known_dataset = get_known_dataset_by_output_name(dataset_name)
# Display audio preview button if from a known dataset
if known_dataset:
# Check if we have the audio URL in cache
audio_url = get_cached_audio_url(selected_entry)
audio_preview_active = st.session_state.audio_preview_active.get(
selected_entry, False
)
preview_audio = False
if not audio_preview_active:
# Create a button to preview audio
preview_audio = st.button("Preview Audio", key="preview_audio")
if preview_audio or audio_url:
st.session_state.audio_preview_active[selected_entry] = True
with st_fixed_container(
mode="sticky", position="top", border=True, margin=0
):
# If button clicked or we already have the URL, get/use the audio URL
if not audio_url:
with st.spinner("Loading audio..."):
audio_url = get_audio_url_for_entry(
known_dataset, selected_entry
)
# Display the audio player in the sticky container at the top
if audio_url:
st.audio(audio_url)
else:
st.error("Failed to load audio for this entry.")
# Display the visualization
html = render_visualize_jiwer_result_html(ref, hyp)
display_rtl(html)
# Display metadata
st.header("Metadata")
metadata_cols = [
"metadata_uuid",
"model",
"dataset",
"dataset_split",
"engine",
]
metadata = eval_results.iloc[selected_entry][metadata_cols]
# Create a DataFrame for better display
metadata_df = pd.DataFrame(
{"Field": metadata_cols, "Value": metadata.values}
)
st.table(metadata_df)
# If we have audio URL, display it in the sticky container
if "audio_url" in locals() and audio_url:
pass # CSS is now applied globally
except Exception as e:
st.error(f"Error processing file: {str(e)}")
else:
st.info(
"Please upload an evaluation results CSV file to visualize the results."
)
st.markdown(
"""
### Expected CSV Format
The CSV should have the following columns:
- id
- reference_text
- predicted_text
- norm_reference_text
- norm_predicted_text
- wer
- wil
- substitutions
- deletions
- insertions
- hits
- metadata_uuid
- model
- dataset
- dataset_split
- engine
"""
)
if __name__ == "__main__":
main()