IAMTFRMZA commited on
Commit
65be86f
·
verified ·
1 Parent(s): b106fa1

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +100 -92
app.py CHANGED
@@ -1,26 +1,28 @@
1
  import gradio as gr
2
- import os
3
- import uuid
4
- import asyncio
5
- import threading
6
- import json
7
- import base64
8
  import numpy as np
9
- import io
10
  import soundfile as sf
11
  from pydub import AudioSegment
 
12
  from websockets import connect, Data, ClientConnection
13
  from dotenv import load_dotenv
14
 
15
- # ========== WebSocket Client Setup ==========
 
 
 
 
 
 
 
 
 
 
16
  class WebSocketClient:
17
- def __init__(self, uri: str, headers: dict, client_id: str):
18
- self.uri = uri
19
- self.headers = headers
20
- self.websocket: ClientConnection = None
21
  self.queue = asyncio.Queue(maxsize=10)
22
- self.loop = None
23
- self.client_id = client_id
24
  self.transcript = ""
25
 
26
  async def connect(self):
@@ -30,108 +32,114 @@ class WebSocketClient:
30
  await asyncio.gather(self.receive_messages(), self.send_audio_chunks())
31
 
32
  def run(self):
33
- self.loop = asyncio.new_event_loop()
34
- asyncio.set_event_loop(self.loop)
35
- self.loop.run_until_complete(self.connect())
36
-
37
- def process_websocket_message(self, message: Data):
38
- msg = json.loads(message)
39
- if msg["type"] == "conversation.item.input_audio_transcription.delta":
40
- self.transcript += msg["delta"]
41
- elif msg["type"] == "conversation.item.input_audio_transcription.completed":
42
- self.transcript += ' '
43
 
44
  async def send_audio_chunks(self):
45
  while True:
46
- sr, audio_array = await self.queue.get()
47
- if audio_array.ndim > 1:
48
- audio_array = audio_array.mean(axis=1)
49
- audio_array = (audio_array / np.max(np.abs(audio_array))) if np.max(np.abs(audio_array)) > 0 else audio_array
50
- int16 = (audio_array * 32767).astype(np.int16)
51
- buffer = io.BytesIO()
52
- sf.write(buffer, int16, sr, format='WAV', subtype='PCM_16')
53
- buffer.seek(0)
54
- audio = AudioSegment.from_file(buffer, format="wav").set_frame_rate(24000)
55
- out = io.BytesIO()
56
- audio.export(out, format="wav")
57
- out.seek(0)
58
- encoded = base64.b64encode(out.read()).decode("utf-8")
59
- await self.websocket.send(json.dumps({"type": "input_audio_buffer.append", "audio": encoded}))
60
 
61
  async def receive_messages(self):
62
- async for message in self.websocket:
63
- self.process_websocket_message(message)
 
 
64
 
65
- def enqueue_audio_chunk(self, sr, chunk):
66
  if not self.queue.full():
67
- asyncio.run_coroutine_threadsafe(self.queue.put((sr, chunk)), self.loop)
68
-
69
- async def close(self):
70
- if self.websocket:
71
- await self.websocket.close()
72
 
73
-
74
- # ========== Transcription Helpers ==========
75
- connections = {}
76
- load_dotenv()
77
- OPENAI_API_KEY = os.environ.get("OPENAI_API_KEY")
78
- HEADERS = {"Authorization": f"Bearer {OPENAI_API_KEY}", "OpenAI-Beta": "realtime=v1"}
79
- URI = "wss://api.openai.com/v1/realtime?intent=transcription"
80
-
81
- def create_websocket():
82
  cid = str(uuid.uuid4())
83
- connections[cid] = WebSocketClient(URI, HEADERS, cid)
84
- threading.Thread(target=connections[cid].run, daemon=True).start()
 
85
  return cid
86
 
87
- def send_audio(new_chunk, cid):
88
- if cid not in connections:
89
- return "Starting connection..."
90
- sr, y = new_chunk
91
- connections[cid].enqueue_audio_chunk(sr, y)
92
  return connections[cid].transcript
93
 
94
  def clear_transcript(cid):
95
- if cid in connections:
96
- connections[cid].transcript = ""
97
  return ""
98
 
99
- # ========== Gradio UI Layout ==========
100
- with gr.Blocks(theme=gr.themes.Soft()) as app:
 
 
 
 
 
 
 
101
 
102
- gr.Markdown("# 🧠 Document AI Assistant with Voice & Viewer")
 
103
 
104
- # State
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
105
  client_id = gr.State()
106
 
107
- # Layout
108
  with gr.Row():
109
- # 🟢 Chat Section (Main)
110
- with gr.Column(scale=2):
111
- chatbot = gr.Chatbot(label="Chat Assistant")
112
- msg = gr.Textbox(label="Ask something", placeholder="e.g., Summarize this document...")
113
- send_btn = gr.Button("Send")
114
-
115
- def chat_response(user_msg, history):
116
- history = history or []
117
- reply = f"🤖 This is a placeholder reply to: {user_msg}"
118
- history.append((user_msg, reply))
119
- return "", history
120
 
121
- send_btn.click(chat_response, inputs=[msg, chatbot], outputs=[msg, chatbot])
 
 
 
122
 
123
- # Smaller widgets section
124
- with gr.Column(scale=1):
125
- # 🟡 Image Viewer
126
- viewer = gr.Image(label="📄 Document Viewer", type="filepath")
 
127
 
128
- # 🔵 Voice Transcription
129
- transcript = gr.Textbox(label="🎤 Transcript", lines=5, interactive=False)
130
- audio = gr.Audio(label="🎙️ Audio", streaming=True)
131
- clear = gr.Button("Clear Transcript")
132
 
133
- audio.stream(fn=send_audio, inputs=[audio, client_id], outputs=transcript, stream_every=0.5)
134
- clear.click(fn=clear_transcript, inputs=[client_id], outputs=transcript)
135
- app.load(create_websocket, outputs=client_id)
 
136
 
137
  app.launch()
 
1
  import gradio as gr
2
+ import os, time, re, json, base64, asyncio, threading, uuid, io
 
 
 
 
 
3
  import numpy as np
 
4
  import soundfile as sf
5
  from pydub import AudioSegment
6
+ from openai import OpenAI
7
  from websockets import connect, Data, ClientConnection
8
  from dotenv import load_dotenv
9
 
10
+ # ---------------- Environment & Client Setup ----------------
11
+ load_dotenv()
12
+ OPENAI_API_KEY = os.getenv("OPENAI_API_KEY")
13
+ ASSISTANT_ID = os.getenv("ASSISTANT_ID")
14
+ client = OpenAI(api_key=OPENAI_API_KEY)
15
+
16
+ HEADERS = {"Authorization": f"Bearer {OPENAI_API_KEY}", "OpenAI-Beta": "realtime=v1"}
17
+ WS_URI = "wss://api.openai.com/v1/realtime?intent=transcription"
18
+ connections = {}
19
+
20
+ # ---------------- WebSocket Client for Voice ----------------
21
  class WebSocketClient:
22
+ def __init__(self, uri, headers, client_id):
23
+ self.uri, self.headers, self.client_id = uri, headers, client_id
24
+ self.websocket = None
 
25
  self.queue = asyncio.Queue(maxsize=10)
 
 
26
  self.transcript = ""
27
 
28
  async def connect(self):
 
32
  await asyncio.gather(self.receive_messages(), self.send_audio_chunks())
33
 
34
  def run(self):
35
+ loop = asyncio.new_event_loop()
36
+ asyncio.set_event_loop(loop)
37
+ loop.run_until_complete(self.connect())
 
 
 
 
 
 
 
38
 
39
  async def send_audio_chunks(self):
40
  while True:
41
+ sr, arr = await self.queue.get()
42
+ if arr.ndim > 1: arr = arr.mean(axis=1)
43
+ arr = (arr / np.max(np.abs(arr))) if np.max(np.abs(arr)) > 0 else arr
44
+ int16 = (arr * 32767).astype(np.int16)
45
+ buf = io.BytesIO(); sf.write(buf, int16, sr, format='WAV', subtype='PCM_16')
46
+ audio = AudioSegment.from_file(buf, format="wav").set_frame_rate(24000)
47
+ out = io.BytesIO(); audio.export(out, format="wav"); out.seek(0)
48
+ await self.websocket.send(json.dumps({"type": "input_audio_buffer.append", "audio": base64.b64encode(out.read()).decode()}))
 
 
 
 
 
 
49
 
50
  async def receive_messages(self):
51
+ async for msg in self.websocket:
52
+ data = json.loads(msg)
53
+ if data["type"] == "conversation.item.input_audio_transcription.delta":
54
+ self.transcript += data["delta"]
55
 
56
+ def enqueue_audio_chunk(self, sr, arr):
57
  if not self.queue.full():
58
+ asyncio.run_coroutine_threadsafe(self.queue.put((sr, arr)), asyncio.get_event_loop())
 
 
 
 
59
 
60
+ def create_ws():
 
 
 
 
 
 
 
 
61
  cid = str(uuid.uuid4())
62
+ client = WebSocketClient(WS_URI, HEADERS, cid)
63
+ threading.Thread(target=client.run, daemon=True).start()
64
+ connections[cid] = client
65
  return cid
66
 
67
+ def send_audio(chunk, cid):
68
+ if cid not in connections: return "Connecting..."
69
+ sr, arr = chunk
70
+ connections[cid].enqueue_audio_chunk(sr, arr)
 
71
  return connections[cid].transcript
72
 
73
  def clear_transcript(cid):
74
+ if cid in connections: connections[cid].transcript = ""
 
75
  return ""
76
 
77
+ # ---------------- Chat Functionality ----------------
78
+ def handle_chat(user_input, history, thread_id, image_url):
79
+ if not OPENAI_API_KEY or not ASSISTANT_ID:
80
+ return "❌ Missing secrets!", history, thread_id, image_url
81
+
82
+ try:
83
+ if thread_id is None:
84
+ thread = client.beta.threads.create()
85
+ thread_id = thread.id
86
 
87
+ client.beta.threads.messages.create(thread_id=thread_id, role="user", content=user_input)
88
+ run = client.beta.threads.runs.create(thread_id=thread_id, assistant_id=ASSISTANT_ID)
89
 
90
+ while True:
91
+ status = client.beta.threads.runs.retrieve(thread_id=thread_id, run_id=run.id)
92
+ if status.status == "completed": break
93
+ time.sleep(1)
94
+
95
+ msgs = client.beta.threads.messages.list(thread_id=thread_id)
96
+ for msg in reversed(msgs.data):
97
+ if msg.role == "assistant":
98
+ content = msg.content[0].text.value
99
+ history.append((user_input, content))
100
+ match = re.search(r'https://raw\.githubusercontent\.com/AndrewLORTech/surgical-pathology-manual/main/[\w\-/]*\.png', content)
101
+ if match: image_url = match.group(0)
102
+ break
103
+
104
+ return "", history, thread_id, image_url
105
+
106
+ except Exception as e:
107
+ return f"❌ {e}", history, thread_id, image_url
108
+
109
+ # ---------------- Gradio UI Layout ----------------
110
+ with gr.Blocks(theme=gr.themes.Soft()) as app:
111
+ gr.Markdown("# 📄 Document AI Assistant")
112
+
113
+ # STATES
114
+ chat_state = gr.State([])
115
+ thread_state = gr.State()
116
+ image_state = gr.State()
117
  client_id = gr.State()
118
 
 
119
  with gr.Row():
120
+ with gr.Column(scale=1):
121
+ # IMAGE VIEWER (left)
122
+ image_display = gr.Image(label="🖼️ Document", type="filepath")
 
 
 
 
 
 
 
 
123
 
124
+ # VOICE (under)
125
+ voice_transcript = gr.Textbox(label="🎙️ Transcript", lines=4, interactive=False)
126
+ voice_input = gr.Audio(label="🔴 Record", streaming=True)
127
+ clear_btn = gr.Button("🧹 Clear Transcript")
128
 
129
+ with gr.Column(scale=2):
130
+ # CHATBOT (right)
131
+ chat = gr.Chatbot(label="💬 Chat", height=450)
132
+ user_prompt = gr.Textbox(show_label=False, placeholder="Ask your question...")
133
+ send_btn = gr.Button("Send")
134
 
135
+ # HANDLERS
136
+ send_btn.click(handle_chat,
137
+ inputs=[user_prompt, chat_state, thread_state, image_state],
138
+ outputs=[user_prompt, chat, thread_state, image_state])
139
 
140
+ image_state.change(fn=lambda x: x, inputs=image_state, outputs=image_display)
141
+ voice_input.stream(fn=send_audio, inputs=[voice_input, client_id], outputs=voice_transcript, stream_every=0.5)
142
+ clear_btn.click(fn=clear_transcript, inputs=[client_id], outputs=voice_transcript)
143
+ app.load(create_ws, outputs=[client_id])
144
 
145
  app.launch()