import asyncio import base64 import os import time from io import BytesIO from google.genai import types from google.genai.types import ( LiveConnectConfig, SpeechConfig, VoiceConfig, PrebuiltVoiceConfig, Content, Part, ) import gradio as gr import numpy as np import websockets from dotenv import load_dotenv from fastrtc import ( AsyncAudioVideoStreamHandler, Stream, WebRTC, get_cloudflare_turn_credentials_async, wait_for_item, ) from google import genai from gradio.utils import get_space from PIL import Image # ------------------------------------------ import asyncio import base64 import json import os import pathlib from typing import AsyncGenerator, Literal import gradio as gr import numpy as np from dotenv import load_dotenv from fastapi import FastAPI from fastapi.responses import HTMLResponse from fastrtc import ( AsyncStreamHandler, Stream, get_cloudflare_turn_credentials_async, wait_for_item, ) from google import genai from google.genai.types import ( LiveConnectConfig, PrebuiltVoiceConfig, SpeechConfig, VoiceConfig, ) from gradio.utils import get_space from pydantic import BaseModel # ------------------------------------------------ from dotenv import load_dotenv load_dotenv() import os import io import asyncio from pydub import AudioSegment # Gemini: google-genai from google import genai # --------------------------------------------------- # VAD imports from reference code import collections import webrtcvad import time # helper functions GEMINI_API_KEY="AIzaSyCUCivstFpC9pq_jMHMYdlPrmh9Bx97dFo" TAVILY_API_KEY="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" OPENAI_API_KEY="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" QDRANT_API_KEY="eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJhY2Nlc3MiOiJtIiwiZXhwIjoxNzUxMDUxNzg4fQ.I9J-K7OM0BtcNKgj2d4uVM8QYAHYfFCVAyP4rlZkK2E" QDRANT_URL="https://6a3aade6-e8ad-4a6c-a579-21f5af90b7e8.us-east4-0.gcp.cloud.qdrant.io" OPENAI_API_KEY="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" WEAVIATE_URL="yorcqe2sqswhcaivxvt9a.c0.us-west3.gcp.weaviate.cloud" WEAVIATE_API_KEY="d2d0VGdZQTBmdTFlOWdDZl9tT2h3WDVWd1NpT1dQWHdGK0xjR1hYeWxicUxHVnFRazRUSjY2VlRUVlkwPV92MjAw" DEEPINFRA_API_KEY="285LUJulGIprqT6hcPhiXtcrphU04FG4" DEEPINFRA_BASE_URL="https://api.deepinfra.com/v1/openai" def encode_audio(data: np.ndarray) -> dict: """Encode Audio data to send to the server""" return { "mime_type": "audio/pcm", "data": base64.b64encode(data.tobytes()).decode("UTF-8"), } def encode_audio2(data: np.ndarray) -> bytes: """Encode Audio data to send to the server""" return data.tobytes() import soundfile as sf def numpy_array_to_wav_bytes(audio_array, sample_rate=16000): buffer = io.BytesIO() sf.write(buffer, audio_array, sample_rate, format='WAV') return buffer.getvalue() def numpy_array_to_wav_bytes(audio_array, sample_rate=16000): """ Convert a NumPy audio array to WAV bytes. Args: audio_array (np.ndarray): Audio signal (1D or 2D). sample_rate (int): Sample rate in Hz. Returns: bytes: WAV-formatted audio data. """ buffer = io.BytesIO() sf.write(buffer, audio_array, sample_rate, format='WAV') buffer.seek(0) return buffer.read() # webrtc handler class class GeminiHandler(AsyncStreamHandler): """Handler for the Gemini API with chained latency calculation.""" def __init__( self, expected_layout: Literal["mono"] = "mono", output_sample_rate: int = 24000,prompt_dict: dict = {"prompt":"PHQ-9"}, ) -> None: super().__init__( expected_layout, output_sample_rate, input_sample_rate=16000, ) self.input_queue: asyncio.Queue = asyncio.Queue() self.output_queue: asyncio.Queue = asyncio.Queue() self.quit: asyncio.Event = asyncio.Event() self.prompt_dict = prompt_dict # self.model = "gemini-2.5-flash-preview-tts" self.model = "gemini-2.0-flash-live-001" self.t2t_model = "gemini-2.0-flash" self.s2t_model = "gemini-2.0-flash" # --- VAD Initialization --- self.vad = webrtcvad.Vad(3) self.VAD_RATE = 16000 self.VAD_FRAME_MS = 20 self.VAD_FRAME_SAMPLES = int(self.VAD_RATE * (self.VAD_FRAME_MS / 1000.0)) self.VAD_FRAME_BYTES = self.VAD_FRAME_SAMPLES * 2 padding_ms = 300 self.vad_padding_frames = padding_ms // self.VAD_FRAME_MS self.vad_ring_buffer = collections.deque(maxlen=self.vad_padding_frames) self.vad_ratio = 0.9 self.vad_triggered = False self.wav_data = bytearray() self.internal_buffer = bytearray() self.end_of_speech_time: float | None = None self.first_latency_calculated: bool = False def copy(self) -> "GeminiHandler": return GeminiHandler( expected_layout="mono", output_sample_rate=self.output_sample_rate, prompt_dict=self.prompt_dict, ) def t2t(self, text: str) -> str: print(f"Sending text to Gemini: {text}") response = self.chat.send_message(text) print(f"Received response from Gemini: {response.text}") return response.text def s2t(self, audio) -> str: response = self.s2t_client.models.generate_content( model=self.s2t_model, contents=[ types.Part.from_bytes(data=audio, mime_type='audio/wav'), 'Generate a transcript of the speech.' ] ) return response.text async def start_up(self): # Flag for if we are using text-to-text in the middle of the chain or not. self.t2t_bool = False self.sys_prompt = None self.t2t_client = genai.Client(api_key=os.getenv("GEMINI_API_KEY")) self.s2t_client = genai.Client(api_key=os.getenv("GEMINI_API_KEY"))#, http_options={"api_version": "v1alpha"}) if self.sys_prompt is not None: chat_config = types.GenerateContentConfig(system_instruction=self.sys_prompt) else: chat_config = types.GenerateContentConfig(system_instruction="You are a helpful assistant.") self.chat = self.t2t_client.chats.create(model=self.t2t_model, config=chat_config) self.t2s_client = genai.Client(api_key=os.environ.get("GEMINI_API_KEY")) voice_name = "Puck" if self.t2t_bool: sys_instruction = f""" You are Wisal, an AI assistant developed by Compumacy AI , and a knowledgeable Autism . Your sole purpose is to provide helpful, respectful, and easy-to-understand answers about Autism Spectrum Disorder (ASD). Always be clear, non-judgmental, and supportive.""" else: sys_instruction = self.sys_prompt if sys_instruction is not None: config = LiveConnectConfig( response_modalities=["AUDIO"], speech_config=SpeechConfig( voice_config=VoiceConfig( prebuilt_voice_config=PrebuiltVoiceConfig(voice_name=voice_name) ) ), system_instruction=Content(parts=[Part.from_text(text=sys_instruction)]) ) else: config = LiveConnectConfig( response_modalities=["AUDIO"], speech_config=SpeechConfig( voice_config=VoiceConfig( prebuilt_voice_config=PrebuiltVoiceConfig(voice_name=voice_name) ) ), ) async with self.t2s_client.aio.live.connect(model=self.model, config=config) as session: async for text_from_user in self.stream(): print("--------------------------------------------") print(f"Received text from user and reading aloud: {text_from_user}") print("--------------------------------------------") if text_from_user and text_from_user.strip(): if self.t2t_bool: prompt = f""" You are Wisal, an AI assistant developed by Compumacy AI , and a knowledgeable Autism . Your sole purpose is to provide helpful, respectful, and easy-to-understand answers about Autism Spectrum Disorder (ASD). Always be clear, non-judgmental, and supportive. {text_from_user} """ else: prompt = text_from_user await session.send_client_content( turns=types.Content( role='user', parts=[types.Part(text=prompt)])) async for resp_chunk in session.receive(): if resp_chunk.data: array = np.frombuffer(resp_chunk.data, dtype=np.int16) self.output_queue.put_nowait((self.output_sample_rate, array)) async def stream(self) -> AsyncGenerator[bytes, None]: while not self.quit.is_set(): try: # Get the text message to be converted to speech text_to_speak = await self.input_queue.get() yield text_to_speak except (asyncio.TimeoutError, TimeoutError): pass async def receive(self, frame: tuple[int, np.ndarray]) -> None: sr, array = frame audio_bytes = array.tobytes() self.internal_buffer.extend(audio_bytes) while len(self.internal_buffer) >= self.VAD_FRAME_BYTES: vad_frame = self.internal_buffer[:self.VAD_FRAME_BYTES] self.internal_buffer = self.internal_buffer[self.VAD_FRAME_BYTES:] is_speech = self.vad.is_speech(vad_frame, self.VAD_RATE) if not self.vad_triggered: self.vad_ring_buffer.append((vad_frame, is_speech)) num_voiced = len([f for f, speech in self.vad_ring_buffer if speech]) if num_voiced > self.vad_ratio * self.vad_ring_buffer.maxlen: print("Speech detected, starting to record...") self.vad_triggered = True for f, s in self.vad_ring_buffer: self.wav_data.extend(f) self.vad_ring_buffer.clear() else: self.wav_data.extend(vad_frame) self.vad_ring_buffer.append((vad_frame, is_speech)) num_unvoiced = len([f for f, speech in self.vad_ring_buffer if not speech]) if num_unvoiced > self.vad_ratio * self.vad_ring_buffer.maxlen: print("End of speech detected.") self.end_of_speech_time = time.monotonic() self.vad_triggered = False full_utterance_np = np.frombuffer(self.wav_data, dtype=np.int16) audio_input_wav = numpy_array_to_wav_bytes(full_utterance_np, sr) text_input = self.s2t(audio_input_wav) if text_input and text_input.strip(): if self.t2t_bool: text_message = self.t2t(text_input) else: text_message = text_input self.input_queue.put_nowait(text_message) else: print("STT returned empty transcript, skipping.") self.vad_ring_buffer.clear() self.wav_data = bytearray() async def emit(self) -> tuple[int, np.ndarray] | None: return await wait_for_item(self.output_queue) def shutdown(self) -> None: self.quit.set() with gr.Blocks() as demo: gr.Markdown("# Gemini Chained Speech-to-Speech Demo") # for audio modality # with gr.Row(visible=(modality_selector.value == "audio")) as row2: with gr.Row() as row2: with gr.Column(): # Optional, can be removed if not needed webrtc2 = WebRTC( label="Audio Chat", modality="audio", mode="send-receive", elem_id="audio-source", rtc_configuration=get_cloudflare_turn_credentials_async, icon="https://www.gstatic.com/lamda/images/gemini_favicon_f069958c85030456e93de685481c559f160ea06b.png", pulse_color="rgb(255, 255, 255)", icon_button_color="rgb(255, 255, 255)", ) # Corrected inputs and outputs for webrtc2.stream to use webrtc2 webrtc2.stream( GeminiHandler(), inputs=[webrtc2], # Was webrtc outputs=[webrtc2],# Was webrtc time_limit=180 if get_space() else None, concurrency_limit=2 if get_space() else None, ) if __name__ == "__main__": demo.launch(server_port=7860)