Spaces:
Runtime error
Runtime error
File size: 13,822 Bytes
f5647fb |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 |
#!/usr/bin/env python3
# Compatible with both Windows and Linux
"""
Kokoro TTS Race Condition Test
This script creates multiple concurrent requests to a Kokoro TTS service
to reproduce a race condition where audio outputs don't match the requested text.
Each thread generates a simple numbered sentence, which should make mismatches
easy to identify through listening.
To run:
python kokoro_race_condition_test.py --threads 8 --iterations 5 --url http://localhost:8880
"""
import argparse
import base64
import concurrent.futures
import json
import os
import sys
import time
import wave
from pathlib import Path
import requests
def setup_args():
"""Parse command line arguments"""
parser = argparse.ArgumentParser(description="Test Kokoro TTS for race conditions")
parser.add_argument(
"--url",
default="http://localhost:8880",
help="Base URL of the Kokoro TTS service",
)
parser.add_argument(
"--threads", type=int, default=8, help="Number of concurrent threads to use"
)
parser.add_argument(
"--iterations", type=int, default=5, help="Number of iterations per thread"
)
parser.add_argument("--voice", default="af_heart", help="Voice to use for TTS")
parser.add_argument(
"--output-dir",
default="./tts_test_output",
help="Directory to save output files",
)
parser.add_argument("--debug", action="store_true", help="Enable debug logging")
return parser.parse_args()
def generate_test_sentence(thread_id, iteration):
"""Generate a simple test sentence with numbers to make mismatches easily identifiable"""
return (
f"This is test sentence number {thread_id}-{iteration}. "
f"If you hear this sentence, you should hear the numbers {thread_id}-{iteration}."
)
def log_message(message, debug=False, is_error=False):
"""Log messages with timestamps"""
timestamp = time.strftime("%Y-%m-%d %H:%M:%S")
prefix = "[ERROR]" if is_error else "[INFO]"
if is_error or debug:
print(f"{prefix} {timestamp} - {message}")
sys.stdout.flush() # Ensure logs are visible in Docker output
def request_tts(url, test_id, text, voice, output_dir, debug=False):
"""Request TTS from the Kokoro API and save the WAV output"""
start_time = time.time()
output_file = os.path.join(output_dir, f"test_{test_id}.wav")
text_file = os.path.join(output_dir, f"test_{test_id}.txt")
# Log output paths for debugging
log_message(f"Thread {test_id}: Text will be saved to: {text_file}", debug)
log_message(f"Thread {test_id}: Audio will be saved to: {output_file}", debug)
# Save the text for later comparison
try:
with open(text_file, "w") as f:
f.write(text)
log_message(f"Thread {test_id}: Successfully saved text file", debug)
except Exception as e:
log_message(
f"Thread {test_id}: Error saving text file: {str(e)}", debug, is_error=True
)
# Make the TTS request
try:
log_message(f"Thread {test_id}: Requesting TTS for: '{text}'", debug)
response = requests.post(
f"{url}/v1/audio/speech",
json={
"model": "kokoro",
"input": text,
"voice": voice,
"response_format": "wav",
},
headers={"Accept": "audio/wav"},
timeout=60, # Increase timeout to 60 seconds
)
log_message(
f"Thread {test_id}: Response status code: {response.status_code}", debug
)
log_message(
f"Thread {test_id}: Response content type: {response.headers.get('Content-Type', 'None')}",
debug,
)
log_message(
f"Thread {test_id}: Response content length: {len(response.content)} bytes",
debug,
)
if response.status_code != 200:
log_message(
f"Thread {test_id}: API error: {response.status_code} - {response.text}",
debug,
is_error=True,
)
return False
# Check if we got valid audio data
if (
len(response.content) < 100
): # Sanity check - WAV files should be larger than this
log_message(
f"Thread {test_id}: Received suspiciously small audio data: {len(response.content)} bytes",
debug,
is_error=True,
)
log_message(
f"Thread {test_id}: Content (base64): {base64.b64encode(response.content).decode('utf-8')}",
debug,
is_error=True,
)
return False
# Save the audio output with explicit error handling
try:
with open(output_file, "wb") as f:
bytes_written = f.write(response.content)
log_message(
f"Thread {test_id}: Wrote {bytes_written} bytes to {output_file}",
debug,
)
# Verify the WAV file exists and has content
if os.path.exists(output_file):
file_size = os.path.getsize(output_file)
log_message(
f"Thread {test_id}: Verified file exists with size: {file_size} bytes",
debug,
)
# Validate WAV file by reading its headers
try:
with wave.open(output_file, "rb") as wav_file:
channels = wav_file.getnchannels()
sample_width = wav_file.getsampwidth()
framerate = wav_file.getframerate()
frames = wav_file.getnframes()
log_message(
f"Thread {test_id}: Valid WAV file - channels: {channels}, "
f"sample width: {sample_width}, framerate: {framerate}, frames: {frames}",
debug,
)
except Exception as wav_error:
log_message(
f"Thread {test_id}: Invalid WAV file: {str(wav_error)}",
debug,
is_error=True,
)
else:
log_message(
f"Thread {test_id}: File was not created: {output_file}",
debug,
is_error=True,
)
except Exception as save_error:
log_message(
f"Thread {test_id}: Error saving audio file: {str(save_error)}",
debug,
is_error=True,
)
return False
end_time = time.time()
log_message(
f"Thread {test_id}: Saved output to {output_file} (time: {end_time - start_time:.2f}s)",
debug,
)
return True
except requests.exceptions.Timeout:
log_message(f"Thread {test_id}: Request timed out", debug, is_error=True)
return False
except Exception as e:
log_message(f"Thread {test_id}: Exception: {str(e)}", debug, is_error=True)
return False
def worker_task(thread_id, args):
"""Worker task for each thread"""
for i in range(args.iterations):
iteration = i + 1
test_id = f"{thread_id:02d}_{iteration:02d}"
text = generate_test_sentence(thread_id, iteration)
success = request_tts(
args.url, test_id, text, args.voice, args.output_dir, args.debug
)
if not success:
log_message(
f"Thread {thread_id}: Iteration {iteration} failed",
args.debug,
is_error=True,
)
# Small delay between iterations to avoid overwhelming the API
time.sleep(0.1)
def run_test(args):
"""Run the test with the specified parameters"""
# Ensure output directory exists and check permissions
os.makedirs(args.output_dir, exist_ok=True)
# Test write access to the output directory
test_file = os.path.join(args.output_dir, "write_test.txt")
try:
with open(test_file, "w") as f:
f.write("Testing write access\n")
os.remove(test_file)
log_message(
f"Successfully verified write access to output directory: {args.output_dir}"
)
except Exception as e:
log_message(
f"Warning: Cannot write to output directory {args.output_dir}: {str(e)}",
is_error=True,
)
log_message(f"Current directory: {os.getcwd()}", is_error=True)
log_message(f"Directory contents: {os.listdir('.')}", is_error=True)
# Test connection to Kokoro TTS service
try:
response = requests.get(f"{args.url}/health", timeout=5)
if response.status_code == 200:
log_message(f"Successfully connected to Kokoro TTS service at {args.url}")
else:
log_message(
f"Warning: Kokoro TTS service health check returned status {response.status_code}",
is_error=True,
)
except Exception as e:
log_message(
f"Warning: Cannot connect to Kokoro TTS service at {args.url}: {str(e)}",
is_error=True,
)
# Record start time
start_time = time.time()
log_message(
f"Starting test with {args.threads} threads, {args.iterations} iterations per thread"
)
# Create and start worker threads
with concurrent.futures.ThreadPoolExecutor(max_workers=args.threads) as executor:
futures = []
for thread_id in range(1, args.threads + 1):
futures.append(executor.submit(worker_task, thread_id, args))
# Wait for all tasks to complete
for future in concurrent.futures.as_completed(futures):
try:
future.result()
except Exception as e:
log_message(
f"Thread execution failed: {str(e)}", args.debug, is_error=True
)
# Record end time and print summary
end_time = time.time()
total_time = end_time - start_time
total_requests = args.threads * args.iterations
log_message(f"Test completed in {total_time:.2f} seconds")
log_message(f"Total requests: {total_requests}")
log_message(f"Average time per request: {total_time / total_requests:.2f} seconds")
log_message(f"Requests per second: {total_requests / total_time:.2f}")
log_message(f"Output files saved to: {os.path.abspath(args.output_dir)}")
log_message(
"To verify, listen to the audio files and check if they match the text files"
)
log_message(
"If you hear audio describing a different test number than the filename, you've found a race condition"
)
def analyze_audio_files(output_dir):
"""Provide summary of the generated audio files"""
# Look for both WAV and TXT files
wav_files = list(Path(output_dir).glob("*.wav"))
txt_files = list(Path(output_dir).glob("*.txt"))
log_message(f"Found {len(wav_files)} WAV files and {len(txt_files)} TXT files")
if len(wav_files) == 0:
log_message(
"No WAV files found! This indicates the TTS service requests may be failing.",
is_error=True,
)
log_message(
"Check the connection to the TTS service and the response status codes above.",
is_error=True,
)
file_stats = []
for wav_path in wav_files:
try:
with wave.open(str(wav_path), "rb") as wav_file:
frames = wav_file.getnframes()
rate = wav_file.getframerate()
duration = frames / rate
# Get corresponding text
text_path = wav_path.with_suffix(".txt")
if text_path.exists():
with open(text_path, "r") as text_file:
text = text_file.read().strip()
else:
text = "N/A"
file_stats.append(
{"filename": wav_path.name, "duration": duration, "text": text}
)
except Exception as e:
log_message(f"Error analyzing {wav_path}: {str(e)}", False, is_error=True)
# Print summary table
if file_stats:
log_message("\nAudio File Summary:")
log_message(f"{'Filename':<20}{'Duration':<12}{'Text':<60}")
log_message("-" * 92)
for stat in file_stats:
log_message(
f"{stat['filename']:<20}{stat['duration']:<12.2f}{stat['text'][:57] + '...' if len(stat['text']) > 60 else stat['text']:<60}"
)
# List missing WAV files where text files exist
missing_wavs = set(p.stem for p in txt_files) - set(p.stem for p in wav_files)
if missing_wavs:
log_message(
f"\nFound {len(missing_wavs)} text files without corresponding WAV files:",
is_error=True,
)
for stem in sorted(list(missing_wavs))[:10]: # Limit to 10 for readability
log_message(f" - {stem}.txt (no WAV file)", is_error=True)
if len(missing_wavs) > 10:
log_message(f" ... and {len(missing_wavs) - 10} more", is_error=True)
if __name__ == "__main__":
args = setup_args()
run_test(args)
analyze_audio_files(args.output_dir)
log_message("\nNext Steps:")
log_message("1. Listen to the generated audio files")
log_message("2. Verify if each audio correctly says its ID number")
log_message(
"3. Check for any mismatches between the audio content and the text files"
)
log_message(
"4. If mismatches are found, you've successfully reproduced the race condition"
)
|