Spaces:
Runtime error
Runtime error
import argparse | |
import os | |
import random | |
import torch | |
import numpy as np | |
from diffusers import DiffusionPipeline, AutoencoderKL | |
from PIL import Image | |
import re | |
def generate_image(pipe, prompt, seed=42, randomize_seed=True, width=768, height=768, guidance_scale=4.5, num_inference_steps=20): | |
MAX_SEED = np.iinfo(np.int32).max | |
# 无论如何都先创建 generator | |
generator = torch.Generator(device=pipe.device) | |
if randomize_seed: | |
seed = random.randint(0, MAX_SEED) | |
# 使用最终确定的种子来设置 generator | |
generator.manual_seed(seed) | |
print(f"ℹ️ 使用种子: {seed}") | |
print("🚀 开始生成图像...") | |
image = pipe( | |
prompt=prompt, | |
guidance_scale=guidance_scale, | |
num_inference_steps=num_inference_steps, | |
width=width, | |
height=height, | |
generator=generator, | |
output_type="pil" | |
).images[0] | |
return image, seed | |
def main(): | |
parser = argparse.ArgumentParser(description="使用 FLUX.1-Krea-dev 模型从文本提示生成图像。") | |
parser.add_argument("--prompt", type=str, required=True, help="用于图像生成的文本提示。") | |
parser.add_argument("--seed", type=int, default=None, help="随机种子。如果未提供,将随机生成。") | |
parser.add_argument("--steps", type=int, default=20, help="推理步数。") | |
parser.add_argument("--width", type=int, default=768, help="图像宽度。") | |
parser.add_argument("--height", type=int, default=768, help="图像高度。") | |
parser.add_argument("--guidance", type=float, default=4.5, help="指导比例 (Guidance Scale)。") | |
args = parser.parse_args() | |
print("⏳ 正在加载模型,请稍候...") | |
dtype = torch.bfloat16 | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
good_vae = AutoencoderKL.from_pretrained("black-forest-labs/FLUX.1-Krea-dev", subfolder="vae", torch_dtype=dtype) | |
pipe = DiffusionPipeline.from_pretrained("black-forest-labs/FLUX.1-Krea-dev", torch_dtype=dtype, vae=good_vae).to(device) | |
if device == "cuda": | |
torch.cuda.empty_cache() | |
print(f"✅ 模型加载完成,使用设备: {device}") | |
print(f"🎨 开始为提示生成图像: '{args.prompt}'") | |
randomize = args.seed is None | |
seed_value = args.seed if not randomize else 42 | |
generated_image, used_seed = generate_image( | |
pipe=pipe, | |
prompt=args.prompt, | |
seed=seed_value, | |
randomize_seed=randomize, | |
width=args.width, | |
height=args.height, | |
num_inference_steps=args.steps, | |
guidance_scale=args.guidance | |
) | |
output_dir = "output" | |
os.makedirs(output_dir, exist_ok=True) | |
safe_prompt = re.sub(r'[^\w\s-]', '', args.prompt).strip() | |
safe_prompt = re.sub(r'[-\s]+', '_', safe_prompt) | |
filename = f"{safe_prompt[:50]}_{used_seed}.png" | |
filepath = os.path.join(output_dir, filename) | |
print(f"💾 正在保存图像到: {filepath}") | |
generated_image.save(filepath) | |
print("🎉 完成!") | |
if __name__ == "__main__": | |
main() | |