Spaces:
Runtime error
Runtime error
import argparse | |
import os | |
import random | |
import torch | |
import numpy as np | |
from diffusers import DiffusionPipeline, AutoencoderKL | |
from PIL import Image | |
import re | |
def generate_image(pipe, prompt, seed=42, randomize_seed=True, width=768, height=768, guidance_scale=4.5, num_inference_steps=20): | |
""" | |
使用 FLUX.1-Krea-dev 模型生成图像。 | |
Args: | |
pipe: 配置好的 Diffusers pipeline. | |
prompt (str): 文本提示. | |
seed (int): 随机种子. | |
randomize_seed (bool): 是否随机化种子. | |
width (int): 图像宽度. | |
height (int): 图像高度. | |
guidance_scale (float): 指导比例. | |
num_inference_steps (int): 推理步数. | |
Returns: | |
tuple[Image.Image, int]: 返回生成的 PIL 图像和使用的种子. | |
""" | |
MAX_SEED = np.iinfo(np.int32).max | |
if randomize_seed: | |
seed = random.randint(0, MAX_SEED) | |
generator = torch.Generator(device=pipe.device).manual_seed(seed) | |
print(f"ℹ️ 使用种子: {seed}") | |
print("🚀 开始生成图像...") | |
# 直接调用 pipeline 生成 PIL 图像,内部会自动处理解码 | |
image = pipe( | |
prompt=prompt, | |
guidance_scale=guidance_scale, | |
num_inference_steps=num_inference_steps, | |
width=width, | |
height=height, | |
generator=generator, | |
output_type="pil" | |
).images[0] | |
return image, seed | |
def main(): | |
""" | |
主执行函数,用于解析参数和调用生成逻辑。 | |
""" | |
# --- 参数解析 --- | |
parser = argparse.ArgumentParser(description="使用 FLUX.1-Krea-dev 模型从文本提示生成图像。") | |
parser.add_argument("--prompt", type=str, required=True, help="用于图像生成的文本提示。") | |
parser.add_argument("--seed", type=int, default=None, help="随机种子。如果未提供,将随机生成。") | |
parser.add_argument("--steps", type=int, default=20, help="推理步数。") | |
parser.add_argument("--width", type=int, default=768, help="图像宽度。") | |
parser.add_argument("--height", type=int, default=768, help="图像高度。") | |
parser.add_argument("--guidance", type=float, default=4.5, help="指导比例 (Guidance Scale)。") | |
args = parser.parse_args() | |
# --- 模型加载 --- | |
print("⏳ 正在加载模型,请稍候...") | |
dtype = torch.bfloat16 | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
# 加载高质量的 VAE 解码器 | |
good_vae = AutoencoderKL.from_pretrained("black-forest-labs/FLUX.1-Krea-dev", subfolder="vae", torch_dtype=dtype) | |
# 加载主 pipeline,并直接将高质量的 VAE 传入 | |
pipe = DiffusionPipeline.from_pretrained("black-forest-labs/FLUX.1-Krea-dev", torch_dtype=dtype, vae=good_vae).to(device) | |
if device == "cuda": | |
torch.cuda.empty_cache() | |
print(f"✅ 模型加载完成,使用设备: {device}") | |
# --- 图像生成 --- | |
print(f"🎨 开始为提示生成图像: '{args.prompt}'") | |
randomize = args.seed is None | |
# 如果用户没有指定种子,则在调用函数时随机化;否则使用用户指定的种子 | |
seed_value = args.seed if not randomize else 42 | |
generated_image, used_seed = generate_image( | |
pipe=pipe, | |
prompt=args.prompt, | |
seed=seed_value, | |
randomize_seed=randomize, | |
width=args.width, | |
height=args.height, | |
num_inference_steps=args.steps, | |
guidance_scale=args.guidance | |
) | |
# --- 保存图像 --- | |
output_dir = "output" | |
os.makedirs(output_dir, exist_ok=True) | |
# 清理提示词以用作安全的文件名 | |
safe_prompt = re.sub(r'[^\w\s-]', '', args.prompt).strip() | |
safe_prompt = re.sub(r'[-\s]+', '_', safe_prompt) | |
# 防止文件名过长 | |
filename = f"{safe_prompt[:50]}_{used_seed}.png" | |
filepath = os.path.join(output_dir, filename) | |
print(f"💾 正在保存图像到: {filepath}") | |
generated_image.save(filepath) | |
print("🎉 完成!") | |
if __name__ == "__main__": | |
main() | |