Spaces:
Running
Running
VideoModelStudio
/
docs
/finetrainers-src-codebase
/finetrainers
/patches
/dependencies
/diffusers
/peft.py
import json | |
from pathlib import Path | |
from typing import Optional | |
import safetensors.torch | |
from diffusers import DiffusionPipeline | |
from diffusers.loaders.lora_pipeline import _LOW_CPU_MEM_USAGE_DEFAULT_LORA | |
from huggingface_hub import repo_exists, snapshot_download | |
from peft import LoraConfig, inject_adapter_in_model, set_peft_model_state_dict | |
from finetrainers.logging import get_logger | |
from finetrainers.utils import find_files | |
logger = get_logger() | |
def load_lora_weights( | |
pipeline: DiffusionPipeline, pretrained_model_name_or_path: str, adapter_name: Optional[str] = None, **kwargs | |
) -> None: | |
low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT_LORA) | |
is_local_file_path = Path(pretrained_model_name_or_path).is_dir() | |
if not is_local_file_path: | |
does_repo_exist = repo_exists(pretrained_model_name_or_path, repo_type="model") | |
if not does_repo_exist: | |
raise ValueError(f"Model repo {pretrained_model_name_or_path} does not exist on the Hub or locally.") | |
else: | |
pretrained_model_name_or_path = snapshot_download(pretrained_model_name_or_path, repo_type="model") | |
prefix = "transformer" | |
state_dict = pipeline.lora_state_dict(pretrained_model_name_or_path) | |
state_dict = {k[len(f"{prefix}.") :]: v for k, v in state_dict.items() if k.startswith(f"{prefix}.")} | |
file_list = find_files(pretrained_model_name_or_path, "*.safetensors", depth=1) | |
if len(file_list) == 0: | |
raise ValueError(f"No .safetensors files found in {pretrained_model_name_or_path}.") | |
if len(file_list) > 1: | |
logger.warning( | |
f"Multiple .safetensors files found in {pretrained_model_name_or_path}. Using the first one: {file_list[0]}." | |
) | |
with safetensors.torch.safe_open(file_list[0], framework="pt") as f: | |
metadata = f.metadata() | |
metadata = json.loads(metadata["lora_config"]) | |
transformer = pipeline.transformer | |
if adapter_name is None: | |
adapter_name = "default" | |
lora_config = LoraConfig(**metadata) | |
inject_adapter_in_model(lora_config, transformer, adapter_name=adapter_name, low_cpu_mem_usage=low_cpu_mem_usage) | |
result = set_peft_model_state_dict( | |
transformer, | |
state_dict, | |
adapter_name=adapter_name, | |
ignore_mismatched_sizes=False, | |
low_cpu_mem_usage=low_cpu_mem_usage, | |
) | |
logger.debug( | |
f"Loaded LoRA weights from {pretrained_model_name_or_path} into {pipeline.__class__.__name__}. Result: {result}" | |
) | |