ARIA / midi_emotion /src /generate.py
vincentamato's picture
precache clip
6142e6b
from argparse import ArgumentParser
from copy import deepcopy
import os
import sys
import numpy as np
import torch
import torch.nn.functional as F
import datetime
from tqdm import tqdm
from .utils import get_n_instruments
from .models.build_model import build_model
from .data.data_processing_reverse import ind_tensor_to_mid, ind_tensor_to_str
# os.environ["CUDA_VISIBLE_DEVICES"] = "0"
def chunks(lst, n):
"""Yield successive n-sized chunks from lst."""
for i in range(0, len(lst), n):
yield lst[i:i + n]
def generate(model, maps, device, out_dir, conditioning, short_filename=False,
penalty_coeff=0.5, discrete_conditions=None, continuous_conditions=None,
max_input_len=1024, amp=True, step=None,
gen_len=2048, temperatures=[1.2,1.2], top_k=-1,
top_p=0.7, debug=False, varying_condition=None, seed=-1,
verbose=False, primers=[["<START>"]], min_n_instruments=2):
if not debug:
os.makedirs(out_dir, exist_ok=True)
model = model.to(device)
model.eval()
assert len(temperatures) in (1, 2)
if varying_condition is not None:
batch_size = varying_condition[0].size(0)
else:
try:
continuous_conditions = torch.FloatTensor(continuous_conditions).to(device)
except:
continuous_conditions = None
if conditioning == "none":
batch_size = len(primers)
elif conditioning == "discrete_token":
assert discrete_conditions is not None
discrete_conditions_tensor = [[maps["tuple2idx"][symbol] for symbol in condition_sample] \
for condition_sample in discrete_conditions]
discrete_conditions_tensor = torch.LongTensor(discrete_conditions_tensor).t().to(device)
batch_size = discrete_conditions_tensor.size(1)
elif conditioning in ("continuous_token", "continuous_concat"):
batch_size = len(continuous_conditions)
# will be used to penalize repeats
repeat_counts = [0 for _ in range(batch_size)]
exclude_symbols = [symbol for symbol in maps["tuple2idx"].keys() if symbol[0] == "<"]
# will have generated symbols and indices
gen_song_tensor = torch.LongTensor([]).to(device)
if not isinstance(primers, list):
primers = [[primers]]
primer_inds = [[maps["tuple2idx"][symbol] for symbol in primer] \
for primer in primers]
gen_inds = torch.LongTensor(primer_inds)
null_conditions_tensor = torch.FloatTensor([np.nan, np.nan]).to(device)
if len(primers) == 1:
gen_inds = gen_inds.repeat(batch_size, 1)
null_conditions_tensor = null_conditions_tensor.repeat(batch_size, 1)
if conditioning == "continuous_token":
max_input_len -= 2
conditions_tensor = continuous_conditions
elif conditioning == "continuous_concat":
conditions_tensor = continuous_conditions
elif conditioning == "discrete_token":
max_input_len -= discrete_conditions_tensor.size(0)
conditions_tensor = null_conditions_tensor
else:
conditions_tensor = null_conditions_tensor
if varying_condition is not None:
varying_condition[0] = varying_condition[0].to(device)
varying_condition[1] = varying_condition[1].to(device)
gen_inds = gen_inds.t().to(device)
print("▶ midi_emotion.generate starting")
with torch.no_grad():
pbar = tqdm(total=gen_len, desc="Generating tokens", leave=True)
i = 0
while i < gen_len:
i += 1
pbar.update(1)
gen_song_tensor = torch.cat((gen_song_tensor, gen_inds), 0)
input_ = gen_song_tensor
if len(gen_song_tensor) > max_input_len:
input_ = input_[-max_input_len:, :]
if conditioning == "discrete_token":
# concat with conditions
input_ = torch.cat((discrete_conditions_tensor, input_), 0)
# INTERPOLATED CONDITIONS
if varying_condition is not None:
valences = varying_condition[0][:, i-1]
arousals = varying_condition[1][:, i-1]
conditions_tensor = torch.cat([valences[:, None], arousals[:, None]], dim=-1)
# Run model
with torch.cuda.amp.autocast(enabled=amp):
input_ = input_.t()
output = model(input_, conditions_tensor)
output = output.permute((1, 0, 2))
# Process output, get predicted token
output = output[-1, :, :] # Select last timestep
output[output != output] = 0 # zeroing nans
if torch.all(output == 0) and verbose:
# if everything becomes zero
print("All predictions were NaN during generation")
output = torch.ones(output.shape).to(device)
# exclude certain symbols
for symbol_exclude in exclude_symbols:
try:
idx_exclude = maps["tuple2idx"][symbol_exclude]
output[:, idx_exclude] = -float("inf")
except:
pass
effective_temps = []
for j in range(batch_size):
gen_idx = gen_inds[0, j].item()
gen_tuple = maps["idx2tuple"][gen_idx]
effective_temp = temperatures[1]
if isinstance(gen_tuple, tuple):
gen_event = maps["idx2event"][gen_tuple[0]]
if "TIMESHIFT" in gen_event:
# switch from rest temperature to note temperature
effective_temp = temperatures[0]
effective_temps.append(effective_temp)
temp_tensor = torch.Tensor([effective_temps]).to(device)
output = F.log_softmax(output, dim=-1)
# Add repeat penalty to temperature
if penalty_coeff > 0:
repeat_counts_array = torch.Tensor(repeat_counts).to(device)
temp_multiplier = torch.maximum(torch.zeros_like(repeat_counts_array, device=device),
torch.log((repeat_counts_array+1)/4)*penalty_coeff)
repeat_penalties = temp_multiplier * temp_tensor
temp_tensor += repeat_penalties
# Apply temperature
output /= temp_tensor.t()
# top-k
if top_k <= 0 or top_k > output.size(-1):
top_k_eff = output.size(-1)
else:
top_k_eff = top_k
output, top_inds = torch.topk(output, top_k_eff)
# top-p
if top_p > 0 and top_p < 1:
cumulative_probs = torch.cumsum(F.softmax(output, dim=-1), dim=-1)
remove_inds = cumulative_probs > top_p
remove_inds[:, 0] = False # at least keep top value
output[remove_inds] = -float("inf")
output = F.softmax(output, dim=-1)
# Sample from probabilities
inds_sampled = torch.multinomial(output, 1, replacement=True)
gen_inds = top_inds.gather(1, inds_sampled).t()
# Update repeat counts
num_choices = torch.sum((output > 0).int(), -1)
for j in range(batch_size):
if num_choices[j] <= 2: repeat_counts[j] += 1
else: repeat_counts[j] = repeat_counts[j] // 2
pbar.close()
print("▶ token generation finished")
# Convert to midi and save
print("\nConverting to MIDI...")
# If there are less than n instruments, repeat generation for specific condition
redo_primers, redo_discrete_conditions, redo_continuous_conditions = [], [], []
for i in range(gen_song_tensor.size(-1)):
if short_filename:
out_file_path = f"{i}"
else:
if step is None:
now = datetime.datetime.now()
out_file_path = now.strftime("%Y_%m_%d_%H_%M_%S")
else:
out_file_path = step
out_file_path += f"_{i}"
if seed > 0:
out_file_path += f"_s{seed}"
if continuous_conditions is not None:
condition = continuous_conditions[i, :].tolist()
# convert to string
condition = [str(round(c, 2)).replace(".", "") for c in condition]
out_file_path += f"_V{condition[0]}_A{condition[1]}"
out_file_path += ".mid"
out_path_mid = os.path.join(out_dir, out_file_path)
symbols = ind_tensor_to_str(gen_song_tensor[:, i], maps["idx2tuple"], maps["idx2event"])
n_instruments = get_n_instruments(symbols)
if n_instruments >= min_n_instruments:
mid = ind_tensor_to_mid(gen_song_tensor[:, i], maps["idx2tuple"], maps["idx2event"], verbose=False)
out_path_txt = "txt_" + out_file_path.replace(".mid", ".txt")
out_path_txt = os.path.join(out_dir, out_path_txt)
out_path_inds = "inds_" + out_file_path.replace(".mid", ".pt")
out_path_inds = os.path.join(out_dir, out_path_inds)
if not debug:
mid.write(out_path_mid)
if verbose:
print(f"Saved to {out_path_mid}")
else:
print(f"Only has {n_instruments} instruments, not saving.")
if conditioning == "none":
redo_primers.append(primers[i])
redo_discrete_conditions = None
redo_continuous_conditions = None
elif conditioning == "discrete_token":
redo_discrete_conditions.append(discrete_conditions[i])
redo_continuous_conditions = None
redo_primers = primers
else:
redo_discrete_conditions = None
redo_continuous_conditions.append(continuous_conditions[i, :].tolist())
redo_primers = primers
return redo_primers, redo_discrete_conditions, redo_continuous_conditions
if __name__ == '__main__':
script_dir = os.path.dirname(os.path.abspath(__file__))
code_model_dir = os.path.abspath(os.path.join(script_dir, 'model'))
code_utils_dir = os.path.join(code_model_dir, 'utils')
sys.path.extend([code_model_dir, code_utils_dir])
parser = ArgumentParser()
parser.add_argument('--model_dir', type=str, help='Directory with model', required=True)
parser.add_argument('--no_cuda', action='store_true', help="Use CPU")
parser.add_argument('--num_runs', type=int, help='Number of runs', default=1)
parser.add_argument('--gen_len', type=int, help='Max generation len', default=4096)
parser.add_argument('--max_input_len', type=int, help='Max input len', default=1216)
parser.add_argument('--temp', type=float, nargs='+', help='Generation temperature', default=[1.2, 1.2])
parser.add_argument('--topk', type=int, help='Top-k sampling', default=-1)
parser.add_argument('--topp', type=float, help='Top-p sampling', default=0.7)
parser.add_argument('--debug', action='store_true', help="Do not save anything")
parser.add_argument('--seed', type=int, default=0, help="Random seed")
parser.add_argument('--no_amp', action='store_true', help="Disable automatic mixed precision")
parser.add_argument("--conditioning", type=str, required=True,
choices=["none", "discrete_token", "continuous_token",
"continuous_concat"], help='Conditioning type')
parser.add_argument('--penalty_coeff', type=float, default=0.5,
help="Coefficient for penalizing repeating notes")
parser.add_argument("--quiet", action='store_true', help="Not verbose")
parser.add_argument("--short_filename", action='store_true')
parser.add_argument('--batch_size', type=int, help='Batch size', default=4)
parser.add_argument('--min_n_instruments', type=int, help='Minimum number of instruments', default=1)
parser.add_argument('--valence', type=float, help='Conditioning valence value', default=[None], nargs='+')
parser.add_argument('--arousal', type=float, help='Conditioning arousal value', default=[None], nargs='+')
parser.add_argument("--batch_gen_dir", type=str, default="")
args = parser.parse_args()
assert len(args.valence) == len(args.arousal), "Lengths of valence and arousal must be equal"
assert (args.conditioning == "none") == (args.valence == [None] or args.arousal == [None]), \
"If conditioning is used, specify valence and arousal; if not, don't"
if args.seed > 0:
torch.manual_seed(args.seed)
torch.cuda.manual_seed(args.seed)
main_output_dir = "../output"
assert os.path.exists(os.path.join(main_output_dir, args.model_dir))
midi_output_dir = os.path.join(main_output_dir, args.model_dir, "generations", "inference")
new_dir = ""
if args.batch_gen_dir != "":
new_dir = new_dir + "_" + args.batch_gen_dir
if new_dir != "":
midi_output_dir = os.path.join(midi_output_dir, new_dir)
if not args.debug:
os.makedirs(midi_output_dir, exist_ok=True)
model_fp = os.path.join(main_output_dir, args.model_dir, 'model.pt')
mappings_fp = os.path.join(main_output_dir, args.model_dir, 'mappings.pt')
config_fp = os.path.join(main_output_dir, args.model_dir, 'model_config.pt')
if os.path.exists(mappings_fp):
maps = torch.load(mappings_fp)
else:
raise ValueError("Mapping file not found.")
start_symbol = "<START>"
n_emotion_bins = 5
valence_symbols, arousal_symbols = [], []
emotion_bins = np.linspace(-1-1e-12, 1+1e-12, num=n_emotion_bins+1)
if n_emotion_bins % 2 == 0:
bin_ids = list(range(-n_emotion_bins//2, 0)) + list(range(1, n_emotion_bins//2+1))
else:
bin_ids = list(range(-(n_emotion_bins-1)//2, (n_emotion_bins-1)//2 + 1))
for bin_id in bin_ids:
valence_symbols.append(f"<V{bin_id}>")
arousal_symbols.append(f"<A{bin_id}>")
device = torch.device('cuda' if not args.no_cuda and torch.cuda.is_available() else 'cpu')
verbose = not args.quiet
if verbose:
if device == torch.device("cuda"):
print("Using GPU")
else:
print("Using CPU")
# Load model
config = torch.load(config_fp)
model, _ = build_model(None, load_config_dict=config)
model = model.to(device)
if os.path.exists(model_fp):
model.load_state_dict(torch.load(model_fp, map_location=device))
elif os.path.exists(model_fp.replace("best_", "")):
model.load_state_dict(torch.load(model_fp.replace("best_", ""), map_location=device))
else:
raise ValueError("Model not found")
# Process conditions
null_condition = torch.FloatTensor([np.nan, np.nan]).to(device)
varying_condition = None
label_conditions = None
conditions = []
if args.valence == [None]:
conditions = None
elif len(args.valence) == 1:
for _ in range(args.batch_size):
conditions.append([args.valence[0], args.arousal[0]])
else:
for i in range(len(args.valence)):
conditions.append([args.valence[i], args.arousal[i]])
primers = [["<START>"]]
continuous_conditions = conditions
if args.conditioning == "discrete_token":
discrete_conditions = []
for condition in conditions:
valence_val, arousal_val = condition
valence_symbol = valence_symbols[np.searchsorted(
emotion_bins, valence_val, side="right") - 1]
arousal_symbol = arousal_symbols[np.searchsorted(
emotion_bins, arousal_val, side="right") - 1]
discrete_conditions.append([valence_symbol, arousal_symbol])
conditions = null_condition
elif args.conditioning == "none":
discrete_conditions = None
primers = [["<START>"] for _ in range(args.batch_size)]
elif args.conditioning in ["continuous_token", "continuous_concat"]:
primers = [["<START>"]]
discrete_conditions = None
for i in range(args.num_runs):
primers_run = deepcopy(primers)
discrete_conditions_run = deepcopy(discrete_conditions)
continuous_conditions_run = deepcopy(continuous_conditions)
while not (primers_run == [] or discrete_conditions_run == [] or continuous_conditions_run == []):
primers_run, discrete_conditions_run, continuous_conditions_run = generate(
model, maps, device,
midi_output_dir, args.conditioning, discrete_conditions=discrete_conditions_run,
min_n_instruments=args.min_n_instruments,continuous_conditions=continuous_conditions_run,
penalty_coeff=args.penalty_coeff, short_filename=args.short_filename, top_p=args.topp,
gen_len=args.gen_len, max_input_len=args.max_input_len,
amp=not args.no_amp, primers=primers_run, temperatures=args.temp, top_k=args.topk,
debug=args.debug, verbose=not args.quiet, seed=args.seed)