Spaces:
Runtime error
Runtime error
# Copyright 2025 Google LLC | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
import os | |
import struct | |
import re | |
import logging | |
import io | |
from typing import Optional, Dict, Tuple, Union | |
import google.generativeai as genai | |
from pydub import AudioSegment | |
from cache import cache | |
# --- Constants --- | |
GEMINI_API_KEY = os.environ.get("GEMINI_API_KEY") | |
GENERATE_SPEECH = os.environ.get("GENERATE_SPEECH", "false").lower() == "true" | |
TTS_MODEL = "gemini-2.5-flash-preview-tts" | |
DEFAULT_RAW_AUDIO_MIME = "audio/L16;rate=24000" | |
# --- Setup --- | |
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s") | |
genai.configure(api_key=GEMINI_API_KEY) | |
class TTSGenerationError(Exception): | |
"""Raised when Gemini TTS generation fails.""" | |
pass | |
def parse_audio_mime_type(mime_type: str) -> Dict[str, int]: | |
""" | |
Extracts bits_per_sample and sampling rate from a MIME string. | |
e.g. "audio/L16;rate=24000" → {"bits_per_sample": 16, "rate": 24000} | |
""" | |
bits_per_sample = 16 | |
rate = 24000 | |
for param in mime_type.split(";"): | |
param = param.strip().lower() | |
if param.startswith("rate="): | |
try: | |
rate = int(param.split("=", 1)[1]) | |
except ValueError: | |
pass | |
elif re.match(r"audio/l\d+", param): | |
try: | |
bits_per_sample = int(param.split("l", 1)[1]) | |
except ValueError: | |
pass | |
return {"bits_per_sample": bits_per_sample, "rate": rate} | |
def convert_to_wav(audio_data: bytes, mime_type: str) -> bytes: | |
"""Wrap raw PCM bytes in a WAV header for mono audio.""" | |
params = parse_audio_mime_type(mime_type) | |
bits = params["bits_per_sample"] | |
rate = params["rate"] | |
num_channels = 1 | |
bytes_per_sample = bits // 8 | |
block_align = num_channels * bytes_per_sample | |
byte_rate = rate * block_align | |
data_size = len(audio_data) | |
chunk_size = 36 + data_size | |
header = struct.pack( | |
"<4sI4s4sIHHIIHH4sI", | |
b"RIFF", | |
chunk_size, | |
b"WAVE", | |
b"fmt ", | |
16, | |
1, | |
num_channels, | |
rate, | |
byte_rate, | |
block_align, | |
bits, | |
b"data", | |
data_size, | |
) | |
return header + audio_data | |
def _synthesize_gemini_tts_impl(text: str, gemini_voice_name: str) -> Tuple[bytes, str]: | |
"""Core function to request audio from Gemini TTS (cached).""" | |
if not GENERATE_SPEECH: | |
raise TTSGenerationError("GENERATE_SPEECH not enabled in environment.") | |
try: | |
model = genai.GenerativeModel(TTS_MODEL) | |
response = model.generate_content( | |
contents=[text], | |
generation_config={ | |
"response_modalities": ["AUDIO"], | |
"speech_config": { | |
"voice_config": { | |
"prebuilt_voice_config": {"voice_name": gemini_voice_name} | |
} | |
} | |
}, | |
) | |
audio_part = response.candidates[0].content.parts[0] | |
raw_data = audio_part.inline_data.data | |
mime = audio_part.inline_data.mime_type | |
except Exception as e: | |
logging.error("Gemini TTS API error: %s", e) | |
raise TTSGenerationError(f"TTS request failed: {e}") | |
if not raw_data: | |
raise TTSGenerationError("Empty audio data from Gemini.") | |
# Convert raw audio to WAV if needed | |
mime_lower = mime.lower() if mime else "" | |
if mime_lower and ( | |
mime_lower.startswith("audio/l") | |
or not mime_lower.startswith(("audio/wav", "audio/mpeg", "audio/ogg", "audio/opus")) | |
): | |
raw_data = convert_to_wav(raw_data, mime_lower) | |
mime = "audio/wav" | |
elif not mime: | |
logging.warning("MIME missing; defaulting to WAV") | |
raw_data = convert_to_wav(raw_data, DEFAULT_RAW_AUDIO_MIME) | |
mime = "audio/wav" | |
# Attempt MP3 compression | |
try: | |
segment = AudioSegment.from_file(io.BytesIO(raw_data), format="wav") | |
buf = io.BytesIO() | |
segment.export(buf, format="mp3") | |
return buf.getvalue(), "audio/mpeg" | |
except Exception as e: | |
logging.warning("MP3 conversion failed (%s); returning WAV", e) | |
return raw_data, mime | |
# Choose wrapper based on GENERATE_SPEECH flag | |
if GENERATE_SPEECH: | |
def synthesize_gemini_tts(text: str, voice: str) -> Tuple[Optional[bytes], Optional[str]]: | |
try: | |
return _synthesize_gemini_tts_impl(text, voice) | |
except TTSGenerationError as e: | |
logging.error("TTS failed: %s; skipping audio", e) | |
return None, None | |
else: | |
def synthesize_gemini_tts(text: str, voice: str) -> Tuple[Optional[bytes], Optional[str]]: | |
key = _synthesize_gemini_tts_impl.__cache_key__(text, voice) | |
result = cache.get(key) | |
if result is not None: | |
return result | |
logging.info("No cached audio; speech disabled.") | |
return None, None | |