Spaces:
Build error
Build error
import gradio as gr | |
import os, time, re, json, base64, asyncio, threading, uuid, io | |
import numpy as np | |
import soundfile as sf | |
from pydub import AudioSegment | |
from openai import OpenAI | |
from websockets import connect | |
from dotenv import load_dotenv | |
# Load secrets | |
load_dotenv() | |
OPENAI_API_KEY = os.getenv("OPENAI_API_KEY") | |
ASSISTANT_ID = os.getenv("ASSISTANT_ID") | |
client = OpenAI(api_key=OPENAI_API_KEY) | |
HEADERS = {"Authorization": f"Bearer {OPENAI_API_KEY}", "OpenAI-Beta": "realtime=v1"} | |
WS_URI = "wss://api.openai.com/v1/realtime?intent=transcription" | |
connections = {} | |
# WebSocket Client | |
class WebSocketClient: | |
def __init__(self, uri, headers, client_id): | |
self.uri = uri | |
self.headers = headers | |
self.client_id = client_id | |
self.websocket = None | |
self.queue = asyncio.Queue(maxsize=10) | |
self.transcript = "" | |
self.loop = asyncio.new_event_loop() | |
async def connect(self): | |
try: | |
self.websocket = await connect(self.uri, additional_headers=self.headers) | |
with open("openai_transcription_settings.json", "r") as f: | |
await self.websocket.send(f.read()) | |
await asyncio.gather(self.receive_messages(), self.send_audio_chunks()) | |
except Exception as e: | |
print(f"🔴 WebSocket Connection Failed: {e}") | |
def run(self): | |
asyncio.set_event_loop(self.loop) | |
self.loop.run_until_complete(self.connect()) | |
def enqueue_audio_chunk(self, sr, arr): | |
if not self.queue.full(): | |
asyncio.run_coroutine_threadsafe(self.queue.put((sr, arr)), self.loop) | |
async def send_audio_chunks(self): | |
while True: | |
sr, arr = await self.queue.get() | |
if arr.ndim > 1: | |
arr = arr.mean(axis=1) | |
if np.max(np.abs(arr)) > 0: | |
arr = arr / np.max(np.abs(arr)) | |
int16 = (arr * 32767).astype(np.int16) | |
buf = io.BytesIO() | |
sf.write(buf, int16, sr, format='WAV', subtype='PCM_16') | |
audio = AudioSegment.from_file(buf, format="wav").set_frame_rate(24000) | |
out = io.BytesIO() | |
audio.export(out, format="wav") | |
out.seek(0) | |
await self.websocket.send(json.dumps({ | |
"type": "input_audio_buffer.append", | |
"audio": base64.b64encode(out.read()).decode() | |
})) | |
async def receive_messages(self): | |
async for msg in self.websocket: | |
data = json.loads(msg) | |
if data["type"] == "conversation.item.input_audio_transcription.delta": | |
self.transcript += data["delta"] | |
def create_ws(): | |
cid = str(uuid.uuid4()) | |
client = WebSocketClient(WS_URI, HEADERS, cid) | |
threading.Thread(target=client.run, daemon=True).start() | |
connections[cid] = client | |
return cid | |
def send_audio(chunk, cid): | |
if not cid or cid not in connections: | |
return "Connecting..." | |
sr, arr = chunk | |
connections[cid].enqueue_audio_chunk(sr, arr) | |
return connections[cid].transcript.strip() | |
def clear_transcript(cid): | |
if cid in connections: | |
connections[cid].transcript = "" | |
return "" | |
def format_response(content, prompt): | |
summary = f"""<div class="card"> | |
<h3>❓ {prompt}</h3> | |
<p><b>🧠 In summary:</b></p> | |
<p>{content}</p>""" | |
thumbnails = re.findall(r'https://raw\.githubusercontent\.com/[^\s)]+\.png', content) | |
if thumbnails: | |
summary += "<h4>📎 Sources:</h4><div class='thumb-grid'>" | |
for url in thumbnails: | |
summary += f"<img src='{url}' class='thumb' />" | |
summary += "</div>" | |
summary += "</div>" | |
return summary | |
def handle_chat(prompt, thread_id): | |
if not OPENAI_API_KEY or not ASSISTANT_ID: | |
return "❌ Missing API Key or Assistant ID", thread_id | |
try: | |
if thread_id is None: | |
thread = client.beta.threads.create() | |
thread_id = thread.id | |
client.beta.threads.messages.create(thread_id=thread_id, role="user", content=prompt) | |
run = client.beta.threads.runs.create(thread_id=thread_id, assistant_id=ASSISTANT_ID) | |
while True: | |
status = client.beta.threads.runs.retrieve(thread_id=thread_id, run_id=run.id) | |
if status.status == "completed": | |
break | |
time.sleep(1) | |
msgs = client.beta.threads.messages.list(thread_id=thread_id) | |
for msg in reversed(msgs.data): | |
if msg.role == "assistant": | |
return format_response(msg.content[0].text.value, prompt), thread_id | |
return "⚠️ No assistant reply", thread_id | |
except Exception as e: | |
return f"❌ {e}", thread_id | |
def feed_transcript(transcript, thread_id, cid): | |
if not transcript.strip(): | |
return gr.update(), thread_id | |
if cid in connections: | |
connections[cid].transcript = "" | |
return handle_chat(transcript, thread_id) | |
# Gradio App | |
with gr.Blocks(css=""" | |
body { | |
background-color: #0f0f0f; | |
color: white; | |
font-family: 'Inter', sans-serif; | |
} | |
.card { | |
background: #1a1a1a; | |
padding: 20px; | |
margin-top: 24px; | |
border-radius: 14px; | |
box-shadow: 0 2px 8px #000; | |
} | |
.thumb-grid { | |
display: flex; | |
gap: 10px; | |
flex-wrap: wrap; | |
margin-top: 12px; | |
} | |
.thumb { | |
width: 120px; | |
border-radius: 8px; | |
border: 1px solid #333; | |
} | |
.input-box { | |
position: fixed; | |
bottom: 16px; | |
left: 0; | |
right: 0; | |
max-width: 700px; | |
margin: auto; | |
display: flex; | |
gap: 8px; | |
background: #1f1f1f; | |
padding: 14px; | |
border-radius: 16px; | |
justify-content: space-between; | |
} | |
#main-input { | |
flex-grow: 1; | |
background: #2a2a2a; | |
border: none; | |
padding: 12px; | |
color: white; | |
font-size: 16px; | |
border-radius: 12px; | |
} | |
#send-btn, #mic-btn { | |
background: #3f3fff; | |
color: white; | |
border: none; | |
padding: 12px 16px; | |
border-radius: 12px; | |
font-size: 16px; | |
} | |
""") as app: | |
thread_state = gr.State() | |
client_id = gr.State() | |
voice_visible = gr.State(False) | |
gr.HTML("<h1 style='text-align:center; margin-top:40px;'>How can I help you today?</h1>") | |
output_md = gr.HTML() | |
with gr.Row(elem_classes="input-box"): | |
user_input = gr.Textbox(elem_id="main-input", show_label=False, placeholder="Ask a question...") | |
send_btn = gr.Button("➤", elem_id="send-btn") | |
mic_toggle = gr.Button("🎙", elem_id="mic-btn") | |
with gr.Column(visible=False) as voice_area: | |
mic_audio = gr.Audio(label="Record", streaming=True, type="numpy") | |
mic_transcript = gr.Textbox(label="Transcript", lines=2, interactive=False) | |
mic_send = gr.Button("Send Voice") | |
mic_clear = gr.Button("Clear Transcript") | |
# Bindings | |
send_btn.click(fn=handle_chat, | |
inputs=[user_input, thread_state], | |
outputs=[output_md, thread_state]) | |
mic_toggle.click(fn=lambda v: not v, | |
inputs=voice_visible, | |
outputs=voice_visible) | |
voice_visible.change(fn=None, | |
inputs=voice_visible, | |
outputs=voice_area, | |
show_progress=False) | |
mic_audio.stream(fn=send_audio, | |
inputs=[mic_audio, client_id], | |
outputs=mic_transcript, | |
stream_every=0.5) | |
mic_send.click(fn=feed_transcript, | |
inputs=[mic_transcript, thread_state, client_id], | |
outputs=[output_md, thread_state]) | |
mic_clear.click(fn=clear_transcript, | |
inputs=[client_id], | |
outputs=mic_transcript) | |
app.load(fn=create_ws, outputs=[client_id]) | |
app.launch() | |