import argparse import os import random import torch import numpy as np from diffusers import DiffusionPipeline, AutoencoderKL from PIL import Image import re def generate_image(pipe, prompt, seed=42, randomize_seed=True, width=768, height=768, guidance_scale=4.5, num_inference_steps=20): """ 使用 FLUX.1-Krea-dev 模型生成图像。 Args: pipe: 配置好的 Diffusers pipeline. prompt (str): 文本提示. seed (int): 随机种子. randomize_seed (bool): 是否随机化种子. width (int): 图像宽度. height (int): 图像高度. guidance_scale (float): 指导比例. num_inference_steps (int): 推理步数. Returns: tuple[Image.Image, int]: 返回生成的 PIL 图像和使用的种子. """ MAX_SEED = np.iinfo(np.int32).max if randomize_seed: seed = random.randint(0, MAX_SEED) generator = torch.Generator(device=pipe.device).manual_seed(seed) print(f"ℹ️ 使用种子: {seed}") print("🚀 开始生成图像...") # 直接调用 pipeline 生成 PIL 图像,内部会自动处理解码 image = pipe( prompt=prompt, guidance_scale=guidance_scale, num_inference_steps=num_inference_steps, width=width, height=height, generator=generator, output_type="pil" ).images[0] return image, seed def main(): """ 主执行函数,用于解析参数和调用生成逻辑。 """ # --- 参数解析 --- parser = argparse.ArgumentParser(description="使用 FLUX.1-Krea-dev 模型从文本提示生成图像。") parser.add_argument("--prompt", type=str, required=True, help="用于图像生成的文本提示。") parser.add_argument("--seed", type=int, default=None, help="随机种子。如果未提供,将随机生成。") parser.add_argument("--steps", type=int, default=20, help="推理步数。") parser.add_argument("--width", type=int, default=768, help="图像宽度。") parser.add_argument("--height", type=int, default=768, help="图像高度。") parser.add_argument("--guidance", type=float, default=4.5, help="指导比例 (Guidance Scale)。") args = parser.parse_args() # --- 模型加载 --- print("⏳ 正在加载模型,请稍候...") dtype = torch.bfloat16 device = "cuda" if torch.cuda.is_available() else "cpu" # 加载高质量的 VAE 解码器 good_vae = AutoencoderKL.from_pretrained("black-forest-labs/FLUX.1-Krea-dev", subfolder="vae", torch_dtype=dtype) # 加载主 pipeline,并直接将高质量的 VAE 传入 pipe = DiffusionPipeline.from_pretrained("black-forest-labs/FLUX.1-Krea-dev", torch_dtype=dtype, vae=good_vae).to(device) if device == "cuda": torch.cuda.empty_cache() print(f"✅ 模型加载完成,使用设备: {device}") # --- 图像生成 --- print(f"🎨 开始为提示生成图像: '{args.prompt}'") randomize = args.seed is None # 如果用户没有指定种子,则在调用函数时随机化;否则使用用户指定的种子 seed_value = args.seed if not randomize else 42 generated_image, used_seed = generate_image( pipe=pipe, prompt=args.prompt, seed=seed_value, randomize_seed=randomize, width=args.width, height=args.height, num_inference_steps=args.steps, guidance_scale=args.guidance ) # --- 保存图像 --- output_dir = "output" os.makedirs(output_dir, exist_ok=True) # 清理提示词以用作安全的文件名 safe_prompt = re.sub(r'[^\w\s-]', '', args.prompt).strip() safe_prompt = re.sub(r'[-\s]+', '_', safe_prompt) # 防止文件名过长 filename = f"{safe_prompt[:50]}_{used_seed}.png" filepath = os.path.join(output_dir, filename) print(f"💾 正在保存图像到: {filepath}") generated_image.save(filepath) print("🎉 完成!") if __name__ == "__main__": main()