Spaces:
Configuration error
Configuration error
# Copyright (c) Meta Platforms, Inc. and affiliates. | |
# All rights reserved. | |
# | |
# This source code is licensed under the license found in the | |
# LICENSE file in the root directory of this source tree. | |
"""Export a trained model from the full checkpoint (with optimizer etc.) to | |
a final checkpoint, with only the model itself. The model is always stored as | |
half float to gain space, and because this has zero impact on the final loss. | |
When DiffQ was used for training, the model will actually be quantized and bitpacked.""" | |
from argparse import ArgumentParser | |
from fractions import Fraction | |
import logging | |
from pathlib import Path | |
import sys | |
import torch | |
from demucs import train | |
from demucs.states import serialize_model, save_with_checksum | |
logger = logging.getLogger(__name__) | |
def main(): | |
logging.basicConfig(level=logging.INFO, stream=sys.stderr) | |
parser = ArgumentParser("tools.export", description="Export trained models from XP sigs.") | |
parser.add_argument('signatures', nargs='*', help='XP signatures.') | |
parser.add_argument('-o', '--out', type=Path, default=Path("release_models"), | |
help="Path where to store release models (default release_models)") | |
parser.add_argument('-s', '--sign', action='store_true', | |
help='Add sha256 prefix checksum to the filename.') | |
args = parser.parse_args() | |
args.out.mkdir(exist_ok=True, parents=True) | |
for sig in args.signatures: | |
xp = train.main.get_xp_from_sig(sig) | |
name = train.main.get_name(xp) | |
logger.info('Handling %s/%s', sig, name) | |
out_path = args.out / (sig + ".th") | |
solver = train.get_solver_from_sig(sig) | |
if len(solver.history) < solver.args.epochs: | |
logger.warning( | |
'Model %s has less epoch than expected (%d / %d)', | |
sig, len(solver.history), solver.args.epochs) | |
solver.model.load_state_dict(solver.best_state) | |
pkg = serialize_model(solver.model, solver.args, solver.quantizer, half=True) | |
if getattr(solver.model, 'use_train_segment', False): | |
batch = solver.augment(next(iter(solver.loaders['train']))) | |
pkg['kwargs']['segment'] = Fraction(batch.shape[-1], solver.model.samplerate) | |
print("Override", pkg['kwargs']['segment']) | |
valid, test = None, None | |
for m in solver.history: | |
if 'valid' in m: | |
valid = m['valid'] | |
if 'test' in m: | |
test = m['test'] | |
pkg['metrics'] = (valid, test) | |
if args.sign: | |
save_with_checksum(pkg, out_path) | |
else: | |
torch.save(pkg, out_path) | |
if __name__ == '__main__': | |
main() | |