dangthr commited on
Commit
861bfbe
·
verified ·
1 Parent(s): 15bd9da

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +158 -105
app.py CHANGED
@@ -1,115 +1,168 @@
1
- import argparse
 
 
 
 
 
2
  import os
3
- import random
4
- import torch
5
- import numpy as np
6
- from diffusers import DiffusionPipeline, AutoencoderKL
7
- from PIL import Image
8
- import re
9
-
10
- def generate_image(pipe, prompt, seed=42, randomize_seed=True, width=768, height=768, guidance_scale=4.5, num_inference_steps=20):
11
- """
12
- 使用 FLUX.1-Krea-dev 模型生成图像。
13
-
14
- Args:
15
- pipe: 配置好的 Diffusers pipeline.
16
- prompt (str): 文本提示.
17
- seed (int): 随机种子.
18
- randomize_seed (bool): 是否随机化种子.
19
- width (int): 图像宽度.
20
- height (int): 图像高度.
21
- guidance_scale (float): 指导比例.
22
- num_inference_steps (int): 推理步数.
23
-
24
- Returns:
25
- tuple[Image.Image, int]: 返回生成的 PIL 图像和使用的种子.
26
- """
27
- MAX_SEED = np.iinfo(np.int32).max
28
- if randomize_seed:
29
- seed = random.randint(0, MAX_SEED)
30
-
31
- generator = torch.Generator(device=pipe.device).manual_seed(seed)
32
-
33
- print(f"ℹ️ 使用种子: {seed}")
34
- print("🚀 开始生成图像...")
35
-
36
- # 直接调用 pipeline 生成 PIL 图像,内部会自动处理解码
37
- image = pipe(
38
- prompt=prompt,
39
- guidance_scale=guidance_scale,
40
- num_inference_steps=num_inference_steps,
41
- width=width,
42
- height=height,
43
- generator=generator,
44
- output_type="pil"
45
- ).images[0]
46
-
47
- return image, seed
48
 
49
- def main():
50
- """
51
- 主执行函数,用于解析参数和调用生成逻辑。
52
- """
53
- # --- 参数解析 ---
54
- parser = argparse.ArgumentParser(description="使用 FLUX.1-Krea-dev 模型从文本提示生成图像。")
55
- parser.add_argument("--prompt", type=str, required=True, help="用于图像生成的文本提示。")
56
- parser.add_argument("--seed", type=int, default=None, help="随机种子。如果未提供,将随机生成。")
57
- parser.add_argument("--steps", type=int, default=20, help="推理步数。")
58
- parser.add_argument("--width", type=int, default=768, help="图像宽度。")
59
- parser.add_argument("--height", type=int, default=768, help="图像高度。")
60
- parser.add_argument("--guidance", type=float, default=4.5, help="指导比例 (Guidance Scale)。")
61
- args = parser.parse_args()
62
 
63
- # --- 模型加载 ---
64
- print("⏳ 正在加载模型,请稍候...")
65
- dtype = torch.bfloat16
66
- device = "cuda" if torch.cuda.is_available() else "cpu"
67
-
68
- # 加载高质量的 VAE 解码器
69
- good_vae = AutoencoderKL.from_pretrained("black-forest-labs/FLUX.1-Krea-dev", subfolder="vae", torch_dtype=dtype)
70
-
71
- # 加载主 pipeline,并直接将高质量的 VAE 传入
72
- pipe = DiffusionPipeline.from_pretrained("black-forest-labs/FLUX.1-Krea-dev", torch_dtype=dtype, vae=good_vae).to(device)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
73
 
74
- if device == "cuda":
75
- torch.cuda.empty_cache()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
76
 
77
- print(f"✅ 模型加载完成,使用设备: {device}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
78
 
79
- # --- 图像生成 ---
80
- print(f"🎨 开始为提示生成图像: '{args.prompt}'")
 
 
 
 
81
 
82
- randomize = args.seed is None
83
- # 如果用户没有指定种子,则在调用函数时随机化;否则使用用户指定的种子
84
- seed_value = args.seed if not randomize else 42
85
-
86
- generated_image, used_seed = generate_image(
87
- pipe=pipe,
88
- prompt=args.prompt,
89
- seed=seed_value,
90
- randomize_seed=randomize,
91
- width=args.width,
92
- height=args.height,
93
- num_inference_steps=args.steps,
94
- guidance_scale=args.guidance
95
- )
96
-
97
- # --- 保存图像 ---
98
- output_dir = "output"
99
- os.makedirs(output_dir, exist_ok=True)
100
-
101
- # 清理提示词以用作安全的文件名
102
- safe_prompt = re.sub(r'[^\w\s-]', '', args.prompt).strip()
103
- safe_prompt = re.sub(r'[-\s]+', '_', safe_prompt)
104
 
105
- # 防止文件名过长
106
- filename = f"{safe_prompt[:50]}_{used_seed}.png"
107
- filepath = os.path.join(output_dir, filename)
108
-
109
- print(f"💾 正在保存图像到: {filepath}")
110
- generated_image.save(filepath)
111
-
112
- print("🎉 完成!")
113
 
114
  if __name__ == "__main__":
115
- main()
 
1
+ import asyncio
2
+ import websockets
3
+ import json
4
+ import logging
5
+ import sys
6
+ import base64
7
  import os
8
+ import argparse
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9
 
10
+ # 配置日志
11
+ logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
12
+ logger = logging.getLogger(__name__)
 
 
 
 
 
 
 
 
 
 
13
 
14
+ class FileUploader:
15
+ def __init__(self, server_url, user_token, space_id):
16
+ self.server_url = server_url
17
+ self.user_token = user_token
18
+ self.space_id = space_id
19
+ self.websocket = None
20
+
21
+ async def connect_and_upload(self, upload_dir):
22
+ """连接WebSocket并上传文件"""
23
+ # 构建WebSocket URL
24
+ ws_url = self.server_url.replace('http://', 'ws://').replace('https://', 'wss://')
25
+ uri = f"{ws_url}/ws/upload/{self.space_id}"
26
+
27
+ logger.info(f"正在连接到 {uri}")
28
+
29
+ try:
30
+ async with websockets.connect(uri, ping_interval=20, ping_timeout=60) as websocket:
31
+ self.websocket = websocket
32
+ logger.info("WebSocket连接成功")
33
+
34
+ # 发送认证信息
35
+ auth_message = {
36
+ "type": "auth",
37
+ "token": self.user_token,
38
+ "space_id": self.space_id
39
+ }
40
+ await websocket.send(json.dumps(auth_message))
41
+ logger.info("已发送认证信息")
42
+
43
+ # 等待认证响应
44
+ response = await websocket.recv()
45
+ auth_result = json.loads(response)
46
+
47
+ if auth_result.get("type") == "auth_success":
48
+ logger.info("认证成功,开始上传文件")
49
+ await self.upload_directory(upload_dir)
50
+ else:
51
+ logger.error(f"认证失败: {auth_result.get('message', '未知错误')}")
52
+ return
53
+
54
+ except Exception as e:
55
+ logger.error(f"连接失败: {e}")
56
 
57
+ async def upload_directory(self, upload_dir):
58
+ """扫描并上传目录中的所有文件"""
59
+ if not os.path.exists(upload_dir):
60
+ logger.error(f"目录不存在: {upload_dir}")
61
+ return
62
+
63
+ logger.info(f"🔍 开始扫描目录: {upload_dir}")
64
+ logger.info(f"🔑 Space ID: {self.space_id}")
65
+ logger.info("-" * 50)
66
+
67
+ # 获取所有文件
68
+ all_files = []
69
+ for root, dirs, files in os.walk(upload_dir):
70
+ for file in files:
71
+ file_path = os.path.join(root, file)
72
+ if os.path.isfile(file_path):
73
+ all_files.append(file_path)
74
+
75
+ if not all_files:
76
+ logger.info("📁 目录中没有找到任何文件")
77
+ # 发送完成消息
78
+ await self.websocket.send(json.dumps({
79
+ "type": "upload_complete",
80
+ "success_count": 0,
81
+ "failed_count": 0
82
+ }))
83
+ return
84
+
85
+ logger.info(f"📁 找到 {len(all_files)} 个文件,开始上传...")
86
+
87
+ success_count = 0
88
+ failed_count = 0
89
+
90
+ for file_path in all_files:
91
+ try:
92
+ if await self.upload_file(file_path):
93
+ success_count += 1
94
+ else:
95
+ failed_count += 1
96
+ # 稍微延迟一下,避免发送过快
97
+ await asyncio.sleep(0.1)
98
+ except Exception as e:
99
+ logger.error(f"上传文件 {file_path} 时发生异常: {e}")
100
+ failed_count += 1
101
+
102
+ # 发送上传完成消息
103
+ await self.websocket.send(json.dumps({
104
+ "type": "upload_complete",
105
+ "success_count": success_count,
106
+ "failed_count": failed_count
107
+ }))
108
+
109
+ logger.info("-" * 50)
110
+ logger.info(f"📊 上传完成! 成功: {success_count}, 失败: {failed_count}")
111
+
112
+ if success_count > 0:
113
+ logger.info("🎉 文件已成功上传到您的网盘!")
114
+
115
+ return success_count, failed_count
116
 
117
+ async def upload_file(self, file_path):
118
+ """上传单个文件"""
119
+ if not os.path.exists(file_path):
120
+ logger.error(f"文件不存在: {file_path}")
121
+ return False
122
+
123
+ filename = os.path.basename(file_path)
124
+ logger.info(f"正在上传文件: {filename}")
125
+
126
+ try:
127
+ with open(file_path, 'rb') as f:
128
+ file_content = f.read()
129
+ file_b64 = base64.b64encode(file_content).decode('utf-8')
130
+
131
+ file_message = {
132
+ "type": "file_upload",
133
+ "filename": filename,
134
+ "content": file_b64,
135
+ "file_path": os.path.relpath(file_path)
136
+ }
137
+
138
+ await self.websocket.send(json.dumps(file_message))
139
+
140
+ # 等待上传响应
141
+ response = await websocket.recv()
142
+ result = json.loads(response)
143
+
144
+ if result.get("type") == "upload_success":
145
+ logger.info(f"✅ 文件 '{filename}' 上传成功!")
146
+ return True
147
+ else:
148
+ logger.error(f"❌ 上传失败: {result.get('message', '未知错误')}")
149
+ return False
150
+
151
+ except Exception as e:
152
+ logger.error(f"上传文件 {file_path} 时出错: {e}")
153
+ return False
154
 
155
+ async def main():
156
+ parser = argparse.ArgumentParser(description="WebSocket文件上传器 - 扫描并上传指定文件夹中的所有文件")
157
+ parser.add_argument("user_token", help="您的用户令牌(API密钥)")
158
+ parser.add_argument("space_id", help="Space ID")
159
+ parser.add_argument("--server", default="ws://127.0.0.1:5001", help="服务器的 WebSocket URL 地址 (默认: ws://127.0.0.1:5001)")
160
+ parser.add_argument("--upload-dir", default="output", help="要上传的目录 (默认: output)")
161
 
162
+ args = parser.parse_args()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
163
 
164
+ uploader = FileUploader(args.server, args.user_token, args.space_id)
165
+ await uploader.connect_and_upload(args.upload_dir)
 
 
 
 
 
 
166
 
167
  if __name__ == "__main__":
168
+ asyncio.run(main())