asigalov61 commited on
Commit
419ed32
·
verified ·
1 Parent(s): 19ed5a0

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +7 -21
app.py CHANGED
@@ -32,9 +32,9 @@ from x_transformer_2_3_1 import TransformerWrapper, AutoregressiveWrapper, Decod
32
  SEP = '=' * 70
33
  PDT = timezone('US/Pacific')
34
 
35
- MODEL_CHECKPOINT = 'Orpheus_Music_Transformer_No_Velocity_Trained_Model_21113_steps_0.3454_loss_0.895_acc.pth'
36
  SOUDFONT_PATH = 'SGM-v2.01-YamahaGrand-Guit-Bass-v2.7.sf2'
37
- NUM_OUT_BATCHES = 12
38
  PREVIEW_LENGTH = 120 # in tokens
39
 
40
  # -----------------------------
@@ -77,15 +77,15 @@ dtype = 'bfloat16'
77
  ptdtype = {'bfloat16': torch.bfloat16, 'float16': torch.float16}[dtype]
78
  ctx = torch.amp.autocast(device_type=device_type, dtype=ptdtype)
79
 
80
- SEQ_LEN = 4096
81
- PAD_IDX = 384
82
 
83
  model = TransformerWrapper(
84
  num_tokens=PAD_IDX + 1,
85
  max_seq_len=SEQ_LEN,
86
  attn_layers=Decoder(
87
  dim=2048,
88
- depth=16,
89
  heads=32,
90
  rotary_pos_emb=True,
91
  attn_flash=True
@@ -246,14 +246,6 @@ def generate_music_and_state(input_midi, num_prime_tokens, num_gen_tokens, num_m
246
  final_composition = load_midi(input_midi)[:num_prime_tokens]
247
  midi_fname, midi_score = save_midi(final_composition)
248
  # Use the last note's time as a marker.
249
- TMIDIX.Tegridy_ms_SONG_to_MIDI_Converter(
250
- midi_score,
251
- output_signature='Orpheus Music Transformer',
252
- output_file_name=midi_fname,
253
- track_name='Project Los Angeles',
254
- list_of_MIDI_patches=[0]*16,
255
- verbose=False
256
- )
257
  block_lines.append(midi_score[-1][1] / 1000 if final_composition else 0)
258
 
259
  batched_gen_tokens = generate_music(final_composition, num_gen_tokens, num_mem_tokens,
@@ -264,16 +256,10 @@ def generate_music_and_state(input_midi, num_prime_tokens, num_gen_tokens, num_m
264
  preview_tokens = final_composition[-PREVIEW_LENGTH:]
265
  midi_fname, midi_score = save_midi(preview_tokens + tokens, batch_number=i)
266
  plot_kwargs = {'plot_title': f'Batch # {i}', 'return_plt': True}
 
267
  if len(final_composition) > PREVIEW_LENGTH:
268
  plot_kwargs['preview_length_in_notes'] = len([t for t in preview_tokens if t > 256])
269
- TMIDIX.Tegridy_ms_SONG_to_MIDI_Converter(
270
- midi_score,
271
- output_signature='Orpheus Music Transformer',
272
- output_file_name=midi_fname,
273
- track_name='Project Los Angeles',
274
- list_of_MIDI_patches=[0]*16,
275
- verbose=False
276
- )
277
  midi_plot = TMIDIX.plot_ms_SONG(midi_score, **plot_kwargs)
278
  midi_audio = midi_to_colab_audio(midi_fname + '.mid',
279
  soundfont_path=SOUDFONT_PATH,
 
32
  SEP = '=' * 70
33
  PDT = timezone('US/Pacific')
34
 
35
+ MODEL_CHECKPOINT = 'Orpheus_Music_Transformer_Trained_Model_26002_steps_0.4232_loss_0.877_acc.pth'
36
  SOUDFONT_PATH = 'SGM-v2.01-YamahaGrand-Guit-Bass-v2.7.sf2'
37
+ NUM_OUT_BATCHES = 8
38
  PREVIEW_LENGTH = 120 # in tokens
39
 
40
  # -----------------------------
 
77
  ptdtype = {'bfloat16': torch.bfloat16, 'float16': torch.float16}[dtype]
78
  ctx = torch.amp.autocast(device_type=device_type, dtype=ptdtype)
79
 
80
+ SEQ_LEN = 8192
81
+ PAD_IDX = 18819
82
 
83
  model = TransformerWrapper(
84
  num_tokens=PAD_IDX + 1,
85
  max_seq_len=SEQ_LEN,
86
  attn_layers=Decoder(
87
  dim=2048,
88
+ depth=8,
89
  heads=32,
90
  rotary_pos_emb=True,
91
  attn_flash=True
 
246
  final_composition = load_midi(input_midi)[:num_prime_tokens]
247
  midi_fname, midi_score = save_midi(final_composition)
248
  # Use the last note's time as a marker.
 
 
 
 
 
 
 
 
249
  block_lines.append(midi_score[-1][1] / 1000 if final_composition else 0)
250
 
251
  batched_gen_tokens = generate_music(final_composition, num_gen_tokens, num_mem_tokens,
 
256
  preview_tokens = final_composition[-PREVIEW_LENGTH:]
257
  midi_fname, midi_score = save_midi(preview_tokens + tokens, batch_number=i)
258
  plot_kwargs = {'plot_title': f'Batch # {i}', 'return_plt': True}
259
+
260
  if len(final_composition) > PREVIEW_LENGTH:
261
  plot_kwargs['preview_length_in_notes'] = len([t for t in preview_tokens if t > 256])
262
+
 
 
 
 
 
 
 
263
  midi_plot = TMIDIX.plot_ms_SONG(midi_score, **plot_kwargs)
264
  midi_audio = midi_to_colab_audio(midi_fname + '.mid',
265
  soundfont_path=SOUDFONT_PATH,