File size: 4,051 Bytes
417719d
 
c925b1b
 
417719d
6b7873d
417719d
 
c925b1b
6b7873d
417719d
 
c925b1b
417719d
6b7873d
417719d
 
 
 
 
 
 
c925b1b
417719d
 
e5681ab
417719d
c925b1b
 
 
417719d
c925b1b
417719d
6b7873d
417719d
6b7873d
 
417719d
 
 
 
 
 
6b7873d
 
c925b1b
417719d
c925b1b
417719d
 
 
 
 
 
 
 
 
 
 
6b7873d
417719d
c925b1b
417719d
 
 
 
 
6b7873d
 
417719d
6b7873d
 
417719d
 
 
 
 
 
 
6b7873d
417719d
 
6b7873d
 
417719d
 
 
 
 
 
 
 
6b7873d
 
c925b1b
 
417719d
 
 
 
6b7873d
417719d
 
 
6b7873d
417719d
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
import argparse
import os
import random
import torch
import numpy as np
from diffusers import DiffusionPipeline, AutoencoderKL
from PIL import Image
import re

def generate_image(pipe, prompt, seed=42, randomize_seed=True, width=768, height=768, guidance_scale=4.5, num_inference_steps=20):
    """
    使用 FLUX.1-Krea-dev 模型生成图像。

    Args:
        pipe: 配置好的 Diffusers pipeline.
        prompt (str): 文本提示.
        seed (int): 随机种子.
        randomize_seed (bool): 是否随机化种子.
        width (int): 图像宽度.
        height (int): 图像高度.
        guidance_scale (float): 指导比例.
        num_inference_steps (int): 推理步数.

    Returns:
        tuple[Image.Image, int]: 返回生成的 PIL 图像和使用的种子.
    """
    MAX_SEED = np.iinfo(np.int32).max
    if randomize_seed:
        seed = random.randint(0, MAX_SEED)
    
    generator = torch.Generator(device=pipe.device).manual_seed(seed)
    
    print(f"ℹ️  使用种子: {seed}")
    print("🚀 开始生成图像...")
    
    # 直接调用 pipeline 生成 PIL 图像,内部会自动处理解码
    image = pipe(
        prompt=prompt,
        guidance_scale=guidance_scale,
        num_inference_steps=num_inference_steps,
        width=width,
        height=height,
        generator=generator,
        output_type="pil"
    ).images[0]
    
    return image, seed

def main():
    """
    主执行函数,用于解析参数和调用生成逻辑。
    """
    # --- 参数解析 ---
    parser = argparse.ArgumentParser(description="使用 FLUX.1-Krea-dev 模型从文本提示生成图像。")
    parser.add_argument("--prompt", type=str, required=True, help="用于图像生成的文本提示。")
    parser.add_argument("--seed", type=int, default=None, help="随机种子。如果未提供,将随机生成。")
    parser.add_argument("--steps", type=int, default=20, help="推理步数。")
    parser.add_argument("--width", type=int, default=768, help="图像宽度。")
    parser.add_argument("--height", type=int, default=768, help="图像高度。")
    parser.add_argument("--guidance", type=float, default=4.5, help="指导比例 (Guidance Scale)。")
    args = parser.parse_args()

    # --- 模型加载 ---
    print("⏳ 正在加载模型,请稍候...")
    dtype = torch.bfloat16
    device = "cuda" if torch.cuda.is_available() else "cpu"
    
    # 加载高质量的 VAE 解码器
    good_vae = AutoencoderKL.from_pretrained("black-forest-labs/FLUX.1-Krea-dev", subfolder="vae", torch_dtype=dtype)
    
    # 加载主 pipeline,并直接将高质量的 VAE 传入
    pipe = DiffusionPipeline.from_pretrained("black-forest-labs/FLUX.1-Krea-dev", torch_dtype=dtype, vae=good_vae).to(device)
    
    if device == "cuda":
        torch.cuda.empty_cache()
    
    print(f"✅ 模型加载完成,使用设备: {device}")

    # --- 图像生成 ---
    print(f"🎨 开始为提示生成图像: '{args.prompt}'")
    
    randomize = args.seed is None
    # 如果用户没有指定种子,则在调用函数时随机化;否则使用用户指定的种子
    seed_value = args.seed if not randomize else 42 

    generated_image, used_seed = generate_image(
        pipe=pipe,
        prompt=args.prompt,
        seed=seed_value,
        randomize_seed=randomize,
        width=args.width,
        height=args.height,
        num_inference_steps=args.steps,
        guidance_scale=args.guidance
    )

    # --- 保存图像 ---
    output_dir = "output"
    os.makedirs(output_dir, exist_ok=True)
    
    # 清理提示词以用作安全的文件名
    safe_prompt = re.sub(r'[^\w\s-]', '', args.prompt).strip()
    safe_prompt = re.sub(r'[-\s]+', '_', safe_prompt)
    
    # 防止文件名过长
    filename = f"{safe_prompt[:50]}_{used_seed}.png"
    filepath = os.path.join(output_dir, filename)

    print(f"💾 正在保存图像到: {filepath}")
    generated_image.save(filepath)

    print("🎉 完成!")

if __name__ == "__main__":
    main()