Update app.py
Browse files
app.py
CHANGED
@@ -4,7 +4,7 @@ Author: @Raxephion 2025
|
|
4 |
"""
|
5 |
|
6 |
import gradio as gr
|
7 |
-
import numpy as np
|
8 |
import random
|
9 |
import torch
|
10 |
from diffusers import StableDiffusionPipeline
|
@@ -95,12 +95,21 @@ if INITIAL_MODEL_ID:
|
|
95 |
print(f"\nLoading initial model '{INITIAL_MODEL_ID}' on startup...")
|
96 |
try:
|
97 |
# Load the pipeline onto the initial device and dtype
|
98 |
-
|
99 |
INITIAL_MODEL_ID,
|
100 |
torch_dtype=initial_dtype_to_use,
|
101 |
safety_checker=None, # <<< SAFETY CHECKER DISABLED <<<
|
102 |
)
|
103 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
104 |
current_model_id = INITIAL_MODEL_ID
|
105 |
current_device_loaded = torch.device(initial_device_to_use)
|
106 |
print(f"Initial model loaded successfully on {current_device_loaded}.")
|
@@ -146,10 +155,11 @@ def infer(
|
|
146 |
size, # From size_dropdown
|
147 |
seed, # From seed_input (now a Slider)
|
148 |
randomize_seed, # From randomize_seed_checkbox
|
|
|
149 |
progress=gr.Progress(track_tqdm=True), # Added progress argument from template
|
150 |
):
|
151 |
"""Generates an image using the selected model and parameters on the chosen device."""
|
152 |
-
global current_pipeline, current_model_id, current_device_loaded, SCHEDULER_MAP, MAX_SEED
|
153 |
|
154 |
# This check is done before parameter parsing so we can determine device/dtype for loading
|
155 |
# Need to redo some parameter parsing here to get device_to_use early
|
@@ -165,7 +175,6 @@ def infer(
|
|
165 |
|
166 |
# 1. Load/Switch Model if necessary
|
167 |
# Check if the requested model identifier OR the requested device has changed
|
168 |
-
# Use string comparison for current_device_loaded as it's a torch.device object
|
169 |
if current_pipeline is None or current_model_id != model_identifier or (current_device_loaded is not None and str(current_device_loaded) != temp_device_to_use):
|
170 |
|
171 |
print(f"Loading model: {model_identifier} onto {temp_device_to_use} with dtype {temp_dtype_to_use}...")
|
@@ -180,6 +189,7 @@ def infer(
|
|
180 |
print(f"Warning: Failed to move previous pipeline to CPU: {move_e}")
|
181 |
del current_pipeline
|
182 |
current_pipeline = None # Set to None immediately
|
|
|
183 |
if str(current_device_loaded) == "cuda":
|
184 |
try:
|
185 |
torch.cuda.empty_cache()
|
@@ -190,7 +200,7 @@ def infer(
|
|
190 |
# Ensure the device is actually available if not CPU (redundant with earlier check but safe)
|
191 |
if temp_device_to_use == "cuda":
|
192 |
if not torch.cuda.is_available():
|
193 |
-
raise gr.Error("
|
194 |
|
195 |
try:
|
196 |
pipeline = StableDiffusionPipeline.from_pretrained(
|
@@ -198,6 +208,24 @@ def infer(
|
|
198 |
torch_dtype=temp_dtype_to_use, # Use the determined dtype for loading
|
199 |
safety_checker=None, # DISABLED
|
200 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
201 |
pipeline = pipeline.to(temp_device_to_use) # Use the determined device
|
202 |
|
203 |
current_pipeline = pipeline
|
@@ -244,6 +272,8 @@ def infer(
|
|
244 |
|
245 |
# Re-determine device_to_use and dtype_to_use *after* ensuring pipeline is loaded
|
246 |
# They should match current_device_loaded and the pipeline's dtype
|
|
|
|
|
247 |
device_to_use = str(current_pipeline.device) if current_pipeline else ("cuda" if selected_device_str == "GPU" and "GPU" in AVAILABLE_DEVICES else "cpu")
|
248 |
dtype_to_use = current_pipeline.dtype if current_pipeline else torch.float32 # Fallback if somehow pipeline is still None
|
249 |
|
@@ -253,6 +283,30 @@ def infer(
|
|
253 |
raise gr.Error("Model failed to load during setup or switching. Cannot generate image.")
|
254 |
|
255 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
256 |
# 2. Configure Scheduler
|
257 |
selected_scheduler_class = SCHEDULER_MAP.get(scheduler_name)
|
258 |
if selected_scheduler_class is None:
|
@@ -348,7 +402,7 @@ def infer(
|
|
348 |
if width <= 0 or height <= 0:
|
349 |
raise ValueError("Image width and height must be positive.")
|
350 |
|
351 |
-
print(f"Generating: Prompt='{prompt[:80]}{'...' if len(prompt) > 80 else ''}', NegPrompt='{negative_prompt[:80]}{'...' if len(negative_prompt) > 80 else ''}', Steps={num_inference_steps_int}, CFG={guidance_scale_float}, Size={width}x{height}, Scheduler={scheduler_name}, Seed={seed_int if generator else 'System Random'}, Device={device_to_use}, Dtype={dtype_to_use}")
|
352 |
start_time = time.time()
|
353 |
|
354 |
try:
|
@@ -367,8 +421,6 @@ def infer(
|
|
367 |
|
368 |
# Add VAE usage here if needed for specific models that require it
|
369 |
# vae=...
|
370 |
-
# Potentially add attention slicing/xformers/etc. for memory efficiency
|
371 |
-
# enable_attention_slicing="auto", # Can help with VRAM on smaller GPUs
|
372 |
# enable_xformers_memory_efficient_attention() # Needs xformers installed & compatible GPU
|
373 |
)
|
374 |
end_time = time.time()
|
@@ -488,6 +540,17 @@ with gr.Blocks(css=css, theme=gr.themes.Soft()) as demo: # Added Soft theme from
|
|
488 |
seed_input = gr.Slider(label="Seed", minimum=0, maximum=MAX_SEED, step=1, value=0, interactive=True) # Use 0 as default, interactive initially
|
489 |
randomize_seed_checkbox = gr.Checkbox(label="Randomize seed", value=True) # Simplified label
|
490 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
491 |
|
492 |
generate_button = gr.Button("✨ Generate Image ✨", variant="primary", scale=1) # Added emojis
|
493 |
|
@@ -520,7 +583,8 @@ with gr.Blocks(css=css, theme=gr.themes.Soft()) as demo: # Added Soft theme from
|
|
520 |
scheduler_dropdown,
|
521 |
size_dropdown,
|
522 |
seed_input,
|
523 |
-
randomize_seed_checkbox,
|
|
|
524 |
],
|
525 |
outputs=[output_image, actual_seed_output], # Return image and the actual seed used
|
526 |
api_name="generate" # Optional: For API access
|
|
|
4 |
"""
|
5 |
|
6 |
import gradio as gr
|
7 |
+
import numpy as np
|
8 |
import random
|
9 |
import torch
|
10 |
from diffusers import StableDiffusionPipeline
|
|
|
95 |
print(f"\nLoading initial model '{INITIAL_MODEL_ID}' on startup...")
|
96 |
try:
|
97 |
# Load the pipeline onto the initial device and dtype
|
98 |
+
pipeline = StableDiffusionPipeline.from_pretrained(
|
99 |
INITIAL_MODEL_ID,
|
100 |
torch_dtype=initial_dtype_to_use,
|
101 |
safety_checker=None, # <<< SAFETY CHECKER DISABLED <<<
|
102 |
)
|
103 |
+
|
104 |
+
# --- Apply Optimizations during initial load ---
|
105 |
+
# Apply attention slicing by default for memory efficiency on Spaces
|
106 |
+
# Can be turned off via UI toggle later, but good default for VRAM
|
107 |
+
# We'll add the UI toggle later, for now, just enable it here
|
108 |
+
# pipeline.enable_attention_slicing() # Enable by default on initial load
|
109 |
+
|
110 |
+
pipeline = pipeline.to(initial_device_to_use) # Move to the initial device
|
111 |
+
|
112 |
+
current_pipeline = pipeline
|
113 |
current_model_id = INITIAL_MODEL_ID
|
114 |
current_device_loaded = torch.device(initial_device_to_use)
|
115 |
print(f"Initial model loaded successfully on {current_device_loaded}.")
|
|
|
155 |
size, # From size_dropdown
|
156 |
seed, # From seed_input (now a Slider)
|
157 |
randomize_seed, # From randomize_seed_checkbox
|
158 |
+
enable_attention_slicing, # <-- New input for the optimization toggle
|
159 |
progress=gr.Progress(track_tqdm=True), # Added progress argument from template
|
160 |
):
|
161 |
"""Generates an image using the selected model and parameters on the chosen device."""
|
162 |
+
global current_pipeline, current_model_id, current_device_loaded, SCHEDULER_MAP, MAX_SEED
|
163 |
|
164 |
# This check is done before parameter parsing so we can determine device/dtype for loading
|
165 |
# Need to redo some parameter parsing here to get device_to_use early
|
|
|
175 |
|
176 |
# 1. Load/Switch Model if necessary
|
177 |
# Check if the requested model identifier OR the requested device has changed
|
|
|
178 |
if current_pipeline is None or current_model_id != model_identifier or (current_device_loaded is not None and str(current_device_loaded) != temp_device_to_use):
|
179 |
|
180 |
print(f"Loading model: {model_identifier} onto {temp_device_to_use} with dtype {temp_dtype_to_use}...")
|
|
|
189 |
print(f"Warning: Failed to move previous pipeline to CPU: {move_e}")
|
190 |
del current_pipeline
|
191 |
current_pipeline = None # Set to None immediately
|
192 |
+
# Attempt to clear CUDA cache if using GPU (from the previous device)
|
193 |
if str(current_device_loaded) == "cuda":
|
194 |
try:
|
195 |
torch.cuda.empty_cache()
|
|
|
200 |
# Ensure the device is actually available if not CPU (redundant with earlier check but safe)
|
201 |
if temp_device_to_use == "cuda":
|
202 |
if not torch.cuda.is_available():
|
203 |
+
raise gr.Error("GPU selected but CUDA is not available to PyTorch on this Space. Please select CPU or ensure the Space is configured with a GPU and the CUDA version of PyTorch is installed.")
|
204 |
|
205 |
try:
|
206 |
pipeline = StableDiffusionPipeline.from_pretrained(
|
|
|
208 |
torch_dtype=temp_dtype_to_use, # Use the determined dtype for loading
|
209 |
safety_checker=None, # DISABLED
|
210 |
)
|
211 |
+
|
212 |
+
# Apply optimizations based on UI input during load
|
213 |
+
if enable_attention_slicing and temp_device_to_use == "cuda": # Only apply on GPU
|
214 |
+
try:
|
215 |
+
pipeline.enable_attention_slicing()
|
216 |
+
print("Attention Slicing enabled.")
|
217 |
+
except Exception as e:
|
218 |
+
print(f"Warning: Failed to enable Attention Slicing: {e}")
|
219 |
+
gr.Warning(f"Failed to enable Attention Slicing. Error: {e}")
|
220 |
+
else:
|
221 |
+
try:
|
222 |
+
pipeline.disable_attention_slicing() # Ensure it's off if toggle is off or on CPU
|
223 |
+
print("Attention Slicing disabled.")
|
224 |
+
except Exception as e:
|
225 |
+
# May fail if it wasn't enabled, ignore
|
226 |
+
pass
|
227 |
+
|
228 |
+
|
229 |
pipeline = pipeline.to(temp_device_to_use) # Use the determined device
|
230 |
|
231 |
current_pipeline = pipeline
|
|
|
272 |
|
273 |
# Re-determine device_to_use and dtype_to_use *after* ensuring pipeline is loaded
|
274 |
# They should match current_device_loaded and the pipeline's dtype
|
275 |
+
# This is crucial because current_pipeline.device and dtype are the definitive source
|
276 |
+
# after a potentially successful load or switch.
|
277 |
device_to_use = str(current_pipeline.device) if current_pipeline else ("cuda" if selected_device_str == "GPU" and "GPU" in AVAILABLE_DEVICES else "cpu")
|
278 |
dtype_to_use = current_pipeline.dtype if current_pipeline else torch.float32 # Fallback if somehow pipeline is still None
|
279 |
|
|
|
283 |
raise gr.Error("Model failed to load during setup or switching. Cannot generate image.")
|
284 |
|
285 |
|
286 |
+
# --- Apply Optimizations *before* generation if model was already loaded ---
|
287 |
+
# If the model didn't need reloading, we need to apply/remove slicing here
|
288 |
+
if str(current_pipeline.device) == "cuda": # Only attempt on GPU
|
289 |
+
if enable_attention_slicing:
|
290 |
+
try:
|
291 |
+
current_pipeline.enable_attention_slicing()
|
292 |
+
# print("Attention Slicing enabled for generation.") # Too verbose
|
293 |
+
except Exception as e:
|
294 |
+
print(f"Warning: Failed to enable Attention Slicing before generation: {e}")
|
295 |
+
gr.Warning(f"Failed to enable Attention Slicing. Error: {e}")
|
296 |
+
else:
|
297 |
+
try:
|
298 |
+
current_pipeline.disable_attention_slicing()
|
299 |
+
# print("Attention Slicing disabled for generation.") # Too verbose
|
300 |
+
except Exception as e:
|
301 |
+
# May fail if it wasn't enabled, ignore
|
302 |
+
pass
|
303 |
+
else: # Ensure slicing is off on CPU
|
304 |
+
try:
|
305 |
+
current_pipeline.disable_attention_slicing()
|
306 |
+
except Exception as e:
|
307 |
+
pass # Ignore
|
308 |
+
|
309 |
+
|
310 |
# 2. Configure Scheduler
|
311 |
selected_scheduler_class = SCHEDULER_MAP.get(scheduler_name)
|
312 |
if selected_scheduler_class is None:
|
|
|
402 |
if width <= 0 or height <= 0:
|
403 |
raise ValueError("Image width and height must be positive.")
|
404 |
|
405 |
+
print(f"Generating: Prompt='{prompt[:80]}{'...' if len(prompt) > 80 else ''}', NegPrompt='{negative_prompt[:80]}{'...' if len(negative_prompt) > 80 else ''}', Steps={num_inference_steps_int}, CFG={guidance_scale_float}, Size={width}x{height}, Scheduler={scheduler_name}, Seed={seed_int if generator else 'System Random'}, Device={device_to_use}, Dtype={dtype_to_use}, Slicing Enabled={enable_attention_slicing and device_to_use == 'cuda'}")
|
406 |
start_time = time.time()
|
407 |
|
408 |
try:
|
|
|
421 |
|
422 |
# Add VAE usage here if needed for specific models that require it
|
423 |
# vae=...
|
|
|
|
|
424 |
# enable_xformers_memory_efficient_attention() # Needs xformers installed & compatible GPU
|
425 |
)
|
426 |
end_time = time.time()
|
|
|
540 |
seed_input = gr.Slider(label="Seed", minimum=0, maximum=MAX_SEED, step=1, value=0, interactive=True) # Use 0 as default, interactive initially
|
541 |
randomize_seed_checkbox = gr.Checkbox(label="Randomize seed", value=True) # Simplified label
|
542 |
|
543 |
+
# --- New: Memory Optimization Toggle ---
|
544 |
+
with gr.Row():
|
545 |
+
# Default to enabled if GPU is available, otherwise off
|
546 |
+
default_slicing = True if "GPU" in AVAILABLE_DEVICES else False
|
547 |
+
enable_attention_slicing_checkbox = gr.Checkbox(
|
548 |
+
label="Enable Attention Slicing (Memory Optimization - GPU only)",
|
549 |
+
value=default_slicing,
|
550 |
+
interactive="GPU" in AVAILABLE_DEVICES # Only interactive if GPU is an option
|
551 |
+
)
|
552 |
+
gr.Markdown("*(Helps reduce VRAM usage, may slightly affect speed/quality)*")
|
553 |
+
|
554 |
|
555 |
generate_button = gr.Button("✨ Generate Image ✨", variant="primary", scale=1) # Added emojis
|
556 |
|
|
|
583 |
scheduler_dropdown,
|
584 |
size_dropdown,
|
585 |
seed_input,
|
586 |
+
randomize_seed_checkbox,
|
587 |
+
enable_attention_slicing_checkbox, # <-- Pass the new checkbox value
|
588 |
],
|
589 |
outputs=[output_image, actual_seed_output], # Return image and the actual seed used
|
590 |
api_name="generate" # Optional: For API access
|