Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
@@ -1,3 +1,5 @@
|
|
|
|
|
|
1 |
import random
|
2 |
import torch
|
3 |
import numpy as np
|
@@ -6,21 +8,13 @@ from PIL import Image
|
|
6 |
import re
|
7 |
|
8 |
def generate_image(pipe, prompt, seed=42, randomize_seed=True, width=768, height=768, guidance_scale=4.5, num_inference_steps=20):
|
9 |
-
|
10 |
-
|
11 |
-
|
12 |
-
|
13 |
-
|
14 |
-
|
15 |
-
prompt (str): 文本提示.
|
16 |
-
seed (int): 随机种子.
|
17 |
-
randomize_seed (bool): 是否随机化种子.
|
18 |
-
generator = torch.Generator(device=pipe.device).manual_seed(seed)
|
19 |
|
20 |
-
print(f"ℹ️ 使用种子: {seed}")
|
21 |
print("🚀 开始生成图像...")
|
22 |
-
|
23 |
-
# 直接调用 pipeline 生成 PIL 图像,内部会自动处理解码
|
24 |
image = pipe(
|
25 |
prompt=prompt,
|
26 |
guidance_scale=guidance_scale,
|
@@ -29,4 +23,61 @@ def generate_image(pipe, prompt, seed=42, randomize_seed=True, width=768, height
|
|
29 |
height=height,
|
30 |
generator=generator,
|
31 |
output_type="pil"
|
32 |
-
).images[0]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import os
|
3 |
import random
|
4 |
import torch
|
5 |
import numpy as np
|
|
|
8 |
import re
|
9 |
|
10 |
def generate_image(pipe, prompt, seed=42, randomize_seed=True, width=768, height=768, guidance_scale=4.5, num_inference_steps=20):
|
11 |
+
MAX_SEED = np.iinfo(np.int32).max
|
12 |
+
if randomize_seed:
|
13 |
+
seed = random.randint(0, MAX_SEED)
|
14 |
+
generator = torch.Generator(device=pipe.device).manual_seed(seed)
|
15 |
+
print(f"ℹ️ 使用种子: {seed}")
|
|
|
|
|
|
|
|
|
|
|
16 |
|
|
|
17 |
print("🚀 开始生成图像...")
|
|
|
|
|
18 |
image = pipe(
|
19 |
prompt=prompt,
|
20 |
guidance_scale=guidance_scale,
|
|
|
23 |
height=height,
|
24 |
generator=generator,
|
25 |
output_type="pil"
|
26 |
+
).images[0]
|
27 |
+
|
28 |
+
return image, seed
|
29 |
+
|
30 |
+
def main():
|
31 |
+
parser = argparse.ArgumentParser(description="使用 FLUX.1-Krea-dev 模型从文本提示生成图像。")
|
32 |
+
parser.add_argument("--prompt", type=str, required=True, help="用于图像生成的文本提示。")
|
33 |
+
parser.add_argument("--seed", type=int, default=None, help="随机种子。如果未提供,将随机生成。")
|
34 |
+
parser.add_argument("--steps", type=int, default=20, help="推理步数。")
|
35 |
+
parser.add_argument("--width", type=int, default=768, help="图像宽度。")
|
36 |
+
parser.add_argument("--height", type=int, default=768, help="图像高度。")
|
37 |
+
parser.add_argument("--guidance", type=float, default=4.5, help="指导比例 (Guidance Scale)。")
|
38 |
+
args = parser.parse_args()
|
39 |
+
|
40 |
+
print("⏳ 正在加载模型,请稍候...")
|
41 |
+
dtype = torch.bfloat16
|
42 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
43 |
+
|
44 |
+
good_vae = AutoencoderKL.from_pretrained("black-forest-labs/FLUX.1-Krea-dev", subfolder="vae", torch_dtype=dtype)
|
45 |
+
pipe = DiffusionPipeline.from_pretrained("black-forest-labs/FLUX.1-Krea-dev", torch_dtype=dtype, vae=good_vae).to(device)
|
46 |
+
|
47 |
+
if device == "cuda":
|
48 |
+
torch.cuda.empty_cache()
|
49 |
+
|
50 |
+
print(f"✅ 模型加载完成,使用设备: {device}")
|
51 |
+
|
52 |
+
print(f"🎨 开始为提示生成图像: '{args.prompt}'")
|
53 |
+
|
54 |
+
randomize = args.seed is None
|
55 |
+
seed_value = args.seed if not randomize else 42
|
56 |
+
|
57 |
+
generated_image, used_seed = generate_image(
|
58 |
+
pipe=pipe,
|
59 |
+
prompt=args.prompt,
|
60 |
+
seed=seed_value,
|
61 |
+
randomize_seed=randomize,
|
62 |
+
width=args.width,
|
63 |
+
height=args.height,
|
64 |
+
num_inference_steps=args.steps,
|
65 |
+
guidance_scale=args.guidance
|
66 |
+
)
|
67 |
+
|
68 |
+
output_dir = "output"
|
69 |
+
os.makedirs(output_dir, exist_ok=True)
|
70 |
+
|
71 |
+
safe_prompt = re.sub(r'[^\w\s-]', '', args.prompt).strip()
|
72 |
+
safe_prompt = re.sub(r'[-\s]+', '_', safe_prompt)
|
73 |
+
|
74 |
+
filename = f"{safe_prompt[:50]}_{used_seed}.png"
|
75 |
+
filepath = os.path.join(output_dir, filename)
|
76 |
+
|
77 |
+
print(f"💾 正在保存图像到: {filepath}")
|
78 |
+
generated_image.save(filepath)
|
79 |
+
|
80 |
+
print("🎉 完成!")
|
81 |
+
|
82 |
+
if __name__ == "__main__":
|
83 |
+
main()
|