Spaces:
Running
Running
fix streaming
Browse files- app.py +14 -12
- javascript/app.js +3 -6
app.py
CHANGED
|
@@ -1,6 +1,7 @@
|
|
| 1 |
import argparse
|
| 2 |
import glob
|
| 3 |
import os.path
|
|
|
|
| 4 |
import uuid
|
| 5 |
|
| 6 |
import gradio as gr
|
|
@@ -113,21 +114,15 @@ def generate(model, prompt=None, max_len=512, temp=1.0, top_p=0.98, top_k=20,
|
|
| 113 |
|
| 114 |
|
| 115 |
def create_msg(name, data):
|
| 116 |
-
return {"name": name, "data": data
|
| 117 |
|
| 118 |
|
| 119 |
-
def send_msgs(msgs
|
| 120 |
-
|
| 121 |
-
msgs_history = []
|
| 122 |
-
msgs_history.append(msgs)
|
| 123 |
-
if len(msgs_history) > 25:
|
| 124 |
-
msgs_history= msgs_history[1:]
|
| 125 |
-
return json.dumps(msgs_history)
|
| 126 |
|
| 127 |
|
| 128 |
def run(model_name, tab, instruments, drum_kit, bpm, mid, midi_events, seed, seed_rand,
|
| 129 |
gen_events, temp, top_p, top_k, allow_cc):
|
| 130 |
-
msgs_history = []
|
| 131 |
mid_seq = []
|
| 132 |
bpm = int(bpm)
|
| 133 |
gen_events = int(gen_events)
|
|
@@ -167,16 +162,23 @@ def run(model_name, tab, instruments, drum_kit, bpm, mid, midi_events, seed, see
|
|
| 167 |
init_msgs = [create_msg("visualizer_clear", False)]
|
| 168 |
for tokens in mid_seq:
|
| 169 |
init_msgs.append(create_msg("visualizer_append", tokenizer.tokens2event(tokens)))
|
| 170 |
-
yield mid_seq, None, None, seed, send_msgs(init_msgs
|
| 171 |
model = models[model_name]
|
| 172 |
midi_generator = generate(model, mid, max_len=max_len, temp=temp, top_p=top_p, top_k=top_k,
|
| 173 |
disable_patch_change=disable_patch_change, disable_control_change=not allow_cc,
|
| 174 |
disable_channels=disable_channels, generator=generator)
|
|
|
|
|
|
|
| 175 |
for i, token_seq in enumerate(midi_generator):
|
| 176 |
token_seq = token_seq.tolist()
|
| 177 |
mid_seq.append(token_seq)
|
| 178 |
-
|
| 179 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 180 |
mid = tokenizer.detokenize(mid_seq)
|
| 181 |
with open(f"output.mid", 'wb') as f:
|
| 182 |
f.write(MIDI.score2midi(mid))
|
|
|
|
| 1 |
import argparse
|
| 2 |
import glob
|
| 3 |
import os.path
|
| 4 |
+
import time
|
| 5 |
import uuid
|
| 6 |
|
| 7 |
import gradio as gr
|
|
|
|
| 114 |
|
| 115 |
|
| 116 |
def create_msg(name, data):
|
| 117 |
+
return {"name": name, "data": data}
|
| 118 |
|
| 119 |
|
| 120 |
+
def send_msgs(msgs):
|
| 121 |
+
return json.dumps(msgs)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 122 |
|
| 123 |
|
| 124 |
def run(model_name, tab, instruments, drum_kit, bpm, mid, midi_events, seed, seed_rand,
|
| 125 |
gen_events, temp, top_p, top_k, allow_cc):
|
|
|
|
| 126 |
mid_seq = []
|
| 127 |
bpm = int(bpm)
|
| 128 |
gen_events = int(gen_events)
|
|
|
|
| 162 |
init_msgs = [create_msg("visualizer_clear", False)]
|
| 163 |
for tokens in mid_seq:
|
| 164 |
init_msgs.append(create_msg("visualizer_append", tokenizer.tokens2event(tokens)))
|
| 165 |
+
yield mid_seq, None, None, seed, send_msgs(init_msgs)
|
| 166 |
model = models[model_name]
|
| 167 |
midi_generator = generate(model, mid, max_len=max_len, temp=temp, top_p=top_p, top_k=top_k,
|
| 168 |
disable_patch_change=disable_patch_change, disable_control_change=not allow_cc,
|
| 169 |
disable_channels=disable_channels, generator=generator)
|
| 170 |
+
t = time.time()
|
| 171 |
+
events = []
|
| 172 |
for i, token_seq in enumerate(midi_generator):
|
| 173 |
token_seq = token_seq.tolist()
|
| 174 |
mid_seq.append(token_seq)
|
| 175 |
+
events.append(tokenizer.tokens2event(token_seq))
|
| 176 |
+
ct = time.time()
|
| 177 |
+
if ct - t > 0.2:
|
| 178 |
+
yield mid_seq, None, None, seed, send_msgs([create_msg("visualizer_append", events), create_msg("progress", [i + 1, gen_events])])
|
| 179 |
+
t = ct
|
| 180 |
+
events = []
|
| 181 |
+
|
| 182 |
mid = tokenizer.detokenize(mid_seq)
|
| 183 |
with open(f"output.mid", 'wb') as f:
|
| 184 |
f.write(MIDI.score2midi(mid))
|
javascript/app.js
CHANGED
|
@@ -420,18 +420,16 @@ customElements.define('midi-visualizer', MidiVisualizer);
|
|
| 420 |
}
|
| 421 |
}
|
| 422 |
})
|
| 423 |
-
let handled_msgs = [];
|
| 424 |
function handleMsg(msg){
|
| 425 |
-
if(handled_msgs.indexOf(msg.uuid)!== -1)
|
| 426 |
-
return;
|
| 427 |
-
handled_msgs.push(msg.uuid);
|
| 428 |
switch (msg.name) {
|
| 429 |
case "visualizer_clear":
|
| 430 |
midi_visualizer.clearMidiEvents(false);
|
| 431 |
createProgressBar(midi_visualizer_container_inited)
|
| 432 |
break;
|
| 433 |
case "visualizer_append":
|
| 434 |
-
|
|
|
|
|
|
|
| 435 |
break;
|
| 436 |
case "progress":
|
| 437 |
let progress = msg.data[0]
|
|
@@ -446,7 +444,6 @@ customElements.define('midi-visualizer', MidiVisualizer);
|
|
| 446 |
midi_visualizer.finishAppendMidiEvent()
|
| 447 |
midi_visualizer.setPlayTime(0);
|
| 448 |
removeProgressBar(midi_visualizer_container_inited);
|
| 449 |
-
handled_msgs = []
|
| 450 |
break;
|
| 451 |
default:
|
| 452 |
}
|
|
|
|
| 420 |
}
|
| 421 |
}
|
| 422 |
})
|
|
|
|
| 423 |
function handleMsg(msg){
|
|
|
|
|
|
|
|
|
|
| 424 |
switch (msg.name) {
|
| 425 |
case "visualizer_clear":
|
| 426 |
midi_visualizer.clearMidiEvents(false);
|
| 427 |
createProgressBar(midi_visualizer_container_inited)
|
| 428 |
break;
|
| 429 |
case "visualizer_append":
|
| 430 |
+
msg.data.forEach( value => {
|
| 431 |
+
midi_visualizer.appendMidiEvent(value);
|
| 432 |
+
})
|
| 433 |
break;
|
| 434 |
case "progress":
|
| 435 |
let progress = msg.data[0]
|
|
|
|
| 444 |
midi_visualizer.finishAppendMidiEvent()
|
| 445 |
midi_visualizer.setPlayTime(0);
|
| 446 |
removeProgressBar(midi_visualizer_container_inited);
|
|
|
|
| 447 |
break;
|
| 448 |
default:
|
| 449 |
}
|