Spaces:
Running
on
Zero
Running
on
Zero
Update app.py
Browse files
app.py
CHANGED
|
@@ -1,4 +1,3 @@
|
|
| 1 |
-
import argparse
|
| 2 |
import glob
|
| 3 |
import json
|
| 4 |
import os.path
|
|
@@ -11,6 +10,7 @@ import torch
|
|
| 11 |
import torch.nn.functional as F
|
| 12 |
|
| 13 |
import gradio as gr
|
|
|
|
| 14 |
|
| 15 |
from x_transformer import *
|
| 16 |
import tqdm
|
|
@@ -24,7 +24,7 @@ in_space = os.getenv("SYSTEM") == "spaces"
|
|
| 24 |
|
| 25 |
# =================================================================================================
|
| 26 |
|
| 27 |
-
@
|
| 28 |
def GenerateMIDI(num_tok, idrums, iinstr):
|
| 29 |
print('=' * 70)
|
| 30 |
print('Req start time: {:%Y-%m-%d %H:%M:%S}'.format(datetime.datetime.now(PDT)))
|
|
@@ -83,6 +83,38 @@ def GenerateMIDI(num_tok, idrums, iinstr):
|
|
| 83 |
|
| 84 |
yield output, None, None, [create_msg("visualizer_clear", None)]
|
| 85 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 86 |
outy = start_tokens
|
| 87 |
|
| 88 |
ctime = 0
|
|
@@ -201,42 +233,6 @@ if __name__ == "__main__":
|
|
| 201 |
print('App start time: {:%Y-%m-%d %H:%M:%S}'.format(datetime.datetime.now(PDT)))
|
| 202 |
print('=' * 70)
|
| 203 |
|
| 204 |
-
parser = argparse.ArgumentParser()
|
| 205 |
-
parser.add_argument("--share", action="store_true", default=False, help="share gradio app")
|
| 206 |
-
parser.add_argument("--port", type=int, default=7860, help="gradio server port")
|
| 207 |
-
opt = parser.parse_args()
|
| 208 |
-
|
| 209 |
-
print('Loading model...')
|
| 210 |
-
|
| 211 |
-
SEQ_LEN = 2048
|
| 212 |
-
|
| 213 |
-
# instantiate the model
|
| 214 |
-
|
| 215 |
-
model = TransformerWrapper(
|
| 216 |
-
num_tokens=3088,
|
| 217 |
-
max_seq_len=SEQ_LEN,
|
| 218 |
-
attn_layers=Decoder(dim=1024, depth=16, heads=8)
|
| 219 |
-
)
|
| 220 |
-
|
| 221 |
-
model = AutoregressiveWrapper(model)
|
| 222 |
-
|
| 223 |
-
model = torch.nn.DataParallel(model)
|
| 224 |
-
|
| 225 |
-
model.cpu()
|
| 226 |
-
print('=' * 70)
|
| 227 |
-
|
| 228 |
-
print('Loading model checkpoint...')
|
| 229 |
-
|
| 230 |
-
model.load_state_dict(
|
| 231 |
-
torch.load('Allegro_Music_Transformer_Tiny_Trained_Model_80000_steps_0.9457_loss_0.7443_acc.pth',
|
| 232 |
-
map_location='cpu'))
|
| 233 |
-
print('=' * 70)
|
| 234 |
-
|
| 235 |
-
model.eval()
|
| 236 |
-
|
| 237 |
-
print('Done!')
|
| 238 |
-
print('=' * 70)
|
| 239 |
-
|
| 240 |
load_javascript()
|
| 241 |
app = gr.Blocks()
|
| 242 |
with app:
|
|
@@ -267,4 +263,4 @@ if __name__ == "__main__":
|
|
| 267 |
[output_midi_seq, output_midi, output_audio, js_msg])
|
| 268 |
interrupt_btn.click(cancel_run, output_midi_seq, [output_midi, output_audio, js_msg],
|
| 269 |
cancels=run_event, queue=False)
|
| 270 |
-
app.queue(
|
|
|
|
|
|
|
| 1 |
import glob
|
| 2 |
import json
|
| 3 |
import os.path
|
|
|
|
| 10 |
import torch.nn.functional as F
|
| 11 |
|
| 12 |
import gradio as gr
|
| 13 |
+
import spaces
|
| 14 |
|
| 15 |
from x_transformer import *
|
| 16 |
import tqdm
|
|
|
|
| 24 |
|
| 25 |
# =================================================================================================
|
| 26 |
|
| 27 |
+
@spaces.GPU
|
| 28 |
def GenerateMIDI(num_tok, idrums, iinstr):
|
| 29 |
print('=' * 70)
|
| 30 |
print('Req start time: {:%Y-%m-%d %H:%M:%S}'.format(datetime.datetime.now(PDT)))
|
|
|
|
| 83 |
|
| 84 |
yield output, None, None, [create_msg("visualizer_clear", None)]
|
| 85 |
|
| 86 |
+
|
| 87 |
+
print('Loading model...')
|
| 88 |
+
|
| 89 |
+
SEQ_LEN = 2048
|
| 90 |
+
|
| 91 |
+
# instantiate the model
|
| 92 |
+
|
| 93 |
+
model = TransformerWrapper(
|
| 94 |
+
num_tokens=3088,
|
| 95 |
+
max_seq_len=SEQ_LEN,
|
| 96 |
+
attn_layers=Decoder(dim=1024, depth=16, heads=8)
|
| 97 |
+
)
|
| 98 |
+
|
| 99 |
+
model = AutoregressiveWrapper(model)
|
| 100 |
+
|
| 101 |
+
model = torch.nn.DataParallel(model)
|
| 102 |
+
|
| 103 |
+
model.cpu()
|
| 104 |
+
print('=' * 70)
|
| 105 |
+
|
| 106 |
+
print('Loading model checkpoint...')
|
| 107 |
+
|
| 108 |
+
model.load_state_dict(
|
| 109 |
+
torch.load('Allegro_Music_Transformer_Tiny_Trained_Model_80000_steps_0.9457_loss_0.7443_acc.pth',
|
| 110 |
+
map_location='cpu'))
|
| 111 |
+
print('=' * 70)
|
| 112 |
+
|
| 113 |
+
model.eval()
|
| 114 |
+
|
| 115 |
+
print('Done!')
|
| 116 |
+
print('=' * 70)
|
| 117 |
+
|
| 118 |
outy = start_tokens
|
| 119 |
|
| 120 |
ctime = 0
|
|
|
|
| 233 |
print('App start time: {:%Y-%m-%d %H:%M:%S}'.format(datetime.datetime.now(PDT)))
|
| 234 |
print('=' * 70)
|
| 235 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 236 |
load_javascript()
|
| 237 |
app = gr.Blocks()
|
| 238 |
with app:
|
|
|
|
| 263 |
[output_midi_seq, output_midi, output_audio, js_msg])
|
| 264 |
interrupt_btn.click(cancel_run, output_midi_seq, [output_midi, output_audio, js_msg],
|
| 265 |
cancels=run_event, queue=False)
|
| 266 |
+
app.queue().launch()
|