Spaces:
Runtime error
Runtime error
import argparse | |
from typing import Any, Dict | |
from pathlib import Path | |
import numpy as np | |
import soundfile as sf | |
from tqdm import tqdm | |
def validate_tts(wav_path: str) -> dict: | |
""" | |
Validation checks for TTS-generated audio files to detect common artifacts. | |
""" | |
try: | |
# Load and process audio | |
audio, sr = sf.read(wav_path) | |
if len(audio.shape) > 1: | |
audio = np.mean(audio, axis=1) | |
duration = len(audio) / sr | |
issues = [] | |
# Basic quality checks | |
abs_audio = np.abs(audio) | |
stats = { | |
"rms": float(np.sqrt(np.mean(audio**2))), | |
"peak": float(np.max(abs_audio)), | |
"dc_offset": float(np.mean(audio)), | |
} | |
clip_count = np.sum(abs_audio >= 0.99) | |
clip_percent = (clip_count / len(audio)) * 100 | |
if duration < 0.1: | |
issues.append( | |
"WARNING: Audio is suspiciously short - possible failed generation" | |
) | |
if stats["peak"] >= 1.0: | |
if clip_percent > 1.0: | |
issues.append( | |
f"WARNING: Significant clipping detected ({clip_percent:.2e}% of samples)" | |
) | |
elif clip_percent > 0.01: | |
issues.append( | |
f"INFO: Minor peak limiting detected ({clip_percent:.2e}% of samples)" | |
) | |
if stats["rms"] < 0.01: | |
issues.append("WARNING: Audio is very quiet - possible failed generation") | |
if abs(stats["dc_offset"]) > 0.1: | |
issues.append(f"WARNING: High DC offset ({stats['dc_offset']:.3f})") | |
# Check for long silence gaps | |
eps = np.finfo(float).eps | |
db = 20 * np.log10(abs_audio + eps) | |
silence_threshold = -45 # dB | |
min_silence = 2.0 # seconds | |
window_size = int(min_silence * sr) | |
silence_count = 0 | |
last_silence = -1 | |
start_idx = int(0.2 * sr) # Skip first 0.2s | |
for i in tqdm( | |
range(start_idx, len(db) - window_size, window_size), | |
desc="Checking for silence", | |
): | |
window = db[i : i + window_size] | |
if np.mean(window) < silence_threshold: | |
silent_ratio = np.mean(window < silence_threshold) | |
if silent_ratio > 0.9: | |
if last_silence == -1 or (i / sr - last_silence) > 2.0: | |
silence_count += 1 | |
last_silence = i / sr | |
issues.append( | |
f"WARNING: Long silence detected at {i/sr:.2f}s (duration: {min_silence:.1f}s)" | |
) | |
if silence_count > 2: | |
issues.append( | |
f"WARNING: Multiple long silences found ({silence_count} total)" | |
) | |
# Detect audio artifacts | |
diff = np.diff(audio) | |
abs_diff = np.abs(diff) | |
window_size = min(int(0.005 * sr), 256) | |
window = np.ones(window_size) / window_size | |
local_avg_diff = np.convolve(abs_diff, window, mode="same") | |
spikes = (abs_diff > (10 * local_avg_diff)) & (abs_diff > 0.1) | |
artifact_indices = np.nonzero(spikes)[0] | |
artifacts = [] | |
if len(artifact_indices) > 0: | |
gaps = np.diff(artifact_indices) | |
min_gap = int(0.005 * sr) | |
break_points = np.nonzero(gaps > min_gap)[0] + 1 | |
groups = np.split(artifact_indices, break_points) | |
for group in groups: | |
if len(group) >= 5: | |
severity = np.max(abs_diff[group]) | |
if severity > 0.2: | |
center_idx = group[len(group) // 2] | |
artifacts.append( | |
{ | |
"time": float( | |
center_idx / sr | |
), # Ensure float for consistent timing | |
"severity": float(severity), | |
} | |
) | |
issues.append( | |
f"WARNING: Audio discontinuity at {center_idx/sr:.3f}s " | |
f"(severity: {severity:.3f})" | |
) | |
# Check for repeated speech segments | |
for chunk_duration in tqdm( | |
[0.5, 2.5, 5.0, 10.0], desc="Checking for repeated speech" | |
): | |
chunk_size = int(chunk_duration * sr) | |
overlap = int(0.2 * chunk_size) | |
for i in range(0, len(audio) - 2 * chunk_size, overlap): | |
chunk1 = audio[i : i + chunk_size] | |
chunk2 = audio[i + chunk_size : i + 2 * chunk_size] | |
if np.mean(np.abs(chunk1)) < 0.01 or np.mean(np.abs(chunk2)) < 0.01: | |
continue | |
try: | |
correlation = np.corrcoef(chunk1, chunk2)[0, 1] | |
if not np.isnan(correlation) and correlation > 0.92: | |
issues.append( | |
f"WARNING: Possible repeated speech at {i/sr:.1f}s " | |
f"(~{int(chunk_duration*160/60):d} words, correlation: {correlation:.3f})" | |
) | |
break | |
except: | |
continue | |
return { | |
"file": wav_path, | |
"duration": f"{duration:.2f}s", | |
"sample_rate": sr, | |
"peak_amplitude": f"{stats['peak']:.3f}", | |
"rms_level": f"{stats['rms']:.3f}", | |
"dc_offset": f"{stats['dc_offset']:.3f}", | |
"artifact_count": len(artifacts), | |
"artifact_locations": [a["time"] for a in artifacts], | |
"artifact_severities": [a["severity"] for a in artifacts], | |
"issues": issues, | |
"valid": len(issues) == 0, | |
} | |
except Exception as e: | |
return {"file": wav_path, "error": str(e), "valid": False} | |
def generate_analysis_plots( | |
wav_path: str, output_dir: str, validation_result: Dict[str, Any] | |
): | |
""" | |
Generate analysis plots for audio file with time-aligned visualizations. | |
""" | |
import matplotlib.pyplot as plt | |
from scipy.signal import spectrogram | |
# Load audio | |
audio, sr = sf.read(wav_path) | |
if len(audio.shape) > 1: | |
audio = np.mean(audio, axis=1) | |
# Create figure with shared x-axis | |
fig = plt.figure(figsize=(15, 8)) | |
gs = plt.GridSpec(2, 1, height_ratios=[1.2, 0.8], hspace=0.1) | |
ax1 = fig.add_subplot(gs[0]) | |
ax2 = fig.add_subplot(gs[1], sharex=ax1) | |
# Calculate spectrogram | |
nperseg = 2048 | |
noverlap = 1536 | |
f, t, Sxx = spectrogram( | |
audio, sr, nperseg=nperseg, noverlap=noverlap, window="hann", scaling="spectrum" | |
) | |
# Plot spectrogram | |
im = ax1.pcolormesh( | |
t, | |
f, | |
10 * np.log10(Sxx + 1e-10), | |
shading="gouraud", | |
cmap="viridis", | |
vmin=-100, | |
vmax=-20, | |
) | |
ax1.set_ylabel("Frequency [Hz]", fontsize=10) | |
cbar = plt.colorbar(im, ax=ax1, label="dB") | |
ax1.set_title("Spectrogram", pad=10, fontsize=12) | |
# Plot waveform with exact time alignment | |
times = np.arange(len(audio)) / sr | |
ax2.plot(times, audio, color="#2E5596", alpha=0.7, linewidth=0.5, label="Audio") | |
ax2.set_ylabel("Amplitude", fontsize=10) | |
ax2.set_xlabel("Time [sec]", fontsize=10) | |
ax2.grid(True, alpha=0.2) | |
# Add artifact markers | |
if ( | |
"artifact_locations" in validation_result | |
and validation_result["artifact_locations"] | |
): | |
for loc in validation_result["artifact_locations"]: | |
ax1.axvline(x=loc, color="red", alpha=0.7, linewidth=2) | |
ax2.axvline( | |
x=loc, color="red", alpha=0.7, linewidth=2, label="Detected Artifacts" | |
) | |
# Add legend to both plots | |
if len(validation_result["artifact_locations"]) > 0: | |
ax1.plot([], [], color="red", linewidth=2, label="Detected Artifacts") | |
ax1.legend(loc="upper right", fontsize=8) | |
# Only add unique labels to legend | |
handles, labels = ax2.get_legend_handles_labels() | |
unique_labels = dict(zip(labels, handles)) | |
ax2.legend( | |
unique_labels.values(), | |
unique_labels.keys(), | |
loc="upper right", | |
fontsize=8, | |
) | |
# Set common x limits | |
xlim = (0, len(audio) / sr) | |
ax1.set_xlim(xlim) | |
ax2.set_xlim(xlim) | |
og_filename = Path(wav_path).name.split(".")[0] | |
# Save plot | |
plt.savefig( | |
Path(output_dir) / f"{og_filename}_audio_analysis.png", | |
dpi=300, | |
bbox_inches="tight", | |
) | |
plt.close() | |
if __name__ == "__main__": | |
wav_file = r"C:\Users\jerem\Desktop\Kokoro-FastAPI\examples\assorted_checks\benchmarks\output_audio\chunk_600_tokens.wav" | |
silent = False | |
print(f"\n\n Processing:\n\t{wav_file}") | |
result = validate_tts(wav_file) | |
if not silent: | |
wav_root_dir = Path(wav_file).parent | |
generate_analysis_plots(wav_file, wav_root_dir, result) | |
print(f"\nValidating: {result['file']}") | |
if "error" in result: | |
print(f"Error: {result['error']}") | |
else: | |
print(f"Duration: {result['duration']}") | |
print(f"Sample Rate: {result['sample_rate']} Hz") | |
print(f"Peak Amplitude: {result['peak_amplitude']}") | |
print(f"RMS Level: {result['rms_level']}") | |
print(f"DC Offset: {result['dc_offset']}") | |
print(f"Detected Artifacts: {result['artifact_count']}") | |
if result["issues"]: | |
print("\nIssues Found:") | |
for issue in result["issues"]: | |
print(f"- {issue}") | |
else: | |
print("\nNo issues found") | |