appoint-ready / gemini_tts.py
lirony's picture
Initial public commit
a6bdbe4
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import google.generativeai as genai
import os
import struct
import re
import logging
from cache import cache
# Add these imports for MP3 conversion
from pydub import AudioSegment
import io
# --- Constants ---
GEMINI_API_KEY = os.environ.get("GEMINI_API_KEY")
GENERATE_SPEECH = os.environ.get("GENERATE_SPEECH", "false").lower() == "true"
TTS_MODEL = "gemini-2.5-flash-preview-tts"
DEFAULT_RAW_AUDIO_MIME = "audio/L16;rate=24000"
# --- Configuration ---
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
genai.configure(api_key=GEMINI_API_KEY)
class TTSGenerationError(Exception):
"""Custom exception for TTS generation failures."""
pass
# --- Helper functions for audio processing ---
def parse_audio_mime_type(mime_type: str) -> dict[str, int | None]:
"""
Parses bits per sample and rate from an audio MIME type string.
e.g., "audio/L16;rate=24000" -> {"bits_per_sample": 16, "rate": 24000}
"""
bits_per_sample = 16 # Default
rate = 24000 # Default
parts = mime_type.split(";")
for param in parts:
param = param.strip().lower()
if param.startswith("rate="):
try:
rate_str = param.split("=", 1)[1]
rate = int(rate_str)
except (ValueError, IndexError):
pass # Keep default if parsing fails
elif re.match(r"audio/l\d+", param): # Matches audio/L<digits>
try:
bits_str = param.split("l",1)[1]
bits_per_sample = int(bits_str)
except (ValueError, IndexError):
pass # Keep default
return {"bits_per_sample": bits_per_sample, "rate": rate}
def convert_to_wav(audio_data: bytes, mime_type: str) -> bytes:
"""
Generates a WAV file header for the given raw audio data and parameters.
Assumes mono audio.
"""
parameters = parse_audio_mime_type(mime_type)
bits_per_sample = parameters["bits_per_sample"]
sample_rate = parameters["rate"]
num_channels = 1 # Mono
data_size = len(audio_data)
bytes_per_sample = bits_per_sample // 8
block_align = num_channels * bytes_per_sample
byte_rate = sample_rate * block_align
chunk_size = 36 + data_size
header = struct.pack(
"<4sI4s4sIHHIIHH4sI",
b"RIFF", chunk_size, b"WAVE", b"fmt ",
16, 1, num_channels, sample_rate, byte_rate, block_align,
bits_per_sample, b"data", data_size
)
return header + audio_data
# --- End of helper functions ---
def _synthesize_gemini_tts_impl(text: str, gemini_voice_name: str) -> tuple[bytes, str]:
"""
Synthesizes English text using the Gemini API via the google-genai library.
Returns a tuple: (processed_audio_data_bytes, final_mime_type).
Raises TTSGenerationError on failure.
"""
if not GENERATE_SPEECH:
# This should ideally not be hit if the logic outside this function is correct,
# but as a safeguard, we raise an error.
raise TTSGenerationError(
"GENERATE_SPEECH is not set. Please set it in your environment variables to generate speech."
)
try:
model = genai.GenerativeModel(TTS_MODEL)
generation_config = {
"response_modalities": ["AUDIO"],
"speech_config": {
"voice_config": {
"prebuilt_voice_config": {
"voice_name": gemini_voice_name
}
}
}
}
response = model.generate_content(
contents=[text],
generation_config=generation_config,
)
audio_part = response.candidates[0].content.parts[0]
audio_data_bytes = audio_part.inline_data.data
final_mime_type = audio_part.inline_data.mime_type
except Exception as e:
error_message = f"An unexpected error occurred with google-genai: {e}"
logging.error(error_message)
raise TTSGenerationError(error_message) from e
if not audio_data_bytes:
error_message = "No audio data was successfully retrieved or decoded."
logging.error(error_message)
raise TTSGenerationError(error_message)
# --- Audio processing ---
if final_mime_type:
final_mime_type_lower = final_mime_type.lower()
needs_wav_conversion = any(p in final_mime_type_lower for p in ("audio/l16", "audio/l24", "audio/l8")) or \
not final_mime_type_lower.startswith(("audio/wav", "audio/mpeg", "audio/ogg", "audio/opus"))
if needs_wav_conversion:
processed_audio_data = convert_to_wav(audio_data_bytes, final_mime_type)
processed_audio_mime = "audio/wav"
else:
processed_audio_data = audio_data_bytes
processed_audio_mime = final_mime_type
else:
logging.warning("MIME type not determined. Assuming raw audio and attempting WAV conversion (defaulting to %s).", DEFAULT_RAW_AUDIO_MIME)
processed_audio_data = convert_to_wav(audio_data_bytes, DEFAULT_RAW_AUDIO_MIME)
processed_audio_mime = "audio/wav"
# --- MP3 compression ---
if processed_audio_data:
try:
# Load audio into AudioSegment
audio_segment = AudioSegment.from_file(io.BytesIO(processed_audio_data), format="wav")
mp3_buffer = io.BytesIO()
audio_segment.export(mp3_buffer, format="mp3")
mp3_bytes = mp3_buffer.getvalue()
return mp3_bytes, "audio/mpeg"
except Exception as e:
logging.warning("MP3 compression failed: %s. Falling back to WAV.", e)
# Fallback to WAV if MP3 conversion fails
return processed_audio_data, processed_audio_mime
else:
error_message = "Audio processing failed."
logging.error(error_message)
raise TTSGenerationError(error_message)
# Always create the memoized function first, so we can access its .key() method
_memoized_tts_func = cache.memoize()(_synthesize_gemini_tts_impl)
if GENERATE_SPEECH:
def synthesize_gemini_tts_with_error_handling(*args, **kwargs) -> tuple[bytes | None, str | None]:
"""
A wrapper for the memoized TTS function that catches errors and returns (None, None).
This makes the audio generation more resilient to individual failures.
"""
try:
# Attempt to get the audio from the cache or by generating it.
return _memoized_tts_func(*args, **kwargs)
except TTSGenerationError as e:
# If generation fails, log the error and return None, None.
logging.error("Handled TTS Generation Error: %s. Continuing without audio for this segment.", e)
return None, None
synthesize_gemini_tts = synthesize_gemini_tts_with_error_handling
else:
# When not generating speech, create a read-only function that only
# checks the cache and does not generate new audio.
def read_only_synthesize_gemini_tts(*args, **kwargs):
"""
Checks cache for a result, but never calls the underlying TTS function.
This is a 'read-only' memoization check.
"""
# Generate the cache key using the memoized function's key method.
key = _memoized_tts_func.__cache_key__(*args, **kwargs)
# Check the cache directly using the generated key.
_sentinel = object()
result = cache.get(key, default=_sentinel)
if result is not _sentinel:
return result # Cache hit
# Cache miss
logging.info("GENERATE_SPEECH is false and no cached result found for key: %s", key)
return None, None
synthesize_gemini_tts = read_only_synthesize_gemini_tts