multimodalart HF Staff commited on
Commit
512c7c4
·
verified ·
1 Parent(s): ca40f58

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +74 -27
app.py CHANGED
@@ -385,35 +385,82 @@ run_lora.zerogpu = True
385
 
386
  def get_huggingface_safetensors(link):
387
  split_link = link.split("/")
388
- if len(split_link) == 2:
389
- model_card = ModelCard.load(link)
390
- base_model = model_card.data.get("base_model")
391
- print(f"Base model: {base_model}")
392
- acceptable_models = {"black-forest-labs/FLUX.1-dev", "black-forest-labs/FLUX.1-schnell"}
393
- models_to_check = base_model if isinstance(base_model, list) else [base_model]
394
- if not any(model in acceptable_models for model in models_to_check):
395
- raise Exception("Not a FLUX LoRA!")
396
- image_path = model_card.data.get("widget", [{}])[0].get("output", {}).get("url", None)
397
- trigger_word = model_card.data.get("instance_prompt", "")
398
- image_url = f"https://huggingface.co/{link}/resolve/main/{image_path}" if image_path else None
399
- fs = HfFileSystem()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
400
  safetensors_name = None
401
- try:
402
- list_of_files = fs.ls(link, detail=False)
403
- for file in list_of_files:
404
- if file.endswith(".safetensors"):
405
- safetensors_name = file.split("/")[-1]
406
- if not image_url and file.lower().endswith((".jpg", ".jpeg", ".png", ".webp")):
407
- image_elements = file.split("/")
408
- image_url = f"https://huggingface.co/{link}/resolve/main/{image_elements[-1]}"
409
- except Exception as e:
410
- print(e)
411
- raise gr.Error("Invalid Hugging Face repository with a *.safetensors LoRA")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
412
  if not safetensors_name:
413
- raise gr.Error("No *.safetensors file found in the repository")
414
- return split_link[1], link, safetensors_name, trigger_word, image_url
415
- else:
416
- raise gr.Error("Invalid Hugging Face repository link")
 
 
 
 
 
 
 
417
 
418
  def check_custom_model(link):
419
  if link.endswith(".safetensors"):
 
385
 
386
  def get_huggingface_safetensors(link):
387
  split_link = link.split("/")
388
+ if len(split_link) != 2:
389
+ raise Exception("Invalid Hugging Face repository link format.")
390
+
391
+ print(f"Repository attempted: {split_link}")
392
+
393
+ # Load model card
394
+ model_card = ModelCard.load(link)
395
+ base_model = model_card.data.get("base_model")
396
+ print(f"Base model: {base_model}")
397
+
398
+ # Validate model type
399
+ acceptable_models = {"black-forest-labs/FLUX.1-dev", "black-forest-labs/FLUX.1-schnell"}
400
+
401
+ models_to_check = base_model if isinstance(base_model, list) else [base_model]
402
+
403
+ if not any(model in acceptable_models for model in models_to_check):
404
+ raise Exception("Not a FLUX LoRA!")
405
+
406
+ # Extract image and trigger word
407
+ print("Before trying to get image")
408
+ image_path = model_card.data.get("widget", [{}])[0].get("output", {}).get("url", None)
409
+ print(f"Image path {image_path}")
410
+ trigger_word = model_card.data.get("instance_prompt", "")
411
+ print(f"Image path {trigger_word}")
412
+ image_url = f"https://huggingface.co/{link}/resolve/main/{image_path}" if image_path else None
413
+ print(f"Image URL {image_url}")
414
+
415
+
416
+ # Initialize Hugging Face file system
417
+ fs = HfFileSystem()
418
+ try:
419
+ list_of_files = fs.ls(link, detail=False)
420
+
421
+ # Initialize variables for safetensors selection
422
  safetensors_name = None
423
+ highest_trained_file = None
424
+ highest_steps = -1
425
+ last_safetensors_file = None
426
+ step_pattern = re.compile(r"_0{3,}\d+") # Detects step count `_000...`
427
+
428
+ for file in list_of_files:
429
+ filename = file.split("/")[-1]
430
+
431
+ # Select safetensors file
432
+ if filename.endswith(".safetensors"):
433
+ last_safetensors_file = filename # Track last encountered file
434
+
435
+ match = step_pattern.search(filename)
436
+ if not match:
437
+ # Found a full model without step numbers, return immediately
438
+ safetensors_name = filename
439
+ break
440
+ else:
441
+ # Extract step count and track highest
442
+ steps = int(match.group().lstrip("_"))
443
+ if steps > highest_steps:
444
+ highest_trained_file = filename
445
+ highest_steps = steps
446
+
447
+ # Select an image file if not found in model card
448
+ if not image_url and filename.lower().endswith((".jpg", ".jpeg", ".png", ".webp")):
449
+ image_url = f"https://huggingface.co/{link}/resolve/main/{filename}"
450
+
451
+ # If no full model found, fall back to the most trained safetensors file
452
  if not safetensors_name:
453
+ safetensors_name = highest_trained_file if highest_trained_file else last_safetensors_file
454
+
455
+ # If still no safetensors file found, raise an exception
456
+ if not safetensors_name:
457
+ raise Exception("No valid *.safetensors file found in the repository.")
458
+
459
+ except Exception as e:
460
+ print(e)
461
+ raise Exception("You didn't include a valid Hugging Face repository with a *.safetensors LoRA")
462
+
463
+ return split_link[1], link, safetensors_name, trigger_word, image_url
464
 
465
  def check_custom_model(link):
466
  if link.endswith(".safetensors"):