Spaces:
Running
Running
File size: 2,527 Bytes
9fd1204 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 |
import json
from pathlib import Path
from typing import Optional
import safetensors.torch
from diffusers import DiffusionPipeline
from diffusers.loaders.lora_pipeline import _LOW_CPU_MEM_USAGE_DEFAULT_LORA
from huggingface_hub import repo_exists, snapshot_download
from peft import LoraConfig, inject_adapter_in_model, set_peft_model_state_dict
from finetrainers.logging import get_logger
from finetrainers.utils import find_files
logger = get_logger()
def load_lora_weights(
pipeline: DiffusionPipeline, pretrained_model_name_or_path: str, adapter_name: Optional[str] = None, **kwargs
) -> None:
low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT_LORA)
is_local_file_path = Path(pretrained_model_name_or_path).is_dir()
if not is_local_file_path:
does_repo_exist = repo_exists(pretrained_model_name_or_path, repo_type="model")
if not does_repo_exist:
raise ValueError(f"Model repo {pretrained_model_name_or_path} does not exist on the Hub or locally.")
else:
pretrained_model_name_or_path = snapshot_download(pretrained_model_name_or_path, repo_type="model")
prefix = "transformer"
state_dict = pipeline.lora_state_dict(pretrained_model_name_or_path)
state_dict = {k[len(f"{prefix}.") :]: v for k, v in state_dict.items() if k.startswith(f"{prefix}.")}
file_list = find_files(pretrained_model_name_or_path, "*.safetensors", depth=1)
if len(file_list) == 0:
raise ValueError(f"No .safetensors files found in {pretrained_model_name_or_path}.")
if len(file_list) > 1:
logger.warning(
f"Multiple .safetensors files found in {pretrained_model_name_or_path}. Using the first one: {file_list[0]}."
)
with safetensors.torch.safe_open(file_list[0], framework="pt") as f:
metadata = f.metadata()
metadata = json.loads(metadata["lora_config"])
transformer = pipeline.transformer
if adapter_name is None:
adapter_name = "default"
lora_config = LoraConfig(**metadata)
inject_adapter_in_model(lora_config, transformer, adapter_name=adapter_name, low_cpu_mem_usage=low_cpu_mem_usage)
result = set_peft_model_state_dict(
transformer,
state_dict,
adapter_name=adapter_name,
ignore_mismatched_sizes=False,
low_cpu_mem_usage=low_cpu_mem_usage,
)
logger.debug(
f"Loaded LoRA weights from {pretrained_model_name_or_path} into {pipeline.__class__.__name__}. Result: {result}"
)
|