docappointemet / gemini_tts.py
mgbam's picture
Update gemini_tts.py
1effb75 verified
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import struct
import re
import logging
import io
from typing import Optional, Dict, Tuple, Union
import google.generativeai as genai
from pydub import AudioSegment
from cache import cache
# --- Constants ---
GEMINI_API_KEY = os.environ.get("GEMINI_API_KEY")
GENERATE_SPEECH = os.environ.get("GENERATE_SPEECH", "false").lower() == "true"
TTS_MODEL = "gemini-2.5-flash-preview-tts"
DEFAULT_RAW_AUDIO_MIME = "audio/L16;rate=24000"
# --- Setup ---
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
genai.configure(api_key=GEMINI_API_KEY)
class TTSGenerationError(Exception):
"""Raised when Gemini TTS generation fails."""
pass
def parse_audio_mime_type(mime_type: str) -> Dict[str, int]:
"""
Extracts bits_per_sample and sampling rate from a MIME string.
e.g. "audio/L16;rate=24000" → {"bits_per_sample": 16, "rate": 24000}
"""
bits_per_sample = 16
rate = 24000
for param in mime_type.split(";"):
param = param.strip().lower()
if param.startswith("rate="):
try:
rate = int(param.split("=", 1)[1])
except ValueError:
pass
elif re.match(r"audio/l\d+", param):
try:
bits_per_sample = int(param.split("l", 1)[1])
except ValueError:
pass
return {"bits_per_sample": bits_per_sample, "rate": rate}
def convert_to_wav(audio_data: bytes, mime_type: str) -> bytes:
"""Wrap raw PCM bytes in a WAV header for mono audio."""
params = parse_audio_mime_type(mime_type)
bits = params["bits_per_sample"]
rate = params["rate"]
num_channels = 1
bytes_per_sample = bits // 8
block_align = num_channels * bytes_per_sample
byte_rate = rate * block_align
data_size = len(audio_data)
chunk_size = 36 + data_size
header = struct.pack(
"<4sI4s4sIHHIIHH4sI",
b"RIFF",
chunk_size,
b"WAVE",
b"fmt ",
16,
1,
num_channels,
rate,
byte_rate,
block_align,
bits,
b"data",
data_size,
)
return header + audio_data
@cache.memoize()
def _synthesize_gemini_tts_impl(text: str, gemini_voice_name: str) -> Tuple[bytes, str]:
"""Core function to request audio from Gemini TTS (cached)."""
if not GENERATE_SPEECH:
raise TTSGenerationError("GENERATE_SPEECH not enabled in environment.")
try:
model = genai.GenerativeModel(TTS_MODEL)
response = model.generate_content(
contents=[text],
generation_config={
"response_modalities": ["AUDIO"],
"speech_config": {
"voice_config": {
"prebuilt_voice_config": {"voice_name": gemini_voice_name}
}
}
},
)
audio_part = response.candidates[0].content.parts[0]
raw_data = audio_part.inline_data.data
mime = audio_part.inline_data.mime_type
except Exception as e:
logging.error("Gemini TTS API error: %s", e)
raise TTSGenerationError(f"TTS request failed: {e}")
if not raw_data:
raise TTSGenerationError("Empty audio data from Gemini.")
# Convert raw audio to WAV if needed
mime_lower = mime.lower() if mime else ""
if mime_lower and (
mime_lower.startswith("audio/l")
or not mime_lower.startswith(("audio/wav", "audio/mpeg", "audio/ogg", "audio/opus"))
):
raw_data = convert_to_wav(raw_data, mime_lower)
mime = "audio/wav"
elif not mime:
logging.warning("MIME missing; defaulting to WAV")
raw_data = convert_to_wav(raw_data, DEFAULT_RAW_AUDIO_MIME)
mime = "audio/wav"
# Attempt MP3 compression
try:
segment = AudioSegment.from_file(io.BytesIO(raw_data), format="wav")
buf = io.BytesIO()
segment.export(buf, format="mp3")
return buf.getvalue(), "audio/mpeg"
except Exception as e:
logging.warning("MP3 conversion failed (%s); returning WAV", e)
return raw_data, mime
# Choose wrapper based on GENERATE_SPEECH flag
if GENERATE_SPEECH:
def synthesize_gemini_tts(text: str, voice: str) -> Tuple[Optional[bytes], Optional[str]]:
try:
return _synthesize_gemini_tts_impl(text, voice)
except TTSGenerationError as e:
logging.error("TTS failed: %s; skipping audio", e)
return None, None
else:
def synthesize_gemini_tts(text: str, voice: str) -> Tuple[Optional[bytes], Optional[str]]:
key = _synthesize_gemini_tts_impl.__cache_key__(text, voice)
result = cache.get(key)
if result is not None:
return result
logging.info("No cached audio; speech disabled.")
return None, None