avans06 commited on
Commit
b114cd4
·
1 Parent(s): 80ab93c

feat: Implement stereo audio to MIDI transcription

Browse files

This commit introduces a new stereo processing workflow for audio-to-MIDI transcription, allowing the preservation of spatial information from stereo recordings. The previous implementation was limited to mono processing.

Scale MIDI velocities by 0.8 in Stereo Transcription to avoid loudness/clipping after merge

Applied `scale_instrument_velocity(scale=0.8)` during Stereo Transcription to prevent excessive loudness caused by summing left and right channel MIDI tracks. This helps maintain a more natural dynamic range, avoiding clipping and ensuring more consistent perceived volume after rendering to WAV/FLAC.

Files changed (2) hide show
  1. app.py +289 -36
  2. requirements.txt +2 -0
app.py CHANGED
@@ -1,14 +1,15 @@
1
  # =================================================================
2
  #
3
- # Merged and Integrated Script for Audio/MIDI Processing and Rendering
4
  #
5
  # This script combines two functionalities:
6
  # 1. Transcribing audio to MIDI using two methods:
7
  # a) A general-purpose model (basic-pitch by Spotify).
8
  # b) A model specialized for solo piano (ByteDance).
 
9
  # 2. Applying advanced transformations and re-rendering MIDI files using:
10
- # a) Standard SoundFonts via FluidSynth.
11
- # b) A custom 8-bit style synthesizer for a chiptune sound.
12
  #
13
  # The user can upload a Audio (e.g., WAV, MP3), or MIDI file.
14
  # - If an audio file is uploaded, it is first transcribed to MIDI using the selected method.
@@ -29,7 +30,7 @@
29
  #
30
  # pip install gradio torch pytz numpy scipy matplotlib networkx scikit-learn
31
  # pip install piano_transcription_inference huggingface_hub
32
- # pip install basic-pitch pretty_midi librosa
33
  #
34
  # =================================================================
35
  # Core modules:
@@ -42,6 +43,9 @@ import os
42
  import hashlib
43
  import time as reqtime
44
  import copy
 
 
 
45
 
46
  import torch
47
  import gradio as gr
@@ -60,7 +64,7 @@ import basic_pitch
60
  from basic_pitch.inference import predict
61
  from basic_pitch import ICASSP_2022_MODEL_PATH
62
 
63
- # --- Imports for 8-bit Synthesizer ---
64
  import pretty_midi
65
  import numpy as np
66
  from scipy import signal
@@ -158,18 +162,36 @@ def prepare_soundfonts():
158
  return ordered_soundfont_map
159
 
160
  # =================================================================================================
161
- # === 8-bit Style Synthesizer ===
162
  # =================================================================================================
163
  def synthesize_8bit_style(midi_data, waveform_type, envelope_type, decay_time_s, pulse_width, vibrato_rate, vibrato_depth, bass_boost_level, fs=44100):
164
  """
165
  Synthesizes an 8-bit style audio waveform from a PrettyMIDI object.
166
  This function generates waveforms manually instead of using a synthesizer like FluidSynth.
167
  Includes an optional sub-octave bass booster with adjustable level.
 
 
168
  """
169
  total_duration = midi_data.get_end_time()
170
- waveform = np.zeros(int(total_duration * fs) + fs)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
171
 
172
- for instrument in midi_data.instruments:
173
  for note in instrument.notes:
174
  freq = pretty_midi.note_number_to_hz(note.pitch)
175
  note_duration = note.end - note.start
@@ -222,13 +244,162 @@ def synthesize_8bit_style(midi_data, waveform_type, envelope_type, decay_time_s,
222
 
223
  start_sample = int(note.start * fs)
224
  end_sample = start_sample + num_samples
225
- if end_sample > len(waveform):
226
- end_sample = len(waveform)
227
  note_waveform = note_waveform[:end_sample-start_sample]
228
 
229
- waveform[start_sample:end_sample] += note_waveform
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
230
 
231
- return waveform
 
 
 
 
 
 
 
 
 
 
 
 
232
 
233
  # =================================================================================================
234
  # === Stage 1: Audio to MIDI Transcription Functions ===
@@ -254,7 +425,7 @@ def TranscribePianoAudio(input_file):
254
  # Use os.path.join to create a platform-independent directory path
255
  output_dir = os.path.join("output", "transcribed_piano_")
256
  out_mid_path = os.path.join(output_dir, fn1 + '.mid')
257
-
258
  # Check for the directory's existence and create it if necessary
259
  if not os.path.exists(output_dir):
260
  os.makedirs(output_dir)
@@ -412,7 +583,7 @@ def Render_MIDI(input_midi_path,
412
  escore = TMIDIX.merge_escore_notes(escore, merge_threshold=merge_misaligned_notes)
413
 
414
  escore = TMIDIX.augment_enhanced_score_notes(escore, timings_divider=1)
415
-
416
  first_note_index = [e[0] for e in raw_score[1]].index('note')
417
  cscore = TMIDIX.chordify_score([1000, escore])
418
 
@@ -420,7 +591,7 @@ def Render_MIDI(input_midi_path,
420
 
421
  aux_escore_notes = TMIDIX.augment_enhanced_score_notes(escore, sort_drums_last=True)
422
  song_description = TMIDIX.escore_notes_to_text_description(aux_escore_notes)
423
-
424
  print('Done!')
425
  print('=' * 70)
426
  print('Input MIDI metadata:', meta_data[:5])
@@ -472,7 +643,7 @@ def Render_MIDI(input_midi_path,
472
 
473
  if render_transpose_to_C4:
474
  output_score = TMIDIX.transpose_escore_notes_to_pitch(output_score, 60) # C4 is MIDI pitch 60
475
-
476
  if render_align == "Start Times":
477
  output_score = TMIDIX.recalculate_score_timings(output_score)
478
  output_score = TMIDIX.align_escore_notes_to_bars(output_score)
@@ -573,11 +744,12 @@ def Render_MIDI(input_midi_path,
573
  s8bit_bass_boost_level,
574
  fs=srate
575
  )
576
- # Normalize audio
577
  peak_val = np.max(np.abs(audio))
578
  if peak_val > 0:
579
  audio /= peak_val
580
- audio = (audio * 32767).astype(np.int16)
 
581
  except Exception as e:
582
  print(f"Error during 8-bit synthesis: {e}")
583
  return [None] * 7
@@ -603,7 +775,7 @@ def Render_MIDI(input_midi_path,
603
  with open(midi_to_render_path, 'rb') as f:
604
  midi_file_content = f.read()
605
 
606
- audio = midi_to_colab_audio(midi_file_content,
607
  soundfont_path=soundfont_path, # Use the dynamically found path
608
  sample_rate=srate,
609
  output_for_gradio=True
@@ -619,7 +791,7 @@ def Render_MIDI(input_midi_path,
619
 
620
  output_midi_summary = str(meta_data)
621
 
622
- return new_md5_hash, fn1, output_midi_summary, midi_to_render_path, (srate, audio), output_plot, song_description
623
 
624
  # =================================================================================================
625
  # === Main Application Logic ===
@@ -627,6 +799,7 @@ def Render_MIDI(input_midi_path,
627
 
628
  def process_and_render_file(input_file,
629
  # --- Transcription params ---
 
630
  transcription_method,
631
  onset_thresh, frame_thresh, min_note_len, min_freq, max_freq, infer_onsets_bool, melodia_trick_bool, multiple_bends_bool,
632
  # --- MIDI rendering params ---
@@ -645,14 +818,18 @@ def process_and_render_file(input_file,
645
  start_time = reqtime.time()
646
  if input_file is None:
647
  # Return a list of updates to clear all output fields
648
- num_outputs = 7
649
- return [gr.update(value=None)] * num_outputs
650
 
651
  # The input_file from gr.Audio(type="filepath") is now the direct path (a string),
652
  # not a temporary file object. We no longer need to access the .name attribute.
653
  input_file_path = input_file
654
  filename = os.path.basename(input_file_path)
655
  print(f"Processing new file: {filename}")
 
 
 
 
 
656
 
657
  # --- Step 1: Check file type and transcribe if necessary ---
658
  if filename.lower().endswith(('.mid', '.midi', '.kar')):
@@ -660,17 +837,86 @@ def process_and_render_file(input_file,
660
  midi_path_for_rendering = input_file_path
661
  else: #if filename.lower().endswith(('.wav', '.mp3'))
662
  print("Audio file detected. Starting transcription...")
663
- try:
664
- if transcription_method == "General Purpose":
665
- midi_path_for_rendering = TranscribeGeneralAudio(
666
- input_file_path, onset_thresh, frame_thresh, min_note_len,
667
- min_freq, max_freq, infer_onsets_bool, melodia_trick_bool, multiple_bends_bool
668
- )
669
- else: # Piano-Specific
670
- midi_path_for_rendering = TranscribePianoAudio(input_file_path)
671
- except Exception as e:
672
- print(f"An error occurred during transcription: {e}")
673
- raise gr.Error(f"Transcription Failed: {e}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
674
 
675
  # --- Step 2: Render the MIDI file with selected options ---
676
  print(f"Proceeding to render MIDI file: {os.path.basename(midi_path_for_rendering)}")
@@ -696,7 +942,7 @@ def update_ui_visibility(transcription_method, soundfont_choice):
696
  """
697
  is_general = (transcription_method == "General Purpose")
698
  is_8bit = (soundfont_choice == SYNTH_8_BIT_LABEL)
699
-
700
  return {
701
  general_transcription_settings: gr.update(visible=is_general),
702
  synth_8bit_settings: gr.update(visible=is_8bit),
@@ -751,8 +997,14 @@ if __name__ == "__main__":
751
  value="General Purpose",
752
  info="Choose 'General Purpose' for most music (vocals, etc.). Choose 'Piano-Specific' only for solo piano recordings."
753
  )
 
 
 
 
 
 
 
754
 
755
- # --- General Purpose (basic-pitch) Settings ---
756
  with gr.Accordion("General Purpose Transcription Settings", open=True) as general_transcription_settings:
757
  onset_threshold = gr.Slider(0.0, 1.0, value=0.5, step=0.05, label="On-set Threshold", info="Sensitivity for detecting note beginnings. Higher is stricter.")
758
  frame_threshold = gr.Slider(0.0, 1.0, value=0.3, step=0.05, label="Frame Threshold", info="Sensitivity for detecting active notes. Higher is stricter.")
@@ -775,7 +1027,7 @@ if __name__ == "__main__":
775
  # --- Dynamically create the list of choices ---
776
  soundfont_choices = [SYNTH_8_BIT_LABEL] + list(soundfonts_dict.keys())
777
  # Set a safe default value
778
- default_sf_choice = "SGM-v2.01-YamahaGrand-Guit-Bass-v2.7" if "SGM-v2.01-YamahaGrand-Guit-Bass-v2.7" in soundfonts_dict else soundfont_choices[0]
779
 
780
  soundfont_bank = gr.Dropdown(
781
  soundfont_choices,
@@ -831,6 +1083,7 @@ if __name__ == "__main__":
831
  # --- Define all input components for the click event ---
832
  all_inputs = [
833
  input_file,
 
834
  transcription_method,
835
  onset_threshold, frame_threshold, minimum_note_length, minimum_frequency, maximum_frequency,
836
  infer_onsets, melodia_trick, multiple_pitch_bends,
 
1
  # =================================================================
2
  #
3
+ # Merged and Integrated Script for Audio/MIDI Processing and Rendering (Stereo Enhanced)
4
  #
5
  # This script combines two functionalities:
6
  # 1. Transcribing audio to MIDI using two methods:
7
  # a) A general-purpose model (basic-pitch by Spotify).
8
  # b) A model specialized for solo piano (ByteDance).
9
+ # - Includes stereo processing by splitting channels, transcribing independently, and merging MIDI.
10
  # 2. Applying advanced transformations and re-rendering MIDI files using:
11
+ # a) Standard SoundFonts via FluidSynth (produces stereo audio).
12
+ # b) A custom 8-bit style synthesizer for a chiptune sound (updated for stereo output).
13
  #
14
  # The user can upload a Audio (e.g., WAV, MP3), or MIDI file.
15
  # - If an audio file is uploaded, it is first transcribed to MIDI using the selected method.
 
30
  #
31
  # pip install gradio torch pytz numpy scipy matplotlib networkx scikit-learn
32
  # pip install piano_transcription_inference huggingface_hub
33
+ # pip install basic-pitch pretty_midi librosa soundfile
34
  #
35
  # =================================================================
36
  # Core modules:
 
43
  import hashlib
44
  import time as reqtime
45
  import copy
46
+ import librosa
47
+ import pyloudnorm as pyln
48
+ import soundfile as sf
49
 
50
  import torch
51
  import gradio as gr
 
64
  from basic_pitch.inference import predict
65
  from basic_pitch import ICASSP_2022_MODEL_PATH
66
 
67
+ # --- Imports for 8-bit Synthesizer & MIDI Merging ---
68
  import pretty_midi
69
  import numpy as np
70
  from scipy import signal
 
162
  return ordered_soundfont_map
163
 
164
  # =================================================================================================
165
+ # === 8-bit Style Synthesizer (Stereo Enabled) ===
166
  # =================================================================================================
167
  def synthesize_8bit_style(midi_data, waveform_type, envelope_type, decay_time_s, pulse_width, vibrato_rate, vibrato_depth, bass_boost_level, fs=44100):
168
  """
169
  Synthesizes an 8-bit style audio waveform from a PrettyMIDI object.
170
  This function generates waveforms manually instead of using a synthesizer like FluidSynth.
171
  Includes an optional sub-octave bass booster with adjustable level.
172
+ Instruments are panned based on their order in the MIDI file.
173
+ Instrument 1 -> Left, Instrument 2 -> Right.
174
  """
175
  total_duration = midi_data.get_end_time()
176
+ # Initialize a stereo waveform buffer (2 channels: Left, Right)
177
+ waveform = np.zeros((2, int(total_duration * fs) + fs))
178
+
179
+ num_instruments = len(midi_data.instruments)
180
+
181
+ for i, instrument in enumerate(midi_data.instruments):
182
+ # --- Panning Logic ---
183
+ # Default to center-panned mono
184
+ pan_l, pan_r = 0.707, 0.707
185
+ if num_instruments == 2:
186
+ if i == 0: # First instrument panned left
187
+ pan_l, pan_r = 1.0, 0.0
188
+ elif i == 1: # Second instrument panned right
189
+ pan_l, pan_r = 0.0, 1.0
190
+ elif num_instruments > 2:
191
+ if i == 0: pan_l, pan_r = 1.0, 0.0 # Left
192
+ elif i == 1: pan_l, pan_r = 0.0, 1.0 # Right
193
+ # Other instruments remain centered
194
 
 
195
  for note in instrument.notes:
196
  freq = pretty_midi.note_number_to_hz(note.pitch)
197
  note_duration = note.end - note.start
 
244
 
245
  start_sample = int(note.start * fs)
246
  end_sample = start_sample + num_samples
247
+ if end_sample > waveform.shape[1]:
248
+ end_sample = waveform.shape[1]
249
  note_waveform = note_waveform[:end_sample-start_sample]
250
 
251
+ # Add the mono note waveform to the stereo buffer with panning
252
+ waveform[0, start_sample:end_sample] += note_waveform * pan_l
253
+ waveform[1, start_sample:end_sample] += note_waveform * pan_r
254
+
255
+ return waveform # Returns a (2, N) numpy array
256
+
257
+
258
+ def analyze_midi_velocity(midi_path):
259
+ midi = pretty_midi.PrettyMIDI(midi_path)
260
+ all_velocities = []
261
+
262
+ print(f"Analyzing velocity for MIDI: {midi_path}")
263
+ for i, instrument in enumerate(midi.instruments):
264
+ velocities = [note.velocity for note in instrument.notes]
265
+ all_velocities.extend(velocities)
266
+
267
+ if velocities:
268
+ print(f"Instrument {i} ({instrument.name}):")
269
+ print(f" Notes count: {len(velocities)}")
270
+ print(f" Velocity min: {min(velocities)}")
271
+ print(f" Velocity max: {max(velocities)}")
272
+ print(f" Velocity mean: {np.mean(velocities):.2f}")
273
+ else:
274
+ print(f"Instrument {i} ({instrument.name}): no notes found.")
275
+
276
+ if all_velocities:
277
+ print("\nOverall MIDI velocity stats:")
278
+ print(f" Total notes: {len(all_velocities)}")
279
+ print(f" Velocity min: {min(all_velocities)}")
280
+ print(f" Velocity max: {max(all_velocities)}")
281
+ print(f" Velocity mean: {np.mean(all_velocities):.2f}")
282
+ else:
283
+ print("No notes found in this MIDI.")
284
+
285
+
286
+ def scale_instrument_velocity(instrument, scale=0.8):
287
+ for note in instrument.notes:
288
+ note.velocity = max(1, min(127, int(note.velocity * scale)))
289
+
290
+
291
+ def normalize_loudness(audio_data, sample_rate, target_lufs=-23.0):
292
+ """
293
+ Normalizes the audio data to a target integrated loudness (LUFS).
294
+ This provides more consistent perceived volume than peak normalization.
295
+
296
+ Args:
297
+ audio_data (np.ndarray): The audio signal.
298
+ sample_rate (int): The sample rate of the audio.
299
+ target_lufs (float): The target loudness in LUFS. Defaults to -23.0,
300
+ a common standard for broadcast.
301
+
302
+ Returns:
303
+ np.ndarray: The loudness-normalized audio data.
304
+ """
305
+ try:
306
+ # 1. Measure the integrated loudness of the input audio
307
+ meter = pyln.Meter(sample_rate) # create meter
308
+ loudness = meter.integrated_loudness(audio_data)
309
+
310
+ # 2. Calculate the gain needed to reach the target loudness
311
+ # The gain is applied in the linear domain, so we convert from dB
312
+ loudness_gain_db = target_lufs - loudness
313
+ loudness_gain_linear = 10.0 ** (loudness_gain_db / 20.0)
314
+
315
+ # 3. Apply the gain
316
+ normalized_audio = audio_data * loudness_gain_linear
317
+
318
+ # 4. Final safety check: peak normalize to prevent clipping, just in case
319
+ # the loudness normalization results in peaks > 1.0
320
+ peak_val = np.max(np.abs(normalized_audio))
321
+ if peak_val > 1.0:
322
+ normalized_audio /= peak_val
323
+ print(f"Warning: Loudness normalization resulted in clipping. Audio was peak-normalized as a safeguard.")
324
+
325
+ print(f"Audio normalized from {loudness:.2f} LUFS to target {target_lufs} LUFS.")
326
+ return normalized_audio
327
+
328
+ except Exception as e:
329
+ print(f"Loudness normalization failed: {e}. Falling back to original audio.")
330
+ return audio_data
331
+
332
+
333
+ # =================================================================================================
334
+ # === MIDI Merging Function ===
335
+ # =================================================================================================
336
+ def merge_midis(midi_path_left, midi_path_right, output_path):
337
+ """
338
+ Merges two MIDI files into a single MIDI file. This robust version iterates
339
+ through ALL instruments in both MIDI files, ensuring no data is lost if the
340
+ source files are multi-instrumental.
341
+
342
+ It applies hard-left panning (Pan=0) to every instrument from the left MIDI
343
+ and hard-right panning (Pan=127) to every instrument from the right MIDI.
344
+ """
345
+ try:
346
+ analyze_midi_velocity(midi_path_left)
347
+ analyze_midi_velocity(midi_path_right)
348
+ midi_left = pretty_midi.PrettyMIDI(midi_path_left)
349
+ midi_right = pretty_midi.PrettyMIDI(midi_path_right)
350
+
351
+ merged_midi = pretty_midi.PrettyMIDI()
352
+
353
+ # --- Process ALL instruments from the left channel MIDI ---
354
+ if midi_left.instruments:
355
+ print(f"Found {len(midi_left.instruments)} instrument(s) in the left channel MIDI.")
356
+ # Use a loop to iterate through every instrument
357
+ for instrument in midi_left.instruments:
358
+ scale_instrument_velocity(instrument, scale=0.8)
359
+ # To avoid confusion, we can prefix the instrument name
360
+ instrument.name = f"Left - {instrument.name if instrument.name else 'Instrument'}"
361
+
362
+ # Create and add the Pan Left control change
363
+ # Create a Control Change event for Pan (controller number 10).
364
+ # Set its value to 0 for hard left panning.
365
+ # Add it at the very beginning of the track (time=0.0).
366
+ pan_left = pretty_midi.ControlChange(number=10, value=0, time=0.0)
367
+ # Use insert() to ensure the pan event is the very first one
368
+ instrument.control_changes.insert(0, pan_left)
369
+
370
+ # Append the fully processed instrument to the merged MIDI
371
+ merged_midi.instruments.append(instrument)
372
+
373
+ # --- Process ALL instruments from the right channel MIDI ---
374
+ if midi_right.instruments:
375
+ print(f"Found {len(midi_right.instruments)} instrument(s) in the right channel MIDI.")
376
+ # Use a loop here as well
377
+ for instrument in midi_right.instruments:
378
+ scale_instrument_velocity(instrument, scale=0.8)
379
+ instrument.name = f"Right - {instrument.name if instrument.name else 'Instrument'}"
380
+
381
+ # Create and add the Pan Right control change
382
+ # Create a Control Change event for Pan (controller number 10).
383
+ # Set its value to 127 for hard right panning.
384
+ # Add it at the very beginning of the track (time=0.0).
385
+ pan_right = pretty_midi.ControlChange(number=10, value=127, time=0.0)
386
+ instrument.control_changes.insert(0, pan_right)
387
+
388
+ merged_midi.instruments.append(instrument)
389
 
390
+ merged_midi.write(output_path)
391
+ print(f"Successfully merged all instruments and panned into '{os.path.basename(output_path)}'")
392
+ analyze_midi_velocity(output_path)
393
+ return output_path
394
+
395
+ except Exception as e:
396
+ print(f"Error merging MIDI files: {e}")
397
+ # Fallback logic remains the same
398
+ if os.path.exists(midi_path_left):
399
+ print("Fallback: Using only the left channel MIDI.")
400
+ return midi_path_left
401
+ return None
402
+
403
 
404
  # =================================================================================================
405
  # === Stage 1: Audio to MIDI Transcription Functions ===
 
425
  # Use os.path.join to create a platform-independent directory path
426
  output_dir = os.path.join("output", "transcribed_piano_")
427
  out_mid_path = os.path.join(output_dir, fn1 + '.mid')
428
+
429
  # Check for the directory's existence and create it if necessary
430
  if not os.path.exists(output_dir):
431
  os.makedirs(output_dir)
 
583
  escore = TMIDIX.merge_escore_notes(escore, merge_threshold=merge_misaligned_notes)
584
 
585
  escore = TMIDIX.augment_enhanced_score_notes(escore, timings_divider=1)
586
+
587
  first_note_index = [e[0] for e in raw_score[1]].index('note')
588
  cscore = TMIDIX.chordify_score([1000, escore])
589
 
 
591
 
592
  aux_escore_notes = TMIDIX.augment_enhanced_score_notes(escore, sort_drums_last=True)
593
  song_description = TMIDIX.escore_notes_to_text_description(aux_escore_notes)
594
+
595
  print('Done!')
596
  print('=' * 70)
597
  print('Input MIDI metadata:', meta_data[:5])
 
643
 
644
  if render_transpose_to_C4:
645
  output_score = TMIDIX.transpose_escore_notes_to_pitch(output_score, 60) # C4 is MIDI pitch 60
646
+
647
  if render_align == "Start Times":
648
  output_score = TMIDIX.recalculate_score_timings(output_score)
649
  output_score = TMIDIX.align_escore_notes_to_bars(output_score)
 
744
  s8bit_bass_boost_level,
745
  fs=srate
746
  )
747
+ # Normalize and prepare for Gradio
748
  peak_val = np.max(np.abs(audio))
749
  if peak_val > 0:
750
  audio /= peak_val
751
+ # Transpose from (2, N) to (N, 2) and convert to int16 for Gradio
752
+ audio_out = (audio.T * 32767).astype(np.int16)
753
  except Exception as e:
754
  print(f"Error during 8-bit synthesis: {e}")
755
  return [None] * 7
 
775
  with open(midi_to_render_path, 'rb') as f:
776
  midi_file_content = f.read()
777
 
778
+ audio_out = midi_to_colab_audio(midi_file_content,
779
  soundfont_path=soundfont_path, # Use the dynamically found path
780
  sample_rate=srate,
781
  output_for_gradio=True
 
791
 
792
  output_midi_summary = str(meta_data)
793
 
794
+ return new_md5_hash, fn1, output_midi_summary, midi_to_render_path, (srate, audio_out), output_plot, song_description
795
 
796
  # =================================================================================================
797
  # === Main Application Logic ===
 
799
 
800
  def process_and_render_file(input_file,
801
  # --- Transcription params ---
802
+ enable_stereo_processing,
803
  transcription_method,
804
  onset_thresh, frame_thresh, min_note_len, min_freq, max_freq, infer_onsets_bool, melodia_trick_bool, multiple_bends_bool,
805
  # --- MIDI rendering params ---
 
818
  start_time = reqtime.time()
819
  if input_file is None:
820
  # Return a list of updates to clear all output fields
821
+ return [gr.update(value=None)] * 7
 
822
 
823
  # The input_file from gr.Audio(type="filepath") is now the direct path (a string),
824
  # not a temporary file object. We no longer need to access the .name attribute.
825
  input_file_path = input_file
826
  filename = os.path.basename(input_file_path)
827
  print(f"Processing new file: {filename}")
828
+
829
+ try:
830
+ audio_data, native_sample_rate = librosa.load(input_file_path, sr=None, mono=False)
831
+ except Exception as e:
832
+ raise gr.Error(f"Failed to load audio file: {e}")
833
 
834
  # --- Step 1: Check file type and transcribe if necessary ---
835
  if filename.lower().endswith(('.mid', '.midi', '.kar')):
 
837
  midi_path_for_rendering = input_file_path
838
  else: #if filename.lower().endswith(('.wav', '.mp3'))
839
  print("Audio file detected. Starting transcription...")
840
+
841
+ base_name = os.path.splitext(filename)[0]
842
+ temp_dir = "output/temp_normalized"
843
+ os.makedirs(temp_dir, exist_ok=True)
844
+
845
+ # === STEREO PROCESSING LOGIC ===
846
+ if enable_stereo_processing:
847
+ if audio_data.ndim != 2 or audio_data.shape[0] != 2:
848
+ print("Warning: Audio is not stereo or could not be loaded as stereo. Falling back to mono transcription.")
849
+ enable_stereo_processing = False # Disable stereo processing if audio is not stereo
850
+
851
+ if enable_stereo_processing:
852
+ print("Stereo processing enabled. Splitting channels...")
853
+ try:
854
+ left_channel = audio_data[0]
855
+ right_channel = audio_data[1]
856
+
857
+ normalized_left = normalize_loudness(left_channel, native_sample_rate)
858
+ normalized_right = normalize_loudness(right_channel, native_sample_rate)
859
+
860
+ temp_left_wav_path = os.path.join(temp_dir, f"{base_name}_left.wav")
861
+ temp_right_wav_path = os.path.join(temp_dir, f"{base_name}_right.wav")
862
+
863
+ sf.write(temp_left_wav_path, normalized_left, native_sample_rate)
864
+ sf.write(temp_right_wav_path, normalized_right, native_sample_rate)
865
+
866
+ print(f"Saved left channel to: {temp_left_wav_path}")
867
+ print(f"Saved right channel to: {temp_right_wav_path}")
868
+
869
+ print("Transcribing left channel...")
870
+ if transcription_method == "General Purpose":
871
+ midi_path_left = TranscribeGeneralAudio(temp_left_wav_path, onset_thresh, frame_thresh, min_note_len, min_freq, max_freq, infer_onsets_bool, melodia_trick_bool, multiple_bends_bool)
872
+ else:
873
+ midi_path_left = TranscribePianoAudio(temp_left_wav_path)
874
+
875
+ print("Transcribing right channel...")
876
+ if transcription_method == "General Purpose":
877
+ midi_path_right = TranscribeGeneralAudio(temp_right_wav_path, onset_thresh, frame_thresh, min_note_len, min_freq, max_freq, infer_onsets_bool, melodia_trick_bool, multiple_bends_bool)
878
+ else:
879
+ midi_path_right = TranscribePianoAudio(temp_right_wav_path)
880
+
881
+ if midi_path_left and midi_path_right:
882
+ merged_midi_path = os.path.join(temp_dir, f"{base_name}_merged.mid")
883
+ midi_path_for_rendering = merge_midis(midi_path_left, midi_path_right, merged_midi_path)
884
+ elif midi_path_left:
885
+ print("Warning: Right channel transcription failed. Using left channel only.")
886
+ midi_path_for_rendering = midi_path_left
887
+ elif midi_path_right:
888
+ print("Warning: Left channel transcription failed. Using right channel only.")
889
+ midi_path_for_rendering = midi_path_right
890
+ else:
891
+ raise gr.Error("Both left and right channel transcriptions failed.")
892
+
893
+ except Exception as e:
894
+ print(f"An error occurred during stereo processing: {e}")
895
+ raise gr.Error(f"Stereo Processing Failed: {e}")
896
+ else:
897
+ print("Stereo processing disabled. Using standard mono transcription.")
898
+ if audio_data.ndim == 1:
899
+ mono_signal = audio_data
900
+ else:
901
+ mono_signal = np.mean(audio_data, axis=0)
902
+
903
+ normalized_mono = normalize_loudness(mono_signal, native_sample_rate)
904
+
905
+ temp_mono_wav_path = os.path.join(temp_dir, f"{base_name}_mono.wav")
906
+ sf.write(temp_mono_wav_path, normalized_mono, native_sample_rate)
907
+
908
+ try:
909
+ if transcription_method == "General Purpose":
910
+ midi_path_for_rendering = TranscribeGeneralAudio(
911
+ temp_mono_wav_path, onset_thresh, frame_thresh, min_note_len,
912
+ min_freq, max_freq, infer_onsets_bool, melodia_trick_bool, multiple_bends_bool
913
+ )
914
+ else: # Piano-Specific
915
+ midi_path_for_rendering = TranscribePianoAudio(temp_mono_wav_path)
916
+ analyze_midi_velocity(midi_path_for_rendering)
917
+ except Exception as e:
918
+ print(f"An error occurred during transcription: {e}")
919
+ raise gr.Error(f"Transcription Failed: {e}")
920
 
921
  # --- Step 2: Render the MIDI file with selected options ---
922
  print(f"Proceeding to render MIDI file: {os.path.basename(midi_path_for_rendering)}")
 
942
  """
943
  is_general = (transcription_method == "General Purpose")
944
  is_8bit = (soundfont_choice == SYNTH_8_BIT_LABEL)
945
+
946
  return {
947
  general_transcription_settings: gr.update(visible=is_general),
948
  synth_8bit_settings: gr.update(visible=is_8bit),
 
997
  value="General Purpose",
998
  info="Choose 'General Purpose' for most music (vocals, etc.). Choose 'Piano-Specific' only for solo piano recordings."
999
  )
1000
+
1001
+ # --- Stereo Processing Checkbox ---
1002
+ enable_stereo_processing = gr.Checkbox(
1003
+ label="Enable Stereo Transcription",
1004
+ value=False,
1005
+ info="If checked, left/right audio channels are transcribed separately and merged. Doubles processing time."
1006
+ )
1007
 
 
1008
  with gr.Accordion("General Purpose Transcription Settings", open=True) as general_transcription_settings:
1009
  onset_threshold = gr.Slider(0.0, 1.0, value=0.5, step=0.05, label="On-set Threshold", info="Sensitivity for detecting note beginnings. Higher is stricter.")
1010
  frame_threshold = gr.Slider(0.0, 1.0, value=0.3, step=0.05, label="Frame Threshold", info="Sensitivity for detecting active notes. Higher is stricter.")
 
1027
  # --- Dynamically create the list of choices ---
1028
  soundfont_choices = [SYNTH_8_BIT_LABEL] + list(soundfonts_dict.keys())
1029
  # Set a safe default value
1030
+ default_sf_choice = "SGM-v2.01-YamahaGrand-Guit-Bass-v2.7" if "SGM-v2.01-YamahaGrand-Guit-Bass-v2.7" in soundfonts_dict else (soundfont_choices[0] if soundfont_choices else "")
1031
 
1032
  soundfont_bank = gr.Dropdown(
1033
  soundfont_choices,
 
1083
  # --- Define all input components for the click event ---
1084
  all_inputs = [
1085
  input_file,
1086
+ enable_stereo_processing,
1087
  transcription_method,
1088
  onset_threshold, frame_threshold, minimum_note_length, minimum_frequency, maximum_frequency,
1089
  infer_onsets, melodia_trick, multiple_pitch_bends,
requirements.txt CHANGED
@@ -16,6 +16,8 @@ networkx
16
  scikit-learn
17
  psutil
18
  pretty_midi
 
 
19
  piano_transcription_inference
20
 
21
  basic-pitch @ git+https://github.com/avan06/basic-pitch; sys_platform != 'linux'
 
16
  scikit-learn
17
  psutil
18
  pretty_midi
19
+ soundfile
20
+ pyloudnorm
21
  piano_transcription_inference
22
 
23
  basic-pitch @ git+https://github.com/avan06/basic-pitch; sys_platform != 'linux'