FLUX.1-Krea-dev / websocket_uploader.py
dangthr's picture
Update websocket_uploader.py
ae1ebcf verified
raw
history blame
5.66 kB
import asyncio
import websockets
import json
import logging
import sys
import base64
import os
import time
import ssl
# 配置日志
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)
async def upload_files_via_websocket(space_id, user_token, website_url):
"""通过WebSocket上传output目录中的所有文件"""
output_dir = "output"
if not os.path.exists(output_dir):
logger.error(f"输出目录不存在: {output_dir}")
return False
# 获取所有文件
files_to_upload = []
for root, dirs, files in os.walk(output_dir):
for file in files:
file_path = os.path.join(root, file)
if os.path.isfile(file_path):
files_to_upload.append(file_path)
if not files_to_upload:
logger.info("输出目录中没有找到任何文件")
return True
logger.info(f"找到 {len(files_to_upload)} 个文件需要上传")
# 构建WebSocket URL - 使用原始WebSocket而不是socket.io
ws_url = website_url.replace('http://', 'ws://').replace('https://', 'wss://') + '/ws'
try:
# 连接到WebSocket服务器
logger.info(f"正在连接到 WebSocket 服务器: {ws_url}")
# 创建SSL上下文(如果是wss://)
ssl_context = None
if ws_url.startswith('wss://'):
ssl_context = ssl.create_default_context()
ssl_context.check_hostname = False
ssl_context.verify_mode = ssl.CERT_NONE
async with websockets.connect(ws_url, ssl=ssl_context) as websocket:
logger.info("WebSocket 连接成功")
# 发送注册消息
register_msg = {
'type': 'register',
'space_id': space_id,
'user_token': user_token
}
await websocket.send(json.dumps(register_msg))
logger.info("已发送注册消息")
# 等待注册确认
response = await websocket.recv()
response_data = json.loads(response)
if response_data.get('type') == 'registered':
logger.info("注册成功,开始上传文件")
# 上传文件
success_count = 0
failed_count = 0
for file_path in files_to_upload:
try:
filename = os.path.basename(file_path)
logger.info(f"正在上传文件: {filename}")
# 读取文件内容并编码为base64
with open(file_path, 'rb') as f:
file_content = f.read()
file_b64 = base64.b64encode(file_content).decode('utf-8')
# 发送文件上传消息
upload_msg = {
'type': 'file_upload',
'space_id': space_id,
'user_token': user_token,
'filename': filename,
'content': file_b64
}
await websocket.send(json.dumps(upload_msg))
# 等待上传结果
upload_response = await websocket.recv()
upload_result = json.loads(upload_response)
if upload_result.get('type') == 'upload_success':
logger.info(f"文件 {filename} 上传成功")
success_count += 1
else:
logger.error(f"文件 {filename} 上传失败: {upload_result}")
failed_count += 1
# 稍微延迟避免服务器压力过大
await asyncio.sleep(0.5)
except Exception as e:
logger.error(f"上传文件 {file_path} 时发生错误: {e}")
failed_count += 1
logger.info(f"上传完成! 成功: {success_count}, 失败: {failed_count}")
# 发送完成消息
complete_msg = {'type': 'upload_complete'}
await websocket.send(json.dumps(complete_msg))
return success_count > 0
else:
logger.error(f"注册失败: {response_data}")
return False
except Exception as e:
logger.error(f"WebSocket 连接或上传过程中发生错误: {e}")
return False
def main():
if len(sys.argv) != 4:
print("用法: python3 websocket_uploader.py <space_id> <user_token> <website_url>")
sys.exit(1)
space_id = sys.argv[1]
user_token = sys.argv[2]
website_url = sys.argv[3]
logger.info(f"开始 WebSocket 文件上传")
logger.info(f"Space ID: {space_id}")
logger.info(f"Website URL: {website_url}")
# 运行异步上传
result = asyncio.run(upload_files_via_websocket(space_id, user_token, website_url))
if result:
logger.info("所有文件上传成功")
sys.exit(0)
else:
logger.error("文件上传失败")
sys.exit(1)
if __name__ == "__main__":
main()