asigalov61 commited on
Commit
9ee602d
·
verified ·
1 Parent(s): 340c7af

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +451 -0
app.py ADDED
@@ -0,0 +1,451 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #====================================================================
2
+ # https://huggingface.co/spaces/asigalov61/Orpheus-Music-Transformer
3
+ #====================================================================
4
+
5
+ """
6
+ Orpheus Music Transformer Gradio App - Single Model, Simplified Version
7
+ SOTA 8k multi-instrumental music transformer trained on 2.31M+ high-quality MIDIs
8
+ Using one model which was trained for 3 full epochs"
9
+ """
10
+
11
+ import os
12
+
13
+ os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"
14
+
15
+ import time as reqtime
16
+ import datetime
17
+ from pytz import timezone
18
+
19
+ import torch
20
+ import matplotlib.pyplot as plt
21
+ import gradio as gr
22
+ import spaces
23
+
24
+ from huggingface_hub import hf_hub_download
25
+ import TMIDIX
26
+ from midi_to_colab_audio import midi_to_colab_audio
27
+ from x_transformer_2_3_1 import TransformerWrapper, AutoregressiveWrapper, Decoder
28
+
29
+ # -----------------------------
30
+ # CONFIGURATION & GLOBALS
31
+ # -----------------------------
32
+ SEP = '=' * 70
33
+ PDT = timezone('US/Pacific')
34
+
35
+ MODEL_CHECKPOINT = 'Orpheus_Music_Transformer_No_Velocity_Trained_Model_21113_steps_0.3454_loss_0.895_acc.pth'
36
+ SOUDFONT_PATH = 'SGM-v2.01-YamahaGrand-Guit-Bass-v2.7.sf2'
37
+ NUM_OUT_BATCHES = 12
38
+ PREVIEW_LENGTH = 120 # in tokens
39
+
40
+ # -----------------------------
41
+ # PRINT START-UP INFO
42
+ # -----------------------------
43
+ def print_sep():
44
+ print(SEP)
45
+
46
+ print_sep()
47
+ print("Orpheus Music Transformer Gradio App")
48
+ print_sep()
49
+ print("Loading modules...")
50
+
51
+ # -----------------------------
52
+ # ENVIRONMENT & PyTorch Settings
53
+ # -----------------------------
54
+ os.environ['USE_FLASH_ATTENTION'] = '1'
55
+
56
+ torch.set_float32_matmul_precision('high')
57
+ torch.backends.cuda.matmul.allow_tf32 = True
58
+ torch.backends.cudnn.allow_tf32 = True
59
+ torch.backends.cuda.enable_mem_efficient_sdp(True)
60
+ torch.backends.cuda.enable_math_sdp(True)
61
+ torch.backends.cuda.enable_flash_sdp(True)
62
+ torch.backends.cuda.enable_cudnn_sdp(True)
63
+
64
+ print_sep()
65
+ print("PyTorch version:", torch.__version__)
66
+ print("Done loading modules!")
67
+ print_sep()
68
+
69
+ # -----------------------------
70
+ # MODEL INITIALIZATION
71
+ # -----------------------------
72
+ print_sep()
73
+ print("Instantiating model...")
74
+
75
+ device_type = 'cuda'
76
+ dtype = 'bfloat16'
77
+ ptdtype = {'bfloat16': torch.bfloat16, 'float16': torch.float16}[dtype]
78
+ ctx = torch.amp.autocast(device_type=device_type, dtype=ptdtype)
79
+
80
+ SEQ_LEN = 4096
81
+ PAD_IDX = 384
82
+
83
+ model = TransformerWrapper(
84
+ num_tokens=PAD_IDX + 1,
85
+ max_seq_len=SEQ_LEN,
86
+ attn_layers=Decoder(
87
+ dim=2048,
88
+ depth=16,
89
+ heads=32,
90
+ rotary_pos_emb=True,
91
+ attn_flash=True
92
+ )
93
+ )
94
+ model = AutoregressiveWrapper(model, ignore_index=PAD_IDX, pad_value=PAD_IDX)
95
+
96
+ print_sep()
97
+ print("Loading model checkpoint...")
98
+ checkpoint = hf_hub_download(
99
+ repo_id='asigalov61/Orpheus-Music-Transformer',
100
+ filename=MODEL_CHECKPOINT
101
+ )
102
+ model.load_state_dict(torch.load(checkpoint, map_location='cuda', weights_only=True))
103
+ model = torch.compile(model, mode='max-autotune')
104
+ print_sep()
105
+ print("Done!")
106
+ print("Model will use", dtype, "precision...")
107
+ print_sep()
108
+
109
+ model.cuda()
110
+ model.eval()
111
+
112
+ # -----------------------------
113
+ # HELPER FUNCTIONS
114
+ # -----------------------------
115
+ def render_midi_output(final_composition):
116
+ """Generate MIDI score, plot, and audio from final composition."""
117
+ fname, midi_score = save_midi(final_composition)
118
+ time_val = midi_score[-1][1] / 1000 # seconds marker from last note
119
+ midi_plot = TMIDIX.plot_ms_SONG(
120
+ midi_score,
121
+ plot_title='Orpheus Music Transformer Composition',
122
+ block_lines_times_list=[],
123
+ return_plt=True
124
+ )
125
+ midi_audio = midi_to_colab_audio(
126
+ fname + '.mid',
127
+ soundfont_path=SOUDFONT_PATH,
128
+ sample_rate=16000,
129
+ output_for_gradio=True
130
+ )
131
+ return (16000, midi_audio), midi_plot, fname + '.mid', time_val
132
+
133
+ # -----------------------------
134
+ # MIDI PROCESSING FUNCTIONS
135
+ # -----------------------------
136
+ def load_midi(input_midi):
137
+ """Process the input MIDI file and create a token sequence using without velocity logic."""
138
+ raw_score = TMIDIX.midi2single_track_ms_score(input_midi.name)
139
+ escore_notes = TMIDIX.advanced_score_processor(
140
+ raw_score, return_enhanced_score_notes=True, apply_sustain=True
141
+ )[0]
142
+ sp_escore_notes = TMIDIX.solo_piano_escore_notes(escore_notes)
143
+ zscore = TMIDIX.recalculate_score_timings(sp_escore_notes)
144
+ zscore = TMIDIX.augment_enhanced_score_notes(zscore, timings_divider=32)
145
+ fscore = TMIDIX.fix_escore_notes_durations(zscore)
146
+ cscore = TMIDIX.chordify_score([1000, fscore])
147
+
148
+ score = []
149
+ prev_chord = cscore[0]
150
+ for chord in cscore:
151
+ # Time difference token.
152
+ score.append(max(0, min(127, chord[0][1] - prev_chord[0][1])))
153
+ for note in chord:
154
+ score.extend([
155
+ max(1, min(127, note[2])) + 128,
156
+ max(1, min(127, note[4])) + 256
157
+ ])
158
+ prev_chord = chord
159
+ return score
160
+
161
+ def save_midi(tokens, batch_number=None):
162
+ """Convert token sequence back to a MIDI score and write it using TMIDIX (without velocity).
163
+ The output MIDI file name incorporates a date-time stamp.
164
+ """
165
+ song_events = []
166
+ time_marker = 0
167
+ duration = 0
168
+ pitch = 0
169
+ patches = [0] * 16
170
+
171
+ for token in tokens:
172
+ if 0 <= token < 128:
173
+ time_marker += token * 32
174
+ elif 128 <= token < 256:
175
+ duration = (token - 128) * 32
176
+ elif 256 <= token < 384:
177
+ pitch = token - 256
178
+ song_events.append(['note', time_marker, duration, 0, pitch, max(40, pitch), 0])
179
+ # No velocity tokens are used.
180
+
181
+ # Generate a time stamp using the PDT timezone.
182
+ timestamp = datetime.datetime.now(PDT).strftime("%Y%m%d_%H%M%S")
183
+
184
+ '''if batch_number is None:
185
+ fname = f"Orpheus-Music-Transformer-Music-Composition_{timestamp}"
186
+ else:
187
+ fname = f"Orpheus-Music-Transformer-Music-Composition_{timestamp}_Batch_{batch_number}"'''
188
+
189
+ fname = f"Orpheus-Music-Transformer-Music-Composition"
190
+
191
+ TMIDIX.Tegridy_ms_SONG_to_MIDI_Converter(
192
+ song_events,
193
+ output_signature='Orpheus Music Transformer',
194
+ output_file_name=fname,
195
+ track_name='Project Los Angeles',
196
+ list_of_MIDI_patches=patches,
197
+ verbose=False
198
+ )
199
+ return fname, song_events
200
+
201
+ # -----------------------------
202
+ # MUSIC GENERATION FUNCTION (Combined)
203
+ # -----------------------------
204
+ @spaces.GPU
205
+ def generate_music(prime, num_gen_tokens, num_mem_tokens, num_gen_batches, model_temperature):
206
+ """Generate music tokens given prime tokens and parameters."""
207
+ inputs = prime[-num_mem_tokens:] if prime else [0]
208
+ print("Generating...")
209
+ inp = torch.LongTensor([inputs] * num_gen_batches).cuda()
210
+ with ctx:
211
+ out = model.generate(
212
+ inp,
213
+ num_gen_tokens,
214
+ temperature=model_temperature,
215
+ return_prime=False,
216
+ verbose=False
217
+ )
218
+ print("Done!")
219
+ print_sep()
220
+ return out.tolist()
221
+
222
+ def generate_music_and_state(input_midi, num_prime_tokens, num_gen_tokens, num_mem_tokens,
223
+ model_temperature, final_composition, generated_batches, block_lines):
224
+ """
225
+ Generate tokens using the model, update the composition state, and prepare outputs.
226
+ This function combines seed loading, token generation, and UI output packaging.
227
+ """
228
+ print_sep()
229
+ print("Request start time:", datetime.datetime.now(PDT).strftime("%Y-%m-%d %H:%M:%S"))
230
+
231
+ print('=' * 70)
232
+ if input_midi is not None:
233
+ fn = os.path.basename(input_midi.name)
234
+ fn1 = fn.split('.')[0]
235
+ print('Input file name:', fn)
236
+
237
+ print('Num prime tokens:', num_prime_tokens)
238
+ print('Num gen tokens:', num_gen_tokens)
239
+ print('Num mem tokens:', num_mem_tokens)
240
+
241
+ print('Model temp:', model_temperature)
242
+ print('=' * 70)
243
+
244
+ # Load seed from MIDI if there is no existing composition.
245
+ if not final_composition and input_midi is not None:
246
+ final_composition = load_midi(input_midi)[:num_prime_tokens]
247
+ midi_fname, midi_score = save_midi(final_composition)
248
+ # Use the last note's time as a marker.
249
+ TMIDIX.Tegridy_ms_SONG_to_MIDI_Converter(
250
+ midi_score,
251
+ output_signature='Orpheus Music Transformer',
252
+ output_file_name=midi_fname,
253
+ track_name='Project Los Angeles',
254
+ list_of_MIDI_patches=[0]*16,
255
+ verbose=False
256
+ )
257
+ block_lines.append(midi_score[-1][1] / 1000 if final_composition else 0)
258
+
259
+ batched_gen_tokens = generate_music(final_composition, num_gen_tokens, num_mem_tokens,
260
+ NUM_OUT_BATCHES, model_temperature)
261
+
262
+ output_batches = []
263
+ for i, tokens in enumerate(batched_gen_tokens):
264
+ preview_tokens = final_composition[-PREVIEW_LENGTH:]
265
+ midi_fname, midi_score = save_midi(preview_tokens + tokens, batch_number=i)
266
+ plot_kwargs = {'plot_title': f'Batch # {i}', 'return_plt': True}
267
+ if len(final_composition) > PREVIEW_LENGTH:
268
+ plot_kwargs['preview_length_in_notes'] = len([t for t in preview_tokens if t > 256])
269
+ TMIDIX.Tegridy_ms_SONG_to_MIDI_Converter(
270
+ midi_score,
271
+ output_signature='Orpheus Music Transformer',
272
+ output_file_name=midi_fname,
273
+ track_name='Project Los Angeles',
274
+ list_of_MIDI_patches=[0]*16,
275
+ verbose=False
276
+ )
277
+ midi_plot = TMIDIX.plot_ms_SONG(midi_score, **plot_kwargs)
278
+ midi_audio = midi_to_colab_audio(midi_fname + '.mid',
279
+ soundfont_path=SOUDFONT_PATH,
280
+ sample_rate=16000,
281
+ output_for_gradio=True)
282
+ output_batches.append([(16000, midi_audio), midi_plot, tokens])
283
+
284
+ # Update generated_batches (for use by add/remove functions)
285
+ generated_batches = batched_gen_tokens
286
+
287
+ print("Request end time:", datetime.datetime.now(PDT).strftime("%Y-%m-%d %H:%M:%S"))
288
+ print_sep()
289
+
290
+ # Flatten outputs: states then audio and plots for each batch.
291
+ outputs_flat = []
292
+ for batch in output_batches:
293
+ outputs_flat.extend([batch[0], batch[1]])
294
+ return [final_composition, generated_batches, block_lines] + outputs_flat
295
+
296
+ # -----------------------------
297
+ # BATCH HANDLING FUNCTIONS
298
+ # -----------------------------
299
+ def add_batch(batch_number, final_composition, generated_batches, block_lines):
300
+ """Add tokens from the specified batch to the final composition and update outputs."""
301
+ if generated_batches:
302
+ final_composition.extend(generated_batches[batch_number])
303
+ midi_fname, midi_score = save_midi(final_composition)
304
+ block_lines.append(midi_score[-1][1] / 1000 if final_composition else 0)
305
+ TMIDIX.Tegridy_ms_SONG_to_MIDI_Converter(
306
+ midi_score,
307
+ output_signature='Orpheus Music Transformer',
308
+ output_file_name=midi_fname,
309
+ track_name='Project Los Angeles',
310
+ list_of_MIDI_patches=[0]*16,
311
+ verbose=False
312
+ )
313
+ midi_plot = TMIDIX.plot_ms_SONG(
314
+ midi_score,
315
+ plot_title='Orpheus Music Transformer Composition',
316
+ block_lines_times_list=block_lines[:-1],
317
+ return_plt=True
318
+ )
319
+ midi_audio = midi_to_colab_audio(midi_fname + '.mid',
320
+ soundfont_path=SOUDFONT_PATH,
321
+ sample_rate=16000,
322
+ output_for_gradio=True)
323
+ print("Added batch #", batch_number)
324
+ print_sep()
325
+ return (16000, midi_audio), midi_plot, midi_fname + '.mid', final_composition, generated_batches, block_lines
326
+ else:
327
+ return None, None, None, [], [], []
328
+
329
+ def remove_batch(batch_number, num_tokens, final_composition, generated_batches, block_lines):
330
+ """Remove tokens from the final composition and update outputs."""
331
+ if final_composition and len(final_composition) > num_tokens:
332
+ final_composition = final_composition[:-num_tokens]
333
+ if block_lines:
334
+ block_lines.pop()
335
+ midi_fname, midi_score = save_midi(final_composition)
336
+ TMIDIX.Tegridy_ms_SONG_to_MIDI_Converter(
337
+ midi_score,
338
+ output_signature='Orpheus Music Transformer',
339
+ output_file_name=midi_fname,
340
+ track_name='Project Los Angeles',
341
+ list_of_MIDI_patches=[0]*16,
342
+ verbose=False
343
+ )
344
+ midi_plot = TMIDIX.plot_ms_SONG(
345
+ midi_score,
346
+ plot_title='Orpheus Music Transformer Composition',
347
+ block_lines_times_list=block_lines[:-1],
348
+ return_plt=True
349
+ )
350
+ midi_audio = midi_to_colab_audio(midi_fname + '.mid',
351
+ soundfont_path=SOUDFONT_PATH,
352
+ sample_rate=16000,
353
+ output_for_gradio=True)
354
+ print("Removed batch #", batch_number)
355
+ print_sep()
356
+ return (16000, midi_audio), midi_plot, midi_fname + '.mid', final_composition, generated_batches, block_lines
357
+ else:
358
+ return None, None, None, [], [], []
359
+
360
+ def clear():
361
+ """Clear outputs and reset state."""
362
+ return None, None, None, [], []
363
+
364
+ def reset(final_composition=[], generated_batches=[], block_lines=[]):
365
+ """Reset composition state."""
366
+ return [], [], []
367
+
368
+ # -----------------------------
369
+ # GRADIO INTERFACE SETUP
370
+ # -----------------------------
371
+ with gr.Blocks() as demo:
372
+
373
+ gr.Markdown("<h1 style='text-align: left; margin-bottom: 1rem'>Orpheus Music Transformer</h1>")
374
+ gr.Markdown("<h1 style='text-align: left; margin-bottom: 1rem'>SOTA 8k multi-instrumental music transformer trained on 2.31M+ high-quality MIDIs</h1>")
375
+ gr.HTML("""
376
+ Check out <a href="https://huggingface.co/datasets/projectlosangeles/Godzilla-MIDI-Dataset">Godzilla MIDI Dataset</a> on Hugging Face
377
+ <p>
378
+ <a href="https://huggingface.co/spaces/asigalov61/Orpheus-Music-Transformer?duplicate=true">
379
+ <img src="https://huggingface.co/datasets/huggingface/badges/resolve/main/duplicate-this-space-md.svg" alt="Duplicate in Hugging Face">
380
+ </a>
381
+ </p>
382
+ for faster execution and endless generation!
383
+ """)
384
+
385
+ gr.Markdown("## Key Features")
386
+ gr.Markdown("""
387
+ - **Efficient Architecture with RoPE**: Compact and very fast 479M full attention autoregressive transformer with RoPE.
388
+ - **Extended Sequence Length**: 8k tokens that comfortably fit most music compositions and facilitate long-term music structure generation.
389
+ - **Premium Training Data**: Exclusively trained on high-quality MIDIs from the Godzilla MIDI dataset.
390
+ - **Optimized MIDI Encoding**: Extremely efficient MIDI representation using only 3 tokens per note and 7 tokens per tri-chord.
391
+ - **Distinct Encoding Order**: Features a unique duration/velocity last MIDI encoding order for refined musical expression.
392
+ - **Full-Range Instrumental Learning**: True full-range MIDI instruments encoding enabling the model to learn each instrument separately.
393
+ - **Natural Composition Endings**: Outro tokens that help generate smooth and natural musical conclusions.
394
+ """)
395
+
396
+ # Global state variables for composition
397
+ final_composition = gr.State([])
398
+ generated_batches = gr.State([])
399
+ block_lines = gr.State([])
400
+
401
+ gr.Markdown("## Upload seed MIDI or click 'Generate' for a random output")
402
+ input_midi = gr.File(label="Input MIDI", file_types=[".midi", ".mid", ".kar"])
403
+ input_midi.upload(reset, [final_composition, generated_batches, block_lines],
404
+ [final_composition, generated_batches, block_lines])
405
+
406
+ gr.Markdown("## Generate")
407
+ num_prime_tokens = gr.Slider(15, 3072, value=3072, step=1, label="Number of prime tokens")
408
+ num_gen_tokens = gr.Slider(15, 1024, value=512, step=1, label="Number of tokens to generate")
409
+ num_mem_tokens = gr.Slider(15, 4096, value=4096, step=1, label="Number of memory tokens")
410
+ model_temperature = gr.Slider(0.1, 1, value=0.9, step=0.01, label="Model temperature")
411
+ generate_btn = gr.Button("Generate", variant="primary")
412
+
413
+ gr.Markdown("## Batch Previews")
414
+ outputs = [final_composition, generated_batches, block_lines]
415
+ # Two outputs (audio and plot) for each batch
416
+ for i in range(NUM_OUT_BATCHES):
417
+ with gr.Tab(f"Batch # {i}"):
418
+ audio_output = gr.Audio(label=f"Batch # {i} MIDI Audio", format="mp3")
419
+ plot_output = gr.Plot(label=f"Batch # {i} MIDI Plot")
420
+ outputs.extend([audio_output, plot_output])
421
+ generate_btn.click(
422
+ generate_music_and_state,
423
+ [input_midi, num_prime_tokens, num_gen_tokens, num_mem_tokens, model_temperature,
424
+ final_composition, generated_batches, block_lines],
425
+ outputs
426
+ )
427
+
428
+ gr.Markdown("## Add/Remove Batch")
429
+ batch_number = gr.Slider(0, NUM_OUT_BATCHES - 1, value=0, step=1, label="Batch number to add/remove")
430
+ add_btn = gr.Button("Add batch", variant="primary")
431
+ remove_btn = gr.Button("Remove batch", variant="stop")
432
+ clear_btn = gr.ClearButton()
433
+
434
+ final_audio_output = gr.Audio(label="Final MIDI audio", format="mp3")
435
+ final_plot_output = gr.Plot(label="Final MIDI plot")
436
+ final_file_output = gr.File(label="Final MIDI file")
437
+
438
+ add_btn.click(
439
+ add_batch,
440
+ [batch_number, final_composition, generated_batches, block_lines],
441
+ [final_audio_output, final_plot_output, final_file_output, final_composition, generated_batches, block_lines]
442
+ )
443
+ remove_btn.click(
444
+ remove_batch,
445
+ [batch_number, num_gen_tokens, final_composition, generated_batches, block_lines],
446
+ [final_audio_output, final_plot_output, final_file_output, final_composition, generated_batches, block_lines]
447
+ )
448
+ clear_btn.click(clear, inputs=None,
449
+ outputs=[final_audio_output, final_plot_output, final_file_output, final_composition, block_lines])
450
+
451
+ demo.launch()