|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from __future__ import annotations |
|
|
|
from operator import attrgetter |
|
|
|
import torch |
|
|
|
from peft.config import PeftConfig |
|
from peft.mapping import PEFT_TYPE_TO_CONFIG_MAPPING |
|
|
|
from .constants import PEFT_TYPE_TO_PREFIX_MAPPING |
|
from .other import infer_device |
|
from .peft_types import PeftType |
|
from .save_and_load import _insert_adapter_name_into_state_dict, load_peft_weights |
|
|
|
|
|
|
|
CONFIG_KEYS_TO_CHECK = {PeftType.LORA: ["lora_alpha", "use_rslora", "lora_dropout", "alpha_pattern", "use_dora"]} |
|
|
|
|
|
def hotswap_adapter_from_state_dict(model, state_dict, adapter_name, parameter_prefix="lora_"): |
|
""" |
|
Swap out the adapter weights from the model with the weights from state_dict. |
|
|
|
As of now, only LoRA is supported. |
|
|
|
This is a low-level function that assumes that the adapters have been checked for compatibility and that the |
|
state_dict has been correctly mapped to work with PEFT. For a high level function that performs this work for you, |
|
use `hotswap_adapter` instead. |
|
|
|
Args: |
|
model (`nn.Module`): |
|
The model with the loaded adapter. |
|
state_dict (`dict[str, torch.Tensor]`): |
|
The state dict of the new adapter, which needs to be compatible (targeting same modules etc.). |
|
adapter_name (`str`): |
|
The name of the adapter that should be hot-swapped, e.g. `"default"`. The name will remain the same after |
|
swapping. |
|
parameter_prefix (`str`, *optional*, defaults to `"lora_"`) |
|
The prefix used to identify the adapter's keys in the state dict. For LoRA, this would be `"lora_"` (the |
|
default). |
|
|
|
Raises: |
|
RuntimeError |
|
If the old and the new adapter are not compatible, a RuntimeError is raised. |
|
|
|
""" |
|
|
|
|
|
|
|
is_compiled = hasattr(model, "_orig_mod") |
|
|
|
missing_keys = {k for k in model.state_dict() if (parameter_prefix in k) and (adapter_name in k)} |
|
unexpected_keys = set() |
|
|
|
|
|
for key, new_val in state_dict.items(): |
|
try: |
|
old_val = attrgetter(key)(model) |
|
except AttributeError: |
|
unexpected_keys.add(key) |
|
continue |
|
|
|
if is_compiled: |
|
missing_keys.remove("_orig_mod." + key) |
|
else: |
|
missing_keys.remove(key) |
|
|
|
if missing_keys or unexpected_keys: |
|
msg = "Hot swapping the adapter did not succeed." |
|
if missing_keys: |
|
msg += f" Missing keys: {', '.join(sorted(missing_keys))}." |
|
if unexpected_keys: |
|
msg += f" Unexpected keys: {', '.join(sorted(unexpected_keys))}." |
|
raise RuntimeError(msg) |
|
|
|
|
|
for key, new_val in state_dict.items(): |
|
|
|
old_val = attrgetter(key)(model) |
|
if is_compiled: |
|
|
|
|
|
old_val.data = new_val.data |
|
else: |
|
torch.utils.swap_tensors(old_val, new_val) |
|
|
|
|
|
def _check_hotswap_configs_compatible(config0: PeftConfig, config1: PeftConfig) -> None: |
|
""" |
|
Check if two configs are compatible for hot-swapping. |
|
|
|
Only LoRA parameters are checked for now. |
|
|
|
To hot-swap two adapters, their configs must be compatible. Otherwise, the results could be false. E.g. if they use |
|
different alpha values, after hot-swapping, the alphas from the first adapter would still be used with the weights |
|
from the 2nd adapter, which would result in incorrect behavior. There is probably a way to swap these values as |
|
well, but that's not implemented yet, and we need to be careful not to trigger re-compilation if the model is |
|
compiled (so no modification of the dict). |
|
|
|
""" |
|
|
|
if config0.peft_type != config1.peft_type: |
|
msg = f"Incompatible PEFT types found: {config0.peft_type.value} and {config1.peft_type.value}" |
|
raise ValueError(msg) |
|
|
|
if config0.peft_type not in CONFIG_KEYS_TO_CHECK: |
|
msg = ( |
|
f"Hotswapping only supports {', '.join(CONFIG_KEYS_TO_CHECK.keys())} but " |
|
f"{config0.peft_type.value} was passed." |
|
) |
|
raise ValueError(msg) |
|
config_keys_to_check = CONFIG_KEYS_TO_CHECK[config0.peft_type] |
|
|
|
|
|
|
|
config0 = config0.to_dict() |
|
config1 = config1.to_dict() |
|
sentinel = object() |
|
for key in config_keys_to_check: |
|
val0 = config0.get(key, sentinel) |
|
val1 = config1.get(key, sentinel) |
|
if val0 != val1: |
|
raise ValueError(f"Configs are incompatible: for {key}, {val0} != {val1}") |
|
|
|
|
|
def hotswap_adapter(model, model_name_or_path, adapter_name, torch_device=None, **kwargs): |
|
"""Substitute old adapter data with new adapter data, keeping the rest the same. |
|
|
|
As of now, only LoRA is supported. |
|
|
|
This function is useful when you want to replace the loaded adapter with a new adapter. The adapter name will |
|
remain the same, but the weights and other parameters will be swapped out. |
|
|
|
If the adapters are incomptabile, e.g. targeting different layers or having different alpha values, an error will |
|
be raised. |
|
|
|
Example: |
|
|
|
```py |
|
>>> import torch |
|
>>> from transformers import AutoModelForCausalLM |
|
>>> from peft import PeftModel |
|
>>> from peft.utils.hotswap import hotswap_adapter |
|
|
|
>>> model_id = ... |
|
>>> inputs = ... |
|
>>> device = ... |
|
>>> model = AutoModelForCausalLM.from_pretrained(model_id).to(device) |
|
|
|
>>> # load lora 0 |
|
>>> model = PeftModel.from_pretrained(model, "path-adapter-0") |
|
>>> model = torch.compile(model) # optionally compile the model |
|
>>> with torch.inference_mode(): |
|
... output_adapter_0 = model(inputs) |
|
|
|
>>> # replace the "default" lora adapter with the new one |
|
>>> hotswap_adapter(model, "path-adapter-1", adapter_name="default", torch_device=device) |
|
>>> with torch.inference_mode(): |
|
... output_adapter_1 = model(inputs).logits |
|
``` |
|
|
|
Args: |
|
model ([`~PeftModel`]): |
|
The PEFT model with the loaded adapter. |
|
model_name_or_path (`str`): |
|
The name or path of the model to load the new adapter from. |
|
adapter_name (`str`): |
|
The name of the adapter to swap, e.g. `"default"`. The name will stay the same after swapping. |
|
torch_device: (`str`, *optional*, defaults to None): |
|
The device to load the new adapter onto. |
|
**kwargs (`optional`): |
|
Additional keyword arguments used for loading the config and weights. |
|
|
|
""" |
|
if torch_device is None: |
|
torch_device = infer_device() |
|
|
|
|
|
|
|
|
|
|
|
config_cls = PEFT_TYPE_TO_CONFIG_MAPPING[ |
|
PeftConfig._get_peft_type( |
|
model_name_or_path, |
|
subfolder=kwargs.get("subfolder", None), |
|
revision=kwargs.get("revision", None), |
|
cache_dir=kwargs.get("cache_dir", None), |
|
use_auth_token=kwargs.get("use_auth_token", None), |
|
token=kwargs.get("token", None), |
|
) |
|
] |
|
config = config_cls.from_pretrained(model_name_or_path, **kwargs) |
|
|
|
_check_hotswap_configs_compatible(model.active_peft_config, config) |
|
|
|
state_dict = load_peft_weights(model_name_or_path, device=torch_device, **kwargs) |
|
|
|
|
|
|
|
|
|
|
|
parameter_prefix = PEFT_TYPE_TO_PREFIX_MAPPING[config.peft_type] |
|
peft_model_state_dict = _insert_adapter_name_into_state_dict( |
|
state_dict, adapter_name=adapter_name, parameter_prefix=parameter_prefix |
|
) |
|
|
|
hotswap_adapter_from_state_dict( |
|
model=model, |
|
state_dict=peft_model_state_dict, |
|
adapter_name=adapter_name, |
|
parameter_prefix=parameter_prefix, |
|
) |
|
|