Spaces:
Runtime error
Runtime error
File size: 5,433 Bytes
db56fd6 1effb75 db56fd6 1effb75 db56fd6 1effb75 db56fd6 1effb75 db56fd6 1effb75 db56fd6 1effb75 db56fd6 1effb75 db56fd6 1effb75 db56fd6 1effb75 db56fd6 1effb75 db56fd6 1effb75 db56fd6 1effb75 db56fd6 1effb75 db56fd6 1effb75 db56fd6 1effb75 db56fd6 1effb75 db56fd6 1effb75 db56fd6 1effb75 db56fd6 1effb75 db56fd6 1effb75 db56fd6 1effb75 db56fd6 1effb75 db56fd6 1effb75 db56fd6 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 |
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import struct
import re
import logging
import io
from typing import Optional, Dict, Tuple, Union
import google.generativeai as genai
from pydub import AudioSegment
from cache import cache
# --- Constants ---
GEMINI_API_KEY = os.environ.get("GEMINI_API_KEY")
GENERATE_SPEECH = os.environ.get("GENERATE_SPEECH", "false").lower() == "true"
TTS_MODEL = "gemini-2.5-flash-preview-tts"
DEFAULT_RAW_AUDIO_MIME = "audio/L16;rate=24000"
# --- Setup ---
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
genai.configure(api_key=GEMINI_API_KEY)
class TTSGenerationError(Exception):
"""Raised when Gemini TTS generation fails."""
pass
def parse_audio_mime_type(mime_type: str) -> Dict[str, int]:
"""
Extracts bits_per_sample and sampling rate from a MIME string.
e.g. "audio/L16;rate=24000" → {"bits_per_sample": 16, "rate": 24000}
"""
bits_per_sample = 16
rate = 24000
for param in mime_type.split(";"):
param = param.strip().lower()
if param.startswith("rate="):
try:
rate = int(param.split("=", 1)[1])
except ValueError:
pass
elif re.match(r"audio/l\d+", param):
try:
bits_per_sample = int(param.split("l", 1)[1])
except ValueError:
pass
return {"bits_per_sample": bits_per_sample, "rate": rate}
def convert_to_wav(audio_data: bytes, mime_type: str) -> bytes:
"""Wrap raw PCM bytes in a WAV header for mono audio."""
params = parse_audio_mime_type(mime_type)
bits = params["bits_per_sample"]
rate = params["rate"]
num_channels = 1
bytes_per_sample = bits // 8
block_align = num_channels * bytes_per_sample
byte_rate = rate * block_align
data_size = len(audio_data)
chunk_size = 36 + data_size
header = struct.pack(
"<4sI4s4sIHHIIHH4sI",
b"RIFF",
chunk_size,
b"WAVE",
b"fmt ",
16,
1,
num_channels,
rate,
byte_rate,
block_align,
bits,
b"data",
data_size,
)
return header + audio_data
@cache.memoize()
def _synthesize_gemini_tts_impl(text: str, gemini_voice_name: str) -> Tuple[bytes, str]:
"""Core function to request audio from Gemini TTS (cached)."""
if not GENERATE_SPEECH:
raise TTSGenerationError("GENERATE_SPEECH not enabled in environment.")
try:
model = genai.GenerativeModel(TTS_MODEL)
response = model.generate_content(
contents=[text],
generation_config={
"response_modalities": ["AUDIO"],
"speech_config": {
"voice_config": {
"prebuilt_voice_config": {"voice_name": gemini_voice_name}
}
}
},
)
audio_part = response.candidates[0].content.parts[0]
raw_data = audio_part.inline_data.data
mime = audio_part.inline_data.mime_type
except Exception as e:
logging.error("Gemini TTS API error: %s", e)
raise TTSGenerationError(f"TTS request failed: {e}")
if not raw_data:
raise TTSGenerationError("Empty audio data from Gemini.")
# Convert raw audio to WAV if needed
mime_lower = mime.lower() if mime else ""
if mime_lower and (
mime_lower.startswith("audio/l")
or not mime_lower.startswith(("audio/wav", "audio/mpeg", "audio/ogg", "audio/opus"))
):
raw_data = convert_to_wav(raw_data, mime_lower)
mime = "audio/wav"
elif not mime:
logging.warning("MIME missing; defaulting to WAV")
raw_data = convert_to_wav(raw_data, DEFAULT_RAW_AUDIO_MIME)
mime = "audio/wav"
# Attempt MP3 compression
try:
segment = AudioSegment.from_file(io.BytesIO(raw_data), format="wav")
buf = io.BytesIO()
segment.export(buf, format="mp3")
return buf.getvalue(), "audio/mpeg"
except Exception as e:
logging.warning("MP3 conversion failed (%s); returning WAV", e)
return raw_data, mime
# Choose wrapper based on GENERATE_SPEECH flag
if GENERATE_SPEECH:
def synthesize_gemini_tts(text: str, voice: str) -> Tuple[Optional[bytes], Optional[str]]:
try:
return _synthesize_gemini_tts_impl(text, voice)
except TTSGenerationError as e:
logging.error("TTS failed: %s; skipping audio", e)
return None, None
else:
def synthesize_gemini_tts(text: str, voice: str) -> Tuple[Optional[bytes], Optional[str]]:
key = _synthesize_gemini_tts_impl.__cache_key__(text, voice)
result = cache.get(key)
if result is not None:
return result
logging.info("No cached audio; speech disabled.")
return None, None
|