avans06's picture
Integrated and Enhanced: Audio-to-MIDI and Advanced MIDI Renderer
adcbc9f
# src/piano_transcription/utils.py
import os
from pathlib import Path
from typing import Final
# Import for the new downloader
from huggingface_hub import hf_hub_download
# Imports for the patch
import numpy as np
import librosa
import audioread
from piano_transcription_inference import utilities
# --- Constants ---
# By convention, uppercase variables are treated as constants and should not be modified.
# Using typing.Final to indicate to static type checkers that these should not be reassigned.
MODEL_NAME: Final[str] = "CRNN_note_F1=0.9677_pedal_F1=0.9186.pth"
REPO_ID: Final[str] = "Genius-Society/piano_trans"
# --- Model Download Function ---
def download_model_from_hf_if_needed():
"""
Checks for the model and downloads it from the Hugging Face Hub if not present.
The hf_hub_download function handles caching and existence checks automatically.
"""
# Assuming this utils.py is in 'src/piano_transcription/', models are in 'src/models/'
utils_dir = Path(__file__).parent
base_dir = utils_dir.parent # This should be the 'src' directory
model_dir = base_dir / "models"
model_path = model_dir / MODEL_NAME
print(f"Checking for model '{MODEL_NAME}' from Hugging Face Hub repo '{REPO_ID}'...")
try:
# hf_hub_download will download the file to a cache and return the path.
# To place it directly in our desired folder, we use `local_dir`.
# `local_dir_use_symlinks=False` ensures the actual file is copied to model_dir.
hf_hub_download(
repo_id=REPO_ID,
filename=MODEL_NAME,
local_dir=model_dir,
# local_dir_use_symlinks=False, # Recommended for moving projects around
# resume_download=True,
)
print(f"Model is available at '{model_path}'")
except AttributeError as e:
print(f"Error downloading from Hugging Face Hub. Please check your network connection and the repo/filename.")
print(f"Details: {e}")
# You might want to exit or raise the exception if the model is critical
# raise e
except Exception as e:
print(f"An unexpected error occurred: {e}")
# raise e
# --- Monkey Patching Function ---
def _fixed_load_audio(path, sr=22050, mono=True, offset=0.0, duration=None,
dtype=np.float32, res_type='kaiser_best',
backends=[audioread.ffdec.FFmpegAudioFile]):
"""
A patched version of load_audio that uses updated function paths
for newer librosa versions. This function is intended to replace the
original one in the `piano_transcription_inference` library.
"""
# (The code for this function remains unchanged)
y = []
with audioread.audio_open(os.path.realpath(path), backends=backends) as input_file:
sr_native = input_file.samplerate
n_channels = input_file.channels
s_start = int(np.round(sr_native * offset)) * n_channels
if duration is None:
s_end = np.inf
else:
s_end = s_start + (int(np.round(sr_native * duration)) * n_channels)
n = 0
for frame in input_file:
frame = librosa.util.buf_to_float(frame, dtype=dtype)
n_prev = n
n = n + len(frame)
if n < s_start:
continue
if s_end < n_prev:
break
if s_end < n:
frame = frame[:s_end - n_prev]
if n_prev <= s_start <= n:
frame = frame[(s_start - n_prev):]
y.append(frame)
if y:
y = np.concatenate(y)
if n_channels > 1:
y = y.reshape((-1, n_channels)).T
if mono:
y = librosa.to_mono(y)
if sr is not None:
y = librosa.resample(y, orig_sr=sr_native, target_sr=sr, res_type=res_type)
else:
sr = sr_native
y = np.ascontiguousarray(y, dtype=dtype)
return (y, sr)
def apply_monkey_patch():
"""
Applies the patch to the `piano_transcription_inference` library by
replacing its `load_audio` function with our fixed version.
"""
print("Applying librosa compatibility patch...")
utilities.load_audio = _fixed_load_audio
# --- Main Initializer ---
def initialize_app():
"""
Main initialization function. Call this at the start of your app.
It downloads the model from Hugging Face and applies the necessary patches.
"""
print("--- Initializing Application ---")
download_model_from_hf_if_needed()
apply_monkey_patch()
print("--- Initialization Complete ---")