|
import gradio as gr |
|
from fastapi import FastAPI, HTTPException, BackgroundTasks |
|
from pydantic import BaseModel, HttpUrl |
|
from typing import List, Optional, Dict |
|
import torch |
|
import torchaudio |
|
from transformers import AutoProcessor, AutoModelForCTC |
|
import evaluate |
|
import zipfile |
|
from datetime import datetime |
|
import json |
|
import uuid |
|
import os |
|
from pathlib import Path |
|
from huggingface_hub import HfApi |
|
import evaluate |
|
|
|
|
|
|
|
app = FastAPI(title="TIMIT Phoneme Transcription Leaderboard") |
|
|
|
|
|
demo = gr.Interface( |
|
fn=lambda x: x, |
|
inputs=gr.Textbox(visible=False), |
|
outputs=gr.Textbox(visible=False), |
|
title="TIMIT Phoneme Transcription Queue", |
|
description="API endpoints are available at /api/leaderboard, /api/evaluate, and /api/tasks/{task_id}" |
|
) |
|
|
|
|
|
|
|
CURRENT_DIR = Path(__file__).parent.absolute() |
|
|
|
|
|
TIMIT_PATH = CURRENT_DIR / ".data" / "TIMIT.zip" |
|
QUEUE_DIR = CURRENT_DIR / "queue" |
|
PATHS = { |
|
'tasks': QUEUE_DIR / "tasks.json", |
|
'results': QUEUE_DIR / "results.json", |
|
'leaderboard': QUEUE_DIR / "leaderboard.json" |
|
} |
|
|
|
|
|
phone_errors = evaluate.load("ginic/phone_errors") |
|
|
|
|
|
class TimitDataManager: |
|
"""Handles all TIMIT dataset operations""" |
|
|
|
|
|
TIMIT_TO_IPA = { |
|
|
|
'aa': 'ɑ', |
|
'ae': 'æ', |
|
'ah': 'ʌ', |
|
'ao': 'ɔ', |
|
'aw': 'aʊ', |
|
'ay': 'aɪ', |
|
'eh': 'ɛ', |
|
'er': 'ɹ', |
|
'ey': 'eɪ', |
|
'ih': 'ɪ', |
|
'ix': 'i', |
|
'iy': 'i', |
|
'ow': 'oʊ', |
|
'oy': 'ɔɪ', |
|
'uh': 'ʊ', |
|
'uw': 'u', |
|
'ux': 'u', |
|
'ax': 'ə', |
|
'ax-h': 'ə', |
|
'axr': 'ɹ', |
|
|
|
|
|
'b': '', |
|
'bcl': 'b', |
|
'd': '', |
|
'dcl': 'd', |
|
'g': '', |
|
'gcl': 'g', |
|
'p': '', |
|
'pcl': 'p', |
|
't': '', |
|
'tcl': 't', |
|
'k': '', |
|
'kcl': 'k', |
|
'dx': 'ɾ', |
|
'q': 'ʔ', |
|
|
|
|
|
'jh': 'dʒ', |
|
'ch': 'tʃ', |
|
's': 's', |
|
'sh': 'ʃ', |
|
'z': 'z', |
|
'zh': 'ʒ', |
|
'f': 'f', |
|
'th': 'θ', |
|
'v': 'v', |
|
'dh': 'ð', |
|
'hh': 'h', |
|
'hv': 'h', |
|
|
|
|
|
'm': 'm', |
|
'n': 'n', |
|
'ng': 'ŋ', |
|
'em': 'm', |
|
'en': 'n', |
|
'eng': 'ŋ', |
|
'nx': 'ɾ', |
|
|
|
|
|
'l': 'l', |
|
'r': 'ɹ', |
|
'w': 'w', |
|
'wh': 'ʍ', |
|
'y': 'j', |
|
'el': 'l', |
|
|
|
|
|
'epi': '', |
|
'h#': '', |
|
'pau': '', |
|
} |
|
|
|
|
|
def __init__(self, timit_path: Path): |
|
self.timit_path = timit_path |
|
self._zip = None |
|
print(f"TimitDataManager initialized with path: {self.timit_path.absolute()}") |
|
if not self.timit_path.exists(): |
|
raise FileNotFoundError(f"TIMIT dataset not found at {self.timit_path.absolute()}") |
|
print("TIMIT dataset file exists!") |
|
|
|
@property |
|
def zip(self): |
|
if not self._zip: |
|
try: |
|
self._zip = zipfile.ZipFile(self.timit_path, 'r') |
|
print("Successfully opened TIMIT zip file") |
|
except FileNotFoundError: |
|
raise FileNotFoundError(f"TIMIT dataset not found at {self.timit_path}") |
|
return self._zip |
|
|
|
def get_file_list(self, subset: str) -> List[str]: |
|
"""Get list of WAV files for given subset""" |
|
files = [f for f in self.zip.namelist() |
|
if f.endswith('.WAV') and subset.lower() in f.lower()] |
|
print(f"Found {len(files)} WAV files in {subset} subset") |
|
if files: |
|
print("First 3 files:", files[:3]) |
|
return files |
|
|
|
def load_audio(self, filename: str) -> torch.Tensor: |
|
"""Load and preprocess audio file""" |
|
with self.zip.open(filename) as wav_file: |
|
waveform, sample_rate = torchaudio.load(wav_file) |
|
|
|
if waveform.shape[0] > 1: |
|
waveform = torch.mean(waveform, dim=0, keepdim=True) |
|
|
|
if sample_rate != 16000: |
|
waveform = torchaudio.transforms.Resample(sample_rate, 16000)(waveform) |
|
|
|
waveform = (waveform - waveform.mean()) / (waveform.std() + 1e-7) |
|
|
|
if waveform.dim() == 1: |
|
waveform = waveform.unsqueeze(0) |
|
|
|
return waveform |
|
|
|
def get_phonemes(self, filename: str) -> str: |
|
"""Get cleaned phoneme sequence from PHN file and convert to IPA""" |
|
phn_file = filename.replace('.WAV', '.PHN') |
|
with self.zip.open(phn_file) as f: |
|
phonemes = [] |
|
for line in f.read().decode('utf-8').splitlines(): |
|
if line.strip(): |
|
_, _, phone = line.split() |
|
phone = self.remove_stress_mark(phone) |
|
|
|
ipa = self.TIMIT_TO_IPA.get(phone.lower(), '') |
|
if ipa: |
|
phonemes.append(ipa) |
|
return ''.join(phonemes) |
|
|
|
def simplify_timit(self, phoneme: str) -> str: |
|
"""Apply substitutions to simplify TIMIT phonemes""" |
|
return self.PHONE_SUBSTITUTIONS.get(phoneme, phoneme) |
|
|
|
def remove_stress_mark(self, text: str) -> str: |
|
"""Removes the combining double inverted breve (͡) from text""" |
|
if not isinstance(text, str): |
|
raise TypeError("Input must be string") |
|
return text.replace('͡', '') |
|
|
|
class ModelManager: |
|
"""Handles model loading and inference""" |
|
|
|
def __init__(self): |
|
self.models = {} |
|
self.processors = {} |
|
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
self.batch_size = 32 |
|
|
|
def get_model_and_processor(self, model_name: str): |
|
"""Get or load model and processor""" |
|
if model_name not in self.models: |
|
print("Loading processor with phoneme tokenizer...") |
|
processor = AutoProcessor.from_pretrained(model_name) |
|
|
|
print("Loading model...", {model_name}) |
|
model = AutoModelForCTC.from_pretrained(model_name).to(self.device) |
|
|
|
self.models[model_name] = model |
|
self.processors[model_name] = processor |
|
|
|
return self.models[model_name], self.processors[model_name] |
|
|
|
def transcribe(self, audio_list: List[torch.Tensor], model_name: str) -> List[str]: |
|
"""Transcribe a batch of audio using specified model""" |
|
model, processor = self.get_model_and_processor(model_name) |
|
if not model or not processor: |
|
raise Exception("Model and processor not loaded") |
|
|
|
|
|
all_predictions = [] |
|
for i in range(0, len(audio_list), self.batch_size): |
|
batch_audio = audio_list[i:i + self.batch_size] |
|
|
|
|
|
max_length = max(audio.shape[-1] for audio in batch_audio) |
|
padded_audio = torch.zeros((len(batch_audio), 1, max_length)) |
|
attention_mask = torch.zeros((len(batch_audio), max_length)) |
|
|
|
for j, audio in enumerate(batch_audio): |
|
padded_audio[j, :, :audio.shape[-1]] = audio |
|
attention_mask[j, :audio.shape[-1]] = 1 |
|
|
|
|
|
inputs = processor( |
|
padded_audio.squeeze(1).numpy(), |
|
sampling_rate=16000, |
|
return_tensors="pt", |
|
padding=True |
|
) |
|
|
|
input_values = inputs.input_values.to(self.device) |
|
attention_mask = inputs.get("attention_mask", attention_mask).to(self.device) |
|
|
|
with torch.no_grad(): |
|
outputs = model( |
|
input_values=input_values, |
|
attention_mask=attention_mask |
|
) |
|
logits = outputs.logits |
|
predicted_ids = torch.argmax(logits, dim=-1) |
|
predictions = processor.batch_decode(predicted_ids, skip_special_tokens=True) |
|
predictions = [pred.replace(' ', '') for pred in predictions] |
|
all_predictions.extend(predictions) |
|
|
|
return all_predictions |
|
|
|
class StorageManager: |
|
"""Handles all JSON storage operations""" |
|
|
|
def __init__(self, paths: Dict[str, Path]): |
|
self.paths = paths |
|
self._ensure_directories() |
|
|
|
def _ensure_directories(self): |
|
"""Ensure all necessary directories and files exist""" |
|
for path in self.paths.values(): |
|
path.parent.mkdir(parents=True, exist_ok=True) |
|
if not path.exists(): |
|
path.write_text('[]') |
|
|
|
def load(self, key: str) -> List: |
|
"""Load JSON file""" |
|
return json.loads(self.paths[key].read_text()) |
|
|
|
def save(self, key: str, data: List): |
|
"""Save data to JSON file""" |
|
self.paths[key].write_text(json.dumps(data, indent=4, default=str, ensure_ascii=False)) |
|
|
|
def update_task(self, task_id: str, updates: Dict): |
|
"""Update specific task with new data""" |
|
tasks = self.load('tasks') |
|
for task in tasks: |
|
if task['id'] == task_id: |
|
task.update(updates) |
|
break |
|
self.save('tasks', tasks) |
|
|
|
class EvaluationRequest(BaseModel): |
|
"""Request model for TIMIT evaluation""" |
|
transcription_model: str |
|
subset: str = "test" |
|
max_samples: Optional[int] = None |
|
submission_name: str |
|
github_url: Optional[str] = None |
|
|
|
|
|
timit_manager = TimitDataManager(TIMIT_PATH) |
|
model_manager = ModelManager() |
|
storage_manager = StorageManager(PATHS) |
|
|
|
async def evaluate_model(task_id: str, request: EvaluationRequest): |
|
"""Background task to evaluate model on TIMIT""" |
|
try: |
|
storage_manager.update_task(task_id, {"status": "processing"}) |
|
|
|
files = timit_manager.get_file_list(request.subset) |
|
if request.max_samples: |
|
files = files[:request.max_samples] |
|
|
|
results = [] |
|
total_per = total_pwed = 0 |
|
|
|
|
|
batch_size = model_manager.batch_size |
|
for i in range(0, len(files), batch_size): |
|
batch_files = files[i:i + batch_size] |
|
|
|
|
|
batch_audio = [] |
|
batch_ground_truth = [] |
|
for wav_file in batch_files: |
|
audio = timit_manager.load_audio(wav_file) |
|
ground_truth = timit_manager.get_phonemes(wav_file) |
|
batch_audio.append(audio) |
|
batch_ground_truth.append(ground_truth) |
|
|
|
|
|
predictions = model_manager.transcribe(batch_audio, request.transcription_model) |
|
|
|
|
|
for j, (wav_file, prediction, ground_truth) in enumerate(zip(batch_files, predictions, batch_ground_truth)): |
|
|
|
|
|
|
|
metrics = phone_errors.compute( |
|
predictions=[prediction], |
|
references=[ground_truth], |
|
is_normalize_pfer=True |
|
) |
|
|
|
per = metrics['phone_error_rates'][0] |
|
pwed = metrics['phone_feature_error_rates'][0] |
|
|
|
results.append({ |
|
"file": wav_file, |
|
"ground_truth": ground_truth, |
|
"prediction": prediction, |
|
"per": per, |
|
"pwed": pwed |
|
}) |
|
|
|
total_per += per |
|
total_pwed += pwed |
|
|
|
if not results: |
|
raise Exception("No files were successfully processed") |
|
|
|
avg_per = total_per / len(results) |
|
avg_pwed = total_pwed / len(results) |
|
|
|
result = { |
|
"task_id": task_id, |
|
"model": request.transcription_model, |
|
"subset": request.subset, |
|
"num_files": len(results), |
|
"average_per": avg_per, |
|
"average_pwed": avg_pwed, |
|
"detailed_results": results[:5], |
|
"timestamp": datetime.now().isoformat() |
|
} |
|
|
|
|
|
print("Saving results...") |
|
current_results = storage_manager.load('results') |
|
current_results.append(result) |
|
storage_manager.save('results', current_results) |
|
|
|
|
|
print("Updating leaderboard...") |
|
leaderboard = storage_manager.load('leaderboard') |
|
entry = next((e for e in leaderboard |
|
if e["submission_name"] == request.submission_name), None) |
|
|
|
if entry: |
|
|
|
entry.update({ |
|
"average_per": avg_per, |
|
"average_pwed": avg_pwed, |
|
"model": request.transcription_model, |
|
"subset": request.subset, |
|
"github_url": request.github_url, |
|
"submission_date": datetime.now().isoformat() |
|
}) |
|
else: |
|
leaderboard.append({ |
|
"submission_id": str(uuid.uuid4()), |
|
"submission_name": request.submission_name, |
|
"model": request.transcription_model, |
|
"average_per": avg_per, |
|
"average_pwed": avg_pwed, |
|
"subset": request.subset, |
|
"github_url": request.github_url, |
|
"submission_date": datetime.now().isoformat() |
|
}) |
|
|
|
storage_manager.save('leaderboard', leaderboard) |
|
storage_manager.update_task(task_id, {"status": "completed"}) |
|
print("Evaluation completed successfully") |
|
|
|
except Exception as e: |
|
error_msg = f"Evaluation failed: {str(e)}" |
|
print(error_msg) |
|
storage_manager.update_task(task_id, { |
|
"status": "failed", |
|
"error": error_msg |
|
}) |
|
|
|
|
|
def init_directories(): |
|
"""Ensure all necessary directories exist""" |
|
(CURRENT_DIR / ".data").mkdir(parents=True, exist_ok=True) |
|
QUEUE_DIR.mkdir(parents=True, exist_ok=True) |
|
|
|
for path in PATHS.values(): |
|
if not path.exists(): |
|
path.write_text('[]') |
|
|
|
|
|
init_directories() |
|
timit_manager = TimitDataManager(TIMIT_PATH) |
|
model_manager = ModelManager() |
|
storage_manager = StorageManager(PATHS) |
|
|
|
@app.get("/api/health") |
|
async def health_check(): |
|
"""Simple health check endpoint""" |
|
return {"status": "healthy"} |
|
|
|
@app.post("/api/evaluate") |
|
async def submit_evaluation( |
|
request: EvaluationRequest, |
|
background_tasks: BackgroundTasks |
|
): |
|
"""Submit new evaluation task""" |
|
task_id = str(uuid.uuid4()) |
|
|
|
task = { |
|
"id": task_id, |
|
"model": request.transcription_model, |
|
"subset": request.subset, |
|
"submission_name": request.submission_name, |
|
"github_url": request.github_url, |
|
"status": "queued", |
|
"submitted_at": datetime.now().isoformat() |
|
} |
|
|
|
tasks = storage_manager.load('tasks') |
|
tasks.append(task) |
|
storage_manager.save('tasks', tasks) |
|
|
|
background_tasks.add_task(evaluate_model, task_id, request) |
|
|
|
return { |
|
"message": "Evaluation task submitted successfully", |
|
"task_id": task_id |
|
} |
|
|
|
@app.get("/api/tasks/{task_id}") |
|
async def get_task(task_id: str): |
|
"""Get specific task status""" |
|
tasks = storage_manager.load('tasks') |
|
task = next((t for t in tasks if t["id"] == task_id), None) |
|
if not task: |
|
raise HTTPException(status_code=404, detail="Task not found") |
|
return task |
|
|
|
@app.get("/api/leaderboard") |
|
async def get_leaderboard(): |
|
"""Get current leaderboard""" |
|
try: |
|
leaderboard = storage_manager.load('leaderboard') |
|
sorted_leaderboard = sorted(leaderboard, key=lambda x: (x["average_per"], x["average_pwed"])) |
|
return sorted_leaderboard |
|
except Exception as e: |
|
print(f"Error loading leaderboard: {e}") |
|
return [] |
|
|
|
|
|
app = gr.mount_gradio_app(app, demo, path="/") |
|
|
|
|
|
if __name__ == "__main__": |
|
import uvicorn |
|
uvicorn.run(app, host="0.0.0.0", port=7860) |
|
|
|
|