File size: 5,433 Bytes
db56fd6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1effb75
 
db56fd6
1effb75
db56fd6
1effb75
db56fd6
 
 
 
 
 
 
1effb75
 
db56fd6
 
 
1effb75
db56fd6
 
1effb75
db56fd6
1effb75
 
db56fd6
1effb75
 
db56fd6
1effb75
db56fd6
 
 
1effb75
 
 
 
 
 
 
 
 
db56fd6
 
 
1effb75
 
 
 
 
 
 
db56fd6
1effb75
 
db56fd6
 
 
 
1effb75
 
 
 
 
 
 
 
 
 
 
 
 
db56fd6
 
 
1effb75
 
 
db56fd6
1effb75
db56fd6
 
 
 
 
1effb75
 
 
 
 
 
 
 
db56fd6
 
1effb75
 
db56fd6
1effb75
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
db56fd6
1effb75
db56fd6
1effb75
db56fd6
1effb75
db56fd6
1effb75
db56fd6
 
1effb75
 
 
 
 
 
db56fd6
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os
import struct
import re
import logging
import io
from typing import Optional, Dict, Tuple, Union

import google.generativeai as genai
from pydub import AudioSegment
from cache import cache

# --- Constants ---
GEMINI_API_KEY = os.environ.get("GEMINI_API_KEY")
GENERATE_SPEECH = os.environ.get("GENERATE_SPEECH", "false").lower() == "true"
TTS_MODEL = "gemini-2.5-flash-preview-tts"
DEFAULT_RAW_AUDIO_MIME = "audio/L16;rate=24000"

# --- Setup ---
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
genai.configure(api_key=GEMINI_API_KEY)

class TTSGenerationError(Exception):
    """Raised when Gemini TTS generation fails."""
    pass

def parse_audio_mime_type(mime_type: str) -> Dict[str, int]:
    """
    Extracts bits_per_sample and sampling rate from a MIME string.
    e.g. "audio/L16;rate=24000" → {"bits_per_sample": 16, "rate": 24000}
    """
    bits_per_sample = 16
    rate = 24000

    for param in mime_type.split(";"):
        param = param.strip().lower()
        if param.startswith("rate="):
            try:
                rate = int(param.split("=", 1)[1])
            except ValueError:
                pass
        elif re.match(r"audio/l\d+", param):
            try:
                bits_per_sample = int(param.split("l", 1)[1])
            except ValueError:
                pass

    return {"bits_per_sample": bits_per_sample, "rate": rate}

def convert_to_wav(audio_data: bytes, mime_type: str) -> bytes:
    """Wrap raw PCM bytes in a WAV header for mono audio."""
    params = parse_audio_mime_type(mime_type)
    bits = params["bits_per_sample"]
    rate = params["rate"]

    num_channels = 1
    bytes_per_sample = bits // 8
    block_align = num_channels * bytes_per_sample
    byte_rate = rate * block_align
    data_size = len(audio_data)
    chunk_size = 36 + data_size

    header = struct.pack(
        "<4sI4s4sIHHIIHH4sI",
        b"RIFF",
        chunk_size,
        b"WAVE",
        b"fmt ",
        16,
        1,
        num_channels,
        rate,
        byte_rate,
        block_align,
        bits,
        b"data",
        data_size,
    )
    return header + audio_data

@cache.memoize()
def _synthesize_gemini_tts_impl(text: str, gemini_voice_name: str) -> Tuple[bytes, str]:
    """Core function to request audio from Gemini TTS (cached)."""
    if not GENERATE_SPEECH:
        raise TTSGenerationError("GENERATE_SPEECH not enabled in environment.")

    try:
        model = genai.GenerativeModel(TTS_MODEL)
        response = model.generate_content(
            contents=[text],
            generation_config={
                "response_modalities": ["AUDIO"],
                "speech_config": {
                    "voice_config": {
                        "prebuilt_voice_config": {"voice_name": gemini_voice_name}
                    }
                }
            },
        )
        audio_part = response.candidates[0].content.parts[0]
        raw_data = audio_part.inline_data.data
        mime = audio_part.inline_data.mime_type
    except Exception as e:
        logging.error("Gemini TTS API error: %s", e)
        raise TTSGenerationError(f"TTS request failed: {e}")

    if not raw_data:
        raise TTSGenerationError("Empty audio data from Gemini.")

    # Convert raw audio to WAV if needed
    mime_lower = mime.lower() if mime else ""
    if mime_lower and (
        mime_lower.startswith("audio/l")
        or not mime_lower.startswith(("audio/wav", "audio/mpeg", "audio/ogg", "audio/opus"))
    ):
        raw_data = convert_to_wav(raw_data, mime_lower)
        mime = "audio/wav"
    elif not mime:
        logging.warning("MIME missing; defaulting to WAV")
        raw_data = convert_to_wav(raw_data, DEFAULT_RAW_AUDIO_MIME)
        mime = "audio/wav"

    # Attempt MP3 compression
    try:
        segment = AudioSegment.from_file(io.BytesIO(raw_data), format="wav")
        buf = io.BytesIO()
        segment.export(buf, format="mp3")
        return buf.getvalue(), "audio/mpeg"
    except Exception as e:
        logging.warning("MP3 conversion failed (%s); returning WAV", e)
        return raw_data, mime

# Choose wrapper based on GENERATE_SPEECH flag
if GENERATE_SPEECH:
    def synthesize_gemini_tts(text: str, voice: str) -> Tuple[Optional[bytes], Optional[str]]:
        try:
            return _synthesize_gemini_tts_impl(text, voice)
        except TTSGenerationError as e:
            logging.error("TTS failed: %s; skipping audio", e)
            return None, None
else:
    def synthesize_gemini_tts(text: str, voice: str) -> Tuple[Optional[bytes], Optional[str]]:
        key = _synthesize_gemini_tts_impl.__cache_key__(text, voice)
        result = cache.get(key)
        if result is not None:
            return result
        logging.info("No cached audio; speech disabled.")
        return None, None