FLUX.1-Krea-dev / remote_client.py
dangthr's picture
Update remote_client.py
28cf1dd verified
raw
history blame
8.94 kB
#!/usr/bin/env python3
"""
WebSocket远程客户端 - 在inferless环境中运行
用于连接到主服务器并执行远程命令和文件传输
"""
import asyncio
import websockets
import subprocess
import json
import logging
import sys
import urllib.parse
import base64
import os
import threading
import time
from pathlib import Path
# 配置日志
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)
class RemoteClient:
def __init__(self, server_url, user_token, space_id):
self.server_url = server_url
self.user_token = user_token
self.space_id = space_id
self.websocket = None
self.is_connected = False
async def connect(self):
"""连接到WebSocket服务器"""
# 构建WebSocket URL
ws_url = self.server_url.replace('http', 'ws').replace('https', 'wss')
uri = f"{ws_url}/ws/client/{self.space_id}?token={urllib.parse.quote(self.user_token)}"
logger.info(f"正在连接到服务器: {uri}")
try:
async with websockets.connect(uri, ping_interval=20, ping_timeout=60) as websocket:
self.websocket = websocket
self.is_connected = True
logger.info("✅ 已连接到WebSocket服务器")
# 发送客户端注册信息
await self.register_client()
# 启动消息监听循环
await self.message_loop()
except Exception as e:
logger.error(f"❌ 连接失败: {e}")
self.is_connected = False
async def register_client(self):
"""注册客户端到服务器"""
registration = {
"type": "register",
"space_id": self.space_id,
"client_info": {
"environment": "inferless",
"python_version": sys.version,
"working_directory": os.getcwd()
}
}
await self.websocket.send(json.dumps(registration))
logger.info("📝 已发送客户端注册信息")
async def message_loop(self):
"""主消息处理循环"""
while self.is_connected:
try:
message = await self.websocket.recv()
data = json.loads(message)
await self.handle_message(data)
except websockets.exceptions.ConnectionClosed:
logger.error("🔌 WebSocket连接已关闭")
self.is_connected = False
break
except Exception as e:
logger.error(f"❌ 处理消息时出错: {e}")
async def handle_message(self, data):
"""处理收到的消息"""
message_type = data.get("type")
if message_type == "command":
await self.execute_command(data)
elif message_type == "upload_files":
await self.upload_files(data)
elif message_type == "ping":
await self.send_pong()
else:
logger.warning(f"⚠️ 未知消息类型: {message_type}")
async def execute_command(self, data):
"""执行远程命令"""
command = data.get("command", "")
logger.info(f"🚀 执行命令: {command}")
try:
# 执行命令
process = subprocess.run(
command,
shell=True,
capture_output=True,
text=True,
timeout=300 # 5分钟超时
)
# 发送执行结果
result = {
"type": "command_result",
"command": command,
"returncode": process.returncode,
"stdout": process.stdout,
"stderr": process.stderr,
"success": process.returncode == 0
}
await self.websocket.send(json.dumps(result))
if process.returncode == 0:
logger.info("✅ 命令执行成功")
else:
logger.error(f"❌ 命令执行失败,返回码: {process.returncode}")
except subprocess.TimeoutExpired:
error_result = {
"type": "command_result",
"command": command,
"error": "命令执行超时",
"success": False
}
await self.websocket.send(json.dumps(error_result))
logger.error("⏰ 命令执行超时")
except Exception as e:
error_result = {
"type": "command_result",
"command": command,
"error": str(e),
"success": False
}
await self.websocket.send(json.dumps(error_result))
logger.error(f"❌ 命令执行异常: {e}")
async def upload_files(self, data):
"""上传指定目录中的所有文件"""
upload_dir = data.get("directory", "output")
logger.info(f"📁 开始上传目录: {upload_dir}")
if not os.path.exists(upload_dir):
error_msg = f"目录不存在: {upload_dir}"
logger.error(error_msg)
await self.send_upload_result(False, error_msg)
return
try:
# 扫描目录中的所有文件
files_to_upload = []
for root, dirs, files in os.walk(upload_dir):
for file in files:
file_path = os.path.join(root, file)
if os.path.isfile(file_path):
files_to_upload.append(file_path)
if not files_to_upload:
msg = f"目录 {upload_dir} 中没有找到文件"
logger.info(msg)
await self.send_upload_result(True, msg)
return
logger.info(f"📦 找到 {len(files_to_upload)} 个文件,开始上传...")
success_count = 0
failed_count = 0
for file_path in files_to_upload:
try:
await self.upload_single_file(file_path)
success_count += 1
logger.info(f"✅ 上传成功: {os.path.basename(file_path)}")
except Exception as e:
failed_count += 1
logger.error(f"❌ 上传失败: {file_path} - {e}")
# 稍作延迟避免服务器压力
await asyncio.sleep(0.1)
result_msg = f"上传完成! 成功: {success_count}, 失败: {failed_count}"
logger.info(f"📊 {result_msg}")
await self.send_upload_result(True, result_msg)
except Exception as e:
error_msg = f"上传过程中出错: {e}"
logger.error(error_msg)
await self.send_upload_result(False, error_msg)
async def upload_single_file(self, file_path):
"""上传单个文件"""
filename = os.path.basename(file_path)
with open(file_path, 'rb') as f:
file_content = f.read()
file_b64 = base64.b64encode(file_content).decode('utf-8')
file_message = {
"type": "file_upload",
"filename": filename,
"content": file_b64,
"file_size": len(file_content)
}
await self.websocket.send(json.dumps(file_message))
async def send_upload_result(self, success, message):
"""发送上传结果"""
result = {
"type": "upload_result",
"success": success,
"message": message
}
await self.websocket.send(json.dumps(result))
async def send_pong(self):
"""响应ping消息"""
pong = {"type": "pong"}
await self.websocket.send(json.dumps(pong))
async def main():
if len(sys.argv) < 4:
print("使用方法: python remote_client.py <server_url> <user_token> <space_id>")
print("示例: python remote_client.py https://gkbtyo-rqvays-5001.preview.cloudstudio.work abc123 space456")
sys.exit(1)
server_url = sys.argv[1]
user_token = sys.argv[2]
space_id = sys.argv[3]
client = RemoteClient(server_url, user_token, space_id)
while True:
try:
await client.connect()
except KeyboardInterrupt:
logger.info("👋 收到中断信号,正在退出...")
break
except Exception as e:
logger.error(f"❌ 连接异常: {e}")
logger.info("🔄 5秒后重新尝试连接...")
await asyncio.sleep(5)
if __name__ == "__main__":
asyncio.run(main())