#==================================================================== # https://huggingface.co/spaces/asigalov61/Orpheus-Music-Transformer #==================================================================== """ Orpheus Music Transformer Gradio App - Single Model, Simplified Version SOTA 8k multi-instrumental music transformer trained on 2.31M+ high-quality MIDIs Using one model which was trained for 3 full epochs" """ import os os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1" import time as reqtime import datetime from pytz import timezone import torch import matplotlib.pyplot as plt import gradio as gr import spaces from huggingface_hub import hf_hub_download import TMIDIX from midi_to_colab_audio import midi_to_colab_audio from x_transformer_2_3_1 import TransformerWrapper, AutoregressiveWrapper, Decoder # ----------------------------- # CONFIGURATION & GLOBALS # ----------------------------- SEP = '=' * 70 PDT = timezone('US/Pacific') MODEL_CHECKPOINT = 'Orpheus_Music_Transformer_Trained_Model_26002_steps_0.4232_loss_0.877_acc.pth' SOUDFONT_PATH = 'SGM-v2.01-YamahaGrand-Guit-Bass-v2.7.sf2' NUM_OUT_BATCHES = 8 PREVIEW_LENGTH = 120 # in tokens # ----------------------------- # PRINT START-UP INFO # ----------------------------- def print_sep(): print(SEP) print_sep() print("Orpheus Music Transformer Gradio App") print_sep() print("Loading modules...") # ----------------------------- # ENVIRONMENT & PyTorch Settings # ----------------------------- os.environ['USE_FLASH_ATTENTION'] = '1' torch.set_float32_matmul_precision('high') torch.backends.cuda.matmul.allow_tf32 = True torch.backends.cudnn.allow_tf32 = True torch.backends.cuda.enable_mem_efficient_sdp(True) torch.backends.cuda.enable_math_sdp(True) torch.backends.cuda.enable_flash_sdp(True) torch.backends.cuda.enable_cudnn_sdp(True) print_sep() print("PyTorch version:", torch.__version__) print("Done loading modules!") print_sep() # ----------------------------- # MODEL INITIALIZATION # ----------------------------- print_sep() print("Instantiating model...") device_type = 'cuda' dtype = 'bfloat16' ptdtype = {'bfloat16': torch.bfloat16, 'float16': torch.float16}[dtype] ctx = torch.amp.autocast(device_type=device_type, dtype=ptdtype) SEQ_LEN = 8192 PAD_IDX = 18819 model = TransformerWrapper( num_tokens=PAD_IDX + 1, max_seq_len=SEQ_LEN, attn_layers=Decoder( dim=2048, depth=8, heads=32, rotary_pos_emb=True, attn_flash=True ) ) model = AutoregressiveWrapper(model, ignore_index=PAD_IDX, pad_value=PAD_IDX) print_sep() print("Loading model checkpoint...") checkpoint = hf_hub_download( repo_id='asigalov61/Orpheus-Music-Transformer', filename=MODEL_CHECKPOINT ) model.load_state_dict(torch.load(checkpoint, map_location='cuda', weights_only=True)) model = torch.compile(model, mode='max-autotune') print_sep() print("Done!") print("Model will use", dtype, "precision...") print_sep() model.cuda() model.eval() # ----------------------------- # HELPER FUNCTIONS # ----------------------------- def render_midi_output(final_composition): """Generate MIDI score, plot, and audio from final composition.""" fname, midi_score = save_midi(final_composition) time_val = midi_score[-1][1] / 1000 # seconds marker from last note midi_plot = TMIDIX.plot_ms_SONG( midi_score, plot_title='Orpheus Music Transformer Composition', block_lines_times_list=[], return_plt=True ) midi_audio = midi_to_colab_audio( fname + '.mid', soundfont_path=SOUDFONT_PATH, sample_rate=16000, output_for_gradio=True ) return (16000, midi_audio), midi_plot, fname + '.mid', time_val # ----------------------------- # MIDI PROCESSING FUNCTIONS # ----------------------------- def load_midi(input_midi): """Process the input MIDI file and create a token sequence using without velocity logic.""" raw_score = TMIDIX.midi2single_track_ms_score(input_midi.name) escore_notes = TMIDIX.advanced_score_processor(raw_score, return_enhanced_score_notes=True, apply_sustain=True) escore_notes = TMIDIX.augment_enhanced_score_notes(escore_notes[0], sort_drums_last=True) dscore = TMIDIX.delta_score_notes(escore_notes) dcscore = TMIDIX.chordify_score([d[1:] for d in dscore]) melody_chords = [18816] #======================================================= # MAIN PROCESSING CYCLE #======================================================= for i, c in enumerate(dcscore): delta_time = c[0][0] melody_chords.append(delta_time) for e in c: #======================================================= # Durations dur = max(1, min(255, e[1])) # Patches pat = max(0, min(128, e[5])) # Pitches ptc = max(1, min(127, e[3])) # Velocities # Calculating octo-velocity vel = max(8, min(127, e[4])) velocity = round(vel / 15)-1 #======================================================= # FINAL NOTE SEQ #======================================================= # Writing final note pat_ptc = (128 * pat) + ptc dur_vel = (8 * dur) + velocity melody_chords.extend([pat_ptc+256, dur_vel+16768]) return melody_chords def save_midi(tokens): """Convert token sequence back to a MIDI score and write it using TMIDIX (without velocity). The output MIDI file name incorporates a date-time stamp. """ time = 0 dur = 1 vel = 90 pitch = 60 channel = 0 patch = 0 patches = [-1] * 16 channels = [0] * 16 channels[9] = 1 song_f = [] for ss in tokens: if 0 <= ss < 256: time += ss * 16 if 256 <= ss < 16768: patch = (ss-256) // 128 if patch < 128: if patch not in patches: if 0 in channels: cha = channels.index(0) channels[cha] = 1 else: cha = 15 patches[cha] = patch channel = patches.index(patch) else: channel = patches.index(patch) if patch == 128: channel = 9 pitch = (ss-256) % 128 if 16768 <= ss < 18816: dur = ((ss-16768) // 8) * 16 vel = (((ss-16768) % 8)+1) * 15 song_f.append(['note', time, dur, channel, pitch, vel ]) patches = [0 if x==-1 else x for x in patches] # Generate a time stamp using the PDT timezone. timestamp = datetime.datetime.now(PDT).strftime("%Y%m%d_%H%M%S") fname = f"Orpheus-Music-Transformer-Composition" TMIDIX.Tegridy_ms_SONG_to_MIDI_Converter( song_f, output_signature='Orpheus Music Transformer', output_file_name=fname, track_name='Project Los Angeles', list_of_MIDI_patches=patches, verbose=False ) return fname, song_f # ----------------------------- # MUSIC GENERATION FUNCTION (Combined) # ----------------------------- @spaces.GPU def generate_music(prime, num_gen_tokens, num_mem_tokens, num_gen_batches, model_temperature): """Generate music tokens given prime tokens and parameters.""" inputs = prime[-num_mem_tokens:] if prime else [18816] print("Generating...") inp = torch.LongTensor([inputs] * num_gen_batches).cuda() with ctx: out = model.generate( inp, num_gen_tokens, temperature=model_temperature, return_prime=False, verbose=False ) print("Done!") print_sep() return out.tolist() def generate_music_and_state(input_midi, num_prime_tokens, num_gen_tokens, num_mem_tokens, model_temperature, final_composition, generated_batches, block_lines): """ Generate tokens using the model, update the composition state, and prepare outputs. This function combines seed loading, token generation, and UI output packaging. """ print_sep() print("Request start time:", datetime.datetime.now(PDT).strftime("%Y-%m-%d %H:%M:%S")) print('=' * 70) if input_midi is not None: fn = os.path.basename(input_midi.name) fn1 = fn.split('.')[0] print('Input file name:', fn) print('Num prime tokens:', num_prime_tokens) print('Num gen tokens:', num_gen_tokens) print('Num mem tokens:', num_mem_tokens) print('Model temp:', model_temperature) print('=' * 70) # Load seed from MIDI if there is no existing composition. if not final_composition and input_midi is not None: final_composition = load_midi(input_midi)[:num_prime_tokens] midi_fname, midi_score = save_midi(final_composition) # Use the last note's time as a marker. block_lines.append(midi_score[-1][1] / 1000 if final_composition else 0) batched_gen_tokens = generate_music(final_composition, num_gen_tokens, num_mem_tokens, NUM_OUT_BATCHES, model_temperature) output_batches = [] for i, tokens in enumerate(batched_gen_tokens): preview_tokens = final_composition[-PREVIEW_LENGTH:] midi_fname, midi_score = save_midi(preview_tokens + tokens) plot_kwargs = {'plot_title': f'Batch # {i}', 'return_plt': True} if len(final_composition) > PREVIEW_LENGTH: plot_kwargs['preview_length_in_notes'] = len([t for t in preview_tokens if 256 <= t < 16768]) midi_plot = TMIDIX.plot_ms_SONG(midi_score, **plot_kwargs) midi_audio = midi_to_colab_audio(midi_fname + '.mid', soundfont_path=SOUDFONT_PATH, sample_rate=16000, output_for_gradio=True) output_batches.append([(16000, midi_audio), midi_plot, tokens]) # Update generated_batches (for use by add/remove functions) generated_batches = batched_gen_tokens print("Request end time:", datetime.datetime.now(PDT).strftime("%Y-%m-%d %H:%M:%S")) print_sep() # Flatten outputs: states then audio and plots for each batch. outputs_flat = [] for batch in output_batches: outputs_flat.extend([batch[0], batch[1]]) return [final_composition, generated_batches, block_lines] + outputs_flat # ----------------------------- # BATCH HANDLING FUNCTIONS # ----------------------------- def add_batch(batch_number, final_composition, generated_batches, block_lines): """Add tokens from the specified batch to the final composition and update outputs.""" if generated_batches: final_composition.extend(generated_batches[batch_number]) midi_fname, midi_score = save_midi(final_composition) block_lines.append(midi_score[-1][1] / 1000 if final_composition else 0) TMIDIX.Tegridy_ms_SONG_to_MIDI_Converter( midi_score, output_signature='Orpheus Music Transformer', output_file_name=midi_fname, track_name='Project Los Angeles', list_of_MIDI_patches=[0]*16, verbose=False ) midi_plot = TMIDIX.plot_ms_SONG( midi_score, plot_title='Orpheus Music Transformer Composition', block_lines_times_list=block_lines[:-1], return_plt=True ) midi_audio = midi_to_colab_audio(midi_fname + '.mid', soundfont_path=SOUDFONT_PATH, sample_rate=16000, output_for_gradio=True) print("Added batch #", batch_number) print_sep() return (16000, midi_audio), midi_plot, midi_fname + '.mid', final_composition, generated_batches, block_lines else: return None, None, None, [], [], [] def remove_batch(batch_number, num_tokens, final_composition, generated_batches, block_lines): """Remove tokens from the final composition and update outputs.""" if final_composition and len(final_composition) > num_tokens: final_composition = final_composition[:-num_tokens] if block_lines: block_lines.pop() midi_fname, midi_score = save_midi(final_composition) midi_plot = TMIDIX.plot_ms_SONG( midi_score, plot_title='Orpheus Music Transformer Composition', block_lines_times_list=block_lines[:-1], return_plt=True ) midi_audio = midi_to_colab_audio(midi_fname + '.mid', soundfont_path=SOUDFONT_PATH, sample_rate=16000, output_for_gradio=True) print("Removed batch #", batch_number) print_sep() return (16000, midi_audio), midi_plot, midi_fname + '.mid', final_composition, generated_batches, block_lines else: return None, None, None, [], [], [] def clear(): """Clear outputs and reset state.""" return None, None, None, [], [] def reset(final_composition=[], generated_batches=[], block_lines=[]): """Reset composition state.""" return [], [], [] # ----------------------------- # GRADIO INTERFACE SETUP # ----------------------------- with gr.Blocks() as demo: gr.Markdown("

Orpheus Music Transformer

") gr.Markdown("

SOTA 8k multi-instrumental music transformer trained on 2.31M+ high-quality MIDIs

") gr.Markdown("

This is a WIP preview. Please check back for final release soon.

") gr.HTML(""" Check out Godzilla MIDI Dataset on Hugging Face

Duplicate in Hugging Face

for faster execution and endless generation! """) gr.Markdown("## Key Features") gr.Markdown(""" - **Efficient Architecture with RoPE**: Compact and very fast 479M full attention autoregressive transformer with RoPE. - **Extended Sequence Length**: 8k tokens that comfortably fit most music compositions and facilitate long-term music structure generation. - **Premium Training Data**: Exclusively trained on high-quality MIDIs from the Godzilla MIDI dataset. - **Optimized MIDI Encoding**: Extremely efficient MIDI representation using only 3 tokens per note and 7 tokens per tri-chord. - **Distinct Encoding Order**: Features a unique duration/velocity last MIDI encoding order for refined musical expression. - **Full-Range Instrumental Learning**: True full-range MIDI instruments encoding enabling the model to learn each instrument separately. - **Natural Composition Endings**: Outro tokens that help generate smooth and natural musical conclusions. """) # Global state variables for composition final_composition = gr.State([]) generated_batches = gr.State([]) block_lines = gr.State([]) gr.Markdown("## Upload seed MIDI or click 'Generate' for a random output") input_midi = gr.File(label="Input MIDI", file_types=[".midi", ".mid", ".kar"]) input_midi.upload(reset, [final_composition, generated_batches, block_lines], [final_composition, generated_batches, block_lines]) gr.Markdown("## Generate") num_prime_tokens = gr.Slider(16, 7168, value=7168, step=1, label="Number of prime tokens") num_gen_tokens = gr.Slider(16, 1024, value=512, step=1, label="Number of tokens to generate") num_mem_tokens = gr.Slider(16, 8192, value=8192, step=1, label="Number of memory tokens") model_temperature = gr.Slider(0.1, 1, value=0.9, step=0.01, label="Model temperature") generate_btn = gr.Button("Generate", variant="primary") gr.Markdown("## Batch Previews") outputs = [final_composition, generated_batches, block_lines] # Two outputs (audio and plot) for each batch for i in range(NUM_OUT_BATCHES): with gr.Tab(f"Batch # {i}"): audio_output = gr.Audio(label=f"Batch # {i} MIDI Audio", format="mp3") plot_output = gr.Plot(label=f"Batch # {i} MIDI Plot") outputs.extend([audio_output, plot_output]) generate_btn.click( generate_music_and_state, [input_midi, num_prime_tokens, num_gen_tokens, num_mem_tokens, model_temperature, final_composition, generated_batches, block_lines], outputs ) gr.Markdown("## Add/Remove Batch") batch_number = gr.Slider(0, NUM_OUT_BATCHES - 1, value=0, step=1, label="Batch number to add/remove") add_btn = gr.Button("Add batch", variant="primary") remove_btn = gr.Button("Remove batch", variant="stop") clear_btn = gr.ClearButton() final_audio_output = gr.Audio(label="Final MIDI audio", format="mp3") final_plot_output = gr.Plot(label="Final MIDI plot") final_file_output = gr.File(label="Final MIDI file") add_btn.click( add_batch, [batch_number, final_composition, generated_batches, block_lines], [final_audio_output, final_plot_output, final_file_output, final_composition, generated_batches, block_lines] ) remove_btn.click( remove_batch, [batch_number, num_gen_tokens, final_composition, generated_batches, block_lines], [final_audio_output, final_plot_output, final_file_output, final_composition, generated_batches, block_lines] ) clear_btn.click(clear, inputs=None, outputs=[final_audio_output, final_plot_output, final_file_output, final_composition, block_lines]) demo.launch()