ovi054 commited on
Commit
b0aa5c4
·
verified ·
1 Parent(s): 3e81ff5

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +5 -10
app.py CHANGED
@@ -44,7 +44,6 @@ except Exception as e:
44
 
45
  print("Initialization complete. Gradio is starting...")
46
 
47
-
48
  @spaces.GPU()
49
  def generate(prompt, negative_prompt, width=1024, height=1024, num_inference_steps=30, lora_id=None, progress=gr.Progress(track_tqdm=True)):
50
 
@@ -56,7 +55,6 @@ def generate(prompt, negative_prompt, width=1024, height=1024, num_inference_ste
56
  if causvid_path:
57
  try:
58
  print(f"Loading base LoRA '{BASE_LORA_NAME}'...")
59
- # THE CORRECT FIX: Use device_map to load the LoRA directly to the GPU.
60
  pipe.load_lora_weights(causvid_path, adapter_name=BASE_LORA_NAME, device_map={"":device})
61
  active_adapters.append(BASE_LORA_NAME)
62
  adapter_weights.append(1.0)
@@ -69,7 +67,6 @@ def generate(prompt, negative_prompt, width=1024, height=1024, num_inference_ste
69
  if clean_lora_id:
70
  try:
71
  print(f"Loading custom LoRA '{CUSTOM_LORA_NAME}' from '{clean_lora_id}'...")
72
- # THE CORRECT FIX: Also use device_map for the custom LoRA.
73
  pipe.load_lora_weights(clean_lora_id, adapter_name=CUSTOM_LORA_NAME, device_map={"":device})
74
  active_adapters.append(CUSTOM_LORA_NAME)
75
  adapter_weights.append(1.0)
@@ -83,6 +80,7 @@ def generate(prompt, negative_prompt, width=1024, height=1024, num_inference_ste
83
  if active_adapters:
84
  print(f"Activating adapters: {active_adapters} with weights: {adapter_weights}")
85
  pipe.set_adapters(active_adapters, adapter_weights)
 
86
  else:
87
  # Ensure LoRA is disabled if no adapters were loaded
88
  pipe.disable_lora()
@@ -104,24 +102,21 @@ def generate(prompt, negative_prompt, width=1024, height=1024, num_inference_ste
104
  return Image.fromarray(image)
105
  finally:
106
  # --- PROPER CLEANUP ---
107
- # The most reliable way to clean up in this complex environment is to unload ALL LoRAs.
108
- # This avoids leaving dangling configs.
109
  print("Unloading all LoRAs to ensure a clean state...")
110
  pipe.unload_lora_weights()
111
- gc.collect() # Force garbage collection
112
- torch.cuda.empty_cache() # Clear CUDA cache
113
  print("✅ LoRAs unloaded and memory cleaned.")
114
 
115
-
116
  iface = gr.Interface(
117
  fn=generate,
118
  inputs=[
119
  gr.Textbox(label="Input prompt"),
120
- gr.Textbox(label="Negative prompt", value = "Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality, low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured, misshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards"),
121
  gr.Slider(label="Width", minimum=480, maximum=1280, step=16, value=1024),
122
  gr.Slider(label="Height", minimum=480, maximum=1280, step=16, value=1024),
123
  gr.Slider(minimum=1, maximum=80, step=1, label="Inference Steps", value=10),
124
- gr.Textbox(label="LoRA ID (Optional)"),
125
  ],
126
  outputs=gr.Image(label="output"),
127
  )
 
44
 
45
  print("Initialization complete. Gradio is starting...")
46
 
 
47
  @spaces.GPU()
48
  def generate(prompt, negative_prompt, width=1024, height=1024, num_inference_steps=30, lora_id=None, progress=gr.Progress(track_tqdm=True)):
49
 
 
55
  if causvid_path:
56
  try:
57
  print(f"Loading base LoRA '{BASE_LORA_NAME}'...")
 
58
  pipe.load_lora_weights(causvid_path, adapter_name=BASE_LORA_NAME, device_map={"":device})
59
  active_adapters.append(BASE_LORA_NAME)
60
  adapter_weights.append(1.0)
 
67
  if clean_lora_id:
68
  try:
69
  print(f"Loading custom LoRA '{CUSTOM_LORA_NAME}' from '{clean_lora_id}'...")
 
70
  pipe.load_lora_weights(clean_lora_id, adapter_name=CUSTOM_LORA_NAME, device_map={"":device})
71
  active_adapters.append(CUSTOM_LORA_NAME)
72
  adapter_weights.append(1.0)
 
80
  if active_adapters:
81
  print(f"Activating adapters: {active_adapters} with weights: {adapter_weights}")
82
  pipe.set_adapters(active_adapters, adapter_weights)
83
+ pipe.transformer.to(device) # Explicitly move transformer to GPU after setting adapters
84
  else:
85
  # Ensure LoRA is disabled if no adapters were loaded
86
  pipe.disable_lora()
 
102
  return Image.fromarray(image)
103
  finally:
104
  # --- PROPER CLEANUP ---
 
 
105
  print("Unloading all LoRAs to ensure a clean state...")
106
  pipe.unload_lora_weights()
107
+ gc.collect() # Force garbage collection
108
+ torch.cuda.empty_cache() # Clear CUDA cache
109
  print("✅ LoRAs unloaded and memory cleaned.")
110
 
 
111
  iface = gr.Interface(
112
  fn=generate,
113
  inputs=[
114
  gr.Textbox(label="Input prompt"),
115
+ gr.Textbox(label="Negative prompt", value="Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality, low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured, misshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards"),
116
  gr.Slider(label="Width", minimum=480, maximum=1280, step=16, value=1024),
117
  gr.Slider(label="Height", minimum=480, maximum=1280, step=16, value=1024),
118
  gr.Slider(minimum=1, maximum=80, step=1, label="Inference Steps", value=10),
119
+ gr.Textbox(label="LoRA ID (Optional, loads dynamically)"),
120
  ],
121
  outputs=gr.Image(label="output"),
122
  )