Spaces:
Runtime error
Runtime error
| import asyncio | |
| import websockets | |
| import json | |
| import logging | |
| import sys | |
| import base64 | |
| import os | |
| import argparse | |
| import time | |
| # 配置日志 | |
| logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') | |
| logger = logging.getLogger(__name__) | |
| class WebSocketFileUploader: | |
| def __init__(self, websocket_url, user_token, space_id): | |
| self.websocket_url = websocket_url | |
| self.user_token = user_token | |
| self.space_id = space_id | |
| async def upload_file(self, websocket, file_path): | |
| """上传单个文件""" | |
| if not os.path.exists(file_path): | |
| logger.error(f"文件不存在: {file_path}") | |
| return False | |
| filename = os.path.basename(file_path) | |
| logger.info(f"正在上传文件: {filename}") | |
| try: | |
| with open(file_path, 'rb') as f: | |
| file_content = f.read() | |
| file_b64 = base64.b64encode(file_content).decode('utf-8') | |
| file_msg = { | |
| "type": "file_upload", | |
| "filename": filename, | |
| "content": file_b64, | |
| "space_id": self.space_id, | |
| "user_token": self.user_token | |
| } | |
| await websocket.send(json.dumps(file_msg)) | |
| logger.info(f"已发送文件: {filename}") | |
| # 等待服务器响应 | |
| try: | |
| response = await asyncio.wait_for(websocket.recv(), timeout=30.0) | |
| data = json.loads(response) | |
| if data.get("type") == "upload_success": | |
| logger.info(f"✅ 文件 '{filename}' 上传成功!") | |
| return True | |
| elif data.get("type") == "upload_error": | |
| logger.error(f"❌ 上传失败: {data.get('message', '未知错误')}") | |
| return False | |
| else: | |
| logger.warning(f"收到未知响应: {data}") | |
| return False | |
| except asyncio.TimeoutError: | |
| logger.error(f"上传文件 {filename} 超时") | |
| return False | |
| except Exception as e: | |
| logger.error(f"上传文件 {file_path} 时出错: {e}") | |
| return False | |
| async def upload_directory(self, upload_dir): | |
| """上传目录中的所有文件""" | |
| if not os.path.exists(upload_dir): | |
| logger.error(f"目录不存在: {upload_dir}") | |
| return 0, 0 | |
| logger.info(f"🔍 开始扫描目录: {upload_dir}") | |
| logger.info(f"📡 WebSocket服务器: {self.websocket_url}") | |
| logger.info(f"🔑 Space ID: {self.space_id}") | |
| logger.info("-" * 50) | |
| # 获取所有文件 | |
| all_files = [] | |
| for root, dirs, files in os.walk(upload_dir): | |
| for file in files: | |
| file_path = os.path.join(root, file) | |
| if os.path.isfile(file_path): | |
| all_files.append(file_path) | |
| if not all_files: | |
| logger.info("📁 目录中没有找到任何文件") | |
| return 0, 0 | |
| logger.info(f"📁 找到 {len(all_files)} 个文件,开始上传...") | |
| try: | |
| async with websockets.connect( | |
| self.websocket_url, | |
| ping_interval=20, | |
| ping_timeout=60, | |
| close_timeout=10 | |
| ) as websocket: | |
| logger.info("WebSocket连接已建立") | |
| # 发送认证信息 | |
| auth_msg = { | |
| "type": "auth", | |
| "user_token": self.user_token, | |
| "space_id": self.space_id | |
| } | |
| await websocket.send(json.dumps(auth_msg)) | |
| # 等待认证响应 | |
| try: | |
| auth_response = await asyncio.wait_for(websocket.recv(), timeout=10.0) | |
| auth_data = json.loads(auth_response) | |
| if auth_data.get("type") != "auth_success": | |
| logger.error(f"认证失败: {auth_data.get('message', '未知错误')}") | |
| return 0, len(all_files) | |
| logger.info("认证成功,开始上传文件") | |
| except asyncio.TimeoutError: | |
| logger.error("认证超时") | |
| return 0, len(all_files) | |
| success_count = 0 | |
| failed_count = 0 | |
| for file_path in all_files: | |
| try: | |
| if await self.upload_file(websocket, file_path): | |
| success_count += 1 | |
| else: | |
| failed_count += 1 | |
| # 稍微延迟一下,避免服务器压力过大 | |
| await asyncio.sleep(0.5) | |
| except Exception as e: | |
| logger.error(f"上传文件 {file_path} 时发生异常: {e}") | |
| failed_count += 1 | |
| logger.info("-" * 50) | |
| logger.info(f"📊 上传完成! 成功: {success_count}, 失败: {failed_count}") | |
| if success_count > 0: | |
| logger.info("🎉 文件已成功上传到您的网盘!") | |
| return success_count, failed_count | |
| except Exception as e: | |
| logger.error(f"WebSocket连接失败: {e}") | |
| return 0, len(all_files) | |
| async def upload_directory_websocket(upload_dir, websocket_url, user_token, space_id): | |
| """使用WebSocket上传目录中的所有文件""" | |
| uploader = WebSocketFileUploader(websocket_url, user_token, space_id) | |
| return await uploader.upload_directory(upload_dir) | |
| if __name__ == "__main__": | |
| parser = argparse.ArgumentParser(description="WebSocket文件上传器 - 扫描并上传指定文件夹中的所有文件") | |
| parser.add_argument("user_token", help="用户令牌 (API Key)") | |
| parser.add_argument("space_id", help="Space ID") | |
| parser.add_argument("--websocket-url", default="ws://127.0.0.1:5001/ws", help="WebSocket服务器地址 (默认: ws://127.0.0.1:5001/ws)") | |
| parser.add_argument("--upload-dir", default="output", help="要上传的目录 (默认: output)") | |
| args = parser.parse_args() | |
| # 开始WebSocket上传 | |
| asyncio.run(upload_directory_websocket( | |
| args.upload_dir, | |
| args.websocket_url, | |
| args.user_token, | |
| args.space_id | |
| )) |