ginipick commited on
Commit
f48fc96
·
verified ·
1 Parent(s): 1e8ff8a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +49 -22
app.py CHANGED
@@ -19,21 +19,25 @@ MAX_SEED = np.iinfo(np.int32).max
19
  pipe = FluxKontextPipeline.from_pretrained("black-forest-labs/FLUX.1-Kontext-dev", torch_dtype=torch.bfloat16).to("cuda")
20
 
21
  # Load LoRA data (you'll need to create this JSON file or modify to load your LoRAs)
22
-
23
- with open("flux_loras.json", "r") as file:
24
- data = json.load(file)
25
- flux_loras_raw = [
26
- {
27
- "image": item["image"],
28
- "title": item["title"],
29
- "repo": item["repo"],
30
- "trigger_word": item.get("trigger_word", ""),
31
- "trigger_position": item.get("trigger_position", "prepend"),
32
- "weights": item.get("weights", "pytorch_lora_weights.safetensors"),
33
- }
34
- for item in data
35
- ]
36
- print(f"Loaded {len(flux_loras_raw)} LoRAs from JSON")
 
 
 
 
37
  # Global variables for LoRA management
38
  current_lora = None
39
  lora_cache = {}
@@ -123,8 +127,16 @@ def remove_custom_lora():
123
 
124
  def classify_gallery(flux_loras):
125
  """Sort gallery by likes"""
126
- sorted_gallery = sorted(flux_loras, key=lambda x: x.get("likes", 0), reverse=True)
127
- return [(item["image"], item["title"]) for item in sorted_gallery], sorted_gallery
 
 
 
 
 
 
 
 
128
 
129
  def infer_with_lora_wrapper(input_image, prompt, selected_index, custom_lora, seed=42, randomize_seed=False, guidance_scale=2.5, lora_scale=1.75, flux_loras=None, progress=gr.Progress(track_tqdm=True)):
130
  """Wrapper function to handle state serialization"""
@@ -149,29 +161,44 @@ def infer_with_lora(input_image, prompt, selected_index, custom_lora, seed=42, r
149
  lora_to_use = custom_lora
150
  elif selected_index is not None and flux_loras and selected_index < len(flux_loras):
151
  lora_to_use = flux_loras[selected_index]
152
- print(f"Loaded {len(flux_loras)} LoRAs from JSON")
153
  # Load LoRA if needed
154
  if lora_to_use and lora_to_use != current_lora:
155
  try:
156
  # Unload current LoRA
157
  if current_lora:
158
  pipe.unload_lora_weights()
 
159
 
160
  # Load new LoRA
161
- lora_path = load_lora_weights(lora_to_use["repo"], lora_to_use["weights"])
 
 
 
 
162
  if lora_path:
163
  pipe.load_lora_weights(lora_path, adapter_name="selected_lora")
164
  pipe.set_adapters(["selected_lora"], adapter_weights=[lora_scale])
165
- print(f"loaded: {lora_path} with scale {lora_scale}")
166
  current_lora = lora_to_use
 
 
 
 
167
 
168
  except Exception as e:
169
  print(f"Error loading LoRA: {e}")
170
  # Continue without LoRA
171
  else:
172
- print(f"using already loaded lora: {lora_to_use}")
 
173
 
174
- input_image = input_image.convert("RGB")
 
 
 
 
 
 
175
 
176
  # Check if LoRA is selected
177
  if lora_to_use is None:
 
19
  pipe = FluxKontextPipeline.from_pretrained("black-forest-labs/FLUX.1-Kontext-dev", torch_dtype=torch.bfloat16).to("cuda")
20
 
21
  # Load LoRA data (you'll need to create this JSON file or modify to load your LoRAs)
22
+ try:
23
+ with open("flux_loras.json", "r") as file:
24
+ data = json.load(file)
25
+ flux_loras_raw = [
26
+ {
27
+ "image": item["image"],
28
+ "title": item["title"],
29
+ "repo": item["repo"],
30
+ "trigger_word": item.get("trigger_word", ""),
31
+ "trigger_position": item.get("trigger_position", "prepend"),
32
+ "weights": item.get("weights", "pytorch_lora_weights.safetensors"),
33
+ "likes": item.get("likes", 0),
34
+ }
35
+ for item in data
36
+ ]
37
+ print(f"Successfully loaded {len(flux_loras_raw)} LoRAs from JSON")
38
+ except Exception as e:
39
+ print(f"Error loading flux_loras.json: {e}")
40
+ flux_loras_raw = []
41
  # Global variables for LoRA management
42
  current_lora = None
43
  lora_cache = {}
 
127
 
128
  def classify_gallery(flux_loras):
129
  """Sort gallery by likes"""
130
+ try:
131
+ sorted_gallery = sorted(flux_loras, key=lambda x: x.get("likes", 0), reverse=True)
132
+ gallery_items = []
133
+ for item in sorted_gallery:
134
+ if "image" in item and "title" in item:
135
+ gallery_items.append((item["image"], item["title"]))
136
+ return gallery_items, sorted_gallery
137
+ except Exception as e:
138
+ print(f"Error loading gallery: {e}")
139
+ return [], []
140
 
141
  def infer_with_lora_wrapper(input_image, prompt, selected_index, custom_lora, seed=42, randomize_seed=False, guidance_scale=2.5, lora_scale=1.75, flux_loras=None, progress=gr.Progress(track_tqdm=True)):
142
  """Wrapper function to handle state serialization"""
 
161
  lora_to_use = custom_lora
162
  elif selected_index is not None and flux_loras and selected_index < len(flux_loras):
163
  lora_to_use = flux_loras[selected_index]
 
164
  # Load LoRA if needed
165
  if lora_to_use and lora_to_use != current_lora:
166
  try:
167
  # Unload current LoRA
168
  if current_lora:
169
  pipe.unload_lora_weights()
170
+ print(f"Unloaded previous LoRA")
171
 
172
  # Load new LoRA
173
+ repo_id = lora_to_use.get("repo", "unknown")
174
+ weights_file = lora_to_use.get("weights", "pytorch_lora_weights.safetensors")
175
+ print(f"Loading LoRA: {repo_id} with weights: {weights_file}")
176
+
177
+ lora_path = load_lora_weights(repo_id, weights_file)
178
  if lora_path:
179
  pipe.load_lora_weights(lora_path, adapter_name="selected_lora")
180
  pipe.set_adapters(["selected_lora"], adapter_weights=[lora_scale])
181
+ print(f"Successfully loaded: {lora_path} with scale {lora_scale}")
182
  current_lora = lora_to_use
183
+ else:
184
+ print(f"Failed to load LoRA from {repo_id}")
185
+ gr.Warning(f"Failed to load LoRA style. Please try a different one.")
186
+ return None, seed, gr.update(visible=False)
187
 
188
  except Exception as e:
189
  print(f"Error loading LoRA: {e}")
190
  # Continue without LoRA
191
  else:
192
+ if lora_to_use:
193
+ print(f"Using already loaded LoRA: {lora_to_use.get('repo', 'unknown')}")
194
 
195
+ try:
196
+ # Convert image to RGB
197
+ input_image = input_image.convert("RGB")
198
+ except Exception as e:
199
+ print(f"Error processing image: {e}")
200
+ gr.Warning("Error processing the uploaded image. Please try a different image.")
201
+ return None, seed, gr.update(visible=False)
202
 
203
  # Check if LoRA is selected
204
  if lora_to_use is None: