File size: 4,583 Bytes
adcbc9f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
# src/piano_transcription/utils.py

import os
from pathlib import Path
from typing import Final

# Import for the new downloader
from huggingface_hub import hf_hub_download

# Imports for the patch
import numpy as np
import librosa
import audioread
from piano_transcription_inference import utilities

# --- Constants ---
# By convention, uppercase variables are treated as constants and should not be modified.
# Using typing.Final to indicate to static type checkers that these should not be reassigned.
MODEL_NAME: Final[str] = "CRNN_note_F1=0.9677_pedal_F1=0.9186.pth"
REPO_ID: Final[str] = "Genius-Society/piano_trans"


# --- Model Download Function ---

def download_model_from_hf_if_needed():
    """
    Checks for the model and downloads it from the Hugging Face Hub if not present.
    The hf_hub_download function handles caching and existence checks automatically.
    """
    # Assuming this utils.py is in 'src/piano_transcription/', models are in 'src/models/'
    utils_dir = Path(__file__).parent
    base_dir = utils_dir.parent  # This should be the 'src' directory
    model_dir = base_dir / "models"
    model_path = model_dir / MODEL_NAME

    print(f"Checking for model '{MODEL_NAME}' from Hugging Face Hub repo '{REPO_ID}'...")

    try:
        # hf_hub_download will download the file to a cache and return the path.
        # To place it directly in our desired folder, we use `local_dir`.
        # `local_dir_use_symlinks=False` ensures the actual file is copied to model_dir.
        hf_hub_download(
            repo_id=REPO_ID,
            filename=MODEL_NAME,
            local_dir=model_dir,
            # local_dir_use_symlinks=False, # Recommended for moving projects around
            # resume_download=True,
        )
        print(f"Model is available at '{model_path}'")

    except AttributeError as e:
        print(f"Error downloading from Hugging Face Hub. Please check your network connection and the repo/filename.")
        print(f"Details: {e}")
        # You might want to exit or raise the exception if the model is critical
        # raise e 
    except Exception as e:
        print(f"An unexpected error occurred: {e}")
        # raise e


# --- Monkey Patching Function ---

def _fixed_load_audio(path, sr=22050, mono=True, offset=0.0, duration=None,
    dtype=np.float32, res_type='kaiser_best',
    backends=[audioread.ffdec.FFmpegAudioFile]):
    """
    A patched version of load_audio that uses updated function paths
    for newer librosa versions. This function is intended to replace the
    original one in the `piano_transcription_inference` library.
    """
    # (The code for this function remains unchanged)
    y = []
    with audioread.audio_open(os.path.realpath(path), backends=backends) as input_file:
        sr_native = input_file.samplerate
        n_channels = input_file.channels
        s_start = int(np.round(sr_native * offset)) * n_channels
        if duration is None:
            s_end = np.inf
        else:
            s_end = s_start + (int(np.round(sr_native * duration)) * n_channels)
        n = 0
        for frame in input_file:
            frame = librosa.util.buf_to_float(frame, dtype=dtype)
            n_prev = n
            n = n + len(frame)
            if n < s_start:
                continue
            if s_end < n_prev:
                break
            if s_end < n:
                frame = frame[:s_end - n_prev]
            if n_prev <= s_start <= n:
                frame = frame[(s_start - n_prev):]
            y.append(frame)
    if y:
        y = np.concatenate(y)
        if n_channels > 1:
            y = y.reshape((-1, n_channels)).T
            if mono:
                y = librosa.to_mono(y)
        if sr is not None:
            y = librosa.resample(y, orig_sr=sr_native, target_sr=sr, res_type=res_type)
        else:
            sr = sr_native
    y = np.ascontiguousarray(y, dtype=dtype)
    return (y, sr)


def apply_monkey_patch():
    """
    Applies the patch to the `piano_transcription_inference` library by
    replacing its `load_audio` function with our fixed version.
    """
    print("Applying librosa compatibility patch...")
    utilities.load_audio = _fixed_load_audio


# --- Main Initializer ---

def initialize_app():
    """
    Main initialization function. Call this at the start of your app.
    It downloads the model from Hugging Face and applies the necessary patches.
    """
    print("--- Initializing Application ---")
    download_model_from_hf_if_needed()
    apply_monkey_patch()
    print("--- Initialization Complete ---")