Spaces:
Running
on
Zero
Running
on
Zero
File size: 1,548 Bytes
69defc9 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 |
import torch.nn as nn
def set_dropout(model, rate):
for name, child in model.named_children():
if isinstance(child, nn.Dropout):
child.p = rate
set_dropout(child, rate)
return model
def build_model(args, load_config_dict=None):
if load_config_dict is not None:
args = load_config_dict
config = {
"vocab_size": args["vocab_size"],
"num_layer": args["n_layer"],
"num_head": args["n_head"],
"embedding_dim": args["d_model"],
"d_inner": args["d_inner"],
"dropout": args["dropout"],
"d_condition": args["d_condition"],
"max_seq": 2048,
"pad_token": 0,
}
if not "regression" in list(args.keys()):
args["regression"] = False
if args["regression"]:
config["output_size"] = 2
from models.music_regression \
import MusicRegression as MusicTransformer
elif args["conditioning"] == "continuous_token":
from models.music_continuous_token \
import MusicTransformerContinuousToken as MusicTransformer
del config["d_condition"]
else:
from .music_multi \
import MusicTransformerMulti as MusicTransformer
model = MusicTransformer(**config)
if load_config_dict is not None and args is not None:
if args["overwrite_dropout"]:
model = set_dropout(model, args["dropout"])
rate = args["dropout"]
print(f"Dropout rate changed to {rate}")
return model, args
|