Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -1,270 +1,89 @@
|
|
1 |
import os
|
2 |
-
import random
|
3 |
-
import torch
|
4 |
-
import numpy as np
|
5 |
import gradio as gr
|
6 |
-
import
|
7 |
-
|
8 |
-
|
9 |
-
|
10 |
-
|
11 |
-
|
12 |
-
|
13 |
-
|
14 |
-
MODEL_ID = "black-forest-labs/FLUX.1-dev"
|
15 |
-
DEFAULT_LORA = "nftnik/BR_ohwx_V1"
|
16 |
-
DEFAULT_WEIGHT_NAME = "BR_ohwx.safetensors"
|
17 |
-
MAX_SEED = int(np.iinfo(np.int32).max)
|
18 |
-
CSS = "footer { visibility: hidden; }"
|
19 |
-
DEFAULT_WIDTH = 896
|
20 |
-
DEFAULT_HEIGHT = 1152
|
21 |
-
DEFAULT_GUIDANCE_SCALE = 3.5
|
22 |
-
DEFAULT_STEPS = 35
|
23 |
-
DEFAULT_LORA_SCALE = 1.0
|
24 |
-
DEFAULT_TRIGGER_WORD = "ohwx"
|
25 |
-
|
26 |
-
|
27 |
-
# -----------------------------------------------------------------------------
|
28 |
-
# FluxGenerator class to handle image generation
|
29 |
-
# -----------------------------------------------------------------------------
|
30 |
-
class FluxGenerator:
|
31 |
-
def __init__(self):
|
32 |
-
# Environment setup
|
33 |
-
os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"
|
34 |
-
self.translator = Translator()
|
35 |
-
self.device = "cuda" if torch.cuda.is_available() else "cpu"
|
36 |
-
print(f"Using {self.device.upper()}")
|
37 |
-
|
38 |
-
# Initialize pipeline
|
39 |
-
self.pipe = self._initialize_pipeline()
|
40 |
-
|
41 |
-
def _initialize_pipeline(self):
|
42 |
-
"""Initialize the Flux pipeline and load default LoRA"""
|
43 |
-
try:
|
44 |
-
pipe = FluxPipeline.from_pretrained(
|
45 |
-
Config.MODEL_ID, torch_dtype=torch.bfloat16
|
46 |
-
).to(self.device)
|
47 |
-
pipe.load_lora_weights(Config.DEFAULT_LORA, weight_name=Config.DEFAULT_WEIGHT_NAME)
|
48 |
-
return pipe
|
49 |
-
except Exception as e:
|
50 |
-
print(f"Error initializing pipeline: {e}")
|
51 |
-
raise
|
52 |
|
53 |
-
|
54 |
-
|
55 |
-
|
56 |
-
|
57 |
-
|
58 |
-
|
59 |
-
|
60 |
-
|
61 |
-
|
62 |
-
except Exception as e:
|
63 |
-
error_msg = f"Failed to load LoRA from {lora_path}: {e}"
|
64 |
-
print(error_msg)
|
65 |
-
raise gr.Error(error_msg)
|
66 |
|
67 |
-
|
68 |
-
|
69 |
-
|
70 |
-
guidance_scale=Config.DEFAULT_GUIDANCE_SCALE, steps=Config.DEFAULT_STEPS,
|
71 |
-
seed=-1, num_images=1):
|
72 |
-
"""Generate images from a prompt"""
|
73 |
-
try:
|
74 |
-
# Ensure the pipe is on the correct device
|
75 |
-
self.pipe.to(self.device)
|
76 |
-
|
77 |
-
# Handle seed
|
78 |
-
seed = random.randint(0, Config.MAX_SEED) if seed == -1 else int(seed)
|
79 |
-
generator = torch.Generator().manual_seed(seed)
|
80 |
-
|
81 |
-
# Translate prompt if not in English
|
82 |
-
prompt_english = str(self.translator.translate(prompt, "English"))
|
83 |
-
full_prompt = f"{prompt_english} {lora_word}"
|
84 |
-
|
85 |
-
# Generate image
|
86 |
-
result = self.pipe(
|
87 |
-
prompt=full_prompt,
|
88 |
-
height=height,
|
89 |
-
width=width,
|
90 |
-
guidance_scale=guidance_scale,
|
91 |
-
output_type="pil",
|
92 |
-
num_inference_steps=steps,
|
93 |
-
num_images_per_prompt=num_images,
|
94 |
-
generator=generator,
|
95 |
-
joint_attention_kwargs={"scale": lora_scale},
|
96 |
-
)
|
97 |
-
|
98 |
-
return result.images, seed
|
99 |
-
|
100 |
-
except Exception as e:
|
101 |
-
error_msg = f"Image generation failed: {e}"
|
102 |
-
print(error_msg)
|
103 |
-
raise gr.Error(error_msg)
|
104 |
|
|
|
|
|
|
|
|
|
|
|
105 |
|
106 |
-
|
107 |
-
|
108 |
-
|
109 |
-
|
110 |
-
|
111 |
-
|
112 |
-
self.example_prompts = [
|
113 |
-
["Medium-shot portrait, ohwx blue alien, wearing black techwear with a high collar, standing inside a futuristic VR showroom.", "ohwx", 0.9],
|
114 |
-
["ohwx blue alien, wearing black techwear with a high collar, immersed in a digital cybernetic landscape.", "ohwx", 0.9],
|
115 |
-
["full-body shot, ohwx blue alien, wearing black techwear with a high collar, black cyber sneakers, running through a neon-lit cyberpunk alley at night.", "ohwx", 0.9],
|
116 |
-
["ohwx blue alien, wearing black techwear with a high collar, sitting inside a sleek, high-tech VR capsule, immersed in an augmented reality experience.", "ohwx", 0.9]
|
117 |
-
]
|
118 |
-
|
119 |
-
def build(self):
|
120 |
-
"""Build and return the Gradio interface"""
|
121 |
-
with gr.Blocks(css=Config.CSS) as demo:
|
122 |
-
gr.HTML("<h1><center>BR METAVERSO - Avatar Generator</center></h1>")
|
123 |
-
|
124 |
-
# Status indicator
|
125 |
-
processing_status = gr.Markdown("**🟢 Ready**", visible=True)
|
126 |
-
|
127 |
-
with gr.Row():
|
128 |
-
with gr.Column(scale=4):
|
129 |
-
gallery = gr.Gallery(label="Flux Generated Image", columns=1, preview=True, height=600)
|
130 |
-
prompt_input = gr.Textbox(
|
131 |
-
label="Enter Your Prompt",
|
132 |
-
lines=2,
|
133 |
-
placeholder="Enter prompt for your avatar..."
|
134 |
-
)
|
135 |
-
generate_btn = gr.Button(value="Generate", variant="primary")
|
136 |
-
|
137 |
-
with gr.Accordion("Advanced Options", open=True):
|
138 |
-
with gr.Row():
|
139 |
-
with gr.Column():
|
140 |
-
width_slider = gr.Slider(
|
141 |
-
label="Width",
|
142 |
-
minimum=512,
|
143 |
-
maximum=1920,
|
144 |
-
step=8,
|
145 |
-
value=Config.DEFAULT_WIDTH
|
146 |
-
)
|
147 |
-
height_slider = gr.Slider(
|
148 |
-
label="Height",
|
149 |
-
minimum=512,
|
150 |
-
maximum=1920,
|
151 |
-
step=8,
|
152 |
-
value=Config.DEFAULT_HEIGHT
|
153 |
-
)
|
154 |
-
with gr.Column():
|
155 |
-
guidance_slider = gr.Slider(
|
156 |
-
label="Guidance Scale",
|
157 |
-
minimum=3.5,
|
158 |
-
maximum=7,
|
159 |
-
step=0.1,
|
160 |
-
value=Config.DEFAULT_GUIDANCE_SCALE
|
161 |
-
)
|
162 |
-
steps_slider = gr.Slider(
|
163 |
-
label="Steps",
|
164 |
-
minimum=1,
|
165 |
-
maximum=100,
|
166 |
-
step=1,
|
167 |
-
value=Config.DEFAULT_STEPS
|
168 |
-
)
|
169 |
-
|
170 |
-
with gr.Row():
|
171 |
-
with gr.Column():
|
172 |
-
seed_slider = gr.Slider(
|
173 |
-
label="Seed (-1 for random)",
|
174 |
-
minimum=-1,
|
175 |
-
maximum=Config.MAX_SEED,
|
176 |
-
step=1,
|
177 |
-
value=-1
|
178 |
-
)
|
179 |
-
nums_slider = gr.Slider(
|
180 |
-
label="Image Count",
|
181 |
-
minimum=1,
|
182 |
-
maximum=2,
|
183 |
-
step=1,
|
184 |
-
value=1
|
185 |
-
)
|
186 |
-
with gr.Column():
|
187 |
-
lora_scale_slider = gr.Slider(
|
188 |
-
label="LoRA Scale",
|
189 |
-
minimum=0.1,
|
190 |
-
maximum=2.0,
|
191 |
-
step=0.1,
|
192 |
-
value=Config.DEFAULT_LORA_SCALE
|
193 |
-
)
|
194 |
-
|
195 |
-
with gr.Row():
|
196 |
-
with gr.Column():
|
197 |
-
lora_add_text = gr.Textbox(
|
198 |
-
label="Flux LoRA Path",
|
199 |
-
lines=1,
|
200 |
-
value=Config.DEFAULT_LORA
|
201 |
-
)
|
202 |
-
with gr.Column():
|
203 |
-
lora_word_text = gr.Textbox(
|
204 |
-
label="Flux LoRA Trigger Word",
|
205 |
-
lines=1,
|
206 |
-
value=Config.DEFAULT_TRIGGER_WORD
|
207 |
-
)
|
208 |
-
|
209 |
-
load_lora_btn = gr.Button(value="Load Custom LoRA", variant="secondary")
|
210 |
-
|
211 |
-
# Examples section
|
212 |
-
gr.Examples(
|
213 |
-
examples=self.example_prompts,
|
214 |
-
inputs=[prompt_input, lora_word_text, lora_scale_slider],
|
215 |
-
cache_examples=False,
|
216 |
-
examples_per_page=4
|
217 |
-
)
|
218 |
-
|
219 |
-
# Wire up the event handlers
|
220 |
-
# Status update functions
|
221 |
-
def update_status_processing():
|
222 |
-
return "**⏳ Processing...**"
|
223 |
-
|
224 |
-
def update_status_done():
|
225 |
-
return "**✅ Done!**"
|
226 |
|
227 |
-
|
228 |
-
|
229 |
-
|
230 |
-
|
231 |
-
|
232 |
-
|
233 |
-
|
234 |
-
|
235 |
-
|
236 |
-
|
237 |
-
steps_slider, seed_slider, nums_slider
|
238 |
-
],
|
239 |
-
outputs=[gallery, seed_slider]
|
240 |
-
).then(
|
241 |
-
fn=update_status_done,
|
242 |
-
inputs=[],
|
243 |
-
outputs=[processing_status]
|
244 |
-
)
|
245 |
-
|
246 |
-
# Load LoRA button click workflow
|
247 |
-
load_lora_btn.click(
|
248 |
-
fn=self.generator.load_lora,
|
249 |
-
inputs=[lora_add_text],
|
250 |
-
outputs=[lora_add_text]
|
251 |
-
)
|
252 |
-
|
253 |
-
return demo
|
254 |
|
|
|
255 |
|
256 |
-
#
|
257 |
-
#
|
258 |
-
#
|
259 |
-
def
|
260 |
-
#
|
261 |
-
generator =
|
|
|
|
|
|
|
262 |
|
263 |
-
#
|
264 |
-
|
265 |
-
|
266 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
267 |
|
|
|
|
|
|
|
|
|
|
|
268 |
|
269 |
-
|
270 |
-
|
|
|
|
|
|
|
|
|
|
1 |
import os
|
|
|
|
|
|
|
2 |
import gradio as gr
|
3 |
+
import json
|
4 |
+
import torch
|
5 |
+
import random
|
6 |
+
import time
|
7 |
+
from diffusers import DiffusionPipeline, AutoencoderTiny, AutoencoderKL, AutoPipelineForImage2Image
|
8 |
+
from live_preview_helpers import calculate_shift, retrieve_timesteps, flux_pipe_call_that_returns_an_iterable_of_images
|
9 |
+
from diffusers.utils import load_image
|
10 |
+
from huggingface_hub import ModelCard
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
11 |
|
12 |
+
# -------------------------------------------------------------------------
|
13 |
+
# CONFIGURAÇÃO GERAL
|
14 |
+
# -------------------------------------------------------------------------
|
15 |
+
CONFIG = {
|
16 |
+
"base_model": "black-forest-labs/FLUX.1-dev",
|
17 |
+
"dtype": torch.float16, # Substituído por torch.float16 para economizar VRAM
|
18 |
+
"device": "cuda" if torch.cuda.is_available() else "cpu",
|
19 |
+
"max_seed": 2**32 - 1
|
20 |
+
}
|
|
|
|
|
|
|
|
|
21 |
|
22 |
+
# Limpa cache da GPU para evitar erro de falta de memória
|
23 |
+
torch.cuda.empty_cache()
|
24 |
+
torch.cuda.ipc_collect()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
25 |
|
26 |
+
# -------------------------------------------------------------------------
|
27 |
+
# CARREGANDO O MODELO BASE
|
28 |
+
# -------------------------------------------------------------------------
|
29 |
+
taef1 = AutoencoderTiny.from_pretrained("madebyollin/taef1", torch_dtype=CONFIG["dtype"]).to(CONFIG["device"])
|
30 |
+
good_vae = AutoencoderKL.from_pretrained(CONFIG["base_model"], subfolder="vae", torch_dtype=CONFIG["dtype"]).to(CONFIG["device"])
|
31 |
|
32 |
+
pipe = DiffusionPipeline.from_pretrained(
|
33 |
+
CONFIG["base_model"],
|
34 |
+
torch_dtype=CONFIG["dtype"],
|
35 |
+
vae=taef1,
|
36 |
+
low_cpu_mem_usage=True # Economiza memória na GPU
|
37 |
+
).to(CONFIG["device"])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
38 |
|
39 |
+
pipe_i2i = AutoPipelineForImage2Image.from_pretrained(
|
40 |
+
CONFIG["base_model"],
|
41 |
+
vae=good_vae,
|
42 |
+
transformer=pipe.transformer,
|
43 |
+
text_encoder=pipe.text_encoder,
|
44 |
+
tokenizer=pipe.tokenizer,
|
45 |
+
text_encoder_2=pipe.text_encoder_2,
|
46 |
+
tokenizer_2=pipe.tokenizer_2,
|
47 |
+
torch_dtype=CONFIG["dtype"]
|
48 |
+
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
49 |
|
50 |
+
pipe.flux_pipe_call_that_returns_an_iterable_of_images = flux_pipe_call_that_returns_an_iterable_of_images.__get__(pipe)
|
51 |
|
52 |
+
# -------------------------------------------------------------------------
|
53 |
+
# FUNÇÃO PARA GERAR IMAGEM
|
54 |
+
# -------------------------------------------------------------------------
|
55 |
+
def generate_image(prompt, steps, seed, cfg_scale, width, height, lora_scale, progress):
|
56 |
+
pipe.to(CONFIG["device"]) # Garante que o modelo está na GPU
|
57 |
+
generator = torch.Generator(device=CONFIG["device"]).manual_seed(seed)
|
58 |
+
|
59 |
+
# Medir tempo de geração
|
60 |
+
start_time = time.time()
|
61 |
|
62 |
+
# Gerando a imagem
|
63 |
+
for img in pipe.flux_pipe_call_that_returns_an_iterable_of_images(
|
64 |
+
prompt=prompt,
|
65 |
+
num_inference_steps=steps,
|
66 |
+
guidance_scale=cfg_scale,
|
67 |
+
width=width,
|
68 |
+
height=height,
|
69 |
+
generator=generator,
|
70 |
+
joint_attention_kwargs={"scale": lora_scale},
|
71 |
+
output_type="pil",
|
72 |
+
good_vae=good_vae,
|
73 |
+
):
|
74 |
+
end_time = time.time()
|
75 |
+
print(f"Tempo de geração: {end_time - start_time:.2f}s")
|
76 |
+
yield img
|
77 |
|
78 |
+
# -------------------------------------------------------------------------
|
79 |
+
# INTERFACE GRADIO
|
80 |
+
# -------------------------------------------------------------------------
|
81 |
+
with gr.Blocks(theme=gr.themes.Soft()) as app:
|
82 |
+
gr.Markdown("# FLUX Avatar Generator")
|
83 |
|
84 |
+
with gr.Row():
|
85 |
+
with gr.Column():
|
86 |
+
prompt = gr.Textbox(label="Prompt", placeholder="Descreva sua imagem...")
|
87 |
+
steps = gr.Slider(1, 50, 25, step=1, label="Passos")
|
88 |
+
cfg_scale = gr.Slider(1, 20, 3.5, step=0.5, label="Escala CFG")
|
89 |
+
width = gr.Slider(512, 1536, 896, step=64, label="Largura
|