|
import argparse |
|
import json |
|
import torch |
|
from torch.nn.parameter import Parameter |
|
from stable_audio_tools.models import create_model_from_config |
|
|
|
if __name__ == '__main__': |
|
args = argparse.ArgumentParser() |
|
args.add_argument('--model-config', type=str, default=None) |
|
args.add_argument('--ckpt-path', type=str, default=None) |
|
args.add_argument('--name', type=str, default='exported_model') |
|
args.add_argument('--use-safetensors', action='store_true') |
|
|
|
args = args.parse_args() |
|
|
|
with open(args.model_config) as f: |
|
model_config = json.load(f) |
|
|
|
model = create_model_from_config(model_config) |
|
|
|
model_type = model_config.get('model_type', None) |
|
|
|
assert model_type is not None, 'model_type must be specified in model config' |
|
|
|
training_config = model_config.get('training', None) |
|
|
|
if model_type == 'autoencoder': |
|
from stable_audio_tools.training.autoencoders import AutoencoderTrainingWrapper |
|
|
|
ema_copy = None |
|
|
|
if training_config.get("use_ema", False): |
|
from stable_audio_tools.models.factory import create_model_from_config |
|
ema_copy = create_model_from_config(model_config) |
|
ema_copy = create_model_from_config(model_config) |
|
|
|
|
|
for name, param in model.state_dict().items(): |
|
if isinstance(param, Parameter): |
|
|
|
param = param.data |
|
ema_copy.state_dict()[name].copy_(param) |
|
|
|
use_ema = training_config.get("use_ema", False) |
|
|
|
training_wrapper = AutoencoderTrainingWrapper.load_from_checkpoint( |
|
args.ckpt_path, |
|
autoencoder=model, |
|
strict=False, |
|
loss_config=training_config["loss_configs"], |
|
use_ema=training_config["use_ema"], |
|
ema_copy=ema_copy if use_ema else None |
|
) |
|
elif model_type == 'diffusion_uncond': |
|
from stable_audio_tools.training.diffusion import DiffusionUncondTrainingWrapper |
|
training_wrapper = DiffusionUncondTrainingWrapper.load_from_checkpoint(args.ckpt_path, model=model, strict=False) |
|
|
|
elif model_type == 'diffusion_autoencoder': |
|
from stable_audio_tools.training.diffusion import DiffusionAutoencoderTrainingWrapper |
|
|
|
ema_copy = create_model_from_config(model_config) |
|
|
|
for name, param in model.state_dict().items(): |
|
if isinstance(param, Parameter): |
|
|
|
param = param.data |
|
ema_copy.state_dict()[name].copy_(param) |
|
|
|
training_wrapper = DiffusionAutoencoderTrainingWrapper.load_from_checkpoint(args.ckpt_path, model=model, ema_copy=ema_copy, strict=False) |
|
elif model_type == 'diffusion_cond': |
|
from stable_audio_tools.training.diffusion import DiffusionCondTrainingWrapper |
|
|
|
use_ema = training_config.get("use_ema", True) |
|
|
|
training_wrapper = DiffusionCondTrainingWrapper.load_from_checkpoint( |
|
args.ckpt_path, |
|
model=model, |
|
use_ema=use_ema, |
|
lr=training_config.get("learning_rate", None), |
|
optimizer_configs=training_config.get("optimizer_configs", None), |
|
strict=False |
|
) |
|
elif model_type == 'diffusion_cond_inpaint': |
|
from stable_audio_tools.training.diffusion import DiffusionCondInpaintTrainingWrapper |
|
training_wrapper = DiffusionCondInpaintTrainingWrapper.load_from_checkpoint(args.ckpt_path, model=model, strict=False) |
|
elif model_type == 'diffusion_prior': |
|
from stable_audio_tools.training.diffusion import DiffusionPriorTrainingWrapper |
|
|
|
ema_copy = create_model_from_config(model_config) |
|
|
|
for name, param in model.state_dict().items(): |
|
if isinstance(param, Parameter): |
|
|
|
param = param.data |
|
ema_copy.state_dict()[name].copy_(param) |
|
|
|
training_wrapper = DiffusionPriorTrainingWrapper.load_from_checkpoint(args.ckpt_path, model=model, strict=False, ema_copy=ema_copy) |
|
elif model_type == 'lm': |
|
from stable_audio_tools.training.lm import AudioLanguageModelTrainingWrapper |
|
|
|
ema_copy = None |
|
|
|
if training_config.get("use_ema", False): |
|
|
|
ema_copy = create_model_from_config(model_config) |
|
|
|
for name, param in model.state_dict().items(): |
|
if isinstance(param, Parameter): |
|
|
|
param = param.data |
|
ema_copy.state_dict()[name].copy_(param) |
|
|
|
training_wrapper = AudioLanguageModelTrainingWrapper.load_from_checkpoint( |
|
args.ckpt_path, |
|
model=model, |
|
strict=False, |
|
ema_copy=ema_copy, |
|
optimizer_configs=training_config.get("optimizer_configs", None) |
|
) |
|
|
|
else: |
|
raise ValueError(f"Unknown model type {model_type}") |
|
|
|
print(f"Loaded model from {args.ckpt_path}") |
|
|
|
if args.use_safetensors: |
|
ckpt_path = f"{args.name}.safetensors" |
|
else: |
|
ckpt_path = f"{args.name}.ckpt" |
|
|
|
training_wrapper.export_model(ckpt_path, use_safetensors=args.use_safetensors) |
|
|
|
print(f"Exported model to {ckpt_path}") |