|
import asyncio |
|
import base64 |
|
import json |
|
import os |
|
from threading import Event |
|
from datetime import datetime |
|
|
|
import gradio as gr |
|
import numpy as np |
|
import websockets.sync.client |
|
from dotenv import load_dotenv |
|
from gradio_webrtc import StreamHandler, WebRTC |
|
|
|
load_dotenv() |
|
|
|
|
|
GEMINI_API_KEY = "AIzaSyBem8AlttTGdGxGH3bZEs0xcnw5RIF5BsY" |
|
|
|
class MedicalGeminiConfig: |
|
def __init__(self, api_key): |
|
self.api_key = api_key |
|
self.host = "generativelanguage.googleapis.com" |
|
self.model = "models/gemini-2.0-flash-exp" |
|
self.ws_url = f"wss://{self.host}/ws/google.ai.generativelanguage.v1alpha.GenerativeService.BidiGenerateContent?key={self.api_key}" |
|
|
|
def get_medical_system_prompt(self): |
|
return """You are SocioCare AI, a compassionate and knowledgeable medical preconsultation assistant. You are engaging in a real-time voice conversation with a patient for their preliminary health assessment. |
|
|
|
IMPORTANT GUIDELINES: |
|
- Speak naturally and conversationally, as if you're a caring healthcare professional |
|
- Be empathetic, warm, and reassuring while maintaining professionalism |
|
- Ask relevant follow-up questions to understand symptoms and concerns better |
|
- Provide general health guidance and preliminary assessments |
|
- ALWAYS emphasize that this is a preconsultation and not a substitute for professional medical care |
|
- If symptoms seem serious or urgent, encourage immediate medical attention |
|
- Maintain patient confidentiality and professionalism |
|
- Use simple, clear language that patients can understand |
|
- Be patient and allow time for the patient to explain their concerns thoroughly |
|
|
|
PRECONSULTATION FLOW: |
|
1. Greet the patient warmly and introduce yourself as SocioCare AI |
|
2. Ask about their main health concern or symptoms |
|
3. Listen actively and ask clarifying questions about symptoms, duration, severity |
|
4. Provide general health information and preliminary guidance |
|
5. Recommend appropriate next steps (rest, hydration, seeing a doctor, specialist referral, etc.) |
|
6. Offer to answer any additional questions about their health concerns |
|
7. Provide a summary of key points discussed |
|
|
|
Remember: You are providing preliminary health assessment and information only. For diagnosis, treatment, and comprehensive care, patients should consult with licensed healthcare professionals.""" |
|
|
|
class AudioProcessor: |
|
@staticmethod |
|
def encode_audio(data, sample_rate): |
|
encoded = base64.b64encode(data.tobytes()).decode("UTF-8") |
|
return { |
|
"realtimeInput": { |
|
"mediaChunks": [ |
|
{ |
|
"mimeType": f"audio/pcm;rate={sample_rate}", |
|
"data": encoded, |
|
} |
|
], |
|
}, |
|
} |
|
|
|
@staticmethod |
|
def process_audio_response(data): |
|
audio_data = base64.b64decode(data) |
|
return np.frombuffer(audio_data, dtype=np.int16) |
|
|
|
|
|
class MedicalGeminiHandler(StreamHandler): |
|
def __init__( |
|
self, expected_layout="mono", output_sample_rate=24000, output_frame_size=480 |
|
) -> None: |
|
super().__init__( |
|
expected_layout, |
|
output_sample_rate, |
|
output_frame_size, |
|
input_sample_rate=24000, |
|
) |
|
self.config = None |
|
self.ws = None |
|
self.all_output_data = None |
|
self.audio_processor = AudioProcessor() |
|
self.args_set = Event() |
|
self.session_started = False |
|
self.conversation_log = [] |
|
|
|
def copy(self): |
|
return MedicalGeminiHandler( |
|
expected_layout=self.expected_layout, |
|
output_sample_rate=self.output_sample_rate, |
|
output_frame_size=self.output_frame_size, |
|
) |
|
|
|
def _initialize_websocket(self): |
|
assert self.config, "Config not set" |
|
try: |
|
self.ws = websockets.sync.client.connect(self.config.ws_url, timeout=30) |
|
initial_request = { |
|
"setup": { |
|
"model": self.config.model, |
|
"generationConfig": { |
|
"responseModalities": ["AUDIO"], |
|
"speechConfig": { |
|
"voiceConfig": { |
|
"prebuiltVoiceConfig": { |
|
"voiceName": "Aoede" |
|
} |
|
} |
|
} |
|
}, |
|
"systemInstruction": { |
|
"parts": [ |
|
{ |
|
"text": self.config.get_medical_system_prompt() |
|
} |
|
] |
|
} |
|
} |
|
} |
|
self.ws.send(json.dumps(initial_request)) |
|
setup_response = json.loads(self.ws.recv()) |
|
print(f"SocioCare AI preconsultation setup: {setup_response}") |
|
|
|
|
|
if not self.session_started: |
|
self._send_initial_greeting() |
|
self.session_started = True |
|
|
|
except websockets.exceptions.WebSocketException as e: |
|
print(f"WebSocket connection failed: {str(e)}") |
|
self.ws = None |
|
except Exception as e: |
|
print(f"Setup failed: {str(e)}") |
|
self.ws = None |
|
|
|
def _send_initial_greeting(self): |
|
"""Send initial greeting to start the medical preconsultation""" |
|
try: |
|
greeting_message = { |
|
"clientContent": { |
|
"turns": [ |
|
{ |
|
"role": "user", |
|
"parts": [ |
|
{ |
|
"text": "Please start the preconsultation by greeting me as a patient and introducing yourself as SocioCare AI." |
|
} |
|
] |
|
} |
|
], |
|
"turnComplete": True |
|
} |
|
} |
|
self.ws.send(json.dumps(greeting_message)) |
|
except Exception as e: |
|
print(f"Error sending initial greeting: {str(e)}") |
|
|
|
async def fetch_args(self): |
|
if self.channel: |
|
self.channel.send("tick") |
|
|
|
def set_args(self, args): |
|
super().set_args(args) |
|
self.args_set.set() |
|
|
|
def receive(self, frame: tuple[int, np.ndarray]) -> None: |
|
if not self.channel: |
|
return |
|
if not self.config: |
|
|
|
self.config = MedicalGeminiConfig(GEMINI_API_KEY) |
|
try: |
|
if not self.ws: |
|
self._initialize_websocket() |
|
|
|
_, array = frame |
|
array = array.squeeze() |
|
audio_message = self.audio_processor.encode_audio( |
|
array, self.output_sample_rate |
|
) |
|
self.ws.send(json.dumps(audio_message)) |
|
except Exception as e: |
|
print(f"Error in receive: {str(e)}") |
|
if self.ws: |
|
self.ws.close() |
|
self.ws = None |
|
|
|
def _process_server_content(self, content): |
|
for part in content.get("parts", []): |
|
data = part.get("inlineData", {}).get("data", "") |
|
if data: |
|
audio_array = self.audio_processor.process_audio_response(data) |
|
if self.all_output_data is None: |
|
self.all_output_data = audio_array |
|
else: |
|
self.all_output_data = np.concatenate( |
|
(self.all_output_data, audio_array) |
|
) |
|
|
|
while self.all_output_data.shape[-1] >= self.output_frame_size: |
|
yield ( |
|
self.output_sample_rate, |
|
self.all_output_data[: self.output_frame_size].reshape(1, -1), |
|
) |
|
self.all_output_data = self.all_output_data[ |
|
self.output_frame_size : |
|
] |
|
|
|
def generator(self): |
|
while True: |
|
if not self.ws or not self.config: |
|
print("WebSocket not connected") |
|
yield None |
|
continue |
|
|
|
try: |
|
message = self.ws.recv(timeout=5) |
|
msg = json.loads(message) |
|
|
|
if "serverContent" in msg: |
|
content = msg["serverContent"].get("modelTurn", {}) |
|
yield from self._process_server_content(content) |
|
except TimeoutError: |
|
print("Timeout waiting for server response") |
|
yield None |
|
except Exception as e: |
|
print(f"Error in generator: {str(e)}") |
|
yield None |
|
|
|
def emit(self) -> tuple[int, np.ndarray] | None: |
|
if not self.ws: |
|
return None |
|
if not hasattr(self, "_generator"): |
|
self._generator = self.generator() |
|
try: |
|
return next(self._generator) |
|
except StopIteration: |
|
self.reset() |
|
return None |
|
|
|
def reset(self) -> None: |
|
if hasattr(self, "_generator"): |
|
delattr(self, "_generator") |
|
self.all_output_data = None |
|
|
|
def shutdown(self) -> None: |
|
if self.ws: |
|
self.ws.close() |
|
|
|
def check_connection(self): |
|
try: |
|
if not self.ws or self.ws.closed: |
|
self._initialize_websocket() |
|
return True |
|
except Exception as e: |
|
print(f"Connection check failed: {str(e)}") |
|
return False |
|
|
|
|
|
def get_rtc_configuration(): |
|
""" |
|
Get RTC configuration using only public STUN servers |
|
""" |
|
return { |
|
"iceServers": [ |
|
{"urls": "stun:stun.l.google.com:19302"}, |
|
{"urls": "stun:stun1.l.google.com:19302"}, |
|
{"urls": "stun:stun2.l.google.com:19302"}, |
|
{"urls": "stun:stun3.l.google.com:19302"}, |
|
{"urls": "stun:stun4.l.google.com:19302"}, |
|
] |
|
} |
|
|
|
|
|
class SocioCareAIPreconsultation: |
|
def __init__(self): |
|
self.demo = self._create_interface() |
|
|
|
def _create_interface(self): |
|
|
|
custom_css = """ |
|
<style> |
|
/* Global dark theme */ |
|
.gradio-container { |
|
background: linear-gradient(135deg, #0F0C29 0%, #24243e 50%, #302B63 100%) !important; |
|
min-height: 100vh; |
|
} |
|
|
|
.dark { |
|
background: linear-gradient(135deg, #0F0C29 0%, #24243e 50%, #302B63 100%) !important; |
|
} |
|
|
|
/* Main container */ |
|
.main-container { |
|
display: flex; |
|
flex-direction: column; |
|
align-items: center; |
|
justify-content: center; |
|
min-height: 90vh; |
|
padding: 2rem; |
|
} |
|
|
|
/* AI Icon and waves container */ |
|
.ai-icon-container { |
|
position: relative; |
|
margin-bottom: 3rem; |
|
display: flex; |
|
flex-direction: column; |
|
align-items: center; |
|
} |
|
|
|
/* Audio waves */ |
|
.audio-waves { |
|
display: flex; |
|
align-items: center; |
|
justify-content: center; |
|
gap: 4px; |
|
margin-bottom: 2rem; |
|
height: 60px; |
|
} |
|
|
|
.wave-dot { |
|
width: 6px; |
|
height: 6px; |
|
background: #667EEA; |
|
border-radius: 50%; |
|
animation: pulse 2s ease-in-out infinite; |
|
opacity: 0.6; |
|
} |
|
|
|
.wave-bar { |
|
width: 4px; |
|
background: linear-gradient(to top, #667EEA, #764BA2); |
|
border-radius: 2px; |
|
animation: wave 1.5s ease-in-out infinite; |
|
} |
|
|
|
.wave-bar:nth-child(1) { height: 20px; animation-delay: 0s; } |
|
.wave-bar:nth-child(2) { height: 35px; animation-delay: 0.1s; } |
|
.wave-bar:nth-child(3) { height: 45px; animation-delay: 0.2s; } |
|
.wave-bar:nth-child(4) { height: 60px; animation-delay: 0.3s; } |
|
.wave-bar:nth-child(5) { height: 50px; animation-delay: 0.4s; } |
|
.wave-bar:nth-child(6) { height: 40px; animation-delay: 0.5s; } |
|
.wave-bar:nth-child(7) { height: 55px; animation-delay: 0.6s; } |
|
.wave-bar:nth-child(8) { height: 35px; animation-delay: 0.7s; } |
|
.wave-bar:nth-child(9) { height: 25px; animation-delay: 0.8s; } |
|
.wave-bar:nth-child(10) { height: 40px; animation-delay: 0.9s; } |
|
.wave-bar:nth-child(11) { height: 50px; animation-delay: 1s; } |
|
.wave-bar:nth-child(12) { height: 30px; animation-delay: 1.1s; } |
|
|
|
@keyframes wave { |
|
0%, 100% { transform: scaleY(0.5); opacity: 0.7; } |
|
50% { transform: scaleY(1); opacity: 1; } |
|
} |
|
|
|
@keyframes pulse { |
|
0%, 100% { opacity: 0.4; transform: scale(1); } |
|
50% { opacity: 1; transform: scale(1.2); } |
|
} |
|
|
|
/* AI Icon */ |
|
.ai-icon { |
|
width: 120px; |
|
height: 120px; |
|
background: linear-gradient(135deg, #667EEA 0%, #764BA2 100%); |
|
border-radius: 24px; |
|
display: flex; |
|
align-items: center; |
|
justify-content: center; |
|
font-size: 3rem; |
|
color: white; |
|
box-shadow: 0 20px 40px rgba(102, 126, 234, 0.3); |
|
margin-bottom: 2rem; |
|
position: relative; |
|
overflow: hidden; |
|
} |
|
|
|
.ai-icon::before { |
|
content: ''; |
|
position: absolute; |
|
top: 0; |
|
left: 0; |
|
right: 0; |
|
bottom: 0; |
|
background: linear-gradient(45deg, transparent 30%, rgba(255,255,255,0.1) 50%, transparent 70%); |
|
animation: shimmer 3s ease-in-out infinite; |
|
} |
|
|
|
@keyframes shimmer { |
|
0% { transform: translateX(-100%); } |
|
100% { transform: translateX(100%); } |
|
} |
|
|
|
/* Title */ |
|
.ai-title { |
|
font-size: 2.5rem; |
|
font-weight: 700; |
|
background: linear-gradient(135deg, #667EEA 0%, #764BA2 100%); |
|
-webkit-background-clip: text; |
|
-webkit-text-fill-color: transparent; |
|
background-clip: text; |
|
text-align: center; |
|
margin-bottom: 0.5rem; |
|
letter-spacing: -0.02em; |
|
} |
|
|
|
/* Subtitle */ |
|
.ai-subtitle { |
|
color: #A0AEC0; |
|
font-size: 1.1rem; |
|
text-align: center; |
|
margin-bottom: 3rem; |
|
font-weight: 400; |
|
} |
|
|
|
/* WebRTC component styling */ |
|
.webrtc-container { |
|
display: flex; |
|
flex-direction: column; |
|
align-items: center; |
|
gap: 1rem; |
|
} |
|
|
|
/* Hide default gradio elements */ |
|
.gradio-container .wrap, |
|
.gradio-container .container, |
|
footer { |
|
background: transparent !important; |
|
border: none !important; |
|
box-shadow: none !important; |
|
} |
|
|
|
/* Custom button styling for WebRTC */ |
|
button { |
|
background: linear-gradient(135deg, #667EEA 0%, #764BA2 100%) !important; |
|
border: none !important; |
|
border-radius: 50px !important; |
|
padding: 1rem 2rem !important; |
|
color: white !important; |
|
font-weight: 600 !important; |
|
font-size: 1.1rem !important; |
|
cursor: pointer !important; |
|
transition: all 0.3s ease !important; |
|
box-shadow: 0 10px 30px rgba(102, 126, 234, 0.3) !important; |
|
} |
|
|
|
button:hover { |
|
transform: translateY(-2px) !important; |
|
box-shadow: 0 15px 35px rgba(102, 126, 234, 0.4) !important; |
|
} |
|
|
|
/* Status indicators */ |
|
.status-dot { |
|
width: 12px; |
|
height: 12px; |
|
border-radius: 50%; |
|
background: #4ADE80; |
|
box-shadow: 0 0 20px rgba(74, 222, 128, 0.5); |
|
animation: pulse 2s ease-in-out infinite; |
|
margin-right: 0.5rem; |
|
} |
|
|
|
.status-text { |
|
color: #E2E8F0; |
|
font-size: 0.9rem; |
|
display: flex; |
|
align-items: center; |
|
justify-content: center; |
|
margin-top: 1rem; |
|
} |
|
|
|
/* Hide default gradio styling */ |
|
.gradio-container .block { |
|
background: transparent !important; |
|
border: none !important; |
|
box-shadow: none !important; |
|
} |
|
|
|
/* Responsive design */ |
|
@media (max-width: 768px) { |
|
.ai-icon { |
|
width: 100px; |
|
height: 100px; |
|
font-size: 2.5rem; |
|
} |
|
|
|
.ai-title { |
|
font-size: 2rem; |
|
} |
|
|
|
.audio-waves { |
|
height: 50px; |
|
gap: 3px; |
|
} |
|
|
|
.wave-bar { |
|
width: 3px; |
|
} |
|
} |
|
</style> |
|
""" |
|
|
|
with gr.Blocks(theme=gr.themes.Glass(), css=custom_css) as demo: |
|
with gr.Column(elem_classes=["main-container"]): |
|
|
|
gr.HTML(""" |
|
<div class="audio-waves"> |
|
<div class="wave-dot"></div> |
|
<div class="wave-dot"></div> |
|
<div class="wave-dot"></div> |
|
<div class="wave-bar"></div> |
|
<div class="wave-bar"></div> |
|
<div class="wave-bar"></div> |
|
<div class="wave-bar"></div> |
|
<div class="wave-bar"></div> |
|
<div class="wave-bar"></div> |
|
<div class="wave-bar"></div> |
|
<div class="wave-bar"></div> |
|
<div class="wave-bar"></div> |
|
<div class="wave-bar"></div> |
|
<div class="wave-bar"></div> |
|
<div class="wave-bar"></div> |
|
<div class="wave-dot"></div> |
|
<div class="wave-dot"></div> |
|
<div class="wave-dot"></div> |
|
</div> |
|
""") |
|
|
|
|
|
gr.HTML(""" |
|
<div class="ai-icon-container"> |
|
<div class="ai-icon"> |
|
AI✨ |
|
</div> |
|
</div> |
|
""") |
|
|
|
|
|
gr.HTML(""" |
|
<h1 class="ai-title">AI Voice Agent</h1> |
|
<p class="ai-subtitle">By SocioCare</p> |
|
""") |
|
|
|
|
|
with gr.Column(elem_classes=["webrtc-container"]): |
|
webrtc = WebRTC( |
|
label="", |
|
modality="audio", |
|
mode="send-receive", |
|
rtc_configuration=get_rtc_configuration(), |
|
) |
|
|
|
webrtc.stream( |
|
MedicalGeminiHandler(), |
|
inputs=[webrtc], |
|
outputs=[webrtc], |
|
time_limit=600, |
|
concurrency_limit=3, |
|
) |
|
|
|
|
|
gr.HTML(""" |
|
<div class="status-text"> |
|
<span class="status-dot"></span> |
|
Ready to assist with your health consultation |
|
</div> |
|
""") |
|
|
|
return demo |
|
|
|
def launch(self): |
|
|
|
import socket |
|
|
|
def find_free_port(start_port=7860): |
|
"""Find a free port starting from the given port number""" |
|
for port in range(start_port, start_port + 100): |
|
try: |
|
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: |
|
s.bind(('', port)) |
|
return port |
|
except OSError: |
|
continue |
|
return None |
|
|
|
|
|
port = int(os.environ.get("PORT", 0)) if os.environ.get("PORT") else find_free_port() |
|
|
|
if port is None: |
|
print("Could not find an available port. Please set the PORT environment variable.") |
|
return |
|
|
|
print(f"Starting AI Voice Agent server on port {port}") |
|
|
|
self.demo.launch( |
|
server_name="0.0.0.0", |
|
server_port=port, |
|
ssl_verify=False, |
|
ssl_keyfile=None, |
|
ssl_certfile=None, |
|
show_api=False, |
|
quiet=False, |
|
inbrowser=True |
|
) |
|
|
|
|
|
if __name__ == "__main__": |
|
app = SocioCareAIPreconsultation() |
|
app.launch() |