nftnik commited on
Commit
91d99a8
·
verified ·
1 Parent(s): fd4cd12

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +109 -267
app.py CHANGED
@@ -1,290 +1,132 @@
1
  import os
2
  import random
 
 
3
  import torch
4
  import numpy as np
5
  import gradio as gr
 
6
  from diffusers import FluxPipeline
7
  from translatepy import Translator
8
 
9
  # -----------------------------------------------------------------------------
10
  # CONFIGURATION
11
  # -----------------------------------------------------------------------------
12
- class Config:
13
- MODEL_ID = "black-forest-labs/FLUX.1-dev"
14
- DEFAULT_LORA = "nftnik/BR_ohwx_V1"
15
- DEFAULT_WEIGHT_NAME = "BR_ohwx.safetensors"
16
- MAX_SEED = int(np.iinfo(np.int32).max)
17
- CSS = "footer { visibility: hidden; }"
18
- DEFAULT_WIDTH = 896
19
- DEFAULT_HEIGHT = 1152
20
- DEFAULT_GUIDANCE_SCALE = 3.5
21
- DEFAULT_STEPS = 35
22
- DEFAULT_LORA_SCALE = 1.0
23
- DEFAULT_TRIGGER_WORD = "ohwx"
24
- # Memory optimization configs
25
- ENABLE_MEMORY_EFFICIENT_ATTENTION = True
26
- ENABLE_SEQUENTIAL_CPU_OFFLOAD = True
27
- ENABLE_ATTENTION_SLICING = "max"
28
 
29
  # -----------------------------------------------------------------------------
30
- # FluxGenerator class to handle image generation
31
  # -----------------------------------------------------------------------------
32
- class FluxGenerator:
33
- def __init__(self):
34
- os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"
35
- self.translator = Translator()
36
- self.device = self._get_optimal_device()
37
- print(f"Using {self.device.upper()}")
38
-
39
- self.pipe = None
40
- self._initialize_pipeline()
41
-
42
- def _get_optimal_device(self):
43
- """Determine the optimal device based on available resources"""
44
- if torch.cuda.is_available():
45
- try:
46
- gpu_memory = torch.cuda.get_device_properties(0).total_memory
47
- if gpu_memory > 10 * 1024 * 1024 * 1024: # More than 10GB
48
- return "cuda"
49
- else:
50
- print("Limited GPU memory detected. Will still use CUDA with memory optimizations.")
51
- return "cuda"
52
- except:
53
- print("Error checking GPU memory, falling back to CPU")
54
- return "cpu"
55
- else:
56
- return "cpu"
57
-
58
- def _initialize_pipeline(self):
59
- """Initialize the Flux pipeline with memory optimizations"""
60
- try:
61
- print("Loading Flux model...")
62
- pipe_kwargs = {
63
- "torch_dtype": torch.bfloat16 if self.device == "cuda" else torch.float32,
64
- }
65
- self.pipe = FluxPipeline.from_pretrained(Config.MODEL_ID, **pipe_kwargs)
66
-
67
- # Apply memory optimizations
68
- if Config.ENABLE_MEMORY_EFFICIENT_ATTENTION and self.device == "cuda":
69
- print("Enabling memory efficient attention")
70
- self.pipe.enable_xformers_memory_efficient_attention()
71
-
72
- if Config.ENABLE_ATTENTION_SLICING:
73
- print("Enabling attention slicing")
74
- self.pipe.enable_attention_slicing(Config.ENABLE_ATTENTION_SLICING)
75
-
76
- if Config.ENABLE_SEQUENTIAL_CPU_OFFLOAD and self.device == "cuda":
77
- print("Enabling sequential CPU offload")
78
- self.pipe.enable_sequential_cpu_offload()
79
- else:
80
- # Only move to device if not offloading
81
- self.pipe.to(self.device)
82
-
83
- print(f"Loading default LoRA: {Config.DEFAULT_LORA}")
84
- self.pipe.load_lora_weights(Config.DEFAULT_LORA, weight_name=Config.DEFAULT_WEIGHT_NAME)
85
-
86
- print("Model initialization complete")
87
- except Exception as e:
88
- error_msg = f"Error initializing pipeline: {str(e)}"
89
- print(error_msg)
90
- raise
91
-
92
- def load_lora(self, lora_path):
93
- """Load a new LoRA model"""
94
- try:
95
- print("Unloading previous LoRA weights...")
96
- self.pipe.unload_lora_weights()
97
-
98
- if not lora_path:
99
- print("No LoRA path provided, skipping LoRA loading.")
100
- return gr.update(value="")
101
-
102
- print(f"Loading LoRA from {lora_path}...")
103
- self.pipe.load_lora_weights(lora_path)
104
- print("LoRA loaded successfully.")
105
- return gr.update(label="LoRA Loaded Successfully")
106
- except Exception as e:
107
- error_msg = f"Failed to load LoRA from {lora_path}: {str(e)}"
108
- print(error_msg)
109
- raise gr.Error(error_msg)
110
-
111
- def _clear_memory(self):
112
- """Clear CUDA memory cache"""
113
- if self.device == "cuda":
114
- try:
115
- print("Clearing CUDA memory cache...")
116
- torch.cuda.empty_cache()
117
- if hasattr(torch.cuda, "amp") and hasattr(torch.cuda.amp, "autocast"):
118
- torch.cuda.amp.clear_autocast_cache()
119
- except Exception as e:
120
- print(f"Warning: Failed to clear CUDA memory: {str(e)}")
121
-
122
- def generate(self, prompt, lora_word, lora_scale=Config.DEFAULT_LORA_SCALE,
123
- width=Config.DEFAULT_WIDTH, height=Config.DEFAULT_HEIGHT,
124
- guidance_scale=Config.DEFAULT_GUIDANCE_SCALE,
125
- steps=Config.DEFAULT_STEPS, seed=-1, num_images=1):
126
- """Generate images from a prompt with memory optimizations."""
127
- try:
128
- print(f"Generating image for prompt: '{prompt}'")
129
- self._clear_memory()
130
-
131
- if not Config.ENABLE_SEQUENTIAL_CPU_OFFLOAD:
132
- print(f"Moving model to {self.device}")
133
- self.pipe.to(self.device)
134
-
135
- seed = random.randint(0, Config.MAX_SEED) if seed == -1 else int(seed)
136
- print(f"Using seed: {seed}")
137
- generator = torch.Generator(device=self.device).manual_seed(seed)
138
-
139
- print("Translating prompt if needed...")
140
- prompt_english = str(self.translator.translate(prompt, "English"))
141
- full_prompt = f"{prompt_english} {lora_word}"
142
- print(f"Full prompt: '{full_prompt}'")
143
-
144
- # If GPU memory is less than 8GB, scale resolution
145
- if (self.device == "cuda" and
146
- torch.cuda.get_device_properties(0).total_memory < 8 * 1024 * 1024 * 1024):
147
- original_width, original_height = width, height
148
- width = int(width * 0.85)
149
- height = int(height * 0.85)
150
- print(f"Memory is tight. Scaled resolution from {original_width}x{original_height} to {width}x{height}")
151
-
152
- print(f"Starting generation with {steps} steps, guidance scale {guidance_scale}")
153
- with torch.autocast("cuda", enabled=(self.device == "cuda")):
154
- result = self.pipe(
155
- prompt=full_prompt,
156
- height=height,
157
- width=width,
158
- guidance_scale=guidance_scale,
159
- output_type="pil",
160
- num_inference_steps=steps,
161
- num_images_per_prompt=num_images,
162
- generator=generator,
163
- joint_attention_kwargs={"scale": lora_scale},
164
- )
165
- print("Generation complete, returning images.")
166
- self._clear_memory()
167
- return result.images, seed
168
- except Exception as e:
169
- error_msg = f"Image generation failed: {str(e)}"
170
- print(error_msg)
171
- self._clear_memory()
172
- raise gr.Error(error_msg)
173
-
174
 
175
  # -----------------------------------------------------------------------------
176
- # UI Builder class
177
  # -----------------------------------------------------------------------------
178
- class FluxUI:
179
- def __init__(self, generator):
180
- self.generator = generator
181
- self.example_prompts = [
182
- ["Medium-shot portrait, ohwx blue alien, wearing black techwear with a high collar, standing inside a futuristic VR showroom.", "ohwx", 0.9],
183
- ["ohwx blue alien, wearing black techwear with a high collar, immersed in a digital cybernetic landscape.", "ohwx", 0.9],
184
- ["full-body shot, ohwx blue alien, wearing black techwear with a high collar, black cyber sneakers, running through a neon-lit cyberpunk alley at night.", "ohwx", 0.9],
185
- ["ohwx blue alien, wearing black techwear with a high collar, sitting inside a sleek, high-tech VR capsule, immersed in an augmented reality experience.", "ohwx", 0.9]
186
- ]
187
-
188
- def build(self):
189
- with gr.Blocks(css=Config.CSS) as demo:
190
- gr.HTML("<h1><center>BR METAVERSO - Avatar Generator</center></h1>")
191
-
192
- status_markdown = gr.Markdown("**🟢 Ready**", visible=True)
193
-
194
- with gr.Row():
195
- with gr.Column(scale=4):
196
- gallery = gr.Gallery(label="Flux Generated Image", columns=1, preview=True, height=600)
197
- prompt_input = gr.Textbox(
198
- label="Enter Your Prompt",
199
- lines=2,
200
- placeholder="Type your avatar description..."
201
- )
202
- generate_btn = gr.Button(value="Generate", variant="primary")
203
-
204
- with gr.Accordion("Advanced Options", open=True):
205
- with gr.Row():
206
- with gr.Column():
207
- width_slider = gr.Slider(label="Width", minimum=512, maximum=1920, step=8, value=Config.DEFAULT_WIDTH)
208
- height_slider = gr.Slider(label="Height", minimum=512, maximum=1920, step=8, value=Config.DEFAULT_HEIGHT)
209
- with gr.Column():
210
- guidance_slider = gr.Slider(label="Guidance Scale", minimum=3.5, maximum=7, step=0.1, value=Config.DEFAULT_GUIDANCE_SCALE)
211
- steps_slider = gr.Slider(label="Steps", minimum=1, maximum=100, step=1, value=Config.DEFAULT_STEPS)
212
-
213
- with gr.Row():
214
- with gr.Column():
215
- seed_slider = gr.Slider(label="Seed (-1 random)", minimum=-1, maximum=Config.MAX_SEED, step=1, value=-1)
216
- nums_slider = gr.Slider(label="Image Count", minimum=1, maximum=2, step=1, value=1)
217
- with gr.Column():
218
- lora_scale_slider = gr.Slider(label="LoRA Scale", minimum=0.1, maximum=2.0, step=0.1, value=Config.DEFAULT_LORA_SCALE)
219
-
220
- with gr.Row():
221
- with gr.Column():
222
- lora_add_text = gr.Textbox(label="Flux LoRA Path", lines=1, value=Config.DEFAULT_LORA)
223
- with gr.Column():
224
- lora_word_text = gr.Textbox(label="Flux LoRA Trigger Word", lines=1, value=Config.DEFAULT_TRIGGER_WORD)
225
-
226
- load_lora_btn = gr.Button(value="Load Custom LoRA", variant="secondary")
227
-
228
- # Examples
229
- gr.Examples(
230
- examples=self.example_prompts,
231
- inputs=[prompt_input, lora_word_text, lora_scale_slider],
232
- outputs=[],
233
- cache_examples=False,
234
- examples_per_page=4
235
- )
236
-
237
- # Helper functions for UI status
238
- def update_status_processing():
239
- return "**⏳ Processing...**"
240
-
241
- def update_status_done():
242
- return "**✅ Done!**"
243
-
244
- # Workflow for generate
245
- generate_btn.click(
246
- fn=update_status_processing,
247
- inputs=[],
248
- outputs=[status_markdown]
249
- ).then(
250
- fn=self.generator.generate,
251
- inputs=[
252
- prompt_input, lora_word_text, lora_scale_slider,
253
- width_slider, height_slider, guidance_slider,
254
- steps_slider, seed_slider, nums_slider
255
- ],
256
- outputs=[gallery, seed_slider]
257
- ).then(
258
- fn=update_status_done,
259
- inputs=[],
260
- outputs=[status_markdown]
261
- )
262
-
263
- # Load LoRA
264
- load_lora_btn.click(
265
- fn=self.generator.load_lora,
266
- inputs=[lora_add_text],
267
- outputs=[lora_add_text]
268
- )
269
-
270
- return demo
271
 
272
  # -----------------------------------------------------------------------------
273
- # Main entry point
274
  # -----------------------------------------------------------------------------
275
- def main():
 
 
 
 
276
  try:
277
- generator = FluxGenerator()
278
- ui = FluxUI(generator)
279
- demo = ui.build()
280
- # Launch with default queue
281
- demo.queue().launch()
282
  except Exception as e:
283
- print(f"Application startup failed: {str(e)}")
284
- with gr.Blocks() as error_demo:
285
- gr.Markdown(f"# Error Starting Application\n\n{str(e)}")
286
- gr.Markdown("Check logs for more details.")
287
- error_demo.launch()
288
 
289
- if __name__ == "__main__":
290
- main()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import os
2
  import random
3
+ import re
4
+ import requests
5
  import torch
6
  import numpy as np
7
  import gradio as gr
8
+ import spaces
9
  from diffusers import FluxPipeline
10
  from translatepy import Translator
11
 
12
  # -----------------------------------------------------------------------------
13
  # CONFIGURATION
14
  # -----------------------------------------------------------------------------
15
+ config = {
16
+ "model_id": "black-forest-labs/FLUX.1-dev",
17
+ "default_lora": "nftnik/BR_ohwx_V1",
18
+ "default_weight_name": "BR_ohwx.safetensors",
19
+ "max_seed": int(np.iinfo(np.int32).max),
20
+ "css": "footer { visibility: hidden; }",
21
+ "default_width": 896,
22
+ "default_height": 1152,
23
+ "default_guidance_scale": 3.5,
24
+ "default_steps": 35,
25
+ "default_loRa_scale": 1.0,
26
+ }
 
 
 
 
27
 
28
  # -----------------------------------------------------------------------------
29
+ # Environment and device setup
30
  # -----------------------------------------------------------------------------
31
+ os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"
32
+ translator = Translator()
33
+ HF_TOKEN = os.environ.get("HF_TOKEN", None)
34
+ device = "cuda" if torch.cuda.is_available() else "cpu"
35
+ print(f"Using {device.upper()}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
36
 
37
  # -----------------------------------------------------------------------------
38
+ # Initialize the Flux pipeline and load default LoRA
39
  # -----------------------------------------------------------------------------
40
+ pipe = FluxPipeline.from_pretrained(
41
+ config["model_id"], torch_dtype=torch.bfloat16
42
+ ).to(device)
43
+ pipe.load_lora_weights(config["default_lora"], weight_name=config["default_weight_name"])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
44
 
45
  # -----------------------------------------------------------------------------
46
+ # Function to load a new LoRA model
47
  # -----------------------------------------------------------------------------
48
+ def enable_lora(lora_add: str):
49
+ pipe.unload_lora_weights()
50
+ if not lora_add:
51
+ return gr.update(value="")
52
+ url = f"https://huggingface.co/{lora_add}/tree/main"
53
  try:
54
+ pipe.load_lora_weights(lora_add)
55
+ return gr.update(label="LoRA Loaded Now")
 
 
 
56
  except Exception as e:
57
+ raise gr.Error(f"Failed to load {lora_add}: {e}")
 
 
 
 
58
 
59
+ # -----------------------------------------------------------------------------
60
+ # Function to generate an image from a prompt
61
+ # -----------------------------------------------------------------------------
62
+ @spaces.GPU()
63
+ def generate_image(
64
+ prompt: str, lora_word: str, lora_scale: float = config["default_loRa_scale"],
65
+ width: int = config["default_width"], height: int = config["default_height"],
66
+ guidance_scale: float = config["default_guidance_scale"], steps: int = config["default_steps"],
67
+ seed: int = -1, nums: int = 1
68
+ ):
69
+ pipe.to(device)
70
+ seed = random.randint(0, config["max_seed"]) if seed == -1 else int(seed)
71
+ prompt_english = str(translator.translate(prompt, "English"))
72
+ full_prompt = f"{prompt_english} {lora_word}"
73
+ generator = torch.Generator().manual_seed(seed)
74
+
75
+ result = pipe(
76
+ prompt=full_prompt, height=height, width=width, guidance_scale=guidance_scale,
77
+ output_type="pil", num_inference_steps=steps, num_images_per_prompt=nums,
78
+ generator=generator, joint_attention_kwargs={"scale": lora_scale},
79
+ )
80
+ return result.images, seed
81
+
82
+ # -----------------------------------------------------------------------------
83
+ # Gradio UI
84
+ # -----------------------------------------------------------------------------
85
+ example_prompts = [
86
+ ["Medium-shot portrait, ohwx blue alien, wearing black techwear with a high collar, standing inside a futuristic VR showroom.", "ohwx", 0.9],
87
+ ["ohwx blue alien, wearing black techwear with a high collar, immersed in a digital cybernetic landscape.", "ohwx", 0.9],
88
+ ["full-body shot, ohwx blue alien, wearing black techwear with a high collar, black cyber sneakers, running through a neon-lit cyberpunk alley at night.", "ohwx", 0.9],
89
+ ["ohwx blue alien, wearing black techwear with a high collar, sitting inside a sleek, high-tech VR capsule, immersed in an augmented reality experience.", "ohwx", 0.9]
90
+ ]
91
+
92
+ with gr.Blocks(css=config["css"]) as demo:
93
+ gr.HTML("<h1><center>BR METAVERSO - Avatar Generator</center></h1>")
94
+
95
+ processing_status = gr.Markdown("**🟢 Ready**", visible=True) # Status indicator
96
+
97
+ with gr.Row():
98
+ with gr.Column(scale=4):
99
+ gallery = gr.Gallery(label="Flux Generated Image", columns=1, preview=True, height=600)
100
+ prompt_input = gr.Textbox(label="Enter Your Prompt", lines=2, placeholder="Enter prompt...")
101
+ generate_btn = gr.Button(variant="primary")
102
+ with gr.Accordion("Advanced Options", open=True):
103
+ width_slider = gr.Slider(label="Width", minimum=512, maximum=1920, step=8, value=config["default_width"])
104
+ height_slider = gr.Slider(label="Height", minimum=512, maximum=1920, step=8, value=config["default_height"])
105
+ guidance_slider = gr.Slider(label="Guidance Scale", minimum=3.5, maximum=7, step=0.1, value=config["default_guidance_scale"])
106
+ steps_slider = gr.Slider(label="Steps", minimum=1, maximum=100, step=1, value=config["default_steps"])
107
+ seed_slider = gr.Slider(label="Seed", minimum=-1, maximum=config["max_seed"], step=1, value=-1)
108
+ nums_slider = gr.Slider(label="Image Count", minimum=1, maximum=2, step=1, value=1)
109
+ lora_scale_slider = gr.Slider(label="LoRA Scale", minimum=0.1, maximum=2.0, step=0.1, value=config["default_loRa_scale"])
110
+ lora_add_text = gr.Textbox(label="Flux LoRA", lines=1, value=config["default_lora"])
111
+ lora_word_text = gr.Textbox(label="Flux LoRA Trigger Word", lines=1, value="ohwx")
112
+ load_lora_btn = gr.Button(value="Load LoRA", variant="secondary")
113
+
114
+ gr.Examples(examples=example_prompts, inputs=[prompt_input, lora_word_text, lora_scale_slider], cache_examples=False, examples_per_page=4)
115
+
116
+ # Ensuring processing status updates correctly
117
+ def update_status():
118
+ return "**⏳ Processing...**"
119
+
120
+ generate_btn.click(fn=update_status, inputs=[], outputs=[processing_status]).then(
121
+ fn=generate_image,
122
+ inputs=[prompt_input, lora_word_text, lora_scale_slider, width_slider, height_slider, guidance_slider, steps_slider, seed_slider, nums_slider],
123
+ outputs=[gallery, seed_slider]
124
+ ).then(
125
+ fn=lambda: "**✅ Done!**",
126
+ inputs=[],
127
+ outputs=[processing_status]
128
+ )
129
+
130
+ load_lora_btn.click(fn=enable_lora, inputs=[lora_add_text], outputs=lora_add_text)
131
+
132
+ demo.queue().launch()