projectlosangeles commited on
Commit
50c36d3
·
verified ·
1 Parent(s): c5864c4

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +328 -395
app.py CHANGED
@@ -1,144 +1,115 @@
1
- #====================================================================
2
- # https://huggingface.co/spaces/asigalov61/Orpheus-Music-Transformer
3
- #====================================================================
4
 
5
- """
6
- Orpheus Music Transformer Gradio App - Single Model, Simplified Version
7
- SOTA 8k multi-instrumental music transformer trained on 2.31M+ high-quality MIDIs
8
- Using one model which was trained for ~2 epochs"
9
- """
10
 
11
- import os
 
12
 
13
- os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"
 
14
 
15
  import time as reqtime
16
  import datetime
17
  from pytz import timezone
18
 
 
 
 
 
 
19
  import torch
20
- import matplotlib.pyplot as plt
21
- import gradio as gr
22
- import spaces
 
 
23
 
24
  from huggingface_hub import hf_hub_download
 
25
  import TMIDIX
 
26
  from midi_to_colab_audio import midi_to_colab_audio
27
- from x_transformer_2_3_1 import TransformerWrapper, AutoregressiveWrapper, Decoder, top_p
 
28
 
29
  import random
30
 
31
- # -----------------------------
32
- # CONFIGURATION & GLOBALS
33
- # -----------------------------
34
- SEP = '=' * 70
35
- PDT = timezone('US/Pacific')
36
 
37
- MODEL_CHECKPOINT = 'Orpheus_Music_Transformer_Trained_Model_60667_steps_0.2238_loss_0.9375_acc.pth'
38
- SOUDFONT_PATH = 'SGM-v2.01-YamahaGrand-Guit-Bass-v2.7.sf2'
39
- NUM_OUT_BATCHES = 12
40
- PREVIEW_LENGTH = 120 # in tokens
41
-
42
- # -----------------------------
43
- # PRINT START-UP INFO
44
- # -----------------------------
45
- def print_sep():
46
- print(SEP)
47
-
48
- print_sep()
49
- print("Orpheus Music Transformer Gradio App")
50
- print_sep()
51
- print("Loading modules...")
52
-
53
- # -----------------------------
54
- # ENVIRONMENT & PyTorch Settings
55
- # -----------------------------
56
- os.environ['USE_FLASH_ATTENTION'] = '1'
57
 
58
- torch.set_float32_matmul_precision('high')
59
- torch.backends.cuda.matmul.allow_tf32 = True
60
- torch.backends.cudnn.allow_tf32 = True
61
- torch.backends.cuda.enable_mem_efficient_sdp(True)
62
- torch.backends.cuda.enable_math_sdp(True)
63
- torch.backends.cuda.enable_flash_sdp(True)
64
- torch.backends.cuda.enable_cudnn_sdp(True)
65
 
66
- print_sep()
67
- print("PyTorch version:", torch.__version__)
68
- print("Done loading modules!")
69
- print_sep()
70
 
71
- # -----------------------------
72
- # MODEL INITIALIZATION
73
- # -----------------------------
74
- print_sep()
75
- print("Instantiating model...")
 
 
 
 
 
 
 
 
 
 
 
 
76
 
77
  device_type = 'cuda'
78
  dtype = 'bfloat16'
 
79
  ptdtype = {'bfloat16': torch.bfloat16, 'float16': torch.float16}[dtype]
80
  ctx = torch.amp.autocast(device_type=device_type, dtype=ptdtype)
81
 
82
- SEQ_LEN = 8192
83
  PAD_IDX = 18819
84
 
85
- model = TransformerWrapper(
86
- num_tokens=PAD_IDX + 1,
87
- max_seq_len=SEQ_LEN,
88
- attn_layers=Decoder(
89
- dim=2048,
90
- depth=8,
91
- heads=32,
92
- rotary_pos_emb=True,
93
- attn_flash=True
94
- )
95
- )
96
  model = AutoregressiveWrapper(model, ignore_index=PAD_IDX, pad_value=PAD_IDX)
97
 
98
- print_sep()
99
- print("Loading model checkpoint...")
100
- checkpoint = hf_hub_download(
101
- repo_id='asigalov61/Orpheus-Music-Transformer',
102
- filename=MODEL_CHECKPOINT
103
- )
104
- model.load_state_dict(torch.load(checkpoint, map_location='cuda', weights_only=True))
105
  model = torch.compile(model, mode='max-autotune')
106
- print_sep()
107
- print("Done!")
108
- print("Model will use", dtype, "precision...")
109
- print_sep()
110
 
111
- model.cuda()
112
  model.eval()
113
 
114
- # -----------------------------
115
- # HELPER FUNCTIONS
116
- # -----------------------------
117
- def render_midi_output(final_composition):
118
- """Generate MIDI score, plot, and audio from final composition."""
119
- fname, midi_score = save_midi(final_composition)
120
- time_val = midi_score[-1][1] / 1000 # seconds marker from last note
121
- midi_plot = TMIDIX.plot_ms_SONG(
122
- midi_score,
123
- plot_title='Orpheus Music Transformer Composition',
124
- block_lines_times_list=[],
125
- return_plt=True
126
- )
127
- midi_audio = midi_to_colab_audio(
128
- fname + '.mid',
129
- soundfont_path=SOUDFONT_PATH,
130
- sample_rate=16000,
131
- output_for_gradio=True
132
- )
133
- return (16000, midi_audio), midi_plot, fname + '.mid', time_val
134
 
135
- # -----------------------------
136
- # MIDI PROCESSING FUNCTIONS
137
- # -----------------------------
138
  def load_midi(input_midi):
139
- """Process the input MIDI file and create a token sequence."""
140
- raw_score = TMIDIX.midi2single_track_ms_score(input_midi.name)
141
-
142
  escore_notes = TMIDIX.advanced_score_processor(raw_score, return_enhanced_score_notes=True, apply_sustain=True)
143
 
144
  if escore_notes:
@@ -176,6 +147,7 @@ def load_midi(input_midi):
176
 
177
  # Velocities
178
  # Calculating octo-velocity
 
179
  vel = max(8, min(127, e[4]))
180
  velocity = round(vel / 15)-1
181
 
@@ -187,324 +159,285 @@ def load_midi(input_midi):
187
  pat_ptc = (128 * pat) + ptc
188
  dur_vel = (8 * dur) + velocity
189
 
190
- melody_chords.extend([pat_ptc+256, dur_vel+16768])
191
-
192
- return melody_chords
193
-
194
- else:
195
- return [18816]
196
-
197
- def save_midi(tokens):
198
- """Convert token sequence back to a MIDI score and write it using TMIDIX.
199
- """
200
-
201
- time = 0
202
- dur = 1
203
- vel = 90
204
- pitch = 60
205
- channel = 0
206
- patch = 0
207
-
208
- patches = [-1] * 16
209
-
210
- channels = [0] * 16
211
- channels[9] = 1
212
-
213
- song_f = []
214
-
215
- for ss in tokens:
216
-
217
- if 0 <= ss < 256:
218
-
219
- time += ss * 16
220
-
221
- if 256 <= ss < 16768:
222
-
223
- patch = (ss-256) // 128
224
 
225
- if patch < 128:
 
226
 
227
- if patch not in patches:
228
- if 0 in channels:
229
- cha = channels.index(0)
230
- channels[cha] = 1
231
- else:
232
- cha = 15
233
 
234
- patches[cha] = patch
235
- channel = patches.index(patch)
236
- else:
237
- channel = patches.index(patch)
238
 
239
- if patch == 128:
240
- channel = 9
241
 
242
- pitch = (ss-256) % 128
 
 
 
 
243
 
 
244
 
245
- if 16768 <= ss < 18816:
 
 
 
246
 
247
- dur = ((ss-16768) // 8) * 16
248
- vel = (((ss-16768) % 8)+1) * 15
 
 
 
 
 
 
 
 
249
 
250
- song_f.append(['note', time, dur, channel, pitch, vel, patch])
251
 
252
- patches = [0 if x==-1 else x for x in patches]
253
 
254
- output_score, patches, overflow_patches = TMIDIX.patch_enhanced_score_notes(song_f)
255
 
256
- # Generate a time stamp using the PDT timezone.
257
- timestamp = datetime.datetime.now(PDT).strftime("%Y%m%d_%H%M%S")
258
-
259
- fname = f"Orpheus-Music-Transformer-Composition"
260
 
261
- TMIDIX.Tegridy_ms_SONG_to_MIDI_Converter(
262
- output_score,
263
- output_signature='Orpheus Music Transformer',
264
- output_file_name=fname,
265
- track_name='Project Los Angeles',
266
- list_of_MIDI_patches=patches,
267
- verbose=False
268
- )
269
- return fname, output_score
270
-
271
- # -----------------------------
272
- # MUSIC GENERATION FUNCTION (Combined)
273
- # -----------------------------
274
- @spaces.GPU
275
- def generate_music(prime, num_gen_tokens, num_mem_tokens, num_gen_batches, model_temperature, model_top_p):
276
- """Generate music tokens given prime tokens and parameters."""
277
- inputs = prime[-num_mem_tokens:] if prime else [18816]
278
- print("Generating...")
279
- inp = torch.LongTensor([inputs] * num_gen_batches).cuda()
280
- with ctx:
281
- out = model.generate(
282
- inp,
283
- num_gen_tokens,
284
- filter_logits_fn=top_p,
285
- filter_kwargs={'thres': model_top_p},
286
- temperature=model_temperature,
287
- eos_token=18818,
288
- return_prime=False,
289
- verbose=False
290
- )
291
 
292
- print("Done!")
293
- print_sep()
294
- return out.tolist()
295
-
296
- def generate_music_and_state(input_midi, num_prime_tokens, num_gen_tokens, num_mem_tokens,
297
- model_temperature, model_top_p, add_drums, add_outro, final_composition, generated_batches, block_lines):
298
- """
299
- Generate tokens using the model, update the composition state, and prepare outputs.
300
- This function combines seed loading, token generation, and UI output packaging.
301
- """
302
- print_sep()
303
- print("Request start time:", datetime.datetime.now(PDT).strftime("%Y-%m-%d %H:%M:%S"))
304
-
305
- print('=' * 70)
306
- if input_midi is not None:
307
- fn = os.path.basename(input_midi.name)
308
- fn1 = fn.split('.')[0]
309
- print('Input file name:', fn)
310
-
311
- print('Num prime tokens:', num_prime_tokens)
312
- print('Num gen tokens:', num_gen_tokens)
313
- print('Num mem tokens:', num_mem_tokens)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
314
 
315
- print('Model temp:', model_temperature)
316
- print('Model top p:', model_top_p)
317
-
318
- print('Add drums:', add_drums)
319
- print('Add outro:', add_outro)
320
- print('=' * 70)
321
-
322
- # Load seed from MIDI if there is no existing composition.
323
- if not final_composition and input_midi is not None:
324
- final_composition = load_midi(input_midi)[:num_prime_tokens]
325
- midi_fname, midi_score = save_midi(final_composition)
326
- # Use the last note's time as a marker.
327
- block_lines.append(midi_score[-1][1] / 1000 if final_composition else 0)
328
-
329
- if add_outro:
330
- final_composition.append(18817) # Outro token
331
-
332
- if add_drums:
333
- drum_pitch = random.choice([36, 38])
334
- final_composition.extend([(128*128)+drum_pitch+256]) # Drum token
335
-
336
- batched_gen_tokens = generate_music(final_composition, num_gen_tokens, num_mem_tokens,
337
- NUM_OUT_BATCHES, model_temperature, model_top_p)
338
-
339
- output_batches = []
340
- for i, tokens in enumerate(batched_gen_tokens):
341
- preview_tokens = final_composition[-PREVIEW_LENGTH:]
342
- midi_fname, midi_score = save_midi(preview_tokens + tokens)
343
- plot_kwargs = {'plot_title': f'Batch # {i}', 'return_plt': True}
344
 
345
- if len(final_composition) > PREVIEW_LENGTH:
346
- plot_kwargs['preview_length_in_notes'] = len([t for t in preview_tokens if 256 <= t < 16768])
347
-
348
- midi_plot = TMIDIX.plot_ms_SONG(midi_score, **plot_kwargs)
349
- midi_audio = midi_to_colab_audio(midi_fname + '.mid',
350
- soundfont_path=SOUDFONT_PATH,
351
- sample_rate=16000,
352
- output_for_gradio=True)
353
- output_batches.append([(16000, midi_audio), midi_plot, tokens])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
354
 
355
- # Update generated_batches (for use by add/remove functions)
356
- generated_batches = batched_gen_tokens
357
 
358
- print("Request end time:", datetime.datetime.now(PDT).strftime("%Y-%m-%d %H:%M:%S"))
359
- print_sep()
 
 
360
 
361
- # Flatten outputs: states then audio and plots for each batch.
362
- outputs_flat = []
363
- for batch in output_batches:
364
- outputs_flat.extend([batch[0], batch[1]])
365
- return [final_composition, generated_batches, block_lines] + outputs_flat
366
-
367
- # -----------------------------
368
- # BATCH HANDLING FUNCTIONS
369
- # -----------------------------
370
- def add_batch(batch_number, final_composition, generated_batches, block_lines):
371
- """Add tokens from the specified batch to the final composition and update outputs."""
372
- if generated_batches:
373
- final_composition.extend(generated_batches[batch_number])
374
- midi_fname, midi_score = save_midi(final_composition)
375
- block_lines.append(midi_score[-1][1] / 1000 if final_composition else 0)
376
- midi_plot = TMIDIX.plot_ms_SONG(
377
- midi_score,
378
- plot_title='Orpheus Music Transformer Composition',
379
- block_lines_times_list=block_lines[:-1],
380
- return_plt=True
381
- )
382
- midi_audio = midi_to_colab_audio(midi_fname + '.mid',
383
- soundfont_path=SOUDFONT_PATH,
384
- sample_rate=16000,
385
- output_for_gradio=True)
386
- print("Added batch #", batch_number)
387
- print_sep()
388
- return (16000, midi_audio), midi_plot, midi_fname + '.mid', final_composition, generated_batches, block_lines
389
- else:
390
- return None, None, None, [], [], []
391
-
392
- def remove_batch(batch_number, num_tokens, final_composition, generated_batches, block_lines):
393
- """Remove tokens from the final composition and update outputs."""
394
- if final_composition and len(final_composition) > num_tokens:
395
- final_composition = final_composition[:-num_tokens]
396
- if block_lines:
397
- block_lines.pop()
398
- midi_fname, midi_score = save_midi(final_composition)
399
- midi_plot = TMIDIX.plot_ms_SONG(
400
- midi_score,
401
- plot_title='Orpheus Music Transformer Composition',
402
- block_lines_times_list=block_lines[:-1],
403
- return_plt=True
404
- )
405
- midi_audio = midi_to_colab_audio(midi_fname + '.mid',
406
- soundfont_path=SOUDFONT_PATH,
407
- sample_rate=16000,
408
- output_for_gradio=True)
409
- print("Removed batch #", batch_number)
410
- print_sep()
411
- return (16000, midi_audio), midi_plot, midi_fname + '.mid', final_composition, generated_batches, block_lines
412
  else:
413
- return None, None, None, [], [], []
 
 
 
 
414
 
415
- def clear():
416
- """Clear outputs and reset state."""
417
- return None, None, None, [], []
418
 
419
- def reset(final_composition=[], generated_batches=[], block_lines=[]):
420
- """Reset composition state."""
421
- return [], [], []
422
 
423
- # -----------------------------
424
- # GRADIO INTERFACE SETUP
425
- # -----------------------------
426
  with gr.Blocks() as demo:
427
 
428
- gr.Markdown("<h1 style='text-align: left; margin-bottom: 1rem'>Orpheus Music Transformer</h1>")
429
- gr.Markdown("<h1 style='text-align: left; margin-bottom: 1rem'>SOTA 8k multi-instrumental music transformer trained on 2.31M+ high-quality MIDIs</h1>")
430
-
431
- gr.HTML("""
432
- Check out <a href="https://huggingface.co/datasets/projectlosangeles/Godzilla-MIDI-Dataset">Godzilla MIDI Dataset</a> on Hugging Face
433
- <p>
434
- <a href="https://huggingface.co/spaces/asigalov61/Orpheus-Music-Transformer?duplicate=true">
435
- <img src="https://huggingface.co/datasets/huggingface/badges/resolve/main/duplicate-this-space-md.svg" alt="Duplicate in Hugging Face">
436
- </a>
437
- </p>
438
- for faster execution and endless generation!
439
- """)
440
-
441
- gr.Markdown("## Key Features")
442
- gr.Markdown("""
443
- - **Efficient Architecture with RoPE**: Compact and very fast 479M full attention autoregressive transformer with RoPE.
444
- - **Extended Sequence Length**: 8k tokens that comfortably fit most music compositions and facilitate long-term music structure generation.
445
- - **Premium Training Data**: Trained solely on the highest-quality MIDIs from the Godzilla MIDI dataset.
446
- - **Optimized MIDI Encoding**: Extremely efficient MIDI representation using only 3 tokens per note and 7 tokens per tri-chord.
447
- - **Distinct Encoding Order**: Features a unique duration/velocity last MIDI encoding order for refined musical expression.
448
- - **Full-Range Instrumental Learning**: True full-range MIDI instruments encoding enabling the model to learn each instrument separately.
449
- - **Natural Composition Endings**: Outro tokens that help generate smooth and natural musical conclusions.
450
- """)
451
-
452
- # Global state variables for composition
453
- final_composition = gr.State([])
454
- generated_batches = gr.State([])
455
- block_lines = gr.State([])
456
-
457
- gr.Markdown("## Upload seed MIDI or click 'Generate' for random output")
458
- input_midi = gr.File(label="Input MIDI", file_types=[".midi", ".mid", ".kar"])
459
- input_midi.upload(reset, [final_composition, generated_batches, block_lines],
460
- [final_composition, generated_batches, block_lines])
461
-
462
- gr.Markdown("## Generate")
463
- num_prime_tokens = gr.Slider(16, 7168, value=7168, step=1, label="Number of prime tokens")
464
- num_gen_tokens = gr.Slider(16, 1024, value=512, step=1, label="Number of tokens to generate")
465
- num_mem_tokens = gr.Slider(16, 8192, value=8192, step=1, label="Number of memory tokens")
466
  model_temperature = gr.Slider(0.1, 1, value=0.9, step=0.01, label="Model temperature")
467
- model_top_p = gr.Slider(0.1, 0.99, value=0.96, step=0.01, label="Model sampling top p value")
468
- add_drums = gr.Checkbox(value=False, label="Add drums")
469
- add_outro = gr.Checkbox(value=False, label="Add an outro")
470
  generate_btn = gr.Button("Generate", variant="primary")
471
 
472
- gr.Markdown("## Batch Previews")
473
- outputs = [final_composition, generated_batches, block_lines]
474
- # Two outputs (audio and plot) for each batch
475
- for i in range(NUM_OUT_BATCHES):
476
- with gr.Tab(f"Batch # {i}"):
477
- audio_output = gr.Audio(label=f"Batch # {i} MIDI Audio", format="mp3")
478
- plot_output = gr.Plot(label=f"Batch # {i} MIDI Plot")
479
- outputs.extend([audio_output, plot_output])
480
- generate_btn.click(
481
- generate_music_and_state,
482
- [input_midi, num_prime_tokens, num_gen_tokens, num_mem_tokens, model_temperature, model_top_p, add_drums, add_outro,
483
- final_composition, generated_batches, block_lines],
484
- outputs
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
485
  )
 
 
486
 
487
- gr.Markdown("## Add/Remove Batch")
488
- batch_number = gr.Slider(0, NUM_OUT_BATCHES - 1, value=0, step=1, label="Batch number to add/remove")
489
- add_btn = gr.Button("Add batch", variant="primary")
490
- remove_btn = gr.Button("Remove batch", variant="stop")
491
- clear_btn = gr.ClearButton()
492
-
493
- final_audio_output = gr.Audio(label="Final MIDI audio", format="mp3")
494
- final_plot_output = gr.Plot(label="Final MIDI plot")
495
- final_file_output = gr.File(label="Final MIDI file")
496
-
497
- add_btn.click(
498
- add_batch,
499
- [batch_number, final_composition, generated_batches, block_lines],
500
- [final_audio_output, final_plot_output, final_file_output, final_composition, generated_batches, block_lines]
501
- )
502
- remove_btn.click(
503
- remove_batch,
504
- [batch_number, num_gen_tokens, final_composition, generated_batches, block_lines],
505
- [final_audio_output, final_plot_output, final_file_output, final_composition, generated_batches, block_lines]
506
- )
507
- clear_btn.click(clear, inputs=None,
508
- outputs=[final_audio_output, final_plot_output, final_file_output, final_composition, block_lines])
509
 
510
- demo.launch()
 
1
+ #============================================================================================
2
+ # https://huggingface.co/spaces/projectlosangeles/Orpheus-Drums-Transformer
3
+ #============================================================================================
4
 
5
+ print('=' * 70)
6
+ print('Orpheus Drums Transformer Gradio App')
 
 
 
7
 
8
+ print('=' * 70)
9
+ print('Loading core Orpheus Drums Transformer modules...')
10
 
11
+ import os
12
+ import copy
13
 
14
  import time as reqtime
15
  import datetime
16
  from pytz import timezone
17
 
18
+ print('=' * 70)
19
+ print('Loading main Orpheus Drums Transformer modules...')
20
+
21
+ os.environ['USE_FLASH_ATTENTION'] = '1'
22
+
23
  import torch
24
+
25
+ torch.set_float32_matmul_precision('high')
26
+ torch.backends.cuda.matmul.allow_tf32 = True # allow tf32 on matmul
27
+ torch.backends.cudnn.allow_tf32 = True # allow tf32 on cudnn
28
+ torch.backends.cuda.enable_flash_sdp(True)
29
 
30
  from huggingface_hub import hf_hub_download
31
+
32
  import TMIDIX
33
+
34
  from midi_to_colab_audio import midi_to_colab_audio
35
+
36
+ from x_transformer_2_3_1 import *
37
 
38
  import random
39
 
40
+ import tqdm
 
 
 
 
41
 
42
+ print('=' * 70)
43
+ print('Loading aux Orpheus Drums Transformer modules...')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
44
 
45
+ import matplotlib.pyplot as plt
 
 
 
 
 
 
46
 
47
+ import gradio as gr
48
+ import spaces
 
 
49
 
50
+ print('=' * 70)
51
+ print('PyTorch version:', torch.__version__)
52
+ print('=' * 70)
53
+ print('Done!')
54
+ print('Enjoy! :)')
55
+ print('=' * 70)
56
+
57
+ #==================================================================================
58
+
59
+ MODEL_CHECKPOINT = 'Orpheus_Bridge_Music_Transformer_Trained_Model_19571_steps_0.9396_loss_0.7365_acc.pth'
60
+
61
+ SOUDFONT_PATH = 'SGM-v2.01-YamahaGrand-Guit-Bass-v2.7.sf2'
62
+
63
+ #==================================================================================
64
+
65
+ print('=' * 70)
66
+ print('Instantiating model...')
67
 
68
  device_type = 'cuda'
69
  dtype = 'bfloat16'
70
+
71
  ptdtype = {'bfloat16': torch.bfloat16, 'float16': torch.float16}[dtype]
72
  ctx = torch.amp.autocast(device_type=device_type, dtype=ptdtype)
73
 
74
+ SEQ_LEN = 1668
75
  PAD_IDX = 18819
76
 
77
+ model = TransformerWrapper(num_tokens = PAD_IDX+1,
78
+ max_seq_len = SEQ_LEN,
79
+ attn_layers = Decoder(dim = 2048,
80
+ depth = 8,
81
+ heads = 32,
82
+ rotary_pos_emb = True,
83
+ attn_flash = True
84
+ )
85
+ )
86
+
 
87
  model = AutoregressiveWrapper(model, ignore_index=PAD_IDX, pad_value=PAD_IDX)
88
 
89
+ print('=' * 70)
90
+ print('Loading model checkpoint...')
91
+
92
+ model_checkpoint = hf_hub_download(repo_id='asigalov61/Orpheus-Music-Transformer', filename=MODEL_CHECKPOINT)
93
+
94
+ model.load_state_dict(torch.load(model_checkpoint, map_location=device_type, weights_only=True))
95
+
96
  model = torch.compile(model, mode='max-autotune')
 
 
 
 
97
 
98
+ model.to(device_type)
99
  model.eval()
100
 
101
+ print('=' * 70)
102
+ print('Done!')
103
+ print('=' * 70)
104
+ print('Model will use', dtype, 'precision...')
105
+ print('=' * 70)
106
+
107
+ #==================================================================================
 
 
 
 
 
 
 
 
 
 
 
 
 
108
 
 
 
 
109
  def load_midi(input_midi):
110
+
111
+ raw_score = TMIDIX.midi2single_track_ms_score(input_midi)
112
+
113
  escore_notes = TMIDIX.advanced_score_processor(raw_score, return_enhanced_score_notes=True, apply_sustain=True)
114
 
115
  if escore_notes:
 
147
 
148
  # Velocities
149
  # Calculating octo-velocity
150
+
151
  vel = max(8, min(127, e[4]))
152
  velocity = round(vel / 15)-1
153
 
 
159
  pat_ptc = (128 * pat) + ptc
160
  dur_vel = (8 * dur) + velocity
161
 
162
+ melody_chords.extend([pat_ptc+256, dur_vel+16768]) # 18816
163
+
164
+
165
+ print('Done!')
166
+ print('=' * 70)
167
+ print('Score hss', len(melody_chords), 'tokens')
168
+ print('=' * 70)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
169
 
170
+ if len(melody_chords) > SEQ_LEN:
171
+ return melody_chords
172
 
173
+ else:
174
+ return None
 
 
 
 
175
 
176
+ else:
177
+ return None
 
 
178
 
179
+ #==================================================================================
 
180
 
181
+ @spaces.GPU
182
+ def Generate_Music_Bridge(input_midi,
183
+ model_temperature,
184
+ model_sampling_top_p
185
+ ):
186
 
187
+ #===============================================================================
188
 
189
+ print('=' * 70)
190
+ print('Req start time: {:%Y-%m-%d %H:%M:%S}'.format(datetime.datetime.now(PDT)))
191
+ start_time = reqtime.time()
192
+ print('=' * 70)
193
 
194
+ print('=' * 70)
195
+ print('Requested settings:')
196
+ print('=' * 70)
197
+ fn = os.path.basename(input_midi)
198
+ fn1 = fn.split('.')[0]
199
+ print('Input MIDI file name:', fn)
200
+ print('Model temperature:', model_temperature)
201
+ print('Model top p:', model_sampling_top_p)
202
+
203
+ print('=' * 70)
204
 
205
+ #==================================================================
206
 
207
+ if input_midi is not None:
208
 
209
+ print('Loading MIDI...')
210
 
211
+ score = load_midi(input_midi.name)
 
 
 
212
 
213
+ if score is not None:
214
+
215
+ print('Sample score tokens', score[:10])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
216
 
217
+ #==================================================================
218
+
219
+ full_chunk = score[:1536]
220
+ left_chunk = full_chunk[:512]
221
+ right_chunk = full_chunk[-512:]
222
+
223
+ bridge_chunk = full_chunk[448:1088]
224
+
225
+ seq = [18815] + left_chunk + [18816] + right_chunk + [18817]
226
+
227
+ #==================================================================
228
+
229
+ print('=' * 70)
230
+ print('Generating...')
231
+
232
+ x = torch.LongTensor(seq).to(device_type)
233
+
234
+ with ctx:
235
+ out = model.generate(x,
236
+ 641,
237
+ temperature=model_temperature,
238
+ filter_logits_fn=top_p,
239
+ filter_kwargs={'thres': model_sampling_top_p},
240
+ return_prime=False,
241
+ eos_token=18818,
242
+ verbose=False)
243
+
244
+ y = out.tolist()
245
+
246
+ final_song = left_chunk + y[64:-64] + right_chunk
247
+
248
+ #==================================================================
249
+
250
+ print('=' * 70)
251
+ print('Done!')
252
+ print('=' * 70)
253
+
254
+ #===============================================================================
255
+
256
+ print('Rendering results...')
257
+
258
+ print('=' * 70)
259
+ print('Sample INTs', final_song[:15])
260
+ print('=' * 70)
261
+
262
+ song_f = []
263
+
264
+ if len(final_song) != 0:
265
+
266
+ time = 0
267
+ dur = 1
268
+ vel = 90
269
+ pitch = 60
270
+ channel = 0
271
+ patch = 0
272
+
273
+ patches = [-1] * 16
274
+
275
+ channels = [0] * 16
276
+ channels[9] = 1
277
+
278
+ for ss in final_song:
279
+
280
+ if 0 <= ss < 256:
281
+
282
+ time += ss * 16
283
+
284
+ if 256 <= ss < 16768:
285
+
286
+ patch = (ss-256) // 128
287
+
288
+ if patch < 128:
289
+
290
+ if patch not in patches:
291
+ if 0 in channels:
292
+ cha = channels.index(0)
293
+ channels[cha] = 1
294
+ else:
295
+ cha = 15
296
+
297
+ patches[cha] = patch
298
+ channel = patches.index(patch)
299
+ else:
300
+ channel = patches.index(patch)
301
+
302
+ if patch == 128:
303
+ channel = 9
304
+
305
+ pitch = (ss-256) % 128
306
+
307
+
308
+ if 16768 <= ss < 18816:
309
+
310
+ dur = ((ss-16768) // 8) * 16
311
+ vel = (((ss-16768) % 8)+1) * 15
312
+
313
+ song_f.append(['note', time, dur, channel, pitch, vel, patch])
314
+
315
+ patches = [0 if x==-1 else x for x in patches]
316
 
317
+ output_score, patches, overflow_patches = TMIDIX.patch_enhanced_score_notes(song_f)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
318
 
319
+ fn1 = "Orpheus-Drums-Transformer-Composition"
320
+
321
+ detailed_stats = TMIDIX.Tegridy_ms_SONG_to_MIDI_Converter(output_score,
322
+ output_signature = 'Orpheus Drums Transformer',
323
+ output_file_name = fn1,
324
+ track_name='Project Los Angeles',
325
+ list_of_MIDI_patches=patches
326
+ )
327
+
328
+ new_fn = fn1+'.mid'
329
+
330
+
331
+ audio = midi_to_colab_audio(new_fn,
332
+ soundfont_path=SOUDFONT_PATH,
333
+ sample_rate=16000,
334
+ volume_scale=10,
335
+ output_for_gradio=True
336
+ )
337
+
338
+ print('Done!')
339
+ print('=' * 70)
340
+
341
+ #========================================================
342
+
343
+ output_midi = str(new_fn)
344
+ output_audio = (16000, audio)
345
+ output_plot = TMIDIX.plot_ms_SONG(song_f, plot_title=output_midi, return_plt=True)
346
+
347
+ print('Output MIDI file name:', output_midi)
348
+ print('=' * 70)
349
+
350
+ #========================================================
351
 
352
+ else:
353
+ return None, None, None
354
 
355
+ print('-' * 70)
356
+ print('Req end time: {:%Y-%m-%d %H:%M:%S}'.format(datetime.datetime.now(PDT)))
357
+ print('-' * 70)
358
+ print('Req execution time:', (reqtime.time() - start_time), 'sec')
359
 
360
+ return output_audio, output_plot, output_midi
361
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
362
  else:
363
+ return None, None, None
364
+
365
+ #==================================================================================
366
+
367
+ PDT = timezone('US/Pacific')
368
 
369
+ print('=' * 70)
370
+ print('App start time: {:%Y-%m-%d %H:%M:%S}'.format(datetime.datetime.now(PDT)))
371
+ print('=' * 70)
372
 
373
+ #==================================================================================
 
 
374
 
 
 
 
375
  with gr.Blocks() as demo:
376
 
377
+ #==================================================================================
378
+
379
+ gr.Markdown("<h1 style='text-align: left; margin-bottom: 1rem'>Orpheus Drums Transformer</h1>")
380
+ gr.Markdown("<h1 style='text-align: left; margin-bottom: 1rem'>Seamless music bridges generation with transformers</h1>")
381
+ gr.HTML("""
382
+ <p>
383
+ <a href="https://huggingface.co/spaces/projectlosangeles/Orpheus-Drums-Transformer?duplicate=true">
384
+ <img src="https://huggingface.co/datasets/huggingface/badges/resolve/main/duplicate-this-space-md.svg" alt="Duplicate in Hugging Face">
385
+ </a>
386
+ </p>
387
+
388
+ for faster execution and endless generation!
389
+ """)
390
+
391
+ #==================================================================================
392
+
393
+ gr.Markdown("## Upload source MIDI or select a sample MIDI on the bottom of the page")
394
+ gr.Markdown("### PLEASE NOTE: The MIDI file MUST HAVE at least 800 MIDI pitches for the demo to work properly!")
395
+
396
+ input_midi = gr.File(label="Input MIDI",
397
+ file_types=[".midi", ".mid", ".kar"]
398
+ )
399
+
400
+ gr.Markdown("## Generation options")
401
+
 
 
 
 
 
 
 
 
 
 
 
 
 
402
  model_temperature = gr.Slider(0.1, 1, value=0.9, step=0.01, label="Model temperature")
403
+ model_sampling_top_p = gr.Slider(0.1, 0.99, value=0.96, step=0.01, label="Model sampling top p value")
404
+
 
405
  generate_btn = gr.Button("Generate", variant="primary")
406
 
407
+ gr.Markdown("## Generation results")
408
+
409
+ output_title = gr.Textbox(label="MIDI melody title")
410
+ output_audio = gr.Audio(label="MIDI audio", format="wav", elem_id="midi_audio")
411
+ output_plot = gr.Plot(label="MIDI score plot")
412
+ output_midi = gr.File(label="MIDI file", file_types=[".mid"])
413
+
414
+ generate_btn.click(Generate_Music_Bridge,
415
+ [input_midi,
416
+ model_temperature,
417
+ model_sampling_top_p
418
+ ],
419
+ [output_audio,
420
+ output_plot,
421
+ output_midi
422
+ ]
423
+ )
424
+
425
+ gr.Examples(
426
+ [["Sharing The Night Together.kar", 0.9, 0.96]
427
+ ],
428
+ [input_midi,
429
+ model_temperature,
430
+ model_sampling_top_p
431
+ ],
432
+ [output_audio,
433
+ output_plot,
434
+ output_midi
435
+ ],
436
+ Generate_Music_Bridge
437
  )
438
+
439
+ #==================================================================================
440
 
441
+ demo.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
442
 
443
+ #==================================================================================