dangthr commited on
Commit
417719d
·
verified ·
1 Parent(s): b383184

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +109 -125
app.py CHANGED
@@ -1,142 +1,126 @@
1
- import gradio as gr
2
- import numpy as np
3
  import random
4
- import spaces
5
  import torch
6
- from diffusers import DiffusionPipeline, FlowMatchEulerDiscreteScheduler, AutoencoderTiny, AutoencoderKL
7
- from transformers import CLIPTextModel, CLIPTokenizer,T5EncoderModel, T5TokenizerFast
8
- from live_preview_helpers import calculate_shift, retrieve_timesteps, flux_pipe_call_that_returns_an_iterable_of_images
9
-
10
- dtype = torch.bfloat16
11
- device = "cuda" if torch.cuda.is_available() else "cpu"
12
-
13
- taef1 = AutoencoderTiny.from_pretrained("madebyollin/taef1", torch_dtype=dtype).to(device)
14
- good_vae = AutoencoderKL.from_pretrained("black-forest-labs/FLUX.1-Krea-dev", subfolder="vae", torch_dtype=dtype).to(device)
15
- pipe = DiffusionPipeline.from_pretrained("black-forest-labs/FLUX.1-Krea-dev", torch_dtype=dtype, vae=taef1).to(device)
16
- torch.cuda.empty_cache()
17
 
18
- MAX_SEED = np.iinfo(np.int32).max
19
- MAX_IMAGE_SIZE = 2048
 
20
 
21
- pipe.flux_pipe_call_that_returns_an_iterable_of_images = flux_pipe_call_that_returns_an_iterable_of_images.__get__(pipe)
 
 
 
 
 
 
 
 
 
22
 
23
- @spaces.GPU(duration=25)
24
- def infer(prompt, seed=42, randomize_seed=True, width=768, height=768, guidance_scale=4.5, num_inference_steps=20, progress=gr.Progress(track_tqdm=True)):
25
- """
26
- Generate an image using the Flux.1 Krea-Dev Image Generator
27
  """
 
28
  if randomize_seed:
29
  seed = random.randint(0, MAX_SEED)
30
- generator = torch.Generator().manual_seed(seed)
31
 
32
- for img in pipe.flux_pipe_call_that_returns_an_iterable_of_images(
33
- prompt=prompt,
34
- guidance_scale=guidance_scale,
35
- num_inference_steps=num_inference_steps,
36
- width=width,
37
- height=height,
38
- generator=generator,
39
- output_type="pil",
40
- good_vae=good_vae,
41
- ):
42
- yield img, seed
43
 
44
- examples = [
45
- "a tiny astronaut hatching from an egg on mars",
46
- "a dog holding a sign that reads 'hello world'",
47
- "an anime illustration of an apple strudel",
48
- ]
 
 
 
 
 
 
 
 
49
 
50
- css="""
51
- #col-container {
52
- margin: 0 auto;
53
- max-width: 620px;
54
- }
55
- """
56
 
57
- with gr.Blocks(css=css) as demo:
 
 
 
58
 
59
- with gr.Column(elem_id="col-container"):
60
- gr.Markdown(f"""# FLUX.1 Krea [dev]
61
- FLUX.1 Krea [dev] model further tuned and customized with [Krea](https://krea.ai)
62
- [[non-commercial license](https://huggingface.co/black-forest-labs/FLUX.1-dev/blob/main/LICENSE.md)] [[blog](https://blackforestlabs.ai/announcing-black-forest-labs/)] [[model](https://huggingface.co/black-forest-labs/FLUX.1-dev)]
63
- """)
64
-
65
- with gr.Row():
66
-
67
- prompt = gr.Text(
68
- label="Prompt",
69
- show_label=False,
70
- max_lines=1,
71
- placeholder="Enter your prompt",
72
- container=False,
73
- )
74
-
75
- run_button = gr.Button("Run", scale=0)
76
-
77
- result = gr.Image(label="Result", show_label=False)
78
-
79
- with gr.Accordion("Advanced Settings", open=False):
80
-
81
- seed = gr.Slider(
82
- label="Seed",
83
- minimum=0,
84
- maximum=MAX_SEED,
85
- step=1,
86
- value=0,
87
- )
88
-
89
- randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
90
-
91
- with gr.Row():
92
-
93
- width = gr.Slider(
94
- label="Width",
95
- minimum=256,
96
- maximum=MAX_IMAGE_SIZE,
97
- step=32,
98
- value=768,
99
- )
100
-
101
- height = gr.Slider(
102
- label="Height",
103
- minimum=256,
104
- maximum=MAX_IMAGE_SIZE,
105
- step=32,
106
- value=768,
107
- )
108
-
109
- with gr.Row():
110
 
111
- guidance_scale = gr.Slider(
112
- label="Guidance Scale",
113
- minimum=1,
114
- maximum=15,
115
- step=0.1,
116
- value=4.5,
117
- )
118
-
119
- num_inference_steps = gr.Slider(
120
- label="Number of inference steps",
121
- minimum=1,
122
- maximum=50,
123
- step=1,
124
- value=20,
125
- )
126
-
127
- gr.Examples(
128
- examples = examples,
129
- fn = infer,
130
- inputs = [prompt],
131
- outputs = [result, seed],
132
- cache_examples="lazy"
133
- )
134
 
135
- gr.on(
136
- triggers=[run_button.click, prompt.submit],
137
- fn = infer,
138
- inputs = [prompt, seed, randomize_seed, width, height, guidance_scale, num_inference_steps],
139
- outputs = [result, seed]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
140
  )
141
 
142
- demo.launch(mcp_server=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import os
3
  import random
 
4
  import torch
5
+ import numpy as np
6
+ from diffusers import DiffusionPipeline, AutoencoderTiny, AutoencoderKL
7
+ from PIL import Image
8
+ import re
 
 
 
 
 
 
 
9
 
10
+ def generate_image(pipe, good_vae, prompt, seed=42, randomize_seed=True, width=768, height=768, guidance_scale=4.5, num_inference_steps=20):
11
+ """
12
+ 使用 FLUX.1-Krea-dev 模型生成图像。
13
 
14
+ Args:
15
+ pipe: Diffusers pipeline.
16
+ good_vae: 高质量的 VAE 解码器.
17
+ prompt (str): 文本提示.
18
+ seed (int): 随机种子.
19
+ randomize_seed (bool): 是否随机化种子.
20
+ width (int): 图像宽度.
21
+ height (int): 图像高度.
22
+ guidance_scale (float): 指导比例.
23
+ num_inference_steps (int): 推理步数.
24
 
25
+ Returns:
26
+ tuple[Image.Image, int]: 返回生成的 PIL 图像和使用的种子.
 
 
27
  """
28
+ MAX_SEED = np.iinfo(np.int32).max
29
  if randomize_seed:
30
  seed = random.randint(0, MAX_SEED)
 
31
 
32
+ generator = torch.Generator(device=pipe.device).manual_seed(seed)
 
 
 
 
 
 
 
 
 
 
33
 
34
+ print(f"ℹ️ 使用种子: {seed}")
35
+ print("1. 正在生成潜在向量 (latents)...")
36
+
37
+ # 使用 pipeline 生成潜在向量
38
+ latents = pipe(
39
+ prompt=prompt,
40
+ guidance_scale=guidance_scale,
41
+ num_inference_steps=num_inference_steps,
42
+ width=width,
43
+ height=height,
44
+ generator=generator,
45
+ output_type="latent"
46
+ ).images
47
 
48
+ print("2. 使用高质量 VAE 解码图像...")
49
+
50
+ # 使用高质量的 VAE 解码潜在向量
51
+ # 需要根据 VAE 的配置进行缩放
52
+ latents = latents / good_vae.config.scaling_factor
53
+ image_tensor = good_vae.decode(latents, return_dict=False)[0]
54
 
55
+ print("3. 后处理图像...")
56
+
57
+ # 将张量转换为 PIL 图像
58
+ image = pipe.image_processor.postprocess(image_tensor, output_type="pil")[0]
59
 
60
+ return image, seed
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
61
 
62
+ def main():
63
+ """
64
+ 主执行函数,用于解析参数和调用生成逻辑。
65
+ """
66
+ # --- 参数解析 ---
67
+ parser = argparse.ArgumentParser(description="使用 FLUX.1-Krea-dev 模型从文本提示生成图像。")
68
+ parser.add_argument("--prompt", type=str, required=True, help="用于图像生成的文本提示。")
69
+ parser.add_argument("--seed", type=int, default=None, help="随机种子。如果未提供,将随机生成。")
70
+ parser.add_argument("--steps", type=int, default=20, help="推理步数。")
71
+ parser.add_argument("--width", type=int, default=768, help="图像宽度。")
72
+ parser.add_argument("--height", type=int, default=768, help="图像高度。")
73
+ args = parser.parse_args()
 
 
 
 
 
 
 
 
 
 
 
74
 
75
+ # --- 模型加载 ---
76
+ print("⏳ 正在加载模型,请稍候...")
77
+ dtype = torch.bfloat16
78
+ device = "cuda" if torch.cuda.is_available() else "cpu"
79
+
80
+ # 加载两个 VAE:一个用于快速预览(在 pipeline 中),一个用于高质量最终输出
81
+ taef1 = AutoencoderTiny.from_pretrained("madebyollin/taef1", torch_dtype=dtype).to(device)
82
+ good_vae = AutoencoderKL.from_pretrained("black-forest-labs/FLUX.1-Krea-dev", subfolder="vae", torch_dtype=dtype).to(device)
83
+
84
+ # 加载主 pipeline,并指定使用较小的 VAE 进行快速潜在向量生成
85
+ pipe = DiffusionPipeline.from_pretrained("black-forest-labs/FLUX.1-Krea-dev", torch_dtype=dtype, vae=taef1).to(device)
86
+
87
+ if device == "cuda":
88
+ torch.cuda.empty_cache()
89
+
90
+ print(f"✅ 模型加载完成,使用设备: {device}")
91
+
92
+ # --- 图像生成 ---
93
+ print(f"🚀 开始为提示生成图像: '{args.prompt}'")
94
+
95
+ randomize = args.seed is None
96
+ seed_value = args.seed if not randomize else 42 # 如果指定了种子则使用,否则 generate_image 会随机生成
97
+
98
+ generated_image, used_seed = generate_image(
99
+ pipe=pipe,
100
+ good_vae=good_vae,
101
+ prompt=args.prompt,
102
+ seed=seed_value,
103
+ randomize_seed=randomize,
104
+ width=args.width,
105
+ height=args.height,
106
+ num_inference_steps=args.steps
107
  )
108
 
109
+ # --- 保存图像 ---
110
+ output_dir = "output"
111
+ os.makedirs(output_dir, exist_ok=True)
112
+
113
+ # 清理提示词以用作文件名
114
+ safe_prompt = re.sub(r'[^\w\s-]', '', args.prompt).strip()
115
+ safe_prompt = re.sub(r'[-\s]+', '_', safe_prompt)
116
+
117
+ filename = f"{safe_prompt[:50]}_{used_seed}.png"
118
+ filepath = os.path.join(output_dir, filename)
119
+
120
+ print(f"💾 正在保存图像到: {filepath}")
121
+ generated_image.save(filepath)
122
+
123
+ print("🎉 完成!")
124
+
125
+ if __name__ == "__main__":
126
+ main()