Spaces:
Running
Running
import os | |
import streamlit as st | |
import pandas as pd | |
import jiwer | |
import requests | |
from datetime import datetime | |
from pathlib import Path | |
from st_fixed_container import st_fixed_container | |
from visual_eval.visualization import render_visualize_jiwer_result_html | |
from visual_eval.evaluator import HebrewTextNormalizer | |
HF_API_TOKEN = None | |
try: | |
HF_API_TOKEN = st.secrets["HF_API_TOKEN"] | |
except FileNotFoundError: | |
HF_API_TOKEN = os.environ.get("HF_API_TOKEN") | |
has_api_token = HF_API_TOKEN is not None | |
known_datasets = [ | |
("ivrit-ai/eval-d1:test:text", None, "ivrit_ai_eval_d1"), | |
("upai-inc/saspeech:test:text", None, "saspeech"), | |
("google/fleurs:test:transcription", "he_il", "fleurs"), | |
("mozilla-foundation/common_voice_17_0:test:sentence", "he", "common_voice_17"), | |
("imvladikon/hebrew_speech_kan:validation:sentence", None, "hebrew_speech_kan"), | |
] | |
# Initialize session state for audio cache if it doesn't exist | |
if "audio_cache" not in st.session_state: | |
st.session_state.audio_cache = {} | |
if "audio_preview_active" not in st.session_state: | |
st.session_state.audio_preview_active = {} | |
def on_file_upload(): | |
st.session_state.audio_cache = {} | |
st.session_state.audio_preview_active = {} | |
st.session_state.selected_entry_idx = 0 | |
def display_rtl(html): | |
"""Render an RTL container with the provided HTML string""" | |
st.markdown( | |
f""" | |
<div dir="rtl" lang="he"> | |
{html} | |
</div> | |
""", | |
unsafe_allow_html=True, | |
) | |
def calculate_final_metrics(uploaded_file, _df): | |
"""Calculate final metrics for all entries | |
Args: | |
uploaded_file: The uploaded file object (For cache hash gen) | |
_df: The dataframe containing the evaluation results (not included in cache hash) | |
Returns: | |
A dictionary containing the final metrics | |
""" | |
_df = _df.sort_values(by=["id"]) | |
_df["reference_text"] = _df["reference_text"].fillna("") | |
_df["predicted_text"] = _df["predicted_text"].fillna("") | |
# convert to list of dicts | |
entries_data = _df.to_dict(orient="records") | |
htn = HebrewTextNormalizer() | |
# Calculate final metrics | |
results = jiwer.process_words( | |
[htn(entry["reference_text"]) for entry in entries_data], | |
[htn(entry["predicted_text"]) for entry in entries_data], | |
) | |
return results | |
def get_known_dataset_by_output_name(output_name): | |
for dataset in known_datasets: | |
if dataset[2] == output_name: | |
return dataset | |
return None | |
def get_dataset_entries_audio_urls(dataset, offset=0, max_entries=100): | |
if dataset is None or not has_api_token: | |
return None | |
dataset_repo_id, dataset_config, _ = dataset | |
if not dataset_config: | |
dataset_config = "default" | |
if ":" in dataset_repo_id: | |
dataset_repo_id, split, _ = dataset_repo_id.split(":") | |
else: | |
split = "test" | |
headers = {"Authorization": f"Bearer {HF_API_TOKEN}"} | |
api_query_params = { | |
"dataset": dataset_repo_id, | |
"config": dataset_config, | |
"split": split, | |
"offset": offset, | |
"length": max_entries, | |
} | |
query_params_str = "&".join([f"{k}={v}" for k, v in api_query_params.items()]) | |
API_URL = f"https://datasets-server.huggingface.co/rows?{query_params_str}" | |
def query(): | |
response = requests.get(API_URL, headers=headers) | |
return response.json() | |
data = query() | |
def get_audio_url(row): | |
audio_feature_list = row["row"]["audio"] | |
first_audio = audio_feature_list[0] | |
return first_audio["src"] | |
if "rows" in data and len(data["rows"]) > 0: | |
return [get_audio_url(row) for row in data["rows"]] | |
else: | |
return None | |
def get_audio_url_for_entry( | |
dataset, entry_idx, cache_neighbors=True, neighbor_range=20 | |
): | |
""" | |
Get audio URL for a specific entry and optionally cache neighbors | |
Args: | |
dataset: Dataset tuple (repo_id, config, output_name) | |
entry_idx: Index of the entry to get audio URL for | |
cache_neighbors: Whether to cache audio URLs for neighboring entries | |
neighbor_range: Range of neighboring entries to cache | |
Returns: | |
Audio URL for the specified entry | |
""" | |
# Calculate the range of entries to load | |
if cache_neighbors: | |
start_idx = max(0, entry_idx - neighbor_range) | |
max_entries = neighbor_range * 2 + 1 | |
else: | |
start_idx = entry_idx | |
max_entries = 1 | |
# Get audio URLs for the range of entries | |
audio_urls = get_dataset_entries_audio_urls(dataset, start_idx, max_entries) | |
if not audio_urls: | |
return None | |
# Cache the audio URLs | |
for i, url in enumerate(audio_urls): | |
idx = start_idx + i | |
# Extract expiration time from URL if available | |
expires = None | |
if "expires=" in url: | |
try: | |
expires_param = url.split("expires=")[1].split("&")[0] | |
expires = datetime.fromtimestamp(int(expires_param)) | |
except (ValueError, IndexError): | |
expires = None | |
st.session_state.audio_cache[idx] = {"url": url, "expires": expires} | |
# Return the URL for the requested entry | |
relative_idx = entry_idx - start_idx | |
if 0 <= relative_idx < len(audio_urls): | |
return audio_urls[relative_idx] | |
return None | |
def get_cached_audio_url(entry_idx): | |
""" | |
Get audio URL from cache if available and not expired | |
Args: | |
entry_idx: Index of the entry to get audio URL for | |
Returns: | |
Audio URL if available in cache and not expired, None otherwise | |
""" | |
if entry_idx not in st.session_state.audio_cache: | |
return None | |
cache_entry = st.session_state.audio_cache[entry_idx] | |
# Check if the URL is expired | |
if cache_entry["expires"] and datetime.now() > cache_entry["expires"]: | |
return None | |
return cache_entry["url"] | |
def main(): | |
st.set_page_config( | |
page_title="ASR Evaluation Visualizer", page_icon="🎤", layout="wide" | |
) | |
if not has_api_token: | |
st.warning("No Hugging Face API token found. Audio previews will not work.") | |
st.title("ASR Evaluation Visualizer") | |
# File uploader | |
uploaded_file = st.file_uploader( | |
"Upload evaluation results CSV", type=["csv"], on_change=on_file_upload | |
) | |
if uploaded_file is not None: | |
# Load the data | |
try: | |
eval_results = pd.read_csv(uploaded_file) | |
st.success("File uploaded successfully!") | |
with st.sidebar: | |
# Toggle for calculating total metrics | |
show_total_metrics = st.toggle("Show total metrics", value=False) | |
if show_total_metrics: | |
total_metrics = calculate_final_metrics(uploaded_file, eval_results) | |
# Display total metrics in a nice format | |
with st.container(): | |
st.metric("WER", f"{total_metrics.wer * 100:.4f}%") | |
st.table( | |
{ | |
"Hits": total_metrics.hits, | |
"Subs": total_metrics.substitutions, | |
"Dels": total_metrics.deletions, | |
"Insrt": total_metrics.insertions, | |
} | |
) | |
# Create sidebar for entry selection | |
st.sidebar.header("Select Entry") | |
# Add Next/Prev buttons at the top of the sidebar | |
col1, col2 = st.sidebar.columns(2) | |
# Define navigation functions | |
def go_prev(): | |
if st.session_state.selected_entry_idx > 0: | |
st.session_state.selected_entry_idx -= 1 | |
def go_next(): | |
if st.session_state.selected_entry_idx < len(eval_results) - 1: | |
st.session_state.selected_entry_idx += 1 | |
# Add navigation buttons | |
col1.button("← Prev", on_click=go_prev, use_container_width=True) | |
col2.button("Next →", on_click=go_next, use_container_width=True) | |
# Create a data table with entries and their WER | |
entries_data = [] | |
for i in range(len(eval_results)): | |
wer_value = eval_results.iloc[i].get("wer", 0) | |
# Format WER as percentage | |
wer_formatted = ( | |
f"{wer_value*100:.2f}%" | |
if isinstance(wer_value, (int, float)) | |
else wer_value | |
) | |
entries_data.append({"Entry": f"Entry #{i+1}", "WER": wer_formatted}) | |
# Create a selection mechanism using radio buttons that look like a table | |
st.sidebar.write("Select an entry:") | |
# Use a container for better styling | |
entry_container = st.sidebar.container() | |
# Create a radio button for each entry, styled to look like a table row | |
entry_container.radio( | |
"Select an entry", | |
options=list(range(len(eval_results))), | |
format_func=lambda i: f"Entry #{i+1} ({entries_data[i]['WER']})", | |
label_visibility="collapsed", | |
key="selected_entry_idx", | |
) | |
# Use the selected entry | |
selected_entry = st.session_state.selected_entry_idx | |
# Toggle for normalized vs raw text | |
use_normalized = st.sidebar.toggle("Use normalized text", value=True) | |
# Get the text columns based on the toggle | |
if use_normalized: | |
ref_col, hyp_col = "norm_reference_text", "norm_predicted_text" | |
else: | |
ref_col, hyp_col = "reference_text", "predicted_text" | |
# Get the reference and hypothesis texts | |
ref, hyp = eval_results.iloc[selected_entry][[ref_col, hyp_col]].values | |
st.header("Visualization") | |
# Check if the CSV file is from a known dataset | |
dataset_name = None | |
# If no dataset column, try to infer from filename | |
if uploaded_file is not None: | |
filename_stem = Path(uploaded_file.name).stem | |
dataset_name = filename_stem | |
if not dataset_name and "dataset" in eval_results.columns: | |
dataset_name = eval_results.iloc[selected_entry]["dataset"] | |
# Get the known dataset if available | |
known_dataset = get_known_dataset_by_output_name(dataset_name) | |
# Display audio preview button if from a known dataset | |
if known_dataset: | |
# Check if we have the audio URL in cache | |
audio_url = get_cached_audio_url(selected_entry) | |
audio_preview_active = st.session_state.audio_preview_active.get( | |
selected_entry, False | |
) | |
preview_audio = False | |
if not audio_preview_active: | |
# Create a button to preview audio | |
preview_audio = st.button("Preview Audio", key="preview_audio") | |
if preview_audio or audio_url: | |
st.session_state.audio_preview_active[selected_entry] = True | |
with st_fixed_container( | |
mode="sticky", position="top", border=True, margin=0 | |
): | |
# If button clicked or we already have the URL, get/use the audio URL | |
if not audio_url: | |
with st.spinner("Loading audio..."): | |
audio_url = get_audio_url_for_entry( | |
known_dataset, selected_entry | |
) | |
# Display the audio player in the sticky container at the top | |
if audio_url: | |
st.audio(audio_url) | |
else: | |
st.error("Failed to load audio for this entry.") | |
# Display the visualization | |
html = render_visualize_jiwer_result_html(ref, hyp) | |
display_rtl(html) | |
# Display metadata | |
st.header("Metadata") | |
metadata_cols = [ | |
"metadata_uuid", | |
"model", | |
"dataset", | |
"dataset_split", | |
"engine", | |
] | |
metadata = eval_results.iloc[selected_entry][metadata_cols] | |
# Create a DataFrame for better display | |
metadata_df = pd.DataFrame( | |
{"Field": metadata_cols, "Value": metadata.values} | |
) | |
st.table(metadata_df) | |
# If we have audio URL, display it in the sticky container | |
if "audio_url" in locals() and audio_url: | |
pass # CSS is now applied globally | |
except Exception as e: | |
st.error(f"Error processing file: {str(e)}") | |
else: | |
st.info( | |
"Please upload an evaluation results CSV file to visualize the results." | |
) | |
st.markdown( | |
""" | |
### Expected CSV Format | |
The CSV should have the following columns: | |
- id | |
- reference_text | |
- predicted_text | |
- norm_reference_text | |
- norm_predicted_text | |
- wer | |
- wil | |
- substitutions | |
- deletions | |
- insertions | |
- hits | |
- metadata_uuid | |
- model | |
- dataset | |
- dataset_split | |
- engine | |
""" | |
) | |
if __name__ == "__main__": | |
main() | |