Raxephion commited on
Commit
ce1167c
·
verified ·
1 Parent(s): 049d1c0

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +74 -10
app.py CHANGED
@@ -4,7 +4,7 @@ Author: @Raxephion 2025
4
  """
5
 
6
  import gradio as gr
7
- import numpy as np # <-- Needed for np.iinfo
8
  import random
9
  import torch
10
  from diffusers import StableDiffusionPipeline
@@ -95,12 +95,21 @@ if INITIAL_MODEL_ID:
95
  print(f"\nLoading initial model '{INITIAL_MODEL_ID}' on startup...")
96
  try:
97
  # Load the pipeline onto the initial device and dtype
98
- current_pipeline = StableDiffusionPipeline.from_pretrained(
99
  INITIAL_MODEL_ID,
100
  torch_dtype=initial_dtype_to_use,
101
  safety_checker=None, # <<< SAFETY CHECKER DISABLED <<<
102
  )
103
- current_pipeline = current_pipeline.to(initial_device_to_use)
 
 
 
 
 
 
 
 
 
104
  current_model_id = INITIAL_MODEL_ID
105
  current_device_loaded = torch.device(initial_device_to_use)
106
  print(f"Initial model loaded successfully on {current_device_loaded}.")
@@ -146,10 +155,11 @@ def infer(
146
  size, # From size_dropdown
147
  seed, # From seed_input (now a Slider)
148
  randomize_seed, # From randomize_seed_checkbox
 
149
  progress=gr.Progress(track_tqdm=True), # Added progress argument from template
150
  ):
151
  """Generates an image using the selected model and parameters on the chosen device."""
152
- global current_pipeline, current_model_id, current_device_loaded, SCHEDULER_MAP, MAX_SEED # MAX_SEED is global
153
 
154
  # This check is done before parameter parsing so we can determine device/dtype for loading
155
  # Need to redo some parameter parsing here to get device_to_use early
@@ -165,7 +175,6 @@ def infer(
165
 
166
  # 1. Load/Switch Model if necessary
167
  # Check if the requested model identifier OR the requested device has changed
168
- # Use string comparison for current_device_loaded as it's a torch.device object
169
  if current_pipeline is None or current_model_id != model_identifier or (current_device_loaded is not None and str(current_device_loaded) != temp_device_to_use):
170
 
171
  print(f"Loading model: {model_identifier} onto {temp_device_to_use} with dtype {temp_dtype_to_use}...")
@@ -180,6 +189,7 @@ def infer(
180
  print(f"Warning: Failed to move previous pipeline to CPU: {move_e}")
181
  del current_pipeline
182
  current_pipeline = None # Set to None immediately
 
183
  if str(current_device_loaded) == "cuda":
184
  try:
185
  torch.cuda.empty_cache()
@@ -190,7 +200,7 @@ def infer(
190
  # Ensure the device is actually available if not CPU (redundant with earlier check but safe)
191
  if temp_device_to_use == "cuda":
192
  if not torch.cuda.is_available():
193
- raise gr.Error("CUDA selected but not available to PyTorch on this Space. Please select CPU or ensure the Space is configured with a GPU and the CUDA version of PyTorch is installed.")
194
 
195
  try:
196
  pipeline = StableDiffusionPipeline.from_pretrained(
@@ -198,6 +208,24 @@ def infer(
198
  torch_dtype=temp_dtype_to_use, # Use the determined dtype for loading
199
  safety_checker=None, # DISABLED
200
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
201
  pipeline = pipeline.to(temp_device_to_use) # Use the determined device
202
 
203
  current_pipeline = pipeline
@@ -244,6 +272,8 @@ def infer(
244
 
245
  # Re-determine device_to_use and dtype_to_use *after* ensuring pipeline is loaded
246
  # They should match current_device_loaded and the pipeline's dtype
 
 
247
  device_to_use = str(current_pipeline.device) if current_pipeline else ("cuda" if selected_device_str == "GPU" and "GPU" in AVAILABLE_DEVICES else "cpu")
248
  dtype_to_use = current_pipeline.dtype if current_pipeline else torch.float32 # Fallback if somehow pipeline is still None
249
 
@@ -253,6 +283,30 @@ def infer(
253
  raise gr.Error("Model failed to load during setup or switching. Cannot generate image.")
254
 
255
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
256
  # 2. Configure Scheduler
257
  selected_scheduler_class = SCHEDULER_MAP.get(scheduler_name)
258
  if selected_scheduler_class is None:
@@ -348,7 +402,7 @@ def infer(
348
  if width <= 0 or height <= 0:
349
  raise ValueError("Image width and height must be positive.")
350
 
351
- print(f"Generating: Prompt='{prompt[:80]}{'...' if len(prompt) > 80 else ''}', NegPrompt='{negative_prompt[:80]}{'...' if len(negative_prompt) > 80 else ''}', Steps={num_inference_steps_int}, CFG={guidance_scale_float}, Size={width}x{height}, Scheduler={scheduler_name}, Seed={seed_int if generator else 'System Random'}, Device={device_to_use}, Dtype={dtype_to_use}")
352
  start_time = time.time()
353
 
354
  try:
@@ -367,8 +421,6 @@ def infer(
367
 
368
  # Add VAE usage here if needed for specific models that require it
369
  # vae=...
370
- # Potentially add attention slicing/xformers/etc. for memory efficiency
371
- # enable_attention_slicing="auto", # Can help with VRAM on smaller GPUs
372
  # enable_xformers_memory_efficient_attention() # Needs xformers installed & compatible GPU
373
  )
374
  end_time = time.time()
@@ -488,6 +540,17 @@ with gr.Blocks(css=css, theme=gr.themes.Soft()) as demo: # Added Soft theme from
488
  seed_input = gr.Slider(label="Seed", minimum=0, maximum=MAX_SEED, step=1, value=0, interactive=True) # Use 0 as default, interactive initially
489
  randomize_seed_checkbox = gr.Checkbox(label="Randomize seed", value=True) # Simplified label
490
 
 
 
 
 
 
 
 
 
 
 
 
491
 
492
  generate_button = gr.Button("✨ Generate Image ✨", variant="primary", scale=1) # Added emojis
493
 
@@ -520,7 +583,8 @@ with gr.Blocks(css=css, theme=gr.themes.Soft()) as demo: # Added Soft theme from
520
  scheduler_dropdown,
521
  size_dropdown,
522
  seed_input,
523
- randomize_seed_checkbox, # Pass the checkbox value
 
524
  ],
525
  outputs=[output_image, actual_seed_output], # Return image and the actual seed used
526
  api_name="generate" # Optional: For API access
 
4
  """
5
 
6
  import gradio as gr
7
+ import numpy as np
8
  import random
9
  import torch
10
  from diffusers import StableDiffusionPipeline
 
95
  print(f"\nLoading initial model '{INITIAL_MODEL_ID}' on startup...")
96
  try:
97
  # Load the pipeline onto the initial device and dtype
98
+ pipeline = StableDiffusionPipeline.from_pretrained(
99
  INITIAL_MODEL_ID,
100
  torch_dtype=initial_dtype_to_use,
101
  safety_checker=None, # <<< SAFETY CHECKER DISABLED <<<
102
  )
103
+
104
+ # --- Apply Optimizations during initial load ---
105
+ # Apply attention slicing by default for memory efficiency on Spaces
106
+ # Can be turned off via UI toggle later, but good default for VRAM
107
+ # We'll add the UI toggle later, for now, just enable it here
108
+ # pipeline.enable_attention_slicing() # Enable by default on initial load
109
+
110
+ pipeline = pipeline.to(initial_device_to_use) # Move to the initial device
111
+
112
+ current_pipeline = pipeline
113
  current_model_id = INITIAL_MODEL_ID
114
  current_device_loaded = torch.device(initial_device_to_use)
115
  print(f"Initial model loaded successfully on {current_device_loaded}.")
 
155
  size, # From size_dropdown
156
  seed, # From seed_input (now a Slider)
157
  randomize_seed, # From randomize_seed_checkbox
158
+ enable_attention_slicing, # <-- New input for the optimization toggle
159
  progress=gr.Progress(track_tqdm=True), # Added progress argument from template
160
  ):
161
  """Generates an image using the selected model and parameters on the chosen device."""
162
+ global current_pipeline, current_model_id, current_device_loaded, SCHEDULER_MAP, MAX_SEED
163
 
164
  # This check is done before parameter parsing so we can determine device/dtype for loading
165
  # Need to redo some parameter parsing here to get device_to_use early
 
175
 
176
  # 1. Load/Switch Model if necessary
177
  # Check if the requested model identifier OR the requested device has changed
 
178
  if current_pipeline is None or current_model_id != model_identifier or (current_device_loaded is not None and str(current_device_loaded) != temp_device_to_use):
179
 
180
  print(f"Loading model: {model_identifier} onto {temp_device_to_use} with dtype {temp_dtype_to_use}...")
 
189
  print(f"Warning: Failed to move previous pipeline to CPU: {move_e}")
190
  del current_pipeline
191
  current_pipeline = None # Set to None immediately
192
+ # Attempt to clear CUDA cache if using GPU (from the previous device)
193
  if str(current_device_loaded) == "cuda":
194
  try:
195
  torch.cuda.empty_cache()
 
200
  # Ensure the device is actually available if not CPU (redundant with earlier check but safe)
201
  if temp_device_to_use == "cuda":
202
  if not torch.cuda.is_available():
203
+ raise gr.Error("GPU selected but CUDA is not available to PyTorch on this Space. Please select CPU or ensure the Space is configured with a GPU and the CUDA version of PyTorch is installed.")
204
 
205
  try:
206
  pipeline = StableDiffusionPipeline.from_pretrained(
 
208
  torch_dtype=temp_dtype_to_use, # Use the determined dtype for loading
209
  safety_checker=None, # DISABLED
210
  )
211
+
212
+ # Apply optimizations based on UI input during load
213
+ if enable_attention_slicing and temp_device_to_use == "cuda": # Only apply on GPU
214
+ try:
215
+ pipeline.enable_attention_slicing()
216
+ print("Attention Slicing enabled.")
217
+ except Exception as e:
218
+ print(f"Warning: Failed to enable Attention Slicing: {e}")
219
+ gr.Warning(f"Failed to enable Attention Slicing. Error: {e}")
220
+ else:
221
+ try:
222
+ pipeline.disable_attention_slicing() # Ensure it's off if toggle is off or on CPU
223
+ print("Attention Slicing disabled.")
224
+ except Exception as e:
225
+ # May fail if it wasn't enabled, ignore
226
+ pass
227
+
228
+
229
  pipeline = pipeline.to(temp_device_to_use) # Use the determined device
230
 
231
  current_pipeline = pipeline
 
272
 
273
  # Re-determine device_to_use and dtype_to_use *after* ensuring pipeline is loaded
274
  # They should match current_device_loaded and the pipeline's dtype
275
+ # This is crucial because current_pipeline.device and dtype are the definitive source
276
+ # after a potentially successful load or switch.
277
  device_to_use = str(current_pipeline.device) if current_pipeline else ("cuda" if selected_device_str == "GPU" and "GPU" in AVAILABLE_DEVICES else "cpu")
278
  dtype_to_use = current_pipeline.dtype if current_pipeline else torch.float32 # Fallback if somehow pipeline is still None
279
 
 
283
  raise gr.Error("Model failed to load during setup or switching. Cannot generate image.")
284
 
285
 
286
+ # --- Apply Optimizations *before* generation if model was already loaded ---
287
+ # If the model didn't need reloading, we need to apply/remove slicing here
288
+ if str(current_pipeline.device) == "cuda": # Only attempt on GPU
289
+ if enable_attention_slicing:
290
+ try:
291
+ current_pipeline.enable_attention_slicing()
292
+ # print("Attention Slicing enabled for generation.") # Too verbose
293
+ except Exception as e:
294
+ print(f"Warning: Failed to enable Attention Slicing before generation: {e}")
295
+ gr.Warning(f"Failed to enable Attention Slicing. Error: {e}")
296
+ else:
297
+ try:
298
+ current_pipeline.disable_attention_slicing()
299
+ # print("Attention Slicing disabled for generation.") # Too verbose
300
+ except Exception as e:
301
+ # May fail if it wasn't enabled, ignore
302
+ pass
303
+ else: # Ensure slicing is off on CPU
304
+ try:
305
+ current_pipeline.disable_attention_slicing()
306
+ except Exception as e:
307
+ pass # Ignore
308
+
309
+
310
  # 2. Configure Scheduler
311
  selected_scheduler_class = SCHEDULER_MAP.get(scheduler_name)
312
  if selected_scheduler_class is None:
 
402
  if width <= 0 or height <= 0:
403
  raise ValueError("Image width and height must be positive.")
404
 
405
+ print(f"Generating: Prompt='{prompt[:80]}{'...' if len(prompt) > 80 else ''}', NegPrompt='{negative_prompt[:80]}{'...' if len(negative_prompt) > 80 else ''}', Steps={num_inference_steps_int}, CFG={guidance_scale_float}, Size={width}x{height}, Scheduler={scheduler_name}, Seed={seed_int if generator else 'System Random'}, Device={device_to_use}, Dtype={dtype_to_use}, Slicing Enabled={enable_attention_slicing and device_to_use == 'cuda'}")
406
  start_time = time.time()
407
 
408
  try:
 
421
 
422
  # Add VAE usage here if needed for specific models that require it
423
  # vae=...
 
 
424
  # enable_xformers_memory_efficient_attention() # Needs xformers installed & compatible GPU
425
  )
426
  end_time = time.time()
 
540
  seed_input = gr.Slider(label="Seed", minimum=0, maximum=MAX_SEED, step=1, value=0, interactive=True) # Use 0 as default, interactive initially
541
  randomize_seed_checkbox = gr.Checkbox(label="Randomize seed", value=True) # Simplified label
542
 
543
+ # --- New: Memory Optimization Toggle ---
544
+ with gr.Row():
545
+ # Default to enabled if GPU is available, otherwise off
546
+ default_slicing = True if "GPU" in AVAILABLE_DEVICES else False
547
+ enable_attention_slicing_checkbox = gr.Checkbox(
548
+ label="Enable Attention Slicing (Memory Optimization - GPU only)",
549
+ value=default_slicing,
550
+ interactive="GPU" in AVAILABLE_DEVICES # Only interactive if GPU is an option
551
+ )
552
+ gr.Markdown("*(Helps reduce VRAM usage, may slightly affect speed/quality)*")
553
+
554
 
555
  generate_button = gr.Button("✨ Generate Image ✨", variant="primary", scale=1) # Added emojis
556
 
 
583
  scheduler_dropdown,
584
  size_dropdown,
585
  seed_input,
586
+ randomize_seed_checkbox,
587
+ enable_attention_slicing_checkbox, # <-- Pass the new checkbox value
588
  ],
589
  outputs=[output_image, actual_seed_output], # Return image and the actual seed used
590
  api_name="generate" # Optional: For API access