Spaces:
Runtime error
Runtime error
#!/usr/bin/env python3 | |
""" | |
WebSocket远程客户端 - 在inferless环境中运行 | |
用于连接到主服务器并执行远程命令和文件传输 | |
""" | |
import asyncio | |
import websockets | |
import subprocess | |
import json | |
import logging | |
import sys | |
import urllib.parse | |
import base64 | |
import os | |
import threading | |
import time | |
from pathlib import Path | |
# 配置日志 | |
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') | |
logger = logging.getLogger(__name__) | |
class RemoteClient: | |
def __init__(self, server_url, user_token, space_id): | |
self.server_url = server_url | |
self.user_token = user_token | |
self.space_id = space_id | |
self.websocket = None | |
self.is_connected = False | |
async def connect(self): | |
"""连接到WebSocket服务器""" | |
# 构建WebSocket URL | |
ws_url = self.server_url.replace('http', 'ws').replace('https', 'wss') | |
uri = f"{ws_url}/ws/client/{self.space_id}?token={urllib.parse.quote(self.user_token)}" | |
logger.info(f"正在连接到服务器: {uri}") | |
try: | |
async with websockets.connect(uri, ping_interval=20, ping_timeout=60) as websocket: | |
self.websocket = websocket | |
self.is_connected = True | |
logger.info("✅ 已连接到WebSocket服务器") | |
# 发送客户端注册信息 | |
await self.register_client() | |
# 启动消息监听循环 | |
await self.message_loop() | |
except Exception as e: | |
logger.error(f"❌ 连接失败: {e}") | |
self.is_connected = False | |
async def register_client(self): | |
"""注册客户端到服务器""" | |
registration = { | |
"type": "register", | |
"space_id": self.space_id, | |
"client_info": { | |
"environment": "inferless", | |
"python_version": sys.version, | |
"working_directory": os.getcwd() | |
} | |
} | |
await self.websocket.send(json.dumps(registration)) | |
logger.info("📝 已发送客户端注册信息") | |
async def message_loop(self): | |
"""主消息处理循环""" | |
while self.is_connected: | |
try: | |
message = await self.websocket.recv() | |
data = json.loads(message) | |
await self.handle_message(data) | |
except websockets.exceptions.ConnectionClosed: | |
logger.error("🔌 WebSocket连接已关闭") | |
self.is_connected = False | |
break | |
except Exception as e: | |
logger.error(f"❌ 处理消息时出错: {e}") | |
async def handle_message(self, data): | |
"""处理收到的消息""" | |
message_type = data.get("type") | |
if message_type == "command": | |
await self.execute_command(data) | |
elif message_type == "upload_files": | |
await self.upload_files(data) | |
elif message_type == "ping": | |
await self.send_pong() | |
else: | |
logger.warning(f"⚠️ 未知消息类型: {message_type}") | |
async def execute_command(self, data): | |
"""执行远程命令""" | |
command = data.get("command", "") | |
logger.info(f"🚀 执行命令: {command}") | |
try: | |
# 执行命令 | |
process = subprocess.run( | |
command, | |
shell=True, | |
capture_output=True, | |
text=True, | |
timeout=300 # 5分钟超时 | |
) | |
# 发送执行结果 | |
result = { | |
"type": "command_result", | |
"command": command, | |
"returncode": process.returncode, | |
"stdout": process.stdout, | |
"stderr": process.stderr, | |
"success": process.returncode == 0 | |
} | |
await self.websocket.send(json.dumps(result)) | |
if process.returncode == 0: | |
logger.info("✅ 命令执行成功") | |
else: | |
logger.error(f"❌ 命令执行失败,返回码: {process.returncode}") | |
except subprocess.TimeoutExpired: | |
error_result = { | |
"type": "command_result", | |
"command": command, | |
"error": "命令执行超时", | |
"success": False | |
} | |
await self.websocket.send(json.dumps(error_result)) | |
logger.error("⏰ 命令执行超时") | |
except Exception as e: | |
error_result = { | |
"type": "command_result", | |
"command": command, | |
"error": str(e), | |
"success": False | |
} | |
await self.websocket.send(json.dumps(error_result)) | |
logger.error(f"❌ 命令执行异常: {e}") | |
async def upload_files(self, data): | |
"""上传指定目录中的所有文件""" | |
upload_dir = data.get("directory", "output") | |
logger.info(f"📁 开始上传目录: {upload_dir}") | |
if not os.path.exists(upload_dir): | |
error_msg = f"目录不存在: {upload_dir}" | |
logger.error(error_msg) | |
await self.send_upload_result(False, error_msg) | |
return | |
try: | |
# 扫描目录中的所有文件 | |
files_to_upload = [] | |
for root, dirs, files in os.walk(upload_dir): | |
for file in files: | |
file_path = os.path.join(root, file) | |
if os.path.isfile(file_path): | |
files_to_upload.append(file_path) | |
if not files_to_upload: | |
msg = f"目录 {upload_dir} 中没有找到文件" | |
logger.info(msg) | |
await self.send_upload_result(True, msg) | |
return | |
logger.info(f"📦 找到 {len(files_to_upload)} 个文件,开始上传...") | |
success_count = 0 | |
failed_count = 0 | |
for file_path in files_to_upload: | |
try: | |
await self.upload_single_file(file_path) | |
success_count += 1 | |
logger.info(f"✅ 上传成功: {os.path.basename(file_path)}") | |
except Exception as e: | |
failed_count += 1 | |
logger.error(f"❌ 上传失败: {file_path} - {e}") | |
# 稍作延迟避免服务器压力 | |
await asyncio.sleep(0.1) | |
result_msg = f"上传完成! 成功: {success_count}, 失败: {failed_count}" | |
logger.info(f"📊 {result_msg}") | |
await self.send_upload_result(True, result_msg) | |
except Exception as e: | |
error_msg = f"上传过程中出错: {e}" | |
logger.error(error_msg) | |
await self.send_upload_result(False, error_msg) | |
async def upload_single_file(self, file_path): | |
"""上传单个文件""" | |
filename = os.path.basename(file_path) | |
with open(file_path, 'rb') as f: | |
file_content = f.read() | |
file_b64 = base64.b64encode(file_content).decode('utf-8') | |
file_message = { | |
"type": "file_upload", | |
"filename": filename, | |
"content": file_b64, | |
"file_size": len(file_content) | |
} | |
await self.websocket.send(json.dumps(file_message)) | |
async def send_upload_result(self, success, message): | |
"""发送上传结果""" | |
result = { | |
"type": "upload_result", | |
"success": success, | |
"message": message | |
} | |
await self.websocket.send(json.dumps(result)) | |
async def send_pong(self): | |
"""响应ping消息""" | |
pong = {"type": "pong"} | |
await self.websocket.send(json.dumps(pong)) | |
async def main(): | |
if len(sys.argv) < 4: | |
print("使用方法: python remote_client.py <server_url> <user_token> <space_id>") | |
print("示例: python remote_client.py https://gkbtyo-rqvays-5001.preview.cloudstudio.work abc123 space456") | |
sys.exit(1) | |
server_url = sys.argv[1] | |
user_token = sys.argv[2] | |
space_id = sys.argv[3] | |
client = RemoteClient(server_url, user_token, space_id) | |
while True: | |
try: | |
await client.connect() | |
except KeyboardInterrupt: | |
logger.info("👋 收到中断信号,正在退出...") | |
break | |
except Exception as e: | |
logger.error(f"❌ 连接异常: {e}") | |
logger.info("🔄 5秒后重新尝试连接...") | |
await asyncio.sleep(5) | |
if __name__ == "__main__": | |
asyncio.run(main()) |