audio / tools /convert.py
PreciousMposa's picture
Upload 107 files
519d358 verified
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
# Script to convert option names and model args from the dev branch to
# the cleanup release one. There should be no reaso to use that anymore.
import argparse
import io
import json
from pathlib import Path
import subprocess as sp
import torch
from demucs import train, pretrained, states
DEV_REPO = Path.home() / 'tmp/release_demucs_mdx'
TO_REMOVE = [
'demucs.dconv_kw.gelu=True',
'demucs.dconv_kw.nfreqs=0',
'demucs.dconv_kw.nfreqs=0',
'demucs.dconv_kw.version=4',
'demucs.norm=gn',
'wdemucs.nice=True',
'wdemucs.good=True',
'wdemucs.freq_emb=-0.2',
'special=True',
'special=False',
]
TO_REPLACE = [
('power', 'svd'),
('wdemucs', 'hdemucs'),
('hdemucs.hybrid=True', 'hdemucs.hybrid_old=True'),
('hdemucs.hybrid=2', 'hdemucs.hybrid=True'),
]
TO_INJECT = [
('model=hdemucs', ['hdemucs.cac=False']),
('model=hdemucs', ['hdemucs.norm_starts=999']),
]
def get_original_argv(sig):
return json.load(open(Path(DEV_REPO) / f'outputs/xps/{sig}/.argv.json'))
def transform(argv, mappings, verbose=False):
for rm in TO_REMOVE:
while rm in argv:
argv.remove(rm)
for old, new in TO_REPLACE:
argv[:] = [a.replace(old, new) for a in argv]
for condition, args in TO_INJECT:
if condition in argv:
argv[:] = args + argv
for idx, arg in enumerate(argv):
if 'continue_from=' in arg:
dep_sig = arg.split('=')[1]
if dep_sig.startswith('"'):
dep_sig = eval(dep_sig)
if verbose:
print("Need to recursively convert dependency XP", dep_sig)
new_sig = convert(dep_sig, mappings, verbose).sig
argv[idx] = f'continue_from="{new_sig}"'
def convert(sig, mappings, verbose=False):
argv = get_original_argv(sig)
if verbose:
print("Original argv", argv)
transform(argv, mappings, verbose)
if verbose:
print("New argv", argv)
xp = train.main.get_xp(argv)
train.main.init_xp(xp)
if verbose:
print("Mapping", sig, "->", xp.sig)
mappings[sig] = xp.sig
return xp
def _eval_old(old_sig, x):
script = (
'from demucs import pretrained; import torch; import sys; import io; '
'buf = io.BytesIO(sys.stdin.buffer.read()); '
'x = torch.load(buf); m = pretrained.load_pretrained_model('
f'"{old_sig}"); torch.save(m(x), sys.stdout.buffer)')
buf = io.BytesIO()
torch.save(x, buf)
proc = sp.run(
['python3', '-c', script], input=buf.getvalue(), capture_output=True, cwd=DEV_REPO)
if proc.returncode != 0:
print("Error", proc.stderr.decode())
assert False
buf = io.BytesIO(proc.stdout)
return torch.load(buf)
def compare(old_sig, model):
test = torch.randn(1, 2, 44100 * 10)
old_out = _eval_old(old_sig, test)
out = model(test)
delta = 20 * torch.log10((out - old_out).norm() / out.norm()).item()
return delta
def main():
torch.manual_seed(1234)
parser = argparse.ArgumentParser('convert')
parser.add_argument('sigs', nargs='*')
parser.add_argument('-o', '--output', type=Path, default=Path('release_models'))
parser.add_argument('-d', '--dump', action='store_true')
parser.add_argument('-c', '--compare', action='store_true')
parser.add_argument('-v', '--verbose', action='store_true')
args = parser.parse_args()
args.output.mkdir(exist_ok=True, parents=True)
mappings = {}
for sig in args.sigs:
xp = convert(sig, mappings, args.verbose)
if args.dump or args.compare:
old_pkg = pretrained._load_package(sig, old=True)
model = train.get_model(xp.cfg)
model.load_state_dict(old_pkg['state'])
if args.dump:
pkg = states.serialize_model(model, xp.cfg)
states.save_with_checksum(pkg, args.output / f'{xp.sig}.th')
if args.compare:
delta = compare(sig, model)
print("Delta for", sig, xp.sig, delta)
mappings[sig] = xp.sig
print("FINAL MAPPINGS")
for old, new in mappings.items():
print(old, " ", new)
if __name__ == '__main__':
main()