esl-dialogue-tts / utils /openai_tts.py
abocha's picture
speed, instructions, granular voice change, tts-1 avainlable
5c85d81
raw
history blame
6.95 kB
import asyncio
import os
import time
from openai import AsyncOpenAI, OpenAIError, RateLimitError
import httpx # For NSFW check
# Expanded list of voices based on recent OpenAI documentation
OPENAI_VOICES = ['alloy', 'echo', 'fable', 'onyx', 'nova', 'shimmer', 'ash', 'ballad', 'coral', 'sage', 'verse']
# Concurrency limiter
MAX_CONCURRENT_REQUESTS = 2
semaphore = asyncio.Semaphore(MAX_CONCURRENT_REQUESTS)
# Retry mechanism
MAX_RETRIES = 3
INITIAL_BACKOFF_SECONDS = 1
async def is_content_safe(text: str, api_url_template: str | None) -> bool:
"""
Checks if the content is safe using an external NSFW API.
Returns True if safe or if API URL is not provided, False if unsafe.
"""
if not api_url_template:
return True
if "{text}" not in api_url_template:
print("Warning: NSFW_API_URL_TEMPLATE does not contain {text} placeholder. Skipping NSFW check.")
return True
try:
encoded_text = httpx.utils.quote(text)
url = api_url_template.format(text=encoded_text)
async with httpx.AsyncClient() as client:
response = await client.get(url, timeout=10.0)
if response.status_code == 200:
return True
else:
print(f"NSFW Check: API request failed or content flagged. Status: {response.status_code}, Response: {response.text[:200]}")
return False
except httpx.RequestError as e:
print(f"NSFW Check: API request error: {e}")
return False
except Exception as e:
print(f"NSFW Check: An unexpected error occurred: {e}")
return False
async def synthesize_speech_line(
client: AsyncOpenAI,
text: str,
voice: str,
output_path: str,
model: str = "tts-1-hd",
speed: float = 1.0,
instructions: str | None = None,
nsfw_api_url_template: str | None = None,
line_index: int = -1
) -> str | None:
"""
Synthesizes a single line of text to speech using OpenAI TTS.
Includes speed and instructions parameters based on model compatibility.
Retries on RateLimitError with exponential backoff.
Returns the output_path if successful, None otherwise.
"""
if nsfw_api_url_template:
if not await is_content_safe(text, nsfw_api_url_template):
print(f"Line {line_index if line_index != -1 else 'N/A'}: Content flagged as NSFW. Skipping synthesis.")
return None
current_retry = 0
backoff_seconds = INITIAL_BACKOFF_SECONDS
async with semaphore:
while current_retry < MAX_RETRIES:
try:
request_params = {
"model": model,
"voice": voice,
"input": text,
"response_format": "mp3"
}
# Add speed if model supports it and speed is not default
if model in ["tts-1", "tts-1-hd"]:
if speed is not None and speed != 1.0: # OpenAI default is 1.0
# Ensure speed is within valid range for safety, though UI should also constrain this
clamped_speed = max(0.25, min(speed, 4.0))
request_params["speed"] = clamped_speed
# Add instructions if model supports it and instructions are provided
# Assuming gpt-4o-mini-tts supports it, and tts-1/tts-1-hd do not.
if model not in ["tts-1", "tts-1-hd"] and instructions: # Example: gpt-4o-mini-tts
request_params["instructions"] = instructions
response = await client.audio.speech.create(**request_params)
await response.astream_to_file(output_path)
return output_path
except RateLimitError as e:
current_retry += 1
if current_retry >= MAX_RETRIES:
print(f"Line {line_index if line_index != -1 else ''}: Max retries reached for RateLimitError. Error: {e}")
return None
print(f"Line {line_index if line_index != -1 else ''}: Rate limit hit. Retrying in {backoff_seconds}s... (Attempt {current_retry}/{MAX_RETRIES})")
await asyncio.sleep(backoff_seconds)
backoff_seconds *= 2
except OpenAIError as e:
print(f"Line {line_index if line_index != -1 else ''}: OpenAI API error: {e}")
return None
except Exception as e:
print(f"Line {line_index if line_index != -1 else ''}: An unexpected error occurred during synthesis: {e}")
return None
return None
if __name__ == '__main__':
async def main_test():
api_key = os.getenv("OPENAI_API_KEY")
if not api_key:
print("OPENAI_API_KEY not set. Skipping test.")
return
client = AsyncOpenAI(api_key=api_key)
test_lines = [
{"id": 0, "speaker": "Alice", "text": "Hello, this is a test line for Alice, spoken quickly."},
{"id": 1, "speaker": "Bob", "text": "And this is Bob, testing his voice with instructions.", "instructions": "Speak in a deep, resonant voice."},
{"id": 2, "speaker": "Alice", "text": "A short reply, spoken slowly.", "speed": 0.8},
{"id": 3, "speaker": "Charlie", "text": "Charlie here, normal speed."}
]
temp_dir = "test_audio_output_enhanced"
os.makedirs(temp_dir, exist_ok=True)
tasks = []
for i, line_data in enumerate(test_lines):
# Test with specific models to check param compatibility
# For Alice (speed): tts-1-hd. For Bob (instructions): gpt-4o-mini-tts
current_model = "tts-1-hd"
if "instructions" in line_data:
current_model = "gpt-4o-mini-tts" # Example, ensure this model is available for your key
voice = OPENAI_VOICES[i % len(OPENAI_VOICES)]
output_file = os.path.join(temp_dir, f"line_{line_data['id']}_{current_model}.mp3")
tasks.append(
synthesize_speech_line(
client,
line_data["text"],
voice,
output_file,
model=current_model,
speed=line_data.get("speed", 1.0),
instructions=line_data.get("instructions"),
line_index=line_data['id']
)
)
results = await asyncio.gather(*tasks)
successful_files = [r for r in results if r]
print(f"\nSuccessfully synthesized {len(successful_files)} out of {len(test_lines)} lines.")
for f_path in successful_files:
print(f" - {f_path}")
if os.name == 'nt':
asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy())
asyncio.run(main_test())