Spaces:
Running
Running
| import argparse | |
| import glob | |
| import os.path | |
| import gradio as gr | |
| import pickle | |
| import tqdm | |
| import json | |
| import MIDI | |
| from midi_synthesizer import synthesis | |
| import copy | |
| from collections import Counter | |
| import random | |
| import statistics | |
| import matplotlib.pyplot as plt | |
| #========================================================================================================== | |
| in_space = os.getenv("SYSTEM") == "spaces" | |
| #========================================================================================================== | |
| def match_midi(midi, progress=gr.Progress()): | |
| print('=' * 70) | |
| print('Loading MIDI file...') | |
| #================================================== | |
| score = MIDI.midi2score(midi) | |
| events_matrix = [] | |
| track_count = 0 | |
| for s in score: | |
| if track_count > 0: | |
| track = s | |
| track.sort(key=lambda x: x[1]) | |
| events_matrix.extend(track) | |
| else: | |
| midi_ticks = s | |
| track_count += 1 | |
| events_matrix.sort(key=lambda x: x[1]) | |
| mult_pitches_counts = [] | |
| for i in range(-6, 6): | |
| events_matrix1 = [] | |
| for e in events_matrix: | |
| ev = copy.deepcopy(e) | |
| if e[0] == 'note': | |
| if e[3] == 9: | |
| ev[4] = ((e[4] % 128) + 128) | |
| else: | |
| ev[4] = ((e[4] % 128) + i) | |
| events_matrix1.append(ev) | |
| pitches_counts = [[y[0],y[1]] for y in Counter([y[4] for y in events_matrix1 if y[0] == 'note']).most_common()] | |
| pitches_counts.sort(key=lambda x: x[0], reverse=True) | |
| mult_pitches_counts.append(pitches_counts) | |
| patches_list = sorted(list(set([y[3] for y in events_matrix if y[0] == 'patch_change']))) | |
| #================================================== | |
| ms_score = MIDI.midi2ms_score(midi) | |
| ms_events_matrix = [] | |
| itrack1 = 1 | |
| while itrack1 < len(ms_score): | |
| for event in ms_score[itrack1]: | |
| if event[0] == 'note': | |
| ms_events_matrix.append(event) | |
| itrack1 += 1 | |
| ms_events_matrix.sort(key=lambda x: x[1]) | |
| chords = [] | |
| pe = ms_events_matrix[0] | |
| cho = [] | |
| for e in ms_events_matrix: | |
| if (e[1] - pe[1]) == 0: | |
| if e[3] != 9: | |
| if (e[4] % 12) not in cho: | |
| cho.append(e[4] % 12) | |
| else: | |
| if len(cho) > 0: | |
| chords.append(sorted(cho)) | |
| cho = [] | |
| if e[3] != 9: | |
| if (e[4] % 12) not in cho: | |
| cho.append(e[4] % 12) | |
| pe = e | |
| if len(cho) > 0: | |
| chords.append(sorted(cho)) | |
| ms_chords_counts = sorted([[list(key), val] for key,val in Counter([tuple(c) for c in chords if len(c) > 1]).most_common()], reverse=True, key = lambda x: x[1]) | |
| times = [] | |
| pt = ms_events_matrix[0][1] | |
| start = True | |
| for e in ms_events_matrix: | |
| if (e[1]-pt) != 0 or start == True: | |
| times.append((e[1]-pt)) | |
| start = False | |
| pt = e[1] | |
| durs = [e[2] for e in ms_events_matrix] | |
| vels = [e[5] for e in ms_events_matrix] | |
| avg_time = int(sum(times) / len(times)) | |
| avg_dur = int(sum(durs) / len(durs)) | |
| mode_time = statistics.mode(times) | |
| mode_dur = statistics.mode(durs) | |
| median_time = int(statistics.median(times)) | |
| median_dur = int(statistics.median(durs)) | |
| #================================================== | |
| print('=' * 70) | |
| print('Done!') | |
| print('=' * 70) | |
| #========================================================================================================== | |
| #@title MIDI Pitches Search | |
| #@markdown Match ratio control option | |
| maximum_match_ratio_to_search_for = 1 #@param {type:"slider", min:0, max:1, step:0.01} | |
| #@markdown MIDI pitches search options | |
| pitches_counts_cutoff_threshold_ratio = 0 #@param {type:"slider", min:0, max:1, step:0.05} | |
| search_transposed_pitches = False #@param {type:"boolean"} | |
| skip_exact_matches = True #@param {type:"boolean"} | |
| #@markdown Additional search options | |
| add_pitches_counts_ratios = False #@param {type:"boolean"} | |
| add_timings_ratios = False #@param {type:"boolean"} | |
| add_durations_ratios = False #@param {type:"boolean"} | |
| print('=' * 70) | |
| print('MIDI Pitches Search') | |
| print('=' * 70) | |
| final_ratios = [] | |
| for d in progress.tqdm(meta_data): | |
| p_counts = d[1][10][1] | |
| p_counts.sort(reverse = True, key = lambda x: x[1]) | |
| max_p_count = p_counts[0][1] | |
| trimmed_p_counts = [y for y in p_counts if y[1] >= (max_p_count * pitches_counts_cutoff_threshold_ratio)] | |
| total_p_counts = sum([y[1] for y in trimmed_p_counts]) | |
| if search_transposed_pitches: | |
| search_pitches = mult_pitches_counts | |
| else: | |
| search_pitches = [mult_pitches_counts[6]] | |
| #=================================================== | |
| ratios_list = [] | |
| #=================================================== | |
| atrat = [0] | |
| if add_timings_ratios: | |
| source_times = [avg_time, | |
| median_time, | |
| mode_time] | |
| match_times = meta_data[0][1][3][1] | |
| times_ratios = [] | |
| for i in range(len(source_times)): | |
| maxtratio = max(source_times[i], match_times[i]) | |
| mintratio = min(source_times[i], match_times[i]) | |
| times_ratios.append(mintratio / maxtratio) | |
| avg_times_ratio = sum(times_ratios) / len(times_ratios) | |
| atrat[0] = avg_times_ratio | |
| #=================================================== | |
| adrat = [0] | |
| if add_durations_ratios: | |
| source_durs = [avg_dur, | |
| median_dur, | |
| mode_dur] | |
| match_durs = meta_data[0][1][4][1] | |
| durs_ratios = [] | |
| for i in range(len(source_durs)): | |
| maxtratio = max(source_durs[i], match_durs[i]) | |
| mintratio = min(source_durs[i], match_durs[i]) | |
| durs_ratios.append(mintratio / maxtratio) | |
| avg_durs_ratio = sum(durs_ratios) / len(durs_ratios) | |
| adrat[0] = avg_durs_ratio | |
| #=================================================== | |
| for m in search_pitches: | |
| sprat = [] | |
| m.sort(reverse = True, key = lambda x: x[1]) | |
| max_pitches_count = m[0][1] | |
| trimmed_pitches_counts = [y for y in m if y[1] >= (max_pitches_count * pitches_counts_cutoff_threshold_ratio)] | |
| total_pitches_counts = sum([y[1] for y in trimmed_pitches_counts]) | |
| same_pitches = set([T[0] for T in trimmed_p_counts]) & set([m[0] for m in trimmed_pitches_counts]) | |
| num_same_pitches = len(same_pitches) | |
| if num_same_pitches == len(trimmed_pitches_counts): | |
| same_pitches_ratio = (num_same_pitches / len(trimmed_p_counts)) | |
| else: | |
| same_pitches_ratio = (num_same_pitches / max(len(trimmed_p_counts), len(trimmed_pitches_counts))) | |
| if skip_exact_matches: | |
| if same_pitches_ratio == 1: | |
| same_pitches_ratio = 0 | |
| sprat.append(same_pitches_ratio) | |
| #=================================================== | |
| spcrat = [0] | |
| if add_pitches_counts_ratios: | |
| same_trimmed_p_counts = sorted([T for T in trimmed_p_counts if T[0] in same_pitches], reverse = True) | |
| same_trimmed_pitches_counts = sorted([T for T in trimmed_pitches_counts if T[0] in same_pitches], reverse = True) | |
| same_trimmed_p_counts_ratios = [[s[0], s[1] / total_p_counts] for s in same_trimmed_p_counts] | |
| same_trimmed_pitches_counts_ratios = [[s[0], s[1] / total_pitches_counts] for s in same_trimmed_pitches_counts] | |
| same_pitches_counts_ratios = [] | |
| for i in range(len(same_trimmed_p_counts_ratios)): | |
| mincratio = min(same_trimmed_p_counts_ratios[i][1], same_trimmed_pitches_counts_ratios[i][1]) | |
| maxcratio = max(same_trimmed_p_counts_ratios[i][1], same_trimmed_pitches_counts_ratios[i][1]) | |
| same_pitches_counts_ratios.append([same_trimmed_p_counts_ratios[i][0], mincratio / maxcratio]) | |
| same_counts_ratios = [s[1] for s in same_pitches_counts_ratios] | |
| if len(same_counts_ratios) > 0: | |
| avg_same_pitches_counts_ratio = sum(same_counts_ratios) / len(same_counts_ratios) | |
| else: | |
| avg_same_pitches_counts_ratio = 0 | |
| spcrat[0] = avg_same_pitches_counts_ratio | |
| #=================================================== | |
| r_list = [sprat[0]] | |
| if add_pitches_counts_ratios: | |
| r_list.append(spcrat[0]) | |
| if add_timings_ratios: | |
| r_list.append(atrat[0]) | |
| if add_durations_ratios: | |
| r_list.append(adrat[0]) | |
| ratios_list.append(r_list) | |
| #=================================================== | |
| avg_ratios_list = [] | |
| for r in ratios_list: | |
| avg_ratios_list.append(sum(r) / len(r)) | |
| #=================================================== | |
| final_ratio = max(avg_ratios_list) | |
| if final_ratio > maximum_match_ratio_to_search_for: | |
| final_ratio = 0 | |
| final_ratios.append(final_ratio) | |
| #=================================================== | |
| max_ratio = max(final_ratios) | |
| max_ratio_index = final_ratios.index(max_ratio) | |
| print('FOUND') | |
| print('=' * 70) | |
| print('Match ratio', max_ratio) | |
| print('MIDI file name', meta_data[max_ratio_index][0]) | |
| print('=' * 70) | |
| fn = meta_data[max_ratio_index][0] | |
| #========================================================================================================== | |
| md = meta_data[max_ratio_index] | |
| mid_seq = md[1][17:-1] | |
| mid_seq_ticks = md[1][16][1] | |
| mdata = md[1][:16] | |
| txt_mdata = '' | |
| for m in mdata: | |
| txt_mdata += str(m[0]) + ':' + str(m[1]) | |
| txt_mdata += chr(10) | |
| x = [] | |
| y = [] | |
| c = [] | |
| colors = ['red', 'yellow', 'green', 'cyan', | |
| 'blue', 'pink', 'orange', 'purple', | |
| 'gray', 'white', 'gold', 'silver', | |
| 'lightgreen', 'indigo', 'maroon', 'turquoise'] | |
| for s in [m for m in mid_seq if m[0] == 'note']: | |
| x.append(s[1]) | |
| y.append(s[4]) | |
| c.append(colors[s[3]]) | |
| plt.close() | |
| plt.figure(figsize=(14,5)) | |
| ax=plt.axes(title='MIDI Search Plot') | |
| ax.set_facecolor('black') | |
| plt.scatter(x,y, c=c) | |
| plt.xlabel("Time") | |
| plt.ylabel("Pitch") | |
| with open(f"output.mid", 'wb') as f: | |
| f.write(MIDI.score2midi([mid_seq_ticks, mid_seq])) | |
| audio = synthesis(MIDI.score2opus([mid_seq_ticks, mid_seq]), soundfont_path) | |
| yield txt_mdata, "MIDI-Match-Sample.mid", (44100, audio), plt | |
| #========================================================================================================== | |
| if __name__ == "__main__": | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument("--share", action="store_true", default=False, help="share gradio app") | |
| parser.add_argument("--port", type=int, default=7860, help="gradio server port") | |
| parser.add_argument("--max-gen", type=int, default=1024, help="max") | |
| opt = parser.parse_args() | |
| soundfont_path = "SGM-v2.01-YamahaGrand-Guit-Bass-v2.7.sf2" | |
| meta_data_path = "meta-data/LAMD_META_10000.pickle" | |
| print('Loading meta-data...') | |
| with open(meta_data_path, 'rb') as f: | |
| meta_data = pickle.load(f) | |
| print('Done!') | |
| app = gr.Blocks() | |
| with app: | |
| gr.Markdown("<h1 style='text-align: center; margin-bottom: 1rem'>MIDI Match</h1>") | |
| gr.Markdown("\n\n" | |
| "MIDI Match\n\n" | |
| "Demo for [MIDI Match](https://github.com/asigalov61)\n\n" | |
| "[Open In Colab]" | |
| "(https://colab.research.google.com/github/asigalov61/MIDI-Match/blob/main/demo.ipynb)" | |
| " for faster running and longer generation" | |
| ) | |
| gr.Markdown("# Upload any MIDI file to find its closest match") | |
| input_midi = gr.File(label="input midi", file_types=[".midi", ".mid"], type="binary") | |
| output_plot = gr.Plot(label="output midi match sample plot") | |
| output_audio = gr.Audio(label="output midi match sample audio", format="mp3", elem_id="midi_audio") | |
| output_midi = gr.File(label="output midi match sample file", file_types=[".mid"]) | |
| output_midi_seq = gr.Textbox(label="output midi match metadata") | |
| run_event = input_midi.upload(match_midi, [input_midi], | |
| [output_midi_seq, output_midi, output_audio, output_plot]) | |
| app.queue(1).launch(server_port=opt.port, share=opt.share, inbrowser=True) |