Spaces:
Running
on
Zero
Running
on
Zero
Update app.py
Browse files
app.py
CHANGED
@@ -32,9 +32,9 @@ from x_transformer_2_3_1 import TransformerWrapper, AutoregressiveWrapper, Decod
|
|
32 |
SEP = '=' * 70
|
33 |
PDT = timezone('US/Pacific')
|
34 |
|
35 |
-
MODEL_CHECKPOINT = '
|
36 |
SOUDFONT_PATH = 'SGM-v2.01-YamahaGrand-Guit-Bass-v2.7.sf2'
|
37 |
-
NUM_OUT_BATCHES =
|
38 |
PREVIEW_LENGTH = 120 # in tokens
|
39 |
|
40 |
# -----------------------------
|
@@ -77,15 +77,15 @@ dtype = 'bfloat16'
|
|
77 |
ptdtype = {'bfloat16': torch.bfloat16, 'float16': torch.float16}[dtype]
|
78 |
ctx = torch.amp.autocast(device_type=device_type, dtype=ptdtype)
|
79 |
|
80 |
-
SEQ_LEN =
|
81 |
-
PAD_IDX =
|
82 |
|
83 |
model = TransformerWrapper(
|
84 |
num_tokens=PAD_IDX + 1,
|
85 |
max_seq_len=SEQ_LEN,
|
86 |
attn_layers=Decoder(
|
87 |
dim=2048,
|
88 |
-
depth=
|
89 |
heads=32,
|
90 |
rotary_pos_emb=True,
|
91 |
attn_flash=True
|
@@ -246,14 +246,6 @@ def generate_music_and_state(input_midi, num_prime_tokens, num_gen_tokens, num_m
|
|
246 |
final_composition = load_midi(input_midi)[:num_prime_tokens]
|
247 |
midi_fname, midi_score = save_midi(final_composition)
|
248 |
# Use the last note's time as a marker.
|
249 |
-
TMIDIX.Tegridy_ms_SONG_to_MIDI_Converter(
|
250 |
-
midi_score,
|
251 |
-
output_signature='Orpheus Music Transformer',
|
252 |
-
output_file_name=midi_fname,
|
253 |
-
track_name='Project Los Angeles',
|
254 |
-
list_of_MIDI_patches=[0]*16,
|
255 |
-
verbose=False
|
256 |
-
)
|
257 |
block_lines.append(midi_score[-1][1] / 1000 if final_composition else 0)
|
258 |
|
259 |
batched_gen_tokens = generate_music(final_composition, num_gen_tokens, num_mem_tokens,
|
@@ -264,16 +256,10 @@ def generate_music_and_state(input_midi, num_prime_tokens, num_gen_tokens, num_m
|
|
264 |
preview_tokens = final_composition[-PREVIEW_LENGTH:]
|
265 |
midi_fname, midi_score = save_midi(preview_tokens + tokens, batch_number=i)
|
266 |
plot_kwargs = {'plot_title': f'Batch # {i}', 'return_plt': True}
|
|
|
267 |
if len(final_composition) > PREVIEW_LENGTH:
|
268 |
plot_kwargs['preview_length_in_notes'] = len([t for t in preview_tokens if t > 256])
|
269 |
-
|
270 |
-
midi_score,
|
271 |
-
output_signature='Orpheus Music Transformer',
|
272 |
-
output_file_name=midi_fname,
|
273 |
-
track_name='Project Los Angeles',
|
274 |
-
list_of_MIDI_patches=[0]*16,
|
275 |
-
verbose=False
|
276 |
-
)
|
277 |
midi_plot = TMIDIX.plot_ms_SONG(midi_score, **plot_kwargs)
|
278 |
midi_audio = midi_to_colab_audio(midi_fname + '.mid',
|
279 |
soundfont_path=SOUDFONT_PATH,
|
|
|
32 |
SEP = '=' * 70
|
33 |
PDT = timezone('US/Pacific')
|
34 |
|
35 |
+
MODEL_CHECKPOINT = 'Orpheus_Music_Transformer_Trained_Model_26002_steps_0.4232_loss_0.877_acc.pth'
|
36 |
SOUDFONT_PATH = 'SGM-v2.01-YamahaGrand-Guit-Bass-v2.7.sf2'
|
37 |
+
NUM_OUT_BATCHES = 8
|
38 |
PREVIEW_LENGTH = 120 # in tokens
|
39 |
|
40 |
# -----------------------------
|
|
|
77 |
ptdtype = {'bfloat16': torch.bfloat16, 'float16': torch.float16}[dtype]
|
78 |
ctx = torch.amp.autocast(device_type=device_type, dtype=ptdtype)
|
79 |
|
80 |
+
SEQ_LEN = 8192
|
81 |
+
PAD_IDX = 18819
|
82 |
|
83 |
model = TransformerWrapper(
|
84 |
num_tokens=PAD_IDX + 1,
|
85 |
max_seq_len=SEQ_LEN,
|
86 |
attn_layers=Decoder(
|
87 |
dim=2048,
|
88 |
+
depth=8,
|
89 |
heads=32,
|
90 |
rotary_pos_emb=True,
|
91 |
attn_flash=True
|
|
|
246 |
final_composition = load_midi(input_midi)[:num_prime_tokens]
|
247 |
midi_fname, midi_score = save_midi(final_composition)
|
248 |
# Use the last note's time as a marker.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
249 |
block_lines.append(midi_score[-1][1] / 1000 if final_composition else 0)
|
250 |
|
251 |
batched_gen_tokens = generate_music(final_composition, num_gen_tokens, num_mem_tokens,
|
|
|
256 |
preview_tokens = final_composition[-PREVIEW_LENGTH:]
|
257 |
midi_fname, midi_score = save_midi(preview_tokens + tokens, batch_number=i)
|
258 |
plot_kwargs = {'plot_title': f'Batch # {i}', 'return_plt': True}
|
259 |
+
|
260 |
if len(final_composition) > PREVIEW_LENGTH:
|
261 |
plot_kwargs['preview_length_in_notes'] = len([t for t in preview_tokens if t > 256])
|
262 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
263 |
midi_plot = TMIDIX.plot_ms_SONG(midi_score, **plot_kwargs)
|
264 |
midi_audio = midi_to_colab_audio(midi_fname + '.mid',
|
265 |
soundfont_path=SOUDFONT_PATH,
|