Spaces:
Paused
Paused
v1.2
Browse files- app.py +48 -27
- javascript/app.js +4 -3
- midi_tokenizer.py +146 -35
app.py
CHANGED
|
@@ -111,16 +111,19 @@ def create_msg(name, data):
|
|
| 111 |
return {"name": name, "data": data, "uuid": uuid.uuid4().hex}
|
| 112 |
|
| 113 |
|
| 114 |
-
def send_msgs(msgs, msgs_history):
|
|
|
|
|
|
|
| 115 |
msgs_history.append(msgs)
|
| 116 |
-
if len(msgs_history) >
|
| 117 |
-
msgs_history
|
| 118 |
return json.dumps(msgs_history)
|
| 119 |
|
| 120 |
|
| 121 |
-
def run(model_name, tab, instruments, drum_kit, mid, midi_events, gen_events, temp, top_p, top_k, allow_cc):
|
| 122 |
msgs_history = []
|
| 123 |
mid_seq = []
|
|
|
|
| 124 |
gen_events = int(gen_events)
|
| 125 |
max_len = gen_events
|
| 126 |
|
|
@@ -129,6 +132,8 @@ def run(model_name, tab, instruments, drum_kit, mid, midi_events, gen_events, te
|
|
| 129 |
if tab == 0:
|
| 130 |
i = 0
|
| 131 |
mid = [[tokenizer.bos_id] + [tokenizer.pad_id] * (tokenizer.max_token_seq - 1)]
|
|
|
|
|
|
|
| 132 |
patches = {}
|
| 133 |
if instruments is None:
|
| 134 |
instruments = []
|
|
@@ -151,10 +156,10 @@ def run(model_name, tab, instruments, drum_kit, mid, midi_events, gen_events, te
|
|
| 151 |
max_len += len(mid)
|
| 152 |
for token_seq in mid:
|
| 153 |
mid_seq.append(token_seq.tolist())
|
| 154 |
-
init_msgs = [create_msg("visualizer_clear",
|
| 155 |
for tokens in mid_seq:
|
| 156 |
init_msgs.append(create_msg("visualizer_append", tokenizer.tokens2event(tokens)))
|
| 157 |
-
yield mid_seq, None, None, send_msgs(init_msgs, msgs_history)
|
| 158 |
model = models[model_name]
|
| 159 |
generator = generate(model, mid, max_len=max_len, temp=temp, top_p=top_p, top_k=top_k,
|
| 160 |
disable_patch_change=disable_patch_change, disable_control_change=not allow_cc,
|
|
@@ -163,22 +168,31 @@ def run(model_name, tab, instruments, drum_kit, mid, midi_events, gen_events, te
|
|
| 163 |
token_seq = token_seq.tolist()
|
| 164 |
mid_seq.append(token_seq)
|
| 165 |
event = tokenizer.tokens2event(token_seq)
|
| 166 |
-
yield mid_seq, None, None, send_msgs([create_msg("visualizer_append", event), create_msg("progress", [i + 1, gen_events])], msgs_history)
|
| 167 |
mid = tokenizer.detokenize(mid_seq)
|
| 168 |
with open(f"output.mid", 'wb') as f:
|
| 169 |
f.write(MIDI.score2midi(mid))
|
| 170 |
audio = synthesis(MIDI.score2opus(mid), soundfont_path)
|
| 171 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 172 |
|
| 173 |
|
| 174 |
-
def cancel_run(mid_seq
|
| 175 |
if mid_seq is None:
|
| 176 |
return None, None, []
|
| 177 |
mid = tokenizer.detokenize(mid_seq)
|
| 178 |
with open(f"output.mid", 'wb') as f:
|
| 179 |
f.write(MIDI.score2midi(mid))
|
| 180 |
audio = synthesis(MIDI.score2opus(mid), soundfont_path)
|
| 181 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 182 |
|
| 183 |
|
| 184 |
def load_javascript(dir="javascript"):
|
|
@@ -200,6 +214,7 @@ def load_javascript(dir="javascript"):
|
|
| 200 |
|
| 201 |
|
| 202 |
def hf_hub_download_retry(repo_id, filename):
|
|
|
|
| 203 |
retry = 0
|
| 204 |
err = None
|
| 205 |
while retry < 30:
|
|
@@ -246,9 +261,9 @@ if __name__ == "__main__":
|
|
| 246 |
"Demo for [SkyTNT/midi-model](https://github.com/SkyTNT/midi-model)\n\n"
|
| 247 |
"[Open In Colab]"
|
| 248 |
"(https://colab.research.google.com/github/SkyTNT/midi-model/blob/main/demo.ipynb)"
|
| 249 |
-
" for faster running and longer generation"
|
|
|
|
| 250 |
)
|
| 251 |
-
js_msg_history_state = gr.State(value=[])
|
| 252 |
js_msg = gr.Textbox(elem_id="msg_receiver", visible=False)
|
| 253 |
js_msg.change(None, [js_msg], [], js="""
|
| 254 |
(msg_json) =>{
|
|
@@ -262,19 +277,25 @@ if __name__ == "__main__":
|
|
| 262 |
tab_select = gr.State(value=0)
|
| 263 |
with gr.Tabs():
|
| 264 |
with gr.TabItem("instrument prompt") as tab1:
|
| 265 |
-
input_instruments = gr.Dropdown(label="instruments (auto if empty)", choices=list(patch2number.keys()),
|
| 266 |
multiselect=True, max_choices=15, type="value")
|
| 267 |
-
input_drum_kit = gr.Dropdown(label="drum kit", choices=list(drum_kits2number.keys()), type="value",
|
| 268 |
value="None")
|
|
|
|
|
|
|
|
|
|
| 269 |
example1 = gr.Examples([
|
| 270 |
[[], "None"],
|
| 271 |
[["Acoustic Grand"], "None"],
|
| 272 |
-
[[
|
| 273 |
-
|
| 274 |
-
[[
|
| 275 |
-
|
| 276 |
-
[[
|
| 277 |
-
|
|
|
|
|
|
|
|
|
|
| 278 |
"Electric Bass(finger)"], "Standard"]
|
| 279 |
], [input_instruments, input_drum_kit])
|
| 280 |
with gr.TabItem("midi prompt") as tab2:
|
|
@@ -292,19 +313,19 @@ if __name__ == "__main__":
|
|
| 292 |
with gr.Accordion("options", open=False):
|
| 293 |
input_temp = gr.Slider(label="temperature", minimum=0.1, maximum=1.2, step=0.01, value=1)
|
| 294 |
input_top_p = gr.Slider(label="top p", minimum=0.1, maximum=1, step=0.01, value=0.98)
|
| 295 |
-
input_top_k = gr.Slider(label="top k", minimum=1, maximum=128, step=1, value=
|
| 296 |
input_allow_cc = gr.Checkbox(label="allow midi cc event", value=True)
|
| 297 |
-
example3 = gr.Examples([[1, 0.98,
|
| 298 |
run_btn = gr.Button("generate", variant="primary")
|
| 299 |
stop_btn = gr.Button("stop and output")
|
| 300 |
output_midi_seq = gr.State()
|
| 301 |
output_midi_visualizer = gr.HTML(elem_id="midi_visualizer_container")
|
| 302 |
output_audio = gr.Audio(label="output audio", format="mp3", elem_id="midi_audio")
|
| 303 |
output_midi = gr.File(label="output midi", file_types=[".mid"])
|
| 304 |
-
run_event = run_btn.click(run, [input_model, tab_select, input_instruments, input_drum_kit,
|
| 305 |
-
input_midi_events, input_gen_events, input_temp,
|
| 306 |
-
input_allow_cc],
|
| 307 |
-
[output_midi_seq, output_midi, output_audio, js_msg
|
| 308 |
concurrency_limit=3)
|
| 309 |
-
stop_btn.click(cancel_run, [output_midi_seq
|
| 310 |
app.launch(server_port=opt.port, share=opt.share, inbrowser=True)
|
|
|
|
| 111 |
return {"name": name, "data": data, "uuid": uuid.uuid4().hex}
|
| 112 |
|
| 113 |
|
| 114 |
+
def send_msgs(msgs, msgs_history=None):
|
| 115 |
+
if msgs_history is None:
|
| 116 |
+
msgs_history = []
|
| 117 |
msgs_history.append(msgs)
|
| 118 |
+
if len(msgs_history) > 25:
|
| 119 |
+
msgs_history= msgs_history[1:]
|
| 120 |
return json.dumps(msgs_history)
|
| 121 |
|
| 122 |
|
| 123 |
+
def run(model_name, tab, instruments, drum_kit, bpm, mid, midi_events, gen_events, temp, top_p, top_k, allow_cc):
|
| 124 |
msgs_history = []
|
| 125 |
mid_seq = []
|
| 126 |
+
bpm = int(bpm)
|
| 127 |
gen_events = int(gen_events)
|
| 128 |
max_len = gen_events
|
| 129 |
|
|
|
|
| 132 |
if tab == 0:
|
| 133 |
i = 0
|
| 134 |
mid = [[tokenizer.bos_id] + [tokenizer.pad_id] * (tokenizer.max_token_seq - 1)]
|
| 135 |
+
if bpm != 0:
|
| 136 |
+
mid.append(tokenizer.event2tokens(["set_tempo",0,0,0, bpm]))
|
| 137 |
patches = {}
|
| 138 |
if instruments is None:
|
| 139 |
instruments = []
|
|
|
|
| 156 |
max_len += len(mid)
|
| 157 |
for token_seq in mid:
|
| 158 |
mid_seq.append(token_seq.tolist())
|
| 159 |
+
init_msgs = [create_msg("visualizer_clear", False)]
|
| 160 |
for tokens in mid_seq:
|
| 161 |
init_msgs.append(create_msg("visualizer_append", tokenizer.tokens2event(tokens)))
|
| 162 |
+
yield mid_seq, None, None, send_msgs(init_msgs, msgs_history)
|
| 163 |
model = models[model_name]
|
| 164 |
generator = generate(model, mid, max_len=max_len, temp=temp, top_p=top_p, top_k=top_k,
|
| 165 |
disable_patch_change=disable_patch_change, disable_control_change=not allow_cc,
|
|
|
|
| 168 |
token_seq = token_seq.tolist()
|
| 169 |
mid_seq.append(token_seq)
|
| 170 |
event = tokenizer.tokens2event(token_seq)
|
| 171 |
+
yield mid_seq, None, None, send_msgs([create_msg("visualizer_append", event), create_msg("progress", [i + 1, gen_events])], msgs_history)
|
| 172 |
mid = tokenizer.detokenize(mid_seq)
|
| 173 |
with open(f"output.mid", 'wb') as f:
|
| 174 |
f.write(MIDI.score2midi(mid))
|
| 175 |
audio = synthesis(MIDI.score2opus(mid), soundfont_path)
|
| 176 |
+
# resend all msgs
|
| 177 |
+
msgs = [create_msg("visualizer_end", None), create_msg("visualizer_clear", True)]
|
| 178 |
+
for tokens in mid_seq:
|
| 179 |
+
msgs.append(create_msg("visualizer_append", tokenizer.tokens2event(tokens)))
|
| 180 |
+
msgs.append(create_msg("visualizer_end", None))
|
| 181 |
+
yield mid_seq, "output.mid", (44100, audio), send_msgs(msgs)
|
| 182 |
|
| 183 |
|
| 184 |
+
def cancel_run(mid_seq):
|
| 185 |
if mid_seq is None:
|
| 186 |
return None, None, []
|
| 187 |
mid = tokenizer.detokenize(mid_seq)
|
| 188 |
with open(f"output.mid", 'wb') as f:
|
| 189 |
f.write(MIDI.score2midi(mid))
|
| 190 |
audio = synthesis(MIDI.score2opus(mid), soundfont_path)
|
| 191 |
+
msgs = [create_msg("visualizer_end", None), create_msg("visualizer_clear", True)]
|
| 192 |
+
for tokens in mid_seq:
|
| 193 |
+
msgs.append(create_msg("visualizer_append", tokenizer.tokens2event(tokens)))
|
| 194 |
+
msgs.append(create_msg("visualizer_end", None))
|
| 195 |
+
return "output.mid", (44100, audio), send_msgs(msgs)
|
| 196 |
|
| 197 |
|
| 198 |
def load_javascript(dir="javascript"):
|
|
|
|
| 214 |
|
| 215 |
|
| 216 |
def hf_hub_download_retry(repo_id, filename):
|
| 217 |
+
print(f"downloading {repo_id} {filename}")
|
| 218 |
retry = 0
|
| 219 |
err = None
|
| 220 |
while retry < 30:
|
|
|
|
| 261 |
"Demo for [SkyTNT/midi-model](https://github.com/SkyTNT/midi-model)\n\n"
|
| 262 |
"[Open In Colab]"
|
| 263 |
"(https://colab.research.google.com/github/SkyTNT/midi-model/blob/main/demo.ipynb)"
|
| 264 |
+
" for faster running and longer generation\n\n"
|
| 265 |
+
"**Update v1.2**: Optimise the tokenizer and dataset"
|
| 266 |
)
|
|
|
|
| 267 |
js_msg = gr.Textbox(elem_id="msg_receiver", visible=False)
|
| 268 |
js_msg.change(None, [js_msg], [], js="""
|
| 269 |
(msg_json) =>{
|
|
|
|
| 277 |
tab_select = gr.State(value=0)
|
| 278 |
with gr.Tabs():
|
| 279 |
with gr.TabItem("instrument prompt") as tab1:
|
| 280 |
+
input_instruments = gr.Dropdown(label="🪗instruments (auto if empty)", choices=list(patch2number.keys()),
|
| 281 |
multiselect=True, max_choices=15, type="value")
|
| 282 |
+
input_drum_kit = gr.Dropdown(label="🥁drum kit", choices=list(drum_kits2number.keys()), type="value",
|
| 283 |
value="None")
|
| 284 |
+
input_bpm = gr.Slider(label="BPM (beats per minute, auto if 0)", minimum=0, maximum=255,
|
| 285 |
+
step=1,
|
| 286 |
+
value=0)
|
| 287 |
example1 = gr.Examples([
|
| 288 |
[[], "None"],
|
| 289 |
[["Acoustic Grand"], "None"],
|
| 290 |
+
[['Acoustic Grand', 'SynthStrings 2', 'SynthStrings 1', 'Pizzicato Strings',
|
| 291 |
+
'Pad 2 (warm)', 'Tremolo Strings', 'String Ensemble 1'], "Orchestra"],
|
| 292 |
+
[['Trumpet', 'Oboe', 'Trombone', 'String Ensemble 1', 'Clarinet',
|
| 293 |
+
'French Horn', 'Pad 4 (choir)', 'Bassoon', 'Flute'], "None"],
|
| 294 |
+
[['Flute', 'French Horn', 'Clarinet', 'String Ensemble 2', 'English Horn', 'Bassoon',
|
| 295 |
+
'Oboe', 'Pizzicato Strings'], "Orchestra"],
|
| 296 |
+
[['Electric Piano 2', 'Lead 5 (charang)', 'Electric Bass(pick)', 'Lead 2 (sawtooth)',
|
| 297 |
+
'Pad 1 (new age)', 'Orchestra Hit', 'Cello', 'Electric Guitar(clean)'], "Standard"],
|
| 298 |
+
[["Electric Guitar(clean)", "Electric Guitar(muted)", "Overdriven Guitar", "Distortion Guitar",
|
| 299 |
"Electric Bass(finger)"], "Standard"]
|
| 300 |
], [input_instruments, input_drum_kit])
|
| 301 |
with gr.TabItem("midi prompt") as tab2:
|
|
|
|
| 313 |
with gr.Accordion("options", open=False):
|
| 314 |
input_temp = gr.Slider(label="temperature", minimum=0.1, maximum=1.2, step=0.01, value=1)
|
| 315 |
input_top_p = gr.Slider(label="top p", minimum=0.1, maximum=1, step=0.01, value=0.98)
|
| 316 |
+
input_top_k = gr.Slider(label="top k", minimum=1, maximum=128, step=1, value=20)
|
| 317 |
input_allow_cc = gr.Checkbox(label="allow midi cc event", value=True)
|
| 318 |
+
example3 = gr.Examples([[1, 0.98, 20], [1, 0.98, 12]], [input_temp, input_top_p, input_top_k])
|
| 319 |
run_btn = gr.Button("generate", variant="primary")
|
| 320 |
stop_btn = gr.Button("stop and output")
|
| 321 |
output_midi_seq = gr.State()
|
| 322 |
output_midi_visualizer = gr.HTML(elem_id="midi_visualizer_container")
|
| 323 |
output_audio = gr.Audio(label="output audio", format="mp3", elem_id="midi_audio")
|
| 324 |
output_midi = gr.File(label="output midi", file_types=[".mid"])
|
| 325 |
+
run_event = run_btn.click(run, [input_model, tab_select, input_instruments, input_drum_kit, input_bpm,
|
| 326 |
+
input_midi, input_midi_events, input_gen_events, input_temp,
|
| 327 |
+
input_top_p, input_top_k, input_allow_cc],
|
| 328 |
+
[output_midi_seq, output_midi, output_audio, js_msg],
|
| 329 |
concurrency_limit=3)
|
| 330 |
+
stop_btn.click(cancel_run, [output_midi_seq], [output_midi, output_audio, js_msg], cancels=run_event, queue=False)
|
| 331 |
app.launch(server_port=opt.port, share=opt.share, inbrowser=True)
|
javascript/app.js
CHANGED
|
@@ -146,13 +146,14 @@ class MidiVisualizer extends HTMLElement{
|
|
| 146 |
this.setPlayTime(0);
|
| 147 |
}
|
| 148 |
|
| 149 |
-
clearMidiEvents(){
|
| 150 |
this.pause()
|
| 151 |
this.midiEvents = [];
|
| 152 |
this.activeNotes = [];
|
| 153 |
this.midiTimes = [];
|
| 154 |
this.t1 = 0
|
| 155 |
-
|
|
|
|
| 156 |
this.setPlayTime(0);
|
| 157 |
this.totalTimeMs = 0;
|
| 158 |
this.playTimeMs = 0
|
|
@@ -426,7 +427,7 @@ customElements.define('midi-visualizer', MidiVisualizer);
|
|
| 426 |
handled_msgs.push(msg.uuid);
|
| 427 |
switch (msg.name) {
|
| 428 |
case "visualizer_clear":
|
| 429 |
-
midi_visualizer.clearMidiEvents();
|
| 430 |
createProgressBar(midi_visualizer_container_inited)
|
| 431 |
break;
|
| 432 |
case "visualizer_append":
|
|
|
|
| 146 |
this.setPlayTime(0);
|
| 147 |
}
|
| 148 |
|
| 149 |
+
clearMidiEvents(keepColor=false){
|
| 150 |
this.pause()
|
| 151 |
this.midiEvents = [];
|
| 152 |
this.activeNotes = [];
|
| 153 |
this.midiTimes = [];
|
| 154 |
this.t1 = 0
|
| 155 |
+
if (!keepColor)
|
| 156 |
+
this.colorMap.clear()
|
| 157 |
this.setPlayTime(0);
|
| 158 |
this.totalTimeMs = 0;
|
| 159 |
this.playTimeMs = 0
|
|
|
|
| 427 |
handled_msgs.push(msg.uuid);
|
| 428 |
switch (msg.name) {
|
| 429 |
case "visualizer_clear":
|
| 430 |
+
midi_visualizer.clearMidiEvents(msg.data);
|
| 431 |
createProgressBar(midi_visualizer_container_inited)
|
| 432 |
break;
|
| 433 |
case "visualizer_append":
|
midi_tokenizer.py
CHANGED
|
@@ -42,22 +42,48 @@ class MIDITokenizer:
|
|
| 42 |
tempo = int((60 / bpm) * 10 ** 6)
|
| 43 |
return tempo
|
| 44 |
|
| 45 |
-
def tokenize(self, midi_score, add_bos_eos=True):
|
| 46 |
ticks_per_beat = midi_score[0]
|
| 47 |
event_list = {}
|
| 48 |
for track_idx, track in enumerate(midi_score[1:129]):
|
| 49 |
last_notes = {}
|
|
|
|
|
|
|
|
|
|
| 50 |
for event in track:
|
|
|
|
|
|
|
| 51 |
t = round(16 * event[1] / ticks_per_beat) # quantization
|
| 52 |
new_event = [event[0], t // 16, t % 16, track_idx] + event[2:]
|
| 53 |
if event[0] == "note":
|
| 54 |
new_event[4] = max(1, round(16 * new_event[4] / ticks_per_beat))
|
| 55 |
elif event[0] == "set_tempo":
|
| 56 |
-
new_event[4]
|
|
|
|
|
|
|
|
|
|
| 57 |
if event[0] == "note":
|
| 58 |
key = tuple(new_event[:4] + new_event[5:-1])
|
| 59 |
else:
|
| 60 |
key = tuple(new_event[:-1])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 61 |
if event[0] == "note": # to eliminate note overlap due to quantization
|
| 62 |
cp = tuple(new_event[5:7])
|
| 63 |
if cp in last_notes:
|
|
@@ -71,21 +97,39 @@ class MIDITokenizer:
|
|
| 71 |
event_list = list(event_list.values())
|
| 72 |
event_list = sorted(event_list, key=lambda e: e[1:4])
|
| 73 |
midi_seq = []
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 74 |
|
| 75 |
last_t1 = 0
|
| 76 |
for event in event_list:
|
| 77 |
-
|
| 78 |
-
|
| 79 |
-
|
| 80 |
-
|
| 81 |
-
|
| 82 |
-
|
| 83 |
-
|
| 84 |
-
tokens = [self.event_ids[name]] + [self.parameter_ids[p][params[i]]
|
| 85 |
-
for i, p in enumerate(self.events[name])]
|
| 86 |
-
tokens += [self.pad_id] * (self.max_token_seq - len(tokens))
|
| 87 |
-
midi_seq.append(tokens)
|
| 88 |
-
last_t1 = cur_t1
|
| 89 |
|
| 90 |
if add_bos_eos:
|
| 91 |
bos = [self.bos_id] + [self.pad_id] * (self.max_token_seq - 1)
|
|
@@ -96,6 +140,8 @@ class MIDITokenizer:
|
|
| 96 |
def event2tokens(self, event):
|
| 97 |
name = event[0]
|
| 98 |
params = event[1:]
|
|
|
|
|
|
|
| 99 |
tokens = [self.event_ids[name]] + [self.parameter_ids[p][params[i]]
|
| 100 |
for i, p in enumerate(self.events[name])]
|
| 101 |
tokens += [self.pad_id] * (self.max_token_seq - len(tokens))
|
|
@@ -120,14 +166,10 @@ class MIDITokenizer:
|
|
| 120 |
t1 = 0
|
| 121 |
for tokens in midi_seq:
|
| 122 |
if tokens[0] in self.id_events:
|
| 123 |
-
|
| 124 |
-
if
|
| 125 |
continue
|
| 126 |
-
|
| 127 |
-
params = [params[i] - self.parameter_ids[p][0] for i, p in enumerate(self.events[name])]
|
| 128 |
-
if not all([0 <= params[i] < self.event_parameters[p] for i, p in enumerate(self.events[name])]):
|
| 129 |
-
continue
|
| 130 |
-
event = [name] + params
|
| 131 |
if name == "set_tempo":
|
| 132 |
event[4] = self.bpm2tempo(event[4])
|
| 133 |
if event[0] == "note":
|
|
@@ -183,7 +225,7 @@ class MIDITokenizer:
|
|
| 183 |
return img
|
| 184 |
|
| 185 |
def augment(self, midi_seq, max_pitch_shift=4, max_vel_shift=10, max_cc_val_shift=10, max_bpm_shift=10,
|
| 186 |
-
max_track_shift=
|
| 187 |
pitch_shift = random.randint(-max_pitch_shift, max_pitch_shift)
|
| 188 |
vel_shift = random.randint(-max_vel_shift, max_vel_shift)
|
| 189 |
cc_val_shift = random.randint(-max_cc_val_shift, max_cc_val_shift)
|
|
@@ -239,16 +281,85 @@ class MIDITokenizer:
|
|
| 239 |
midi_seq_new.append(tokens_new)
|
| 240 |
return midi_seq_new
|
| 241 |
|
| 242 |
-
def
|
| 243 |
-
|
| 244 |
-
|
| 245 |
-
|
| 246 |
-
|
| 247 |
-
|
| 248 |
-
|
| 249 |
-
|
| 250 |
-
|
| 251 |
-
|
| 252 |
-
|
| 253 |
-
|
| 254 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 42 |
tempo = int((60 / bpm) * 10 ** 6)
|
| 43 |
return tempo
|
| 44 |
|
| 45 |
+
def tokenize(self, midi_score, add_bos_eos=True, cc_eps=4, tempo_eps=4):
|
| 46 |
ticks_per_beat = midi_score[0]
|
| 47 |
event_list = {}
|
| 48 |
for track_idx, track in enumerate(midi_score[1:129]):
|
| 49 |
last_notes = {}
|
| 50 |
+
patch_dict = {}
|
| 51 |
+
control_dict = {}
|
| 52 |
+
last_tempo = 0
|
| 53 |
for event in track:
|
| 54 |
+
if event[0] not in self.events:
|
| 55 |
+
continue
|
| 56 |
t = round(16 * event[1] / ticks_per_beat) # quantization
|
| 57 |
new_event = [event[0], t // 16, t % 16, track_idx] + event[2:]
|
| 58 |
if event[0] == "note":
|
| 59 |
new_event[4] = max(1, round(16 * new_event[4] / ticks_per_beat))
|
| 60 |
elif event[0] == "set_tempo":
|
| 61 |
+
if new_event[4] == 0: # invalid tempo
|
| 62 |
+
continue
|
| 63 |
+
bpm = int(self.tempo2bpm(new_event[4]))
|
| 64 |
+
new_event[4] = min(bpm, 255)
|
| 65 |
if event[0] == "note":
|
| 66 |
key = tuple(new_event[:4] + new_event[5:-1])
|
| 67 |
else:
|
| 68 |
key = tuple(new_event[:-1])
|
| 69 |
+
if event[0] == "patch_change":
|
| 70 |
+
c, p = event[2:]
|
| 71 |
+
last_p = patch_dict.setdefault(c, None)
|
| 72 |
+
if last_p == p:
|
| 73 |
+
continue
|
| 74 |
+
patch_dict[c] = p
|
| 75 |
+
elif event[0] == "control_change":
|
| 76 |
+
c, cc, v = event[2:]
|
| 77 |
+
last_v = control_dict.setdefault((c, cc), 0)
|
| 78 |
+
if abs(last_v - v) < cc_eps:
|
| 79 |
+
continue
|
| 80 |
+
control_dict[(c, cc)] = v
|
| 81 |
+
elif event[0] == "set_tempo":
|
| 82 |
+
tempo = new_event[-1]
|
| 83 |
+
if abs(last_tempo - tempo) < tempo_eps:
|
| 84 |
+
continue
|
| 85 |
+
last_tempo = tempo
|
| 86 |
+
|
| 87 |
if event[0] == "note": # to eliminate note overlap due to quantization
|
| 88 |
cp = tuple(new_event[5:7])
|
| 89 |
if cp in last_notes:
|
|
|
|
| 97 |
event_list = list(event_list.values())
|
| 98 |
event_list = sorted(event_list, key=lambda e: e[1:4])
|
| 99 |
midi_seq = []
|
| 100 |
+
setup_events = {}
|
| 101 |
+
notes_in_setup = False
|
| 102 |
+
for i, event in enumerate(event_list): # optimise setup
|
| 103 |
+
new_event = [*event]
|
| 104 |
+
if event[0] != "note":
|
| 105 |
+
new_event[1] = 0
|
| 106 |
+
new_event[2] = 0
|
| 107 |
+
has_next = False
|
| 108 |
+
has_pre = False
|
| 109 |
+
if i < len(event_list) - 1:
|
| 110 |
+
next_event = event_list[i + 1]
|
| 111 |
+
has_next = event[1] + event[2] == next_event[1] + next_event[2]
|
| 112 |
+
if notes_in_setup and i > 0:
|
| 113 |
+
pre_event = event_list[i - 1]
|
| 114 |
+
has_pre = event[1] + event[2] == pre_event[1] + pre_event[2]
|
| 115 |
+
if (event[0] == "note" and not has_next) or (notes_in_setup and not has_pre) :
|
| 116 |
+
event_list = sorted(setup_events.values(), key=lambda e: 1 if e[0] == "note" else 0) + event_list[i:]
|
| 117 |
+
break
|
| 118 |
+
else:
|
| 119 |
+
if event[0] == "note":
|
| 120 |
+
notes_in_setup = True
|
| 121 |
+
key = tuple(event[3:-1])
|
| 122 |
+
setup_events[key] = new_event
|
| 123 |
|
| 124 |
last_t1 = 0
|
| 125 |
for event in event_list:
|
| 126 |
+
cur_t1 = event[1]
|
| 127 |
+
event[1] = event[1] - last_t1
|
| 128 |
+
tokens = self.event2tokens(event)
|
| 129 |
+
if not tokens:
|
| 130 |
+
continue
|
| 131 |
+
midi_seq.append(tokens)
|
| 132 |
+
last_t1 = cur_t1
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 133 |
|
| 134 |
if add_bos_eos:
|
| 135 |
bos = [self.bos_id] + [self.pad_id] * (self.max_token_seq - 1)
|
|
|
|
| 140 |
def event2tokens(self, event):
|
| 141 |
name = event[0]
|
| 142 |
params = event[1:]
|
| 143 |
+
if not all([0 <= params[i] < self.event_parameters[p] for i, p in enumerate(self.events[name])]):
|
| 144 |
+
return []
|
| 145 |
tokens = [self.event_ids[name]] + [self.parameter_ids[p][params[i]]
|
| 146 |
for i, p in enumerate(self.events[name])]
|
| 147 |
tokens += [self.pad_id] * (self.max_token_seq - len(tokens))
|
|
|
|
| 166 |
t1 = 0
|
| 167 |
for tokens in midi_seq:
|
| 168 |
if tokens[0] in self.id_events:
|
| 169 |
+
event = self.tokens2event(tokens)
|
| 170 |
+
if not event:
|
| 171 |
continue
|
| 172 |
+
name = event[0]
|
|
|
|
|
|
|
|
|
|
|
|
|
| 173 |
if name == "set_tempo":
|
| 174 |
event[4] = self.bpm2tempo(event[4])
|
| 175 |
if event[0] == "note":
|
|
|
|
| 225 |
return img
|
| 226 |
|
| 227 |
def augment(self, midi_seq, max_pitch_shift=4, max_vel_shift=10, max_cc_val_shift=10, max_bpm_shift=10,
|
| 228 |
+
max_track_shift=0, max_channel_shift=16):
|
| 229 |
pitch_shift = random.randint(-max_pitch_shift, max_pitch_shift)
|
| 230 |
vel_shift = random.randint(-max_vel_shift, max_vel_shift)
|
| 231 |
cc_val_shift = random.randint(-max_cc_val_shift, max_cc_val_shift)
|
|
|
|
| 281 |
midi_seq_new.append(tokens_new)
|
| 282 |
return midi_seq_new
|
| 283 |
|
| 284 |
+
def check_quality(self, midi_seq, alignment_min=0.4, tonality_min=0.8, piano_max=0.7, notes_bandwidth_min=3, notes_density_max=30, notes_density_min=2.5, total_notes_max=10000, total_notes_min=500, note_window_size=16):
|
| 285 |
+
total_notes = 0
|
| 286 |
+
channels = []
|
| 287 |
+
time_hist = [0] * 16
|
| 288 |
+
note_windows = {}
|
| 289 |
+
notes_sametime = []
|
| 290 |
+
notes_density_list = []
|
| 291 |
+
tonality_list = []
|
| 292 |
+
notes_bandwidth_list = []
|
| 293 |
+
instruments = {}
|
| 294 |
+
piano_channels = []
|
| 295 |
+
undef_instrument = False
|
| 296 |
+
abs_t1 = 0
|
| 297 |
+
last_t = 0
|
| 298 |
+
for tsi, tokens in enumerate(midi_seq):
|
| 299 |
+
event = self.tokens2event(tokens)
|
| 300 |
+
if not event:
|
| 301 |
+
continue
|
| 302 |
+
t1, t2, tr = event[1:4]
|
| 303 |
+
abs_t1 += t1
|
| 304 |
+
t = abs_t1 * 16 + t2
|
| 305 |
+
c = None
|
| 306 |
+
if event[0] == "note":
|
| 307 |
+
d, c, p, v = event[4:]
|
| 308 |
+
total_notes += 1
|
| 309 |
+
time_hist[t2] += 1
|
| 310 |
+
if c != 9: # ignore drum channel
|
| 311 |
+
if c not in instruments:
|
| 312 |
+
undef_instrument = True
|
| 313 |
+
note_windows.setdefault(abs_t1 // note_window_size, []).append(p)
|
| 314 |
+
if last_t != t:
|
| 315 |
+
notes_sametime = [(et, p_) for et, p_ in notes_sametime if et > last_t]
|
| 316 |
+
notes_sametime_p = [p_ for _, p_ in notes_sametime]
|
| 317 |
+
if len(notes_sametime) > 0:
|
| 318 |
+
notes_bandwidth_list.append(max(notes_sametime_p) - min(notes_sametime_p))
|
| 319 |
+
notes_sametime.append((t + d - 1, p))
|
| 320 |
+
elif event[0] == "patch_change":
|
| 321 |
+
c, p = event[4:]
|
| 322 |
+
instruments[c] = p
|
| 323 |
+
if p == 0 and c not in piano_channels:
|
| 324 |
+
piano_channels.append(c)
|
| 325 |
+
if c is not None and c not in channels:
|
| 326 |
+
channels.append(c)
|
| 327 |
+
last_t = t
|
| 328 |
+
reasons = []
|
| 329 |
+
if total_notes < total_notes_min:
|
| 330 |
+
reasons.append("total_min")
|
| 331 |
+
if total_notes > total_notes_max:
|
| 332 |
+
reasons.append("total_max")
|
| 333 |
+
if undef_instrument:
|
| 334 |
+
reasons.append("undef_instr")
|
| 335 |
+
if len(note_windows) == 0 and total_notes > 0:
|
| 336 |
+
reasons.append("drum_only")
|
| 337 |
+
if reasons:
|
| 338 |
+
return False, reasons
|
| 339 |
+
time_hist = sorted(time_hist, reverse=True)
|
| 340 |
+
alignment = sum(time_hist[:2]) / total_notes
|
| 341 |
+
for notes in note_windows.values():
|
| 342 |
+
key_hist = [0] * 12
|
| 343 |
+
for p in notes:
|
| 344 |
+
key_hist[p % 12] += 1
|
| 345 |
+
key_hist = sorted(key_hist, reverse=True)
|
| 346 |
+
tonality_list.append(sum(key_hist[:7]) / len(notes))
|
| 347 |
+
notes_density_list.append(len(notes) / note_window_size)
|
| 348 |
+
tonality_list = sorted(tonality_list)
|
| 349 |
+
tonality = sum(tonality_list)/len(tonality_list)
|
| 350 |
+
notes_bandwidth = sum(notes_bandwidth_list)/len(notes_bandwidth_list) if notes_bandwidth_list else 0
|
| 351 |
+
notes_density = max(notes_density_list) if notes_density_list else 0
|
| 352 |
+
piano_ratio = len(piano_channels) / len(channels)
|
| 353 |
+
if len(channels) <=3: # ignore piano threshold if it is a piano solo midi
|
| 354 |
+
piano_max = 1
|
| 355 |
+
if alignment < alignment_min: # check weather the notes align to the bars (because some midi files are recorded)
|
| 356 |
+
reasons.append("alignment")
|
| 357 |
+
if tonality < tonality_min: # check whether the music is tonal
|
| 358 |
+
reasons.append("tonality")
|
| 359 |
+
if notes_bandwidth < notes_bandwidth_min: # check whether music is melodic line only
|
| 360 |
+
reasons.append("bandwidth")
|
| 361 |
+
if not notes_density_min < notes_density < notes_density_max:
|
| 362 |
+
reasons.append("density")
|
| 363 |
+
if piano_ratio > piano_max: # check whether most instruments is piano (because some midi files don't have instruments assigned correctly)
|
| 364 |
+
reasons.append("piano")
|
| 365 |
+
return not reasons, reasons
|