Spaces:
Sleeping
Sleeping
from fastapi import APIRouter, WebSocket, WebSocketDisconnect, Depends, status | |
from typing import List, Dict | |
import logging | |
from datetime import datetime | |
import asyncio | |
import json | |
import os | |
from dotenv import load_dotenv | |
from app.database.mongodb import session_collection | |
from app.utils.utils import get_local_time | |
# Load environment variables | |
load_dotenv() | |
# Get WebSocket configuration from environment variables | |
WEBSOCKET_SERVER = os.getenv("WEBSOCKET_SERVER", "localhost") | |
WEBSOCKET_PORT = os.getenv("WEBSOCKET_PORT", "7860") | |
WEBSOCKET_PATH = os.getenv("WEBSOCKET_PATH", "/notify") | |
# Configure logging | |
logger = logging.getLogger(__name__) | |
# Create router | |
router = APIRouter( | |
tags=["WebSocket"], | |
) | |
# Store active WebSocket connections | |
class ConnectionManager: | |
def __init__(self): | |
self.active_connections: List[WebSocket] = [] | |
async def connect(self, websocket: WebSocket): | |
await websocket.accept() | |
self.active_connections.append(websocket) | |
client_info = f"{websocket.client.host}:{websocket.client.port}" if hasattr(websocket, 'client') else "Unknown" | |
logger.info(f"New WebSocket connection from {client_info}. Total connections: {len(self.active_connections)}") | |
def disconnect(self, websocket: WebSocket): | |
self.active_connections.remove(websocket) | |
logger.info(f"WebSocket connection removed. Total connections: {len(self.active_connections)}") | |
async def broadcast(self, message: Dict): | |
if not self.active_connections: | |
logger.warning("No active WebSocket connections to broadcast to") | |
return | |
disconnected = [] | |
for connection in self.active_connections: | |
try: | |
await connection.send_json(message) | |
logger.info(f"Message sent to WebSocket connection") | |
except Exception as e: | |
logger.error(f"Error sending message to WebSocket: {e}") | |
disconnected.append(connection) | |
# Remove disconnected connections | |
for conn in disconnected: | |
if conn in self.active_connections: | |
self.active_connections.remove(conn) | |
logger.info(f"Removed disconnected WebSocket. Remaining: {len(self.active_connections)}") | |
# Initialize connection manager | |
manager = ConnectionManager() | |
# Create full URL of WebSocket server from environment variables | |
def get_full_websocket_url(server_side=False): | |
if server_side: | |
# Relative URL (for server side) | |
return WEBSOCKET_PATH | |
else: | |
# Full URL (for client) | |
# Check if should use wss:// for HTTPS | |
is_https = True if int(WEBSOCKET_PORT) == 443 else False | |
protocol = "wss" if is_https else "ws" | |
# If using default port for protocol, don't include in URL | |
if (is_https and int(WEBSOCKET_PORT) == 443) or (not is_https and int(WEBSOCKET_PORT) == 80): | |
return f"{protocol}://{WEBSOCKET_SERVER}{WEBSOCKET_PATH}" | |
else: | |
return f"{protocol}://{WEBSOCKET_SERVER}:{WEBSOCKET_PORT}{WEBSOCKET_PATH}" | |
# Add GET endpoint to display WebSocket information in Swagger | |
async def websocket_documentation(): | |
""" | |
Provides information about how to use the WebSocket endpoint /notify. | |
This endpoint is for documentation purposes only. To use WebSocket, please connect to the WebSocket URL. | |
""" | |
ws_url = get_full_websocket_url() | |
return { | |
"websocket_endpoint": WEBSOCKET_PATH, | |
"connection_type": "WebSocket", | |
"protocol": "ws://", | |
"server": WEBSOCKET_SERVER, | |
"port": WEBSOCKET_PORT, | |
"full_url": ws_url, | |
"description": "Endpoint to receive notifications about new sessions requiring attention", | |
"notification_format": { | |
"type": "sorry_response", | |
"timestamp": "YYYY-MM-DD HH:MM:SS", | |
"data": { | |
"session_id": "session id", | |
"factor": "user", | |
"action": "action type", | |
"message": "User question", | |
"response": "I'm sorry...", | |
"user_id": "user id", | |
"first_name": "user's first name", | |
"last_name": "user's last name", | |
"username": "username", | |
"created_at": "creation time" | |
} | |
}, | |
"client_example": """ | |
import websocket | |
import json | |
import os | |
import time | |
import threading | |
from dotenv import load_dotenv | |
# Load environment variables | |
load_dotenv() | |
# Get WebSocket configuration from environment variables | |
WEBSOCKET_SERVER = os.getenv("WEBSOCKET_SERVER", "localhost") | |
WEBSOCKET_PORT = os.getenv("WEBSOCKET_PORT", "7860") | |
WEBSOCKET_PATH = os.getenv("WEBSOCKET_PATH", "/notify") | |
# Create full URL | |
ws_url = f"ws://{WEBSOCKET_SERVER}:{WEBSOCKET_PORT}{WEBSOCKET_PATH}" | |
# If using HTTPS, replace ws:// with wss:// | |
# ws_url = f"wss://{WEBSOCKET_SERVER}{WEBSOCKET_PATH}" | |
# Send keepalive periodically | |
def send_keepalive(ws): | |
while True: | |
try: | |
if ws.sock and ws.sock.connected: | |
ws.send("keepalive") | |
print("Sent keepalive message") | |
time.sleep(300) # 5 minutes | |
except Exception as e: | |
print(f"Error sending keepalive: {e}") | |
time.sleep(60) | |
def on_message(ws, message): | |
try: | |
data = json.loads(message) | |
print(f"Received notification: {data}") | |
# Process notification, e.g.: send to Telegram Admin | |
if data.get("type") == "sorry_response": | |
session_data = data.get("data", {}) | |
user_question = session_data.get("message", "") | |
user_name = session_data.get("first_name", "Unknown User") | |
print(f"User {user_name} asked: {user_question}") | |
# Code to send message to Telegram Admin | |
except json.JSONDecodeError: | |
print(f"Received non-JSON message: {message}") | |
except Exception as e: | |
print(f"Error processing message: {e}") | |
def on_error(ws, error): | |
print(f"WebSocket error: {error}") | |
def on_close(ws, close_status_code, close_msg): | |
print(f"WebSocket connection closed: code={close_status_code}, message={close_msg}") | |
def on_open(ws): | |
print(f"WebSocket connection opened to {ws_url}") | |
# Send keepalive messages periodically in a separate thread | |
keepalive_thread = threading.Thread(target=send_keepalive, args=(ws,), daemon=True) | |
keepalive_thread.start() | |
def run_forever_with_reconnect(): | |
while True: | |
try: | |
# Connect WebSocket with ping to maintain connection | |
ws = websocket.WebSocketApp( | |
ws_url, | |
on_open=on_open, | |
on_message=on_message, | |
on_error=on_error, | |
on_close=on_close | |
) | |
ws.run_forever(ping_interval=60, ping_timeout=30) | |
print("WebSocket connection lost, reconnecting in 5 seconds...") | |
time.sleep(5) | |
except Exception as e: | |
print(f"WebSocket connection error: {e}") | |
time.sleep(5) | |
# Start WebSocket client in a separate thread | |
websocket_thread = threading.Thread(target=run_forever_with_reconnect, daemon=True) | |
websocket_thread.start() | |
# Keep the program running | |
try: | |
while True: | |
time.sleep(1) | |
except KeyboardInterrupt: | |
print("Stopping WebSocket client...") | |
""" | |
} | |
async def websocket_endpoint(websocket: WebSocket): | |
""" | |
WebSocket endpoint to receive notifications about new sessions. | |
Admin Bot will connect to this endpoint to receive notifications when there are new sessions requiring attention. | |
""" | |
await manager.connect(websocket) | |
try: | |
# Keep track of last activity time to prevent connection timeouts | |
last_activity = datetime.now() | |
# Set up a background ping task | |
async def send_periodic_ping(): | |
try: | |
while True: | |
# Send ping every 20 seconds if no other activity | |
await asyncio.sleep(20) | |
current_time = datetime.now() | |
time_since_activity = (current_time - last_activity).total_seconds() | |
# Only send ping if there's been no activity for 15+ seconds | |
if time_since_activity > 15: | |
logger.debug("Sending ping to client to keep connection alive") | |
await websocket.send_json({"type": "ping", "timestamp": current_time.isoformat()}) | |
except asyncio.CancelledError: | |
# Task was cancelled, just exit quietly | |
pass | |
except Exception as e: | |
logger.error(f"Error in ping task: {e}") | |
# Start ping task | |
ping_task = asyncio.create_task(send_periodic_ping()) | |
# Main message loop | |
while True: | |
# Update last activity time | |
last_activity = datetime.now() | |
# Maintain WebSocket connection | |
data = await websocket.receive_text() | |
# Echo back to keep connection active | |
await websocket.send_json({ | |
"status": "connected", | |
"echo": data, | |
"timestamp": last_activity.isoformat() | |
}) | |
logger.info(f"Received message from WebSocket: {data}") | |
except WebSocketDisconnect: | |
logger.info("WebSocket client disconnected") | |
except Exception as e: | |
logger.error(f"WebSocket error: {e}") | |
finally: | |
# Always clean up properly | |
manager.disconnect(websocket) | |
# Cancel ping task if it's still running | |
try: | |
ping_task.cancel() | |
await ping_task | |
except (UnboundLocalError, asyncio.CancelledError): | |
# ping_task wasn't created or already cancelled | |
pass | |
# Function to send notifications over WebSocket | |
async def send_notification(data: dict): | |
""" | |
Send notification to all active WebSocket connections. | |
This function is used to notify admin bots about new issues or questions that need attention. | |
It's triggered when the system cannot answer a user's question (response starts with "I'm sorry"). | |
Args: | |
data: The data to send as notification | |
""" | |
try: | |
# Log number of active connections and notification attempt | |
logger.info(f"Attempting to send notification. Active connections: {len(manager.active_connections)}") | |
logger.info(f"Notification data: session_id={data.get('session_id')}, user_id={data.get('user_id')}") | |
logger.info(f"Response: {data.get('response', '')[:50]}...") | |
# Check if the response starts with "I'm sorry" | |
response = data.get('response', '') | |
if not response or not isinstance(response, str): | |
logger.warning(f"Invalid response format in notification data: {response}") | |
return | |
if not response.strip().lower().startswith("i'm sorry"): | |
logger.info(f"Response doesn't start with 'I'm sorry', notification not needed: {response[:50]}...") | |
return | |
logger.info(f"Response starts with 'I'm sorry', sending notification") | |
# Format the notification data for admin - format theo chuẩn Admin_bot | |
notification_data = { | |
"type": "sorry_response", # Đổi type thành sorry_response để phù hợp với Admin_bot | |
"timestamp": get_local_time(), | |
"user_id": data.get('user_id', 'unknown'), | |
"message": data.get('message', ''), | |
"response": response, | |
"session_id": data.get('session_id', 'unknown'), | |
"user_info": { | |
"first_name": data.get('first_name', 'User'), | |
"last_name": data.get('last_name', ''), | |
"username": data.get('username', '') | |
} | |
} | |
# Check if there are active connections | |
if not manager.active_connections: | |
logger.warning("No active WebSocket connections for notification broadcast") | |
return | |
# Broadcast notification to all active connections | |
logger.info(f"Broadcasting notification to {len(manager.active_connections)} connections") | |
await manager.broadcast(notification_data) | |
logger.info("Notification broadcast completed successfully") | |
except Exception as e: | |
logger.error(f"Error sending notification: {e}") | |
import traceback | |
logger.error(traceback.format_exc()) |