Spaces:
Runtime error
Runtime error
import asyncio | |
import websockets | |
import json | |
import logging | |
import sys | |
import base64 | |
import os | |
import time | |
import ssl | |
# 配置日志 | |
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') | |
logger = logging.getLogger(__name__) | |
async def upload_files_via_websocket(space_id, user_token, website_url): | |
"""通过WebSocket上传output目录中的所有文件""" | |
output_dir = "output" | |
if not os.path.exists(output_dir): | |
logger.error(f"输出目录不存在: {output_dir}") | |
return False | |
# 获取所有文件 | |
files_to_upload = [] | |
for root, dirs, files in os.walk(output_dir): | |
for file in files: | |
file_path = os.path.join(root, file) | |
if os.path.isfile(file_path): | |
files_to_upload.append(file_path) | |
if not files_to_upload: | |
logger.info("输出目录中没有找到任何文件") | |
return True | |
logger.info(f"找到 {len(files_to_upload)} 个文件需要上传") | |
# 构建WebSocket URL - 使用原始WebSocket而不是socket.io | |
ws_url = website_url.replace('http://', 'ws://').replace('https://', 'wss://') + '/ws' | |
try: | |
# 连接到WebSocket服务器 | |
logger.info(f"正在连接到 WebSocket 服务器: {ws_url}") | |
# 创建SSL上下文(如果是wss://) | |
ssl_context = None | |
if ws_url.startswith('wss://'): | |
ssl_context = ssl.create_default_context() | |
ssl_context.check_hostname = False | |
ssl_context.verify_mode = ssl.CERT_NONE | |
async with websockets.connect(ws_url, ssl=ssl_context) as websocket: | |
logger.info("WebSocket 连接成功") | |
# 发送注册消息 | |
register_msg = { | |
'type': 'register', | |
'space_id': space_id, | |
'user_token': user_token | |
} | |
await websocket.send(json.dumps(register_msg)) | |
logger.info("已发送注册消息") | |
# 等待注册确认 | |
response = await websocket.recv() | |
response_data = json.loads(response) | |
if response_data.get('type') == 'registered': | |
logger.info("注册成功,开始上传文件") | |
# 上传文件 | |
success_count = 0 | |
failed_count = 0 | |
for file_path in files_to_upload: | |
try: | |
filename = os.path.basename(file_path) | |
logger.info(f"正在上传文件: {filename}") | |
# 读取文件内容并编码为base64 | |
with open(file_path, 'rb') as f: | |
file_content = f.read() | |
file_b64 = base64.b64encode(file_content).decode('utf-8') | |
# 发送文件上传消息 | |
upload_msg = { | |
'type': 'file_upload', | |
'space_id': space_id, | |
'user_token': user_token, | |
'filename': filename, | |
'content': file_b64 | |
} | |
await websocket.send(json.dumps(upload_msg)) | |
# 等待上传结果 | |
upload_response = await websocket.recv() | |
upload_result = json.loads(upload_response) | |
if upload_result.get('type') == 'upload_success': | |
logger.info(f"文件 {filename} 上传成功") | |
success_count += 1 | |
else: | |
logger.error(f"文件 {filename} 上传失败: {upload_result}") | |
failed_count += 1 | |
# 稍微延迟避免服务器压力过大 | |
await asyncio.sleep(0.5) | |
except Exception as e: | |
logger.error(f"上传文件 {file_path} 时发生错误: {e}") | |
failed_count += 1 | |
logger.info(f"上传完成! 成功: {success_count}, 失败: {failed_count}") | |
# 发送完成消息 | |
complete_msg = {'type': 'upload_complete'} | |
await websocket.send(json.dumps(complete_msg)) | |
return success_count > 0 | |
else: | |
logger.error(f"注册失败: {response_data}") | |
return False | |
except Exception as e: | |
logger.error(f"WebSocket 连接或上传过程中发生错误: {e}") | |
return False | |
def main(): | |
if len(sys.argv) != 4: | |
print("用法: python3 websocket_uploader.py <space_id> <user_token> <website_url>") | |
sys.exit(1) | |
space_id = sys.argv[1] | |
user_token = sys.argv[2] | |
website_url = sys.argv[3] | |
logger.info(f"开始 WebSocket 文件上传") | |
logger.info(f"Space ID: {space_id}") | |
logger.info(f"Website URL: {website_url}") | |
# 运行异步上传 | |
result = asyncio.run(upload_files_via_websocket(space_id, user_token, website_url)) | |
if result: | |
logger.info("所有文件上传成功") | |
sys.exit(0) | |
else: | |
logger.error("文件上传失败") | |
sys.exit(1) | |
if __name__ == "__main__": | |
main() |