FLUX.1-Krea-dev / remote_uploader.py
dangthr's picture
Rename websocket_uploader.py to remote_uploader.py
eb678a5 verified
raw
history blame
11.9 kB
import asyncio
import websockets
import subprocess
import json
import logging
import sys
import urllib.parse
import base64
import os
import threading
import requests
from urllib.parse import urlparse
# Configure logging
logging.basicConfig(level=logging.DEBUG, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)
# 全局变量,用于在用户输入时访问 WebSocket
global_websocket = None
async def connect(space_id, machine_secret, token, upload_file=None, upload_dir=None):
global global_websocket
# 使用查询参数传递 machine_secret 和 token
encoded_secret = urllib.parse.quote(machine_secret)
encoded_token = urllib.parse.quote(token)
uri = f"wss://remote-terminal-worker.nianxi4563.workers.dev/terminal/{space_id}?secret={encoded_secret}&token={encoded_token}"
logger.info(f"Attempting to connect to {uri}")
try:
async with websockets.connect(uri, ping_interval=20, ping_timeout=60) as websocket:
global_websocket = websocket
logger.info("Connected to WebSocket")
machine_info = {"type": "machine", "space_id": space_id, "token": token}
await websocket.send(json.dumps(machine_info))
logger.debug(f"Sent machine registration: {machine_info}")
# 如果指定了上传文件,发送文件内容
if upload_file:
if os.path.exists(upload_file):
try:
with open(upload_file, 'rb') as f:
file_content = f.read()
file_b64 = base64.b64encode(file_content).decode('utf-8')
file_msg = {
"type": "file_upload",
"filename": os.path.basename(upload_file),
"content": file_b64,
"token": token
}
await websocket.send(json.dumps(file_msg))
logger.info(f"Uploaded file: {upload_file}")
except Exception as e:
logger.error(f"Error uploading file {upload_file}: {e}")
else:
logger.error(f"File not found: {upload_file}")
# 如果指定了上传目录,扫描并上传所有文件
if upload_dir:
if os.path.exists(upload_dir):
all_files = []
for root, dirs, files in os.walk(upload_dir):
for file in files:
file_path = os.path.join(root, file)
if os.path.isfile(file_path):
all_files.append(file_path)
if all_files:
logger.info(f"Found {len(all_files)} files in {upload_dir}, starting upload...")
for file_path in all_files:
try:
with open(file_path, 'rb') as f:
file_content = f.read()
file_b64 = base64.b64encode(file_content).decode('utf-8')
file_msg = {
"type": "file_upload",
"filename": os.path.basename(file_path),
"content": file_b64,
"token": token
}
await websocket.send(json.dumps(file_msg))
logger.info(f"Uploaded file from dir: {file_path}")
except Exception as e:
logger.error(f"Error uploading file {file_path}: {e}")
else:
logger.info(f"No files found in {upload_dir}")
else:
logger.error(f"Directory not found: {upload_dir}")
# 创建下载目录
os.makedirs("./downloads", exist_ok=True)
# 启动用户输入监听(非阻塞)
asyncio.create_task(listen_user_input())
while True:
try:
message = await websocket.recv()
logger.debug(f"Received message: {message}")
data = json.loads(message)
if data["type"] == "command":
command = data["command"]
logger.info(f"Executing command: {command}")
try:
process = subprocess.run(command, shell=True, capture_output=True, text=True)
output = process.stdout + process.stderr
logger.debug(f"Command output: {output}")
if process.returncode == 0:
await websocket.send(json.dumps({"type": "output", "data": output}))
else:
await websocket.send(json.dumps({"type": "error", "data": output}))
except Exception as e:
error_message = f"Error executing command: {e}"
logger.error(error_message)
await websocket.send(json.dumps({"type": "error", "data": error_message}))
elif data["type"] == "file_download_url":
# 处理从服务器发来的文件下载URL
url = data["url"]
filename = data["filename"]
logger.info(f"Received file download URL: {url}")
try:
download_dir = "./downloads"
os.makedirs(download_dir, exist_ok=True)
response = requests.get(url)
if response.status_code == 200:
file_path = os.path.join(download_dir, filename)
with open(file_path, 'wb') as f:
f.write(response.content)
success_msg = f"File downloaded successfully: {file_path}"
logger.info(success_msg)
await websocket.send(json.dumps({
"type": "output",
"data": success_msg
}))
else:
error_msg = f"Failed to download file. Status code: {response.status_code}"
logger.error(error_msg)
await websocket.send(json.dumps({
"type": "error",
"data": error_msg
}))
except Exception as e:
error_msg = f"Error downloading file from URL: {e}"
logger.error(error_msg)
await websocket.send(json.dumps({
"type": "error",
"data": error_msg
}))
elif data["type"] == "ping":
logger.debug("Received ping, sending pong")
await websocket.send(json.dumps({"type": "pong"}))
except websockets.exceptions.ConnectionClosed:
logger.error("WebSocket connection closed")
break
except Exception as e:
logger.error(f"Error processing message: {e}")
await websocket.send(json.dumps({"type": "error", "data": str(e)}))
except Exception as e:
logger.error(f"Failed to connect or maintain connection: {e}")
async def listen_user_input():
global global_websocket
logger.info("Started listening for user input. Use ' --upload' to upload files to server.")
while True:
try:
user_input = await asyncio.to_thread(input, "Enter command or file path to upload: ")
if user_input.strip():
parts = user_input.strip().split()
if len(parts) >= 2 and parts[-1] == "--upload":
filename = " ".join(parts[:-1])
if os.path.exists(filename):
try:
with open(filename, 'rb') as f:
file_content = f.read()
file_b64 = base64.b64encode(file_content).decode('utf-8')
file_msg = {
"type": "file_upload",
"filename": os.path.basename(filename),
"content": file_b64
}
if global_websocket:
await global_websocket.send(json.dumps(file_msg))
logger.info(f"Uploaded file via input: {filename}")
else:
logger.error("WebSocket not connected")
except Exception as e:
logger.error(f"Error uploading file {filename}: {e}")
else:
logger.error(f"File not found: {filename}")
elif len(parts) >= 2 and parts[-1] == "--download":
url = " ".join(parts[:-1])
try:
parsed_url = urlparse(url)
filename = os.path.basename(parsed_url.path)
download_dir = "./downloads"
os.makedirs(download_dir, exist_ok=True)
logger.info(f"Downloading file from URL: {url}")
response = requests.get(url)
if response.status_code == 200:
file_path = os.path.join(download_dir, filename)
with open(file_path, 'wb') as f:
f.write(response.content)
logger.info(f"File downloaded successfully: {file_path}")
else:
logger.error(f"Failed to download file. Status code: {response.status_code}")
except Exception as e:
logger.error(f"Error downloading from URL {url}: {e}")
else:
if global_websocket:
await global_websocket.send(json.dumps({
"type": "command",
"command": user_input
}))
logger.info(f"Sent command: {user_input}")
else:
logger.error("WebSocket not connected")
await asyncio.sleep(0.1)
except Exception as e:
logger.error(f"Error in user input listener: {e}")
await asyncio.sleep(1)
if __name__ == "__main__":
if len(sys.argv) < 4:
print("Usage: python remote_client.py <space_id> <machine_secret> <token> [--upload <file_path>] [--upload-dir <dir_path>]")
sys.exit(1)
space_id = sys.argv[1]
machine_secret = sys.argv[2]
token = sys.argv[3]
upload_file = None
upload_dir = None
if len(sys.argv) > 4:
if sys.argv[4] == "--upload" and len(sys.argv) > 5:
upload_file = sys.argv[5]
elif sys.argv[4] == "--upload-dir" and len(sys.argv) > 5:
upload_dir = sys.argv[5]
asyncio.run(connect(space_id, machine_secret, token, upload_file, upload_dir))