mgbam commited on
Commit
1effb75
·
verified ·
1 Parent(s): 228cd9c

Update gemini_tts.py

Browse files
Files changed (1) hide show
  1. gemini_tts.py +96 -147
gemini_tts.py CHANGED
@@ -12,16 +12,16 @@
12
  # See the License for the specific language governing permissions and
13
  # limitations under the License.
14
 
15
- import google.generativeai as genai
16
  import os
17
  import struct
18
  import re
19
  import logging
20
- from cache import cache
 
21
 
22
- # Add these imports for MP3 conversion
23
  from pydub import AudioSegment
24
- import io
25
 
26
  # --- Constants ---
27
  GEMINI_API_KEY = os.environ.get("GEMINI_API_KEY")
@@ -29,184 +29,133 @@ GENERATE_SPEECH = os.environ.get("GENERATE_SPEECH", "false").lower() == "true"
29
  TTS_MODEL = "gemini-2.5-flash-preview-tts"
30
  DEFAULT_RAW_AUDIO_MIME = "audio/L16;rate=24000"
31
 
32
- # --- Configuration ---
33
- logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
34
-
35
  genai.configure(api_key=GEMINI_API_KEY)
36
 
37
  class TTSGenerationError(Exception):
38
- """Custom exception for TTS generation failures."""
39
  pass
40
 
41
-
42
- # --- Helper functions for audio processing ---
43
- def parse_audio_mime_type(mime_type: str) -> dict[str, int | None]:
44
  """
45
- Parses bits per sample and rate from an audio MIME type string.
46
- e.g., "audio/L16;rate=24000" -> {"bits_per_sample": 16, "rate": 24000}
47
  """
48
- bits_per_sample = 16 # Default
49
- rate = 24000 # Default
50
 
51
- parts = mime_type.split(";")
52
- for param in parts:
53
  param = param.strip().lower()
54
  if param.startswith("rate="):
55
  try:
56
- rate_str = param.split("=", 1)[1]
57
- rate = int(rate_str)
58
- except (ValueError, IndexError):
59
- pass # Keep default if parsing fails
60
- elif re.match(r"audio/l\d+", param): # Matches audio/L<digits>
61
- try:
62
- bits_str = param.split("l",1)[1]
63
- bits_per_sample = int(bits_str)
64
- except (ValueError, IndexError):
65
- pass # Keep default
66
  return {"bits_per_sample": bits_per_sample, "rate": rate}
67
 
68
  def convert_to_wav(audio_data: bytes, mime_type: str) -> bytes:
69
- """
70
- Generates a WAV file header for the given raw audio data and parameters.
71
- Assumes mono audio.
72
- """
73
- parameters = parse_audio_mime_type(mime_type)
74
- bits_per_sample = parameters["bits_per_sample"]
75
- sample_rate = parameters["rate"]
76
- num_channels = 1 # Mono
77
- data_size = len(audio_data)
78
- bytes_per_sample = bits_per_sample // 8
79
  block_align = num_channels * bytes_per_sample
80
- byte_rate = sample_rate * block_align
 
81
  chunk_size = 36 + data_size
82
 
83
  header = struct.pack(
84
  "<4sI4s4sIHHIIHH4sI",
85
- b"RIFF", chunk_size, b"WAVE", b"fmt ",
86
- 16, 1, num_channels, sample_rate, byte_rate, block_align,
87
- bits_per_sample, b"data", data_size
 
 
 
 
 
 
 
 
 
 
88
  )
89
  return header + audio_data
90
- # --- End of helper functions ---
91
 
92
- def _synthesize_gemini_tts_impl(text: str, gemini_voice_name: str) -> tuple[bytes, str]:
93
- """
94
- Synthesizes English text using the Gemini API via the google-genai library.
95
- Returns a tuple: (processed_audio_data_bytes, final_mime_type).
96
- Raises TTSGenerationError on failure.
97
- """
98
  if not GENERATE_SPEECH:
99
- # This should ideally not be hit if the logic outside this function is correct,
100
- # but as a safeguard, we raise an error.
101
- raise TTSGenerationError(
102
- "GENERATE_SPEECH is not set. Please set it in your environment variables to generate speech."
103
- )
104
 
105
  try:
106
  model = genai.GenerativeModel(TTS_MODEL)
107
-
108
- generation_config = {
109
- "response_modalities": ["AUDIO"],
110
- "speech_config": {
111
- "voice_config": {
112
- "prebuilt_voice_config": {
113
- "voice_name": gemini_voice_name
114
- }
115
- }
116
- }
117
- }
118
-
119
  response = model.generate_content(
120
  contents=[text],
121
- generation_config=generation_config,
 
 
 
 
 
 
 
122
  )
123
-
124
  audio_part = response.candidates[0].content.parts[0]
125
- audio_data_bytes = audio_part.inline_data.data
126
- final_mime_type = audio_part.inline_data.mime_type
127
  except Exception as e:
128
- error_message = f"An unexpected error occurred with google-genai: {e}"
129
- logging.error(error_message)
130
- raise TTSGenerationError(error_message) from e
131
-
132
- if not audio_data_bytes:
133
- error_message = "No audio data was successfully retrieved or decoded."
134
- logging.error(error_message)
135
- raise TTSGenerationError(error_message)
136
-
137
- # --- Audio processing ---
138
- if final_mime_type:
139
- final_mime_type_lower = final_mime_type.lower()
140
- needs_wav_conversion = any(p in final_mime_type_lower for p in ("audio/l16", "audio/l24", "audio/l8")) or \
141
- not final_mime_type_lower.startswith(("audio/wav", "audio/mpeg", "audio/ogg", "audio/opus"))
142
-
143
- if needs_wav_conversion:
144
- processed_audio_data = convert_to_wav(audio_data_bytes, final_mime_type)
145
- processed_audio_mime = "audio/wav"
146
- else:
147
- processed_audio_data = audio_data_bytes
148
- processed_audio_mime = final_mime_type
149
- else:
150
- logging.warning("MIME type not determined. Assuming raw audio and attempting WAV conversion (defaulting to %s).", DEFAULT_RAW_AUDIO_MIME)
151
- processed_audio_data = convert_to_wav(audio_data_bytes, DEFAULT_RAW_AUDIO_MIME)
152
- processed_audio_mime = "audio/wav"
153
-
154
- # --- MP3 compression ---
155
- if processed_audio_data:
156
- try:
157
- # Load audio into AudioSegment
158
- audio_segment = AudioSegment.from_file(io.BytesIO(processed_audio_data), format="wav")
159
- mp3_buffer = io.BytesIO()
160
- audio_segment.export(mp3_buffer, format="mp3")
161
- mp3_bytes = mp3_buffer.getvalue()
162
- return mp3_bytes, "audio/mpeg"
163
- except Exception as e:
164
- logging.warning("MP3 compression failed: %s. Falling back to WAV.", e)
165
- # Fallback to WAV if MP3 conversion fails
166
- return processed_audio_data, processed_audio_mime
167
- else:
168
- error_message = "Audio processing failed."
169
- logging.error(error_message)
170
- raise TTSGenerationError(error_message)
171
-
172
- # Always create the memoized function first, so we can access its .key() method
173
- _memoized_tts_func = cache.memoize()(_synthesize_gemini_tts_impl)
174
 
 
175
  if GENERATE_SPEECH:
176
- def synthesize_gemini_tts_with_error_handling(*args, **kwargs) -> tuple[bytes | None, str | None]:
177
- """
178
- A wrapper for the memoized TTS function that catches errors and returns (None, None).
179
- This makes the audio generation more resilient to individual failures.
180
- """
181
  try:
182
- # Attempt to get the audio from the cache or by generating it.
183
- return _memoized_tts_func(*args, **kwargs)
184
  except TTSGenerationError as e:
185
- # If generation fails, log the error and return None, None.
186
- logging.error("Handled TTS Generation Error: %s. Continuing without audio for this segment.", e)
187
  return None, None
188
-
189
- synthesize_gemini_tts = synthesize_gemini_tts_with_error_handling
190
  else:
191
- # When not generating speech, create a read-only function that only
192
- # checks the cache and does not generate new audio.
193
- def read_only_synthesize_gemini_tts(*args, **kwargs):
194
- """
195
- Checks cache for a result, but never calls the underlying TTS function.
196
- This is a 'read-only' memoization check.
197
- """
198
- # Generate the cache key using the memoized function's key method.
199
- key = _memoized_tts_func.__cache_key__(*args, **kwargs)
200
-
201
- # Check the cache directly using the generated key.
202
- _sentinel = object()
203
- result = cache.get(key, default=_sentinel)
204
-
205
- if result is not _sentinel:
206
- return result # Cache hit
207
-
208
- # Cache miss
209
- logging.info("GENERATE_SPEECH is false and no cached result found for key: %s", key)
210
  return None, None
211
-
212
- synthesize_gemini_tts = read_only_synthesize_gemini_tts
 
12
  # See the License for the specific language governing permissions and
13
  # limitations under the License.
14
 
 
15
  import os
16
  import struct
17
  import re
18
  import logging
19
+ import io
20
+ from typing import Optional, Dict, Tuple, Union
21
 
22
+ import google.generativeai as genai
23
  from pydub import AudioSegment
24
+ from cache import cache
25
 
26
  # --- Constants ---
27
  GEMINI_API_KEY = os.environ.get("GEMINI_API_KEY")
 
29
  TTS_MODEL = "gemini-2.5-flash-preview-tts"
30
  DEFAULT_RAW_AUDIO_MIME = "audio/L16;rate=24000"
31
 
32
+ # --- Setup ---
33
+ logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
 
34
  genai.configure(api_key=GEMINI_API_KEY)
35
 
36
  class TTSGenerationError(Exception):
37
+ """Raised when Gemini TTS generation fails."""
38
  pass
39
 
40
+ def parse_audio_mime_type(mime_type: str) -> Dict[str, int]:
 
 
41
  """
42
+ Extracts bits_per_sample and sampling rate from a MIME string.
43
+ e.g. "audio/L16;rate=24000" {"bits_per_sample": 16, "rate": 24000}
44
  """
45
+ bits_per_sample = 16
46
+ rate = 24000
47
 
48
+ for param in mime_type.split(";"):
 
49
  param = param.strip().lower()
50
  if param.startswith("rate="):
51
  try:
52
+ rate = int(param.split("=", 1)[1])
53
+ except ValueError:
54
+ pass
55
+ elif re.match(r"audio/l\d+", param):
56
+ try:
57
+ bits_per_sample = int(param.split("l", 1)[1])
58
+ except ValueError:
59
+ pass
60
+
 
61
  return {"bits_per_sample": bits_per_sample, "rate": rate}
62
 
63
  def convert_to_wav(audio_data: bytes, mime_type: str) -> bytes:
64
+ """Wrap raw PCM bytes in a WAV header for mono audio."""
65
+ params = parse_audio_mime_type(mime_type)
66
+ bits = params["bits_per_sample"]
67
+ rate = params["rate"]
68
+
69
+ num_channels = 1
70
+ bytes_per_sample = bits // 8
 
 
 
71
  block_align = num_channels * bytes_per_sample
72
+ byte_rate = rate * block_align
73
+ data_size = len(audio_data)
74
  chunk_size = 36 + data_size
75
 
76
  header = struct.pack(
77
  "<4sI4s4sIHHIIHH4sI",
78
+ b"RIFF",
79
+ chunk_size,
80
+ b"WAVE",
81
+ b"fmt ",
82
+ 16,
83
+ 1,
84
+ num_channels,
85
+ rate,
86
+ byte_rate,
87
+ block_align,
88
+ bits,
89
+ b"data",
90
+ data_size,
91
  )
92
  return header + audio_data
 
93
 
94
+ @cache.memoize()
95
+ def _synthesize_gemini_tts_impl(text: str, gemini_voice_name: str) -> Tuple[bytes, str]:
96
+ """Core function to request audio from Gemini TTS (cached)."""
 
 
 
97
  if not GENERATE_SPEECH:
98
+ raise TTSGenerationError("GENERATE_SPEECH not enabled in environment.")
 
 
 
 
99
 
100
  try:
101
  model = genai.GenerativeModel(TTS_MODEL)
 
 
 
 
 
 
 
 
 
 
 
 
102
  response = model.generate_content(
103
  contents=[text],
104
+ generation_config={
105
+ "response_modalities": ["AUDIO"],
106
+ "speech_config": {
107
+ "voice_config": {
108
+ "prebuilt_voice_config": {"voice_name": gemini_voice_name}
109
+ }
110
+ }
111
+ },
112
  )
 
113
  audio_part = response.candidates[0].content.parts[0]
114
+ raw_data = audio_part.inline_data.data
115
+ mime = audio_part.inline_data.mime_type
116
  except Exception as e:
117
+ logging.error("Gemini TTS API error: %s", e)
118
+ raise TTSGenerationError(f"TTS request failed: {e}")
119
+
120
+ if not raw_data:
121
+ raise TTSGenerationError("Empty audio data from Gemini.")
122
+
123
+ # Convert raw audio to WAV if needed
124
+ mime_lower = mime.lower() if mime else ""
125
+ if mime_lower and (
126
+ mime_lower.startswith("audio/l")
127
+ or not mime_lower.startswith(("audio/wav", "audio/mpeg", "audio/ogg", "audio/opus"))
128
+ ):
129
+ raw_data = convert_to_wav(raw_data, mime_lower)
130
+ mime = "audio/wav"
131
+ elif not mime:
132
+ logging.warning("MIME missing; defaulting to WAV")
133
+ raw_data = convert_to_wav(raw_data, DEFAULT_RAW_AUDIO_MIME)
134
+ mime = "audio/wav"
135
+
136
+ # Attempt MP3 compression
137
+ try:
138
+ segment = AudioSegment.from_file(io.BytesIO(raw_data), format="wav")
139
+ buf = io.BytesIO()
140
+ segment.export(buf, format="mp3")
141
+ return buf.getvalue(), "audio/mpeg"
142
+ except Exception as e:
143
+ logging.warning("MP3 conversion failed (%s); returning WAV", e)
144
+ return raw_data, mime
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
145
 
146
+ # Choose wrapper based on GENERATE_SPEECH flag
147
  if GENERATE_SPEECH:
148
+ def synthesize_gemini_tts(text: str, voice: str) -> Tuple[Optional[bytes], Optional[str]]:
 
 
 
 
149
  try:
150
+ return _synthesize_gemini_tts_impl(text, voice)
 
151
  except TTSGenerationError as e:
152
+ logging.error("TTS failed: %s; skipping audio", e)
 
153
  return None, None
 
 
154
  else:
155
+ def synthesize_gemini_tts(text: str, voice: str) -> Tuple[Optional[bytes], Optional[str]]:
156
+ key = _synthesize_gemini_tts_impl.__cache_key__(text, voice)
157
+ result = cache.get(key)
158
+ if result is not None:
159
+ return result
160
+ logging.info("No cached audio; speech disabled.")
 
 
 
 
 
 
 
 
 
 
 
 
 
161
  return None, None