Spaces:
Sleeping
Sleeping
import gradio as gr | |
import os | |
import uuid | |
import asyncio | |
import threading | |
import json | |
import base64 | |
import numpy as np | |
import io | |
import soundfile as sf | |
from pydub import AudioSegment | |
from websockets import connect, Data, ClientConnection | |
from dotenv import load_dotenv | |
# ========== WebSocket Client Setup ========== | |
class WebSocketClient: | |
def __init__(self, uri: str, headers: dict, client_id: str): | |
self.uri = uri | |
self.headers = headers | |
self.websocket: ClientConnection = None | |
self.queue = asyncio.Queue(maxsize=10) | |
self.loop = None | |
self.client_id = client_id | |
self.transcript = "" | |
async def connect(self): | |
self.websocket = await connect(self.uri, additional_headers=self.headers) | |
with open("openai_transcription_settings.json", "r") as f: | |
await self.websocket.send(f.read()) | |
await asyncio.gather(self.receive_messages(), self.send_audio_chunks()) | |
def run(self): | |
self.loop = asyncio.new_event_loop() | |
asyncio.set_event_loop(self.loop) | |
self.loop.run_until_complete(self.connect()) | |
def process_websocket_message(self, message: Data): | |
msg = json.loads(message) | |
if msg["type"] == "conversation.item.input_audio_transcription.delta": | |
self.transcript += msg["delta"] | |
elif msg["type"] == "conversation.item.input_audio_transcription.completed": | |
self.transcript += ' ' | |
async def send_audio_chunks(self): | |
while True: | |
sr, audio_array = await self.queue.get() | |
if audio_array.ndim > 1: | |
audio_array = audio_array.mean(axis=1) | |
audio_array = (audio_array / np.max(np.abs(audio_array))) if np.max(np.abs(audio_array)) > 0 else audio_array | |
int16 = (audio_array * 32767).astype(np.int16) | |
buffer = io.BytesIO() | |
sf.write(buffer, int16, sr, format='WAV', subtype='PCM_16') | |
buffer.seek(0) | |
audio = AudioSegment.from_file(buffer, format="wav").set_frame_rate(24000) | |
out = io.BytesIO() | |
audio.export(out, format="wav") | |
out.seek(0) | |
encoded = base64.b64encode(out.read()).decode("utf-8") | |
await self.websocket.send(json.dumps({"type": "input_audio_buffer.append", "audio": encoded})) | |
async def receive_messages(self): | |
async for message in self.websocket: | |
self.process_websocket_message(message) | |
def enqueue_audio_chunk(self, sr, chunk): | |
if not self.queue.full(): | |
asyncio.run_coroutine_threadsafe(self.queue.put((sr, chunk)), self.loop) | |
async def close(self): | |
if self.websocket: | |
await self.websocket.close() | |
# ========== Transcription Helpers ========== | |
connections = {} | |
load_dotenv() | |
OPENAI_API_KEY = os.environ.get("OPENAI_API_KEY") | |
HEADERS = {"Authorization": f"Bearer {OPENAI_API_KEY}", "OpenAI-Beta": "realtime=v1"} | |
URI = "wss://api.openai.com/v1/realtime?intent=transcription" | |
def create_websocket(): | |
cid = str(uuid.uuid4()) | |
connections[cid] = WebSocketClient(URI, HEADERS, cid) | |
threading.Thread(target=connections[cid].run, daemon=True).start() | |
return cid | |
def send_audio(new_chunk, cid): | |
if cid not in connections: | |
return "Starting connection..." | |
sr, y = new_chunk | |
connections[cid].enqueue_audio_chunk(sr, y) | |
return connections[cid].transcript | |
def clear_transcript(cid): | |
if cid in connections: | |
connections[cid].transcript = "" | |
return "" | |
# ========== Gradio UI Layout ========== | |
with gr.Blocks(theme=gr.themes.Soft()) as app: | |
gr.Markdown("# π§ Document AI Assistant with Voice & Viewer") | |
# State | |
client_id = gr.State() | |
# Layout | |
with gr.Row(): | |
# π’ Chat Section (Main) | |
with gr.Column(scale=2): | |
chatbot = gr.Chatbot(label="Chat Assistant") | |
msg = gr.Textbox(label="Ask something", placeholder="e.g., Summarize this document...") | |
send_btn = gr.Button("Send") | |
def chat_response(user_msg, history): | |
history = history or [] | |
reply = f"π€ This is a placeholder reply to: {user_msg}" | |
history.append((user_msg, reply)) | |
return "", history | |
send_btn.click(chat_response, inputs=[msg, chatbot], outputs=[msg, chatbot]) | |
# Smaller widgets section | |
with gr.Column(scale=1): | |
# π‘ Image Viewer | |
viewer = gr.Image(label="π Document Viewer", type="filepath") | |
# π΅ Voice Transcription | |
transcript = gr.Textbox(label="π€ Transcript", lines=5, interactive=False) | |
audio = gr.Audio(label="ποΈ Audio", streaming=True) | |
clear = gr.Button("Clear Transcript") | |
audio.stream(fn=send_audio, inputs=[audio, client_id], outputs=transcript, stream_every=0.5) | |
clear.click(fn=clear_transcript, inputs=[client_id], outputs=transcript) | |
app.load(create_websocket, outputs=client_id) | |
app.launch() | |