Michael Hu
initial check in
05b45a5
import argparse
from typing import Any, Dict
from pathlib import Path
import numpy as np
import soundfile as sf
from tqdm import tqdm
def validate_tts(wav_path: str) -> dict:
"""
Validation checks for TTS-generated audio files to detect common artifacts.
"""
try:
# Load and process audio
audio, sr = sf.read(wav_path)
if len(audio.shape) > 1:
audio = np.mean(audio, axis=1)
duration = len(audio) / sr
issues = []
# Basic quality checks
abs_audio = np.abs(audio)
stats = {
"rms": float(np.sqrt(np.mean(audio**2))),
"peak": float(np.max(abs_audio)),
"dc_offset": float(np.mean(audio)),
}
clip_count = np.sum(abs_audio >= 0.99)
clip_percent = (clip_count / len(audio)) * 100
if duration < 0.1:
issues.append(
"WARNING: Audio is suspiciously short - possible failed generation"
)
if stats["peak"] >= 1.0:
if clip_percent > 1.0:
issues.append(
f"WARNING: Significant clipping detected ({clip_percent:.2e}% of samples)"
)
elif clip_percent > 0.01:
issues.append(
f"INFO: Minor peak limiting detected ({clip_percent:.2e}% of samples)"
)
if stats["rms"] < 0.01:
issues.append("WARNING: Audio is very quiet - possible failed generation")
if abs(stats["dc_offset"]) > 0.1:
issues.append(f"WARNING: High DC offset ({stats['dc_offset']:.3f})")
# Check for long silence gaps
eps = np.finfo(float).eps
db = 20 * np.log10(abs_audio + eps)
silence_threshold = -45 # dB
min_silence = 2.0 # seconds
window_size = int(min_silence * sr)
silence_count = 0
last_silence = -1
start_idx = int(0.2 * sr) # Skip first 0.2s
for i in tqdm(
range(start_idx, len(db) - window_size, window_size),
desc="Checking for silence",
):
window = db[i : i + window_size]
if np.mean(window) < silence_threshold:
silent_ratio = np.mean(window < silence_threshold)
if silent_ratio > 0.9:
if last_silence == -1 or (i / sr - last_silence) > 2.0:
silence_count += 1
last_silence = i / sr
issues.append(
f"WARNING: Long silence detected at {i/sr:.2f}s (duration: {min_silence:.1f}s)"
)
if silence_count > 2:
issues.append(
f"WARNING: Multiple long silences found ({silence_count} total)"
)
# Detect audio artifacts
diff = np.diff(audio)
abs_diff = np.abs(diff)
window_size = min(int(0.005 * sr), 256)
window = np.ones(window_size) / window_size
local_avg_diff = np.convolve(abs_diff, window, mode="same")
spikes = (abs_diff > (10 * local_avg_diff)) & (abs_diff > 0.1)
artifact_indices = np.nonzero(spikes)[0]
artifacts = []
if len(artifact_indices) > 0:
gaps = np.diff(artifact_indices)
min_gap = int(0.005 * sr)
break_points = np.nonzero(gaps > min_gap)[0] + 1
groups = np.split(artifact_indices, break_points)
for group in groups:
if len(group) >= 5:
severity = np.max(abs_diff[group])
if severity > 0.2:
center_idx = group[len(group) // 2]
artifacts.append(
{
"time": float(
center_idx / sr
), # Ensure float for consistent timing
"severity": float(severity),
}
)
issues.append(
f"WARNING: Audio discontinuity at {center_idx/sr:.3f}s "
f"(severity: {severity:.3f})"
)
# Check for repeated speech segments
for chunk_duration in tqdm(
[0.5, 2.5, 5.0, 10.0], desc="Checking for repeated speech"
):
chunk_size = int(chunk_duration * sr)
overlap = int(0.2 * chunk_size)
for i in range(0, len(audio) - 2 * chunk_size, overlap):
chunk1 = audio[i : i + chunk_size]
chunk2 = audio[i + chunk_size : i + 2 * chunk_size]
if np.mean(np.abs(chunk1)) < 0.01 or np.mean(np.abs(chunk2)) < 0.01:
continue
try:
correlation = np.corrcoef(chunk1, chunk2)[0, 1]
if not np.isnan(correlation) and correlation > 0.92:
issues.append(
f"WARNING: Possible repeated speech at {i/sr:.1f}s "
f"(~{int(chunk_duration*160/60):d} words, correlation: {correlation:.3f})"
)
break
except:
continue
return {
"file": wav_path,
"duration": f"{duration:.2f}s",
"sample_rate": sr,
"peak_amplitude": f"{stats['peak']:.3f}",
"rms_level": f"{stats['rms']:.3f}",
"dc_offset": f"{stats['dc_offset']:.3f}",
"artifact_count": len(artifacts),
"artifact_locations": [a["time"] for a in artifacts],
"artifact_severities": [a["severity"] for a in artifacts],
"issues": issues,
"valid": len(issues) == 0,
}
except Exception as e:
return {"file": wav_path, "error": str(e), "valid": False}
def generate_analysis_plots(
wav_path: str, output_dir: str, validation_result: Dict[str, Any]
):
"""
Generate analysis plots for audio file with time-aligned visualizations.
"""
import matplotlib.pyplot as plt
from scipy.signal import spectrogram
# Load audio
audio, sr = sf.read(wav_path)
if len(audio.shape) > 1:
audio = np.mean(audio, axis=1)
# Create figure with shared x-axis
fig = plt.figure(figsize=(15, 8))
gs = plt.GridSpec(2, 1, height_ratios=[1.2, 0.8], hspace=0.1)
ax1 = fig.add_subplot(gs[0])
ax2 = fig.add_subplot(gs[1], sharex=ax1)
# Calculate spectrogram
nperseg = 2048
noverlap = 1536
f, t, Sxx = spectrogram(
audio, sr, nperseg=nperseg, noverlap=noverlap, window="hann", scaling="spectrum"
)
# Plot spectrogram
im = ax1.pcolormesh(
t,
f,
10 * np.log10(Sxx + 1e-10),
shading="gouraud",
cmap="viridis",
vmin=-100,
vmax=-20,
)
ax1.set_ylabel("Frequency [Hz]", fontsize=10)
cbar = plt.colorbar(im, ax=ax1, label="dB")
ax1.set_title("Spectrogram", pad=10, fontsize=12)
# Plot waveform with exact time alignment
times = np.arange(len(audio)) / sr
ax2.plot(times, audio, color="#2E5596", alpha=0.7, linewidth=0.5, label="Audio")
ax2.set_ylabel("Amplitude", fontsize=10)
ax2.set_xlabel("Time [sec]", fontsize=10)
ax2.grid(True, alpha=0.2)
# Add artifact markers
if (
"artifact_locations" in validation_result
and validation_result["artifact_locations"]
):
for loc in validation_result["artifact_locations"]:
ax1.axvline(x=loc, color="red", alpha=0.7, linewidth=2)
ax2.axvline(
x=loc, color="red", alpha=0.7, linewidth=2, label="Detected Artifacts"
)
# Add legend to both plots
if len(validation_result["artifact_locations"]) > 0:
ax1.plot([], [], color="red", linewidth=2, label="Detected Artifacts")
ax1.legend(loc="upper right", fontsize=8)
# Only add unique labels to legend
handles, labels = ax2.get_legend_handles_labels()
unique_labels = dict(zip(labels, handles))
ax2.legend(
unique_labels.values(),
unique_labels.keys(),
loc="upper right",
fontsize=8,
)
# Set common x limits
xlim = (0, len(audio) / sr)
ax1.set_xlim(xlim)
ax2.set_xlim(xlim)
og_filename = Path(wav_path).name.split(".")[0]
# Save plot
plt.savefig(
Path(output_dir) / f"{og_filename}_audio_analysis.png",
dpi=300,
bbox_inches="tight",
)
plt.close()
if __name__ == "__main__":
wav_file = r"C:\Users\jerem\Desktop\Kokoro-FastAPI\examples\assorted_checks\benchmarks\output_audio\chunk_600_tokens.wav"
silent = False
print(f"\n\n Processing:\n\t{wav_file}")
result = validate_tts(wav_file)
if not silent:
wav_root_dir = Path(wav_file).parent
generate_analysis_plots(wav_file, wav_root_dir, result)
print(f"\nValidating: {result['file']}")
if "error" in result:
print(f"Error: {result['error']}")
else:
print(f"Duration: {result['duration']}")
print(f"Sample Rate: {result['sample_rate']} Hz")
print(f"Peak Amplitude: {result['peak_amplitude']}")
print(f"RMS Level: {result['rms_level']}")
print(f"DC Offset: {result['dc_offset']}")
print(f"Detected Artifacts: {result['artifact_count']}")
if result["issues"]:
print("\nIssues Found:")
for issue in result["issues"]:
print(f"- {issue}")
else:
print("\nNo issues found")