IAMTFRMZA's picture
Update app.py
b106fa1 verified
raw
history blame
5.06 kB
import gradio as gr
import os
import uuid
import asyncio
import threading
import json
import base64
import numpy as np
import io
import soundfile as sf
from pydub import AudioSegment
from websockets import connect, Data, ClientConnection
from dotenv import load_dotenv
# ========== WebSocket Client Setup ==========
class WebSocketClient:
def __init__(self, uri: str, headers: dict, client_id: str):
self.uri = uri
self.headers = headers
self.websocket: ClientConnection = None
self.queue = asyncio.Queue(maxsize=10)
self.loop = None
self.client_id = client_id
self.transcript = ""
async def connect(self):
self.websocket = await connect(self.uri, additional_headers=self.headers)
with open("openai_transcription_settings.json", "r") as f:
await self.websocket.send(f.read())
await asyncio.gather(self.receive_messages(), self.send_audio_chunks())
def run(self):
self.loop = asyncio.new_event_loop()
asyncio.set_event_loop(self.loop)
self.loop.run_until_complete(self.connect())
def process_websocket_message(self, message: Data):
msg = json.loads(message)
if msg["type"] == "conversation.item.input_audio_transcription.delta":
self.transcript += msg["delta"]
elif msg["type"] == "conversation.item.input_audio_transcription.completed":
self.transcript += ' '
async def send_audio_chunks(self):
while True:
sr, audio_array = await self.queue.get()
if audio_array.ndim > 1:
audio_array = audio_array.mean(axis=1)
audio_array = (audio_array / np.max(np.abs(audio_array))) if np.max(np.abs(audio_array)) > 0 else audio_array
int16 = (audio_array * 32767).astype(np.int16)
buffer = io.BytesIO()
sf.write(buffer, int16, sr, format='WAV', subtype='PCM_16')
buffer.seek(0)
audio = AudioSegment.from_file(buffer, format="wav").set_frame_rate(24000)
out = io.BytesIO()
audio.export(out, format="wav")
out.seek(0)
encoded = base64.b64encode(out.read()).decode("utf-8")
await self.websocket.send(json.dumps({"type": "input_audio_buffer.append", "audio": encoded}))
async def receive_messages(self):
async for message in self.websocket:
self.process_websocket_message(message)
def enqueue_audio_chunk(self, sr, chunk):
if not self.queue.full():
asyncio.run_coroutine_threadsafe(self.queue.put((sr, chunk)), self.loop)
async def close(self):
if self.websocket:
await self.websocket.close()
# ========== Transcription Helpers ==========
connections = {}
load_dotenv()
OPENAI_API_KEY = os.environ.get("OPENAI_API_KEY")
HEADERS = {"Authorization": f"Bearer {OPENAI_API_KEY}", "OpenAI-Beta": "realtime=v1"}
URI = "wss://api.openai.com/v1/realtime?intent=transcription"
def create_websocket():
cid = str(uuid.uuid4())
connections[cid] = WebSocketClient(URI, HEADERS, cid)
threading.Thread(target=connections[cid].run, daemon=True).start()
return cid
def send_audio(new_chunk, cid):
if cid not in connections:
return "Starting connection..."
sr, y = new_chunk
connections[cid].enqueue_audio_chunk(sr, y)
return connections[cid].transcript
def clear_transcript(cid):
if cid in connections:
connections[cid].transcript = ""
return ""
# ========== Gradio UI Layout ==========
with gr.Blocks(theme=gr.themes.Soft()) as app:
gr.Markdown("# 🧠 Document AI Assistant with Voice & Viewer")
# State
client_id = gr.State()
# Layout
with gr.Row():
# 🟒 Chat Section (Main)
with gr.Column(scale=2):
chatbot = gr.Chatbot(label="Chat Assistant")
msg = gr.Textbox(label="Ask something", placeholder="e.g., Summarize this document...")
send_btn = gr.Button("Send")
def chat_response(user_msg, history):
history = history or []
reply = f"πŸ€– This is a placeholder reply to: {user_msg}"
history.append((user_msg, reply))
return "", history
send_btn.click(chat_response, inputs=[msg, chatbot], outputs=[msg, chatbot])
# Smaller widgets section
with gr.Column(scale=1):
# 🟑 Image Viewer
viewer = gr.Image(label="πŸ“„ Document Viewer", type="filepath")
# πŸ”΅ Voice Transcription
transcript = gr.Textbox(label="🎀 Transcript", lines=5, interactive=False)
audio = gr.Audio(label="πŸŽ™οΈ Audio", streaming=True)
clear = gr.Button("Clear Transcript")
audio.stream(fn=send_audio, inputs=[audio, client_id], outputs=transcript, stream_every=0.5)
clear.click(fn=clear_transcript, inputs=[client_id], outputs=transcript)
app.load(create_websocket, outputs=client_id)
app.launch()