File size: 12,544 Bytes
97a4ae5
a905808
ed08f62
 
 
 
 
 
 
a905808
ed08f62
 
 
17cb251
ed08f62
 
 
a905808
ed08f62
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a905808
ed08f62
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a905808
ed08f62
 
 
a905808
ed08f62
 
 
 
a905808
ed08f62
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a905808
ed08f62
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a905808
ed08f62
 
97a4ae5
ed08f62
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4641c1c
97a4ae5
 
ed08f62
97a4ae5
 
ed08f62
 
97a4ae5
ed08f62
 
 
 
97a4ae5
ed08f62
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
97a4ae5
ed08f62
 
 
 
 
 
17cb251
ed08f62
 
 
 
 
10b8972
ed08f62
 
 
 
 
 
 
 
 
 
 
 
 
 
 
97a4ae5
ed08f62
 
a905808
ed08f62
97a4ae5
ed08f62
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
import gradio as gr
import json
import time
import os
import asyncio
import websockets
import logging
from fastrtc import RTCComponent
import threading

# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

RENDER_SIGNALING_URL = os.getenv("RENDER_SIGNALING_URL", "wss://render-signal-audio.onrender.com/stream")
HF_SPACE_URL = os.getenv("HF_SPACE_URL", "androidguy-speaker-diarization.hf.space")
WS_TRANSCRIPTION_URL = f"wss://{HF_SPACE_URL}/ws_transcription"

class TranscriptionClient:
    """Client to handle WebSocket connection to transcription service"""
    
    def __init__(self, url, on_message, on_error=None, on_close=None):
        self.url = url
        self.on_message = on_message
        self.on_error = on_error or (lambda e: logger.error(f"WebSocket error: {e}"))
        self.on_close = on_close or (lambda: logger.info("WebSocket closed"))
        self.ws = None
        self.running = False
        self.connected = False
        self.reconnect_task = None
        self.thread = None

    async def connect_async(self):
        """Connect to WebSocket server asynchronously"""
        try:
            self.ws = await websockets.connect(self.url)
            self.connected = True
            logger.info(f"Connected to {self.url}")
            
            # Start listening for messages
            while self.running:
                try:
                    message = await self.ws.recv()
                    self.on_message(message)
                except websockets.exceptions.ConnectionClosed:
                    logger.warning("Connection closed")
                    self.connected = False
                    break
                except Exception as e:
                    self.on_error(e)
                    break
                    
            # Handle connection closed
            self.connected = False
            self.on_close()
            
        except Exception as e:
            logger.error(f"Connection error: {e}")
            self.connected = False
            self.on_error(e)
    
    def connect(self):
        """Start connection in a separate thread"""
        if self.running:
            return
            
        self.running = True
        
        def run_async_loop():
            loop = asyncio.new_event_loop()
            asyncio.set_event_loop(loop)
            loop.run_until_complete(self.connect_async())
            loop.close()
        
        self.thread = threading.Thread(target=run_async_loop, daemon=True)
        self.thread.start()
    
    def disconnect(self):
        """Disconnect from WebSocket server"""
        self.running = False
        if self.ws:
            asyncio.run(self.ws.close())
            self.ws = None
        self.connected = False

class SpeakerDiarizationUI:
    """Main UI for speaker diarization"""
    
    def __init__(self):
        self.transcription_client = None
        self.conversation_html = ""
        self.system_status = {"status": "disconnected"}
        self.webrtc_state = "stopped"
        
    def handle_transcription_message(self, message):
        """Handle incoming transcription messages"""
        try:
            data = json.loads(message)
            message_type = data.get("type", "unknown")
            
            if message_type == "conversation_update":
                # Update the conversation display
                self.conversation_html = data.get("conversation_html", "")
                # Update system status if available
                if "status" in data:
                    self.system_status = data.get("status", {})
            
            elif message_type == "connection":
                # Handle connection status update
                status = data.get("status", "unknown")
                logger.info(f"Connection status: {status}")
                if "hf_space_status" in data:
                    self.system_status["hf_space_status"] = data["hf_space_status"]
            
            elif message_type == "error":
                # Handle error message
                error_msg = data.get("message", "Unknown error")
                logger.error(f"Error from server: {error_msg}")
                
        except json.JSONDecodeError:
            logger.warning(f"Received invalid JSON: {message}")
        except Exception as e:
            logger.error(f"Error handling message: {e}")
    
    def handle_transcription_error(self, error):
        """Handle WebSocket errors"""
        logger.error(f"WebSocket error: {error}")
    
    def handle_transcription_close(self):
        """Handle WebSocket connection closure"""
        logger.info("WebSocket connection closed")
        self.system_status["status"] = "disconnected"
    
    def connect_to_transcription(self):
        """Connect to transcription WebSocket"""
        if self.transcription_client and self.transcription_client.connected:
            return
            
        self.transcription_client = TranscriptionClient(
            url=WS_TRANSCRIPTION_URL,
            on_message=self.handle_transcription_message,
            on_error=self.handle_transcription_error,
            on_close=self.handle_transcription_close
        )
        self.transcription_client.connect()
    
    def disconnect_from_transcription(self):
        """Disconnect from transcription WebSocket"""
        if self.transcription_client:
            self.transcription_client.disconnect()
            self.transcription_client = None
    
    def start_listening(self):
        """Start listening to audio and connect to services"""
        self.connect_to_transcription()
        self.webrtc_state = "started"
        return {
            webrtc: gr.update(streaming=True),
            status_display: gr.update(value=f"Status: Connected and listening"),
            start_button: gr.update(visible=False),
            stop_button: gr.update(visible=True),
            clear_button: gr.update(visible=True)
        }
    
    def stop_listening(self):
        """Stop listening to audio and disconnect from services"""
        self.disconnect_from_transcription()
        self.webrtc_state = "stopped"
        return {
            webrtc: gr.update(streaming=False),
            status_display: gr.update(value=f"Status: Disconnected"),
            start_button: gr.update(visible=True),
            stop_button: gr.update(visible=False),
            clear_button: gr.update(visible=True)
        }
    
    def clear_conversation(self):
        """Clear the conversation display"""
        # Call API to clear conversation
        import requests
        try:
            response = requests.post(f"https://{HF_SPACE_URL}/clear")
            if response.status_code == 200:
                logger.info("Conversation cleared")
            else:
                logger.error(f"Failed to clear conversation: {response.status_code}")
        except Exception as e:
            logger.error(f"Error clearing conversation: {e}")
        
        # Clear local display
        self.conversation_html = ""
        return {
            conversation_display: gr.update(value="<div class='conversation-container'><p>Conversation cleared</p></div>")
        }
    
    def update_display(self):
        """Update conversation display - called periodically"""
        status_text = f"Status: "
        if self.webrtc_state == "started":
            status_text += "Connected and listening"
        else:
            status_text += "Disconnected"
            
        if self.system_status.get("hf_space_status"):
            status_text += f" | HF Space: {self.system_status['hf_space_status']}"
            
        return {
            conversation_display: gr.update(value=self.conversation_html if self.conversation_html else "<div class='conversation-container'><p>No conversation yet. Start speaking to begin transcription.</p></div>"),
            status_display: gr.update(value=status_text)
        }

# Create UI instance
ui = SpeakerDiarizationUI()

# Custom CSS for better styling
css = """
.conversation-container {
    border: 1px solid #ddd;
    border-radius: 10px;
    padding: 15px;
    margin-bottom: 10px;
    max-height: 500px;
    overflow-y: auto;
    background-color: white;
}
.speaker {
    margin-bottom: 12px;
    border-radius: 8px;
    padding: 8px 12px;
}
.speaker-label {
    font-weight: bold;
    margin-bottom: 5px;
}
.status-display {
    margin-top: 10px;
    padding: 5px 10px;
    background-color: #f0f0f0;
    border-radius: 5px;
    font-size: 0.9rem;
}
"""

# Create Gradio interface as a function to avoid clashing
def create_interface():
    with gr.Blocks(css=css) as interface:
        gr.Markdown("# Real-Time Speaker Diarization")
        gr.Markdown("This app performs real-time speaker diarization on your audio. It automatically transcribes speech and identifies different speakers.")
        
        with gr.Row():
            with gr.Column(scale=2):
                conversation_display = gr.HTML("<div class='conversation-container'><p>No conversation yet. Start speaking to begin transcription.</p></div>")
            
            with gr.Column(scale=1):
                status_display = gr.Markdown("Status: Disconnected", elem_classes=["status-display"])
                webrtc = RTCComponent(url=RENDER_SIGNALING_URL, streaming=False, modality="audio", mode="send-receive")
                
                with gr.Row():
                    start_button = gr.Button("Start Listening", variant="primary")
                    stop_button = gr.Button("Stop Listening", variant="secondary", visible=False)
                    clear_button = gr.Button("Clear Conversation", visible=True)
                
                with gr.Accordion("Advanced Settings", open=False):
                    speaker_threshold = gr.Slider(0.5, 0.9, value=0.65, label="Speaker Change Threshold")
                    max_speakers = gr.Slider(2, 8, value=4, step=1, label="Maximum Number of Speakers")
                    
                    def update_settings(threshold, speakers):
                        import requests
                        try:
                            response = requests.post(
                                f"https://{HF_SPACE_URL}/settings",
                                params={"threshold": threshold, "max_speakers": speakers}
                            )
                            if response.status_code == 200:
                                return gr.update(value=f"Settings updated: Threshold={threshold}, Max Speakers={speakers}")
                            else:
                                return gr.update(value=f"Failed to update settings: {response.status_code}")
                        except Exception as e:
                            return gr.update(value=f"Error updating settings: {e}")
                    
                    settings_button = gr.Button("Update Settings")
                    settings_status = gr.Markdown("", elem_classes=["status-display"])
                    
                    settings_button.click(
                        update_settings,
                        [speaker_threshold, max_speakers],
                        [settings_status]
                    )
        
        # Set up event handlers
        start_button.click(
            ui.start_listening,
            [],
            [webrtc, status_display, start_button, stop_button, clear_button]
        )
        
        stop_button.click(
            ui.stop_listening,
            [],
            [webrtc, status_display, start_button, stop_button, clear_button]
        )
        
        clear_button.click(
            ui.clear_conversation,
            [],
            [conversation_display]
        )
        
        # Periodic update every 0.5 seconds
        interface.load(
            ui.update_display,
            [],
            [conversation_display, status_display],
            every=0.5
        )
        
        return interface

# Global interface instance
interface = None

# Launch the app
if __name__ == "__main__":
    interface = create_interface()
    interface.launch()

# Add mount_ui function for integration with FastAPI
def mount_ui(app):
    """Mount the Gradio interface at /ui path of the FastAPI app"""
    global interface
    
    # Create interface if it doesn't exist yet
    if interface is None:
        interface = create_interface()
    
    # Mount Gradio app at /ui path
    interface.mount_in_app(app, path="/ui")
    
    return app