Spaces:
Paused
Paused
from typing import Any, Dict | |
import torch | |
import argparse | |
from diffusers.loaders.lora_base import LoraBaseMixin | |
from diffusers.models.modeling_utils import load_state_dict | |
def get_state_dict(saved_dict: Dict[str, Any]) -> Dict[str, Any]: | |
state_dict = saved_dict | |
if "model" in saved_dict.keys(): | |
state_dict = state_dict["model"] | |
if "module" in saved_dict.keys(): | |
state_dict = state_dict["module"] | |
if "state_dict" in saved_dict.keys(): | |
state_dict = state_dict["state_dict"] | |
return state_dict | |
LORA_KEYS_RENAME = { | |
'attention.query_key_value.matrix_A.0': 'attn1.to_q.lora_A.weight', | |
'attention.query_key_value.matrix_A.1': 'attn1.to_k.lora_A.weight', | |
'attention.query_key_value.matrix_A.2': 'attn1.to_v.lora_A.weight', | |
'attention.query_key_value.matrix_B.0': 'attn1.to_q.lora_B.weight', | |
'attention.query_key_value.matrix_B.1': 'attn1.to_k.lora_B.weight', | |
'attention.query_key_value.matrix_B.2': 'attn1.to_v.lora_B.weight', | |
'attention.dense.matrix_A.0': 'attn1.to_out.0.lora_A.weight', | |
'attention.dense.matrix_B.0': 'attn1.to_out.0.lora_B.weight' | |
} | |
PREFIX_KEY = "model.diffusion_model." | |
SAT_UNIT_KEY = "layers" | |
LORA_PREFIX_KEY = "transformer_blocks" | |
def export_lora_weight(ckpt_path,lora_save_directory): | |
merge_original_state_dict = get_state_dict(torch.load(ckpt_path, map_location="cpu", mmap=True)) | |
lora_state_dict = {} | |
for key in list(merge_original_state_dict.keys()): | |
new_key = key[len(PREFIX_KEY) :] | |
for special_key, lora_keys in LORA_KEYS_RENAME.items(): | |
if new_key.endswith(special_key): | |
new_key = new_key.replace(special_key, lora_keys) | |
new_key = new_key.replace(SAT_UNIT_KEY, LORA_PREFIX_KEY) | |
lora_state_dict[new_key] = merge_original_state_dict[key] | |
# final length should be 240 | |
if len(lora_state_dict) != 240: | |
raise ValueError("lora_state_dict length is not 240") | |
lora_state_dict.keys() | |
LoraBaseMixin.write_lora_layers( | |
state_dict=lora_state_dict, | |
save_directory=lora_save_directory, | |
is_main_process=True, | |
weight_name=None, | |
save_function=None, | |
safe_serialization=True | |
) | |
def get_args(): | |
parser = argparse.ArgumentParser() | |
parser.add_argument( | |
"--sat_pt_path", type=str, required=True, help="Path to original sat transformer checkpoint" | |
) | |
parser.add_argument("--lora_save_directory", type=str, required=True, help="Path where converted lora should be saved") | |
return parser.parse_args() | |
if __name__ == "__main__": | |
args = get_args() | |
export_lora_weight(args.sat_pt_path, args.lora_save_directory) | |