Spaces:
Configuration error
Configuration error
# Copyright (c) Meta Platforms, Inc. and affiliates. | |
# All rights reserved. | |
# | |
# This source code is licensed under the license found in the | |
# LICENSE file in the root directory of this source tree. | |
# Script to convert option names and model args from the dev branch to | |
# the cleanup release one. There should be no reaso to use that anymore. | |
import argparse | |
import io | |
import json | |
from pathlib import Path | |
import subprocess as sp | |
import torch | |
from demucs import train, pretrained, states | |
DEV_REPO = Path.home() / 'tmp/release_demucs_mdx' | |
TO_REMOVE = [ | |
'demucs.dconv_kw.gelu=True', | |
'demucs.dconv_kw.nfreqs=0', | |
'demucs.dconv_kw.nfreqs=0', | |
'demucs.dconv_kw.version=4', | |
'demucs.norm=gn', | |
'wdemucs.nice=True', | |
'wdemucs.good=True', | |
'wdemucs.freq_emb=-0.2', | |
'special=True', | |
'special=False', | |
] | |
TO_REPLACE = [ | |
('power', 'svd'), | |
('wdemucs', 'hdemucs'), | |
('hdemucs.hybrid=True', 'hdemucs.hybrid_old=True'), | |
('hdemucs.hybrid=2', 'hdemucs.hybrid=True'), | |
] | |
TO_INJECT = [ | |
('model=hdemucs', ['hdemucs.cac=False']), | |
('model=hdemucs', ['hdemucs.norm_starts=999']), | |
] | |
def get_original_argv(sig): | |
return json.load(open(Path(DEV_REPO) / f'outputs/xps/{sig}/.argv.json')) | |
def transform(argv, mappings, verbose=False): | |
for rm in TO_REMOVE: | |
while rm in argv: | |
argv.remove(rm) | |
for old, new in TO_REPLACE: | |
argv[:] = [a.replace(old, new) for a in argv] | |
for condition, args in TO_INJECT: | |
if condition in argv: | |
argv[:] = args + argv | |
for idx, arg in enumerate(argv): | |
if 'continue_from=' in arg: | |
dep_sig = arg.split('=')[1] | |
if dep_sig.startswith('"'): | |
dep_sig = eval(dep_sig) | |
if verbose: | |
print("Need to recursively convert dependency XP", dep_sig) | |
new_sig = convert(dep_sig, mappings, verbose).sig | |
argv[idx] = f'continue_from="{new_sig}"' | |
def convert(sig, mappings, verbose=False): | |
argv = get_original_argv(sig) | |
if verbose: | |
print("Original argv", argv) | |
transform(argv, mappings, verbose) | |
if verbose: | |
print("New argv", argv) | |
xp = train.main.get_xp(argv) | |
train.main.init_xp(xp) | |
if verbose: | |
print("Mapping", sig, "->", xp.sig) | |
mappings[sig] = xp.sig | |
return xp | |
def _eval_old(old_sig, x): | |
script = ( | |
'from demucs import pretrained; import torch; import sys; import io; ' | |
'buf = io.BytesIO(sys.stdin.buffer.read()); ' | |
'x = torch.load(buf); m = pretrained.load_pretrained_model(' | |
f'"{old_sig}"); torch.save(m(x), sys.stdout.buffer)') | |
buf = io.BytesIO() | |
torch.save(x, buf) | |
proc = sp.run( | |
['python3', '-c', script], input=buf.getvalue(), capture_output=True, cwd=DEV_REPO) | |
if proc.returncode != 0: | |
print("Error", proc.stderr.decode()) | |
assert False | |
buf = io.BytesIO(proc.stdout) | |
return torch.load(buf) | |
def compare(old_sig, model): | |
test = torch.randn(1, 2, 44100 * 10) | |
old_out = _eval_old(old_sig, test) | |
out = model(test) | |
delta = 20 * torch.log10((out - old_out).norm() / out.norm()).item() | |
return delta | |
def main(): | |
torch.manual_seed(1234) | |
parser = argparse.ArgumentParser('convert') | |
parser.add_argument('sigs', nargs='*') | |
parser.add_argument('-o', '--output', type=Path, default=Path('release_models')) | |
parser.add_argument('-d', '--dump', action='store_true') | |
parser.add_argument('-c', '--compare', action='store_true') | |
parser.add_argument('-v', '--verbose', action='store_true') | |
args = parser.parse_args() | |
args.output.mkdir(exist_ok=True, parents=True) | |
mappings = {} | |
for sig in args.sigs: | |
xp = convert(sig, mappings, args.verbose) | |
if args.dump or args.compare: | |
old_pkg = pretrained._load_package(sig, old=True) | |
model = train.get_model(xp.cfg) | |
model.load_state_dict(old_pkg['state']) | |
if args.dump: | |
pkg = states.serialize_model(model, xp.cfg) | |
states.save_with_checksum(pkg, args.output / f'{xp.sig}.th') | |
if args.compare: | |
delta = compare(sig, model) | |
print("Delta for", sig, xp.sig, delta) | |
mappings[sig] = xp.sig | |
print("FINAL MAPPINGS") | |
for old, new in mappings.items(): | |
print(old, " ", new) | |
if __name__ == '__main__': | |
main() | |