#!/usr/bin/env python3 """ WebSocket远程客户端 - 在inferless环境中运行 用于连接到主服务器并执行远程命令和文件传输 """ import asyncio import websockets import subprocess import json import logging import sys import urllib.parse import base64 import os import threading import time from pathlib import Path # 配置日志 logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') logger = logging.getLogger(__name__) class RemoteClient: def __init__(self, server_url, user_token, space_id): self.server_url = server_url self.user_token = user_token self.space_id = space_id self.websocket = None self.is_connected = False async def connect(self): """连接到WebSocket服务器""" # 构建WebSocket URL ws_url = self.server_url.replace('http', 'ws').replace('https', 'wss') uri = f"{ws_url}/ws/client/{self.space_id}?token={urllib.parse.quote(self.user_token)}" logger.info(f"正在连接到服务器: {uri}") try: async with websockets.connect(uri, ping_interval=20, ping_timeout=60) as websocket: self.websocket = websocket self.is_connected = True logger.info("✅ 已连接到WebSocket服务器") # 发送客户端注册信息 await self.register_client() # 启动消息监听循环 await self.message_loop() except Exception as e: logger.error(f"❌ 连接失败: {e}") self.is_connected = False async def register_client(self): """注册客户端到服务器""" registration = { "type": "register", "space_id": self.space_id, "client_info": { "environment": "inferless", "python_version": sys.version, "working_directory": os.getcwd() } } await self.websocket.send(json.dumps(registration)) logger.info("📝 已发送客户端注册信息") async def message_loop(self): """主消息处理循环""" while self.is_connected: try: message = await self.websocket.recv() data = json.loads(message) await self.handle_message(data) except websockets.exceptions.ConnectionClosed: logger.error("🔌 WebSocket连接已关闭") self.is_connected = False break except Exception as e: logger.error(f"❌ 处理消息时出错: {e}") async def handle_message(self, data): """处理收到的消息""" message_type = data.get("type") if message_type == "command": await self.execute_command(data) elif message_type == "upload_files": await self.upload_files(data) elif message_type == "ping": await self.send_pong() else: logger.warning(f"⚠️ 未知消息类型: {message_type}") async def execute_command(self, data): """执行远程命令""" command = data.get("command", "") logger.info(f"🚀 执行命令: {command}") try: # 执行命令 process = subprocess.run( command, shell=True, capture_output=True, text=True, timeout=300 # 5分钟超时 ) # 发送执行结果 result = { "type": "command_result", "command": command, "returncode": process.returncode, "stdout": process.stdout, "stderr": process.stderr, "success": process.returncode == 0 } await self.websocket.send(json.dumps(result)) if process.returncode == 0: logger.info("✅ 命令执行成功") else: logger.error(f"❌ 命令执行失败,返回码: {process.returncode}") except subprocess.TimeoutExpired: error_result = { "type": "command_result", "command": command, "error": "命令执行超时", "success": False } await self.websocket.send(json.dumps(error_result)) logger.error("⏰ 命令执行超时") except Exception as e: error_result = { "type": "command_result", "command": command, "error": str(e), "success": False } await self.websocket.send(json.dumps(error_result)) logger.error(f"❌ 命令执行异常: {e}") async def upload_files(self, data): """上传指定目录中的所有文件""" upload_dir = data.get("directory", "output") logger.info(f"📁 开始上传目录: {upload_dir}") if not os.path.exists(upload_dir): error_msg = f"目录不存在: {upload_dir}" logger.error(error_msg) await self.send_upload_result(False, error_msg) return try: # 扫描目录中的所有文件 files_to_upload = [] for root, dirs, files in os.walk(upload_dir): for file in files: file_path = os.path.join(root, file) if os.path.isfile(file_path): files_to_upload.append(file_path) if not files_to_upload: msg = f"目录 {upload_dir} 中没有找到文件" logger.info(msg) await self.send_upload_result(True, msg) return logger.info(f"📦 找到 {len(files_to_upload)} 个文件,开始上传...") success_count = 0 failed_count = 0 for file_path in files_to_upload: try: await self.upload_single_file(file_path) success_count += 1 logger.info(f"✅ 上传成功: {os.path.basename(file_path)}") except Exception as e: failed_count += 1 logger.error(f"❌ 上传失败: {file_path} - {e}") # 稍作延迟避免服务器压力 await asyncio.sleep(0.1) result_msg = f"上传完成! 成功: {success_count}, 失败: {failed_count}" logger.info(f"📊 {result_msg}") await self.send_upload_result(True, result_msg) except Exception as e: error_msg = f"上传过程中出错: {e}" logger.error(error_msg) await self.send_upload_result(False, error_msg) async def upload_single_file(self, file_path): """上传单个文件""" filename = os.path.basename(file_path) with open(file_path, 'rb') as f: file_content = f.read() file_b64 = base64.b64encode(file_content).decode('utf-8') file_message = { "type": "file_upload", "filename": filename, "content": file_b64, "file_size": len(file_content) } await self.websocket.send(json.dumps(file_message)) async def send_upload_result(self, success, message): """发送上传结果""" result = { "type": "upload_result", "success": success, "message": message } await self.websocket.send(json.dumps(result)) async def send_pong(self): """响应ping消息""" pong = {"type": "pong"} await self.websocket.send(json.dumps(pong)) async def main(): if len(sys.argv) < 4: print("使用方法: python remote_client.py ") print("示例: python remote_client.py https://gkbtyo-rqvays-5001.preview.cloudstudio.work abc123 space456") sys.exit(1) server_url = sys.argv[1] user_token = sys.argv[2] space_id = sys.argv[3] client = RemoteClient(server_url, user_token, space_id) while True: try: await client.connect() except KeyboardInterrupt: logger.info("👋 收到中断信号,正在退出...") break except Exception as e: logger.error(f"❌ 连接异常: {e}") logger.info("🔄 5秒后重新尝试连接...") await asyncio.sleep(5) if __name__ == "__main__": asyncio.run(main())