# Copyright 2025 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import os import struct import re import logging import io from typing import Optional, Dict, Tuple, Union import google.generativeai as genai from pydub import AudioSegment from cache import cache # --- Constants --- GEMINI_API_KEY = os.environ.get("GEMINI_API_KEY") GENERATE_SPEECH = os.environ.get("GENERATE_SPEECH", "false").lower() == "true" TTS_MODEL = "gemini-2.5-flash-preview-tts" DEFAULT_RAW_AUDIO_MIME = "audio/L16;rate=24000" # --- Setup --- logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s") genai.configure(api_key=GEMINI_API_KEY) class TTSGenerationError(Exception): """Raised when Gemini TTS generation fails.""" pass def parse_audio_mime_type(mime_type: str) -> Dict[str, int]: """ Extracts bits_per_sample and sampling rate from a MIME string. e.g. "audio/L16;rate=24000" → {"bits_per_sample": 16, "rate": 24000} """ bits_per_sample = 16 rate = 24000 for param in mime_type.split(";"): param = param.strip().lower() if param.startswith("rate="): try: rate = int(param.split("=", 1)[1]) except ValueError: pass elif re.match(r"audio/l\d+", param): try: bits_per_sample = int(param.split("l", 1)[1]) except ValueError: pass return {"bits_per_sample": bits_per_sample, "rate": rate} def convert_to_wav(audio_data: bytes, mime_type: str) -> bytes: """Wrap raw PCM bytes in a WAV header for mono audio.""" params = parse_audio_mime_type(mime_type) bits = params["bits_per_sample"] rate = params["rate"] num_channels = 1 bytes_per_sample = bits // 8 block_align = num_channels * bytes_per_sample byte_rate = rate * block_align data_size = len(audio_data) chunk_size = 36 + data_size header = struct.pack( "<4sI4s4sIHHIIHH4sI", b"RIFF", chunk_size, b"WAVE", b"fmt ", 16, 1, num_channels, rate, byte_rate, block_align, bits, b"data", data_size, ) return header + audio_data @cache.memoize() def _synthesize_gemini_tts_impl(text: str, gemini_voice_name: str) -> Tuple[bytes, str]: """Core function to request audio from Gemini TTS (cached).""" if not GENERATE_SPEECH: raise TTSGenerationError("GENERATE_SPEECH not enabled in environment.") try: model = genai.GenerativeModel(TTS_MODEL) response = model.generate_content( contents=[text], generation_config={ "response_modalities": ["AUDIO"], "speech_config": { "voice_config": { "prebuilt_voice_config": {"voice_name": gemini_voice_name} } } }, ) audio_part = response.candidates[0].content.parts[0] raw_data = audio_part.inline_data.data mime = audio_part.inline_data.mime_type except Exception as e: logging.error("Gemini TTS API error: %s", e) raise TTSGenerationError(f"TTS request failed: {e}") if not raw_data: raise TTSGenerationError("Empty audio data from Gemini.") # Convert raw audio to WAV if needed mime_lower = mime.lower() if mime else "" if mime_lower and ( mime_lower.startswith("audio/l") or not mime_lower.startswith(("audio/wav", "audio/mpeg", "audio/ogg", "audio/opus")) ): raw_data = convert_to_wav(raw_data, mime_lower) mime = "audio/wav" elif not mime: logging.warning("MIME missing; defaulting to WAV") raw_data = convert_to_wav(raw_data, DEFAULT_RAW_AUDIO_MIME) mime = "audio/wav" # Attempt MP3 compression try: segment = AudioSegment.from_file(io.BytesIO(raw_data), format="wav") buf = io.BytesIO() segment.export(buf, format="mp3") return buf.getvalue(), "audio/mpeg" except Exception as e: logging.warning("MP3 conversion failed (%s); returning WAV", e) return raw_data, mime # Choose wrapper based on GENERATE_SPEECH flag if GENERATE_SPEECH: def synthesize_gemini_tts(text: str, voice: str) -> Tuple[Optional[bytes], Optional[str]]: try: return _synthesize_gemini_tts_impl(text, voice) except TTSGenerationError as e: logging.error("TTS failed: %s; skipping audio", e) return None, None else: def synthesize_gemini_tts(text: str, voice: str) -> Tuple[Optional[bytes], Optional[str]]: key = _synthesize_gemini_tts_impl.__cache_key__(text, voice) result = cache.get(key) if result is not None: return result logging.info("No cached audio; speech disabled.") return None, None