nftnik commited on
Commit
19be4eb
·
verified ·
1 Parent(s): 7ed91d5

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +77 -258
app.py CHANGED
@@ -1,270 +1,89 @@
1
  import os
2
- import random
3
- import torch
4
- import numpy as np
5
  import gradio as gr
6
- import spaces
7
- from diffusers import FluxPipeline
8
- from translatepy import Translator
9
-
10
- # -----------------------------------------------------------------------------
11
- # CONFIGURATION
12
- # -----------------------------------------------------------------------------
13
- class Config:
14
- MODEL_ID = "black-forest-labs/FLUX.1-dev"
15
- DEFAULT_LORA = "nftnik/BR_ohwx_V1"
16
- DEFAULT_WEIGHT_NAME = "BR_ohwx.safetensors"
17
- MAX_SEED = int(np.iinfo(np.int32).max)
18
- CSS = "footer { visibility: hidden; }"
19
- DEFAULT_WIDTH = 896
20
- DEFAULT_HEIGHT = 1152
21
- DEFAULT_GUIDANCE_SCALE = 3.5
22
- DEFAULT_STEPS = 35
23
- DEFAULT_LORA_SCALE = 1.0
24
- DEFAULT_TRIGGER_WORD = "ohwx"
25
-
26
-
27
- # -----------------------------------------------------------------------------
28
- # FluxGenerator class to handle image generation
29
- # -----------------------------------------------------------------------------
30
- class FluxGenerator:
31
- def __init__(self):
32
- # Environment setup
33
- os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"
34
- self.translator = Translator()
35
- self.device = "cuda" if torch.cuda.is_available() else "cpu"
36
- print(f"Using {self.device.upper()}")
37
-
38
- # Initialize pipeline
39
- self.pipe = self._initialize_pipeline()
40
-
41
- def _initialize_pipeline(self):
42
- """Initialize the Flux pipeline and load default LoRA"""
43
- try:
44
- pipe = FluxPipeline.from_pretrained(
45
- Config.MODEL_ID, torch_dtype=torch.bfloat16
46
- ).to(self.device)
47
- pipe.load_lora_weights(Config.DEFAULT_LORA, weight_name=Config.DEFAULT_WEIGHT_NAME)
48
- return pipe
49
- except Exception as e:
50
- print(f"Error initializing pipeline: {e}")
51
- raise
52
 
53
- def load_lora(self, lora_path):
54
- """Load a new LoRA model"""
55
- try:
56
- self.pipe.unload_lora_weights()
57
- if not lora_path:
58
- return gr.update(value="")
59
-
60
- self.pipe.load_lora_weights(lora_path)
61
- return gr.update(label="LoRA Loaded Successfully")
62
- except Exception as e:
63
- error_msg = f"Failed to load LoRA from {lora_path}: {e}"
64
- print(error_msg)
65
- raise gr.Error(error_msg)
66
 
67
- @spaces.GPU()
68
- def generate(self, prompt, lora_word, lora_scale=Config.DEFAULT_LORA_SCALE,
69
- width=Config.DEFAULT_WIDTH, height=Config.DEFAULT_HEIGHT,
70
- guidance_scale=Config.DEFAULT_GUIDANCE_SCALE, steps=Config.DEFAULT_STEPS,
71
- seed=-1, num_images=1):
72
- """Generate images from a prompt"""
73
- try:
74
- # Ensure the pipe is on the correct device
75
- self.pipe.to(self.device)
76
-
77
- # Handle seed
78
- seed = random.randint(0, Config.MAX_SEED) if seed == -1 else int(seed)
79
- generator = torch.Generator().manual_seed(seed)
80
-
81
- # Translate prompt if not in English
82
- prompt_english = str(self.translator.translate(prompt, "English"))
83
- full_prompt = f"{prompt_english} {lora_word}"
84
-
85
- # Generate image
86
- result = self.pipe(
87
- prompt=full_prompt,
88
- height=height,
89
- width=width,
90
- guidance_scale=guidance_scale,
91
- output_type="pil",
92
- num_inference_steps=steps,
93
- num_images_per_prompt=num_images,
94
- generator=generator,
95
- joint_attention_kwargs={"scale": lora_scale},
96
- )
97
-
98
- return result.images, seed
99
-
100
- except Exception as e:
101
- error_msg = f"Image generation failed: {e}"
102
- print(error_msg)
103
- raise gr.Error(error_msg)
104
 
 
 
 
 
 
105
 
106
- # -----------------------------------------------------------------------------
107
- # UI Builder class
108
- # -----------------------------------------------------------------------------
109
- class FluxUI:
110
- def __init__(self, generator):
111
- self.generator = generator
112
- self.example_prompts = [
113
- ["Medium-shot portrait, ohwx blue alien, wearing black techwear with a high collar, standing inside a futuristic VR showroom.", "ohwx", 0.9],
114
- ["ohwx blue alien, wearing black techwear with a high collar, immersed in a digital cybernetic landscape.", "ohwx", 0.9],
115
- ["full-body shot, ohwx blue alien, wearing black techwear with a high collar, black cyber sneakers, running through a neon-lit cyberpunk alley at night.", "ohwx", 0.9],
116
- ["ohwx blue alien, wearing black techwear with a high collar, sitting inside a sleek, high-tech VR capsule, immersed in an augmented reality experience.", "ohwx", 0.9]
117
- ]
118
-
119
- def build(self):
120
- """Build and return the Gradio interface"""
121
- with gr.Blocks(css=Config.CSS) as demo:
122
- gr.HTML("<h1><center>BR METAVERSO - Avatar Generator</center></h1>")
123
-
124
- # Status indicator
125
- processing_status = gr.Markdown("**🟢 Ready**", visible=True)
126
-
127
- with gr.Row():
128
- with gr.Column(scale=4):
129
- gallery = gr.Gallery(label="Flux Generated Image", columns=1, preview=True, height=600)
130
- prompt_input = gr.Textbox(
131
- label="Enter Your Prompt",
132
- lines=2,
133
- placeholder="Enter prompt for your avatar..."
134
- )
135
- generate_btn = gr.Button(value="Generate", variant="primary")
136
-
137
- with gr.Accordion("Advanced Options", open=True):
138
- with gr.Row():
139
- with gr.Column():
140
- width_slider = gr.Slider(
141
- label="Width",
142
- minimum=512,
143
- maximum=1920,
144
- step=8,
145
- value=Config.DEFAULT_WIDTH
146
- )
147
- height_slider = gr.Slider(
148
- label="Height",
149
- minimum=512,
150
- maximum=1920,
151
- step=8,
152
- value=Config.DEFAULT_HEIGHT
153
- )
154
- with gr.Column():
155
- guidance_slider = gr.Slider(
156
- label="Guidance Scale",
157
- minimum=3.5,
158
- maximum=7,
159
- step=0.1,
160
- value=Config.DEFAULT_GUIDANCE_SCALE
161
- )
162
- steps_slider = gr.Slider(
163
- label="Steps",
164
- minimum=1,
165
- maximum=100,
166
- step=1,
167
- value=Config.DEFAULT_STEPS
168
- )
169
-
170
- with gr.Row():
171
- with gr.Column():
172
- seed_slider = gr.Slider(
173
- label="Seed (-1 for random)",
174
- minimum=-1,
175
- maximum=Config.MAX_SEED,
176
- step=1,
177
- value=-1
178
- )
179
- nums_slider = gr.Slider(
180
- label="Image Count",
181
- minimum=1,
182
- maximum=2,
183
- step=1,
184
- value=1
185
- )
186
- with gr.Column():
187
- lora_scale_slider = gr.Slider(
188
- label="LoRA Scale",
189
- minimum=0.1,
190
- maximum=2.0,
191
- step=0.1,
192
- value=Config.DEFAULT_LORA_SCALE
193
- )
194
-
195
- with gr.Row():
196
- with gr.Column():
197
- lora_add_text = gr.Textbox(
198
- label="Flux LoRA Path",
199
- lines=1,
200
- value=Config.DEFAULT_LORA
201
- )
202
- with gr.Column():
203
- lora_word_text = gr.Textbox(
204
- label="Flux LoRA Trigger Word",
205
- lines=1,
206
- value=Config.DEFAULT_TRIGGER_WORD
207
- )
208
-
209
- load_lora_btn = gr.Button(value="Load Custom LoRA", variant="secondary")
210
-
211
- # Examples section
212
- gr.Examples(
213
- examples=self.example_prompts,
214
- inputs=[prompt_input, lora_word_text, lora_scale_slider],
215
- cache_examples=False,
216
- examples_per_page=4
217
- )
218
-
219
- # Wire up the event handlers
220
- # Status update functions
221
- def update_status_processing():
222
- return "**⏳ Processing...**"
223
-
224
- def update_status_done():
225
- return "**✅ Done!**"
226
 
227
- # Generate button click workflow
228
- generate_btn.click(
229
- fn=update_status_processing,
230
- inputs=[],
231
- outputs=[processing_status]
232
- ).then(
233
- fn=self.generator.generate,
234
- inputs=[
235
- prompt_input, lora_word_text, lora_scale_slider,
236
- width_slider, height_slider, guidance_slider,
237
- steps_slider, seed_slider, nums_slider
238
- ],
239
- outputs=[gallery, seed_slider]
240
- ).then(
241
- fn=update_status_done,
242
- inputs=[],
243
- outputs=[processing_status]
244
- )
245
-
246
- # Load LoRA button click workflow
247
- load_lora_btn.click(
248
- fn=self.generator.load_lora,
249
- inputs=[lora_add_text],
250
- outputs=[lora_add_text]
251
- )
252
-
253
- return demo
254
 
 
255
 
256
- # -----------------------------------------------------------------------------
257
- # Main application
258
- # -----------------------------------------------------------------------------
259
- def main():
260
- # Initialize generator
261
- generator = FluxGenerator()
 
 
 
262
 
263
- # Build and launch UI
264
- ui = FluxUI(generator)
265
- demo = ui.build()
266
- demo.queue().launch()
 
 
 
 
 
 
 
 
 
 
 
267
 
 
 
 
 
 
268
 
269
- if __name__ == "__main__":
270
- main()
 
 
 
 
 
1
  import os
 
 
 
2
  import gradio as gr
3
+ import json
4
+ import torch
5
+ import random
6
+ import time
7
+ from diffusers import DiffusionPipeline, AutoencoderTiny, AutoencoderKL, AutoPipelineForImage2Image
8
+ from live_preview_helpers import calculate_shift, retrieve_timesteps, flux_pipe_call_that_returns_an_iterable_of_images
9
+ from diffusers.utils import load_image
10
+ from huggingface_hub import ModelCard
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
 
12
+ # -------------------------------------------------------------------------
13
+ # CONFIGURAÇÃO GERAL
14
+ # -------------------------------------------------------------------------
15
+ CONFIG = {
16
+ "base_model": "black-forest-labs/FLUX.1-dev",
17
+ "dtype": torch.float16, # Substituído por torch.float16 para economizar VRAM
18
+ "device": "cuda" if torch.cuda.is_available() else "cpu",
19
+ "max_seed": 2**32 - 1
20
+ }
 
 
 
 
21
 
22
+ # Limpa cache da GPU para evitar erro de falta de memória
23
+ torch.cuda.empty_cache()
24
+ torch.cuda.ipc_collect()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
25
 
26
+ # -------------------------------------------------------------------------
27
+ # CARREGANDO O MODELO BASE
28
+ # -------------------------------------------------------------------------
29
+ taef1 = AutoencoderTiny.from_pretrained("madebyollin/taef1", torch_dtype=CONFIG["dtype"]).to(CONFIG["device"])
30
+ good_vae = AutoencoderKL.from_pretrained(CONFIG["base_model"], subfolder="vae", torch_dtype=CONFIG["dtype"]).to(CONFIG["device"])
31
 
32
+ pipe = DiffusionPipeline.from_pretrained(
33
+ CONFIG["base_model"],
34
+ torch_dtype=CONFIG["dtype"],
35
+ vae=taef1,
36
+ low_cpu_mem_usage=True # Economiza memória na GPU
37
+ ).to(CONFIG["device"])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
38
 
39
+ pipe_i2i = AutoPipelineForImage2Image.from_pretrained(
40
+ CONFIG["base_model"],
41
+ vae=good_vae,
42
+ transformer=pipe.transformer,
43
+ text_encoder=pipe.text_encoder,
44
+ tokenizer=pipe.tokenizer,
45
+ text_encoder_2=pipe.text_encoder_2,
46
+ tokenizer_2=pipe.tokenizer_2,
47
+ torch_dtype=CONFIG["dtype"]
48
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
49
 
50
+ pipe.flux_pipe_call_that_returns_an_iterable_of_images = flux_pipe_call_that_returns_an_iterable_of_images.__get__(pipe)
51
 
52
+ # -------------------------------------------------------------------------
53
+ # FUNÇÃO PARA GERAR IMAGEM
54
+ # -------------------------------------------------------------------------
55
+ def generate_image(prompt, steps, seed, cfg_scale, width, height, lora_scale, progress):
56
+ pipe.to(CONFIG["device"]) # Garante que o modelo está na GPU
57
+ generator = torch.Generator(device=CONFIG["device"]).manual_seed(seed)
58
+
59
+ # Medir tempo de geração
60
+ start_time = time.time()
61
 
62
+ # Gerando a imagem
63
+ for img in pipe.flux_pipe_call_that_returns_an_iterable_of_images(
64
+ prompt=prompt,
65
+ num_inference_steps=steps,
66
+ guidance_scale=cfg_scale,
67
+ width=width,
68
+ height=height,
69
+ generator=generator,
70
+ joint_attention_kwargs={"scale": lora_scale},
71
+ output_type="pil",
72
+ good_vae=good_vae,
73
+ ):
74
+ end_time = time.time()
75
+ print(f"Tempo de geração: {end_time - start_time:.2f}s")
76
+ yield img
77
 
78
+ # -------------------------------------------------------------------------
79
+ # INTERFACE GRADIO
80
+ # -------------------------------------------------------------------------
81
+ with gr.Blocks(theme=gr.themes.Soft()) as app:
82
+ gr.Markdown("# FLUX Avatar Generator")
83
 
84
+ with gr.Row():
85
+ with gr.Column():
86
+ prompt = gr.Textbox(label="Prompt", placeholder="Descreva sua imagem...")
87
+ steps = gr.Slider(1, 50, 25, step=1, label="Passos")
88
+ cfg_scale = gr.Slider(1, 20, 3.5, step=0.5, label="Escala CFG")
89
+ width = gr.Slider(512, 1536, 896, step=64, label="Largura