Spaces:
Runtime error
Runtime error
Update websocket_uploader.py
Browse files- websocket_uploader.py +91 -17
websocket_uploader.py
CHANGED
@@ -7,11 +7,43 @@ import base64
|
|
7 |
import os
|
8 |
import time
|
9 |
import ssl
|
|
|
10 |
|
11 |
# 配置日志
|
12 |
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
|
13 |
logger = logging.getLogger(__name__)
|
14 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
15 |
async def upload_files_via_websocket(space_id, user_token, website_url):
|
16 |
"""通过WebSocket上传output目录中的所有文件"""
|
17 |
|
@@ -34,8 +66,8 @@ async def upload_files_via_websocket(space_id, user_token, website_url):
|
|
34 |
|
35 |
logger.info(f"找到 {len(files_to_upload)} 个文件需要上传")
|
36 |
|
37 |
-
# 构建WebSocket URL
|
38 |
-
ws_url = website_url
|
39 |
|
40 |
try:
|
41 |
# 连接到WebSocket服务器
|
@@ -48,7 +80,16 @@ async def upload_files_via_websocket(space_id, user_token, website_url):
|
|
48 |
ssl_context.check_hostname = False
|
49 |
ssl_context.verify_mode = ssl.CERT_NONE
|
50 |
|
51 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
52 |
logger.info("WebSocket 连接成功")
|
53 |
|
54 |
# 发送注册消息
|
@@ -61,8 +102,13 @@ async def upload_files_via_websocket(space_id, user_token, website_url):
|
|
61 |
logger.info("已发送注册消息")
|
62 |
|
63 |
# 等待注册确认
|
64 |
-
|
65 |
-
|
|
|
|
|
|
|
|
|
|
|
66 |
|
67 |
if response_data.get('type') == 'registered':
|
68 |
logger.info("注册成功,开始上传文件")
|
@@ -81,6 +127,9 @@ async def upload_files_via_websocket(space_id, user_token, website_url):
|
|
81 |
file_content = f.read()
|
82 |
file_b64 = base64.b64encode(file_content).decode('utf-8')
|
83 |
|
|
|
|
|
|
|
84 |
# 发送文件上传消息
|
85 |
upload_msg = {
|
86 |
'type': 'file_upload',
|
@@ -91,10 +140,17 @@ async def upload_files_via_websocket(space_id, user_token, website_url):
|
|
91 |
}
|
92 |
|
93 |
await websocket.send(json.dumps(upload_msg))
|
|
|
94 |
|
95 |
# 等待上传结果
|
96 |
-
|
97 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
98 |
|
99 |
if upload_result.get('type') == 'upload_success':
|
100 |
logger.info(f"文件 {filename} 上传成功")
|
@@ -113,15 +169,28 @@ async def upload_files_via_websocket(space_id, user_token, website_url):
|
|
113 |
logger.info(f"上传完成! 成功: {success_count}, 失败: {failed_count}")
|
114 |
|
115 |
# 发送完成消息
|
116 |
-
|
117 |
-
|
|
|
|
|
|
|
|
|
118 |
|
119 |
return success_count > 0
|
120 |
|
|
|
|
|
|
|
121 |
else:
|
122 |
-
logger.error(f"
|
123 |
return False
|
124 |
|
|
|
|
|
|
|
|
|
|
|
|
|
125 |
except Exception as e:
|
126 |
logger.error(f"WebSocket 连接或上传过程中发生错误: {e}")
|
127 |
return False
|
@@ -137,16 +206,21 @@ def main():
|
|
137 |
|
138 |
logger.info(f"开始 WebSocket 文件上传")
|
139 |
logger.info(f"Space ID: {space_id}")
|
|
|
140 |
logger.info(f"Website URL: {website_url}")
|
141 |
|
142 |
# 运行异步上传
|
143 |
-
|
144 |
-
|
145 |
-
|
146 |
-
|
147 |
-
|
148 |
-
|
149 |
-
|
|
|
|
|
|
|
|
|
150 |
sys.exit(1)
|
151 |
|
152 |
if __name__ == "__main__":
|
|
|
7 |
import os
|
8 |
import time
|
9 |
import ssl
|
10 |
+
from urllib.parse import urlparse
|
11 |
|
12 |
# 配置日志
|
13 |
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
|
14 |
logger = logging.getLogger(__name__)
|
15 |
|
16 |
+
def build_websocket_url(website_url):
|
17 |
+
"""构建WebSocket URL"""
|
18 |
+
try:
|
19 |
+
# 解析原始URL
|
20 |
+
parsed = urlparse(website_url)
|
21 |
+
|
22 |
+
# 构建WebSocket URL
|
23 |
+
if parsed.scheme == 'https':
|
24 |
+
ws_scheme = 'wss'
|
25 |
+
else:
|
26 |
+
ws_scheme = 'ws'
|
27 |
+
|
28 |
+
# 提取主机名,移除端口号
|
29 |
+
hostname = parsed.hostname
|
30 |
+
|
31 |
+
# 构建WebSocket URL,使用端口8765
|
32 |
+
ws_url = f"{ws_scheme}://{hostname}:8765/ws"
|
33 |
+
|
34 |
+
logger.info(f"原始URL: {website_url}")
|
35 |
+
logger.info(f"WebSocket URL: {ws_url}")
|
36 |
+
|
37 |
+
return ws_url
|
38 |
+
|
39 |
+
except Exception as e:
|
40 |
+
logger.error(f"构建WebSocket URL时出错: {e}")
|
41 |
+
# 回退方案:简单替换
|
42 |
+
if website_url.startswith('https://'):
|
43 |
+
return website_url.replace('https://', 'wss://').replace(':5001', ':8765') + '/ws'
|
44 |
+
else:
|
45 |
+
return website_url.replace('http://', 'ws://').replace(':5001', ':8765') + '/ws'
|
46 |
+
|
47 |
async def upload_files_via_websocket(space_id, user_token, website_url):
|
48 |
"""通过WebSocket上传output目录中的所有文件"""
|
49 |
|
|
|
66 |
|
67 |
logger.info(f"找到 {len(files_to_upload)} 个文件需要上传")
|
68 |
|
69 |
+
# 构建WebSocket URL
|
70 |
+
ws_url = build_websocket_url(website_url)
|
71 |
|
72 |
try:
|
73 |
# 连接到WebSocket服务器
|
|
|
80 |
ssl_context.check_hostname = False
|
81 |
ssl_context.verify_mode = ssl.CERT_NONE
|
82 |
|
83 |
+
# 设置连接超时
|
84 |
+
connect_timeout = 10
|
85 |
+
|
86 |
+
async with websockets.connect(
|
87 |
+
ws_url,
|
88 |
+
ssl=ssl_context,
|
89 |
+
ping_interval=20,
|
90 |
+
ping_timeout=10,
|
91 |
+
close_timeout=10
|
92 |
+
) as websocket:
|
93 |
logger.info("WebSocket 连接成功")
|
94 |
|
95 |
# 发送注册消息
|
|
|
102 |
logger.info("已发送注册消息")
|
103 |
|
104 |
# 等待注册确认
|
105 |
+
try:
|
106 |
+
response = await asyncio.wait_for(websocket.recv(), timeout=10)
|
107 |
+
response_data = json.loads(response)
|
108 |
+
logger.info(f"收到响应: {response_data}")
|
109 |
+
except asyncio.TimeoutError:
|
110 |
+
logger.error("等待注册确认超时")
|
111 |
+
return False
|
112 |
|
113 |
if response_data.get('type') == 'registered':
|
114 |
logger.info("注册成功,开始上传文件")
|
|
|
127 |
file_content = f.read()
|
128 |
file_b64 = base64.b64encode(file_content).decode('utf-8')
|
129 |
|
130 |
+
file_size = len(file_content)
|
131 |
+
logger.info(f"文件 {filename} 大小: {file_size} 字节")
|
132 |
+
|
133 |
# 发送文件上传消息
|
134 |
upload_msg = {
|
135 |
'type': 'file_upload',
|
|
|
140 |
}
|
141 |
|
142 |
await websocket.send(json.dumps(upload_msg))
|
143 |
+
logger.info(f"已发送文件 {filename}")
|
144 |
|
145 |
# 等待上传结果
|
146 |
+
try:
|
147 |
+
upload_response = await asyncio.wait_for(websocket.recv(), timeout=30)
|
148 |
+
upload_result = json.loads(upload_response)
|
149 |
+
logger.info(f"上传响应: {upload_result}")
|
150 |
+
except asyncio.TimeoutError:
|
151 |
+
logger.error(f"等待文件 {filename} 上传结果超时")
|
152 |
+
failed_count += 1
|
153 |
+
continue
|
154 |
|
155 |
if upload_result.get('type') == 'upload_success':
|
156 |
logger.info(f"文件 {filename} 上传成功")
|
|
|
169 |
logger.info(f"上传完成! 成功: {success_count}, 失败: {failed_count}")
|
170 |
|
171 |
# 发送完成消息
|
172 |
+
try:
|
173 |
+
complete_msg = {'type': 'upload_complete'}
|
174 |
+
await websocket.send(json.dumps(complete_msg))
|
175 |
+
logger.info("已发送完成消息")
|
176 |
+
except Exception as e:
|
177 |
+
logger.error(f"发送完成消息时出错: {e}")
|
178 |
|
179 |
return success_count > 0
|
180 |
|
181 |
+
elif response_data.get('type') == 'error':
|
182 |
+
logger.error(f"注册失败: {response_data.get('message', '未知错误')}")
|
183 |
+
return False
|
184 |
else:
|
185 |
+
logger.error(f"未知的响应类型: {response_data}")
|
186 |
return False
|
187 |
|
188 |
+
except websockets.exceptions.InvalidURI as e:
|
189 |
+
logger.error(f"无效的WebSocket URI: {e}")
|
190 |
+
return False
|
191 |
+
except websockets.exceptions.ConnectionClosed as e:
|
192 |
+
logger.error(f"WebSocket连接被关闭: {e}")
|
193 |
+
return False
|
194 |
except Exception as e:
|
195 |
logger.error(f"WebSocket 连接或上传过程中发生错误: {e}")
|
196 |
return False
|
|
|
206 |
|
207 |
logger.info(f"开始 WebSocket 文件上传")
|
208 |
logger.info(f"Space ID: {space_id}")
|
209 |
+
logger.info(f"User Token: {user_token[:8]}...")
|
210 |
logger.info(f"Website URL: {website_url}")
|
211 |
|
212 |
# 运行异步上传
|
213 |
+
try:
|
214 |
+
result = asyncio.run(upload_files_via_websocket(space_id, user_token, website_url))
|
215 |
+
|
216 |
+
if result:
|
217 |
+
logger.info("所有文件上传成功")
|
218 |
+
sys.exit(0)
|
219 |
+
else:
|
220 |
+
logger.error("文件上传失败")
|
221 |
+
sys.exit(1)
|
222 |
+
except Exception as e:
|
223 |
+
logger.error(f"运行上传程序时出错: {e}")
|
224 |
sys.exit(1)
|
225 |
|
226 |
if __name__ == "__main__":
|