File size: 3,641 Bytes
1ad80f6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1ec2d7e
 
 
 
 
 
 
 
 
 
 
2548eae
 
1ad80f6
2548eae
1ad80f6
 
2548eae
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1ad80f6
 
2548eae
 
 
1ad80f6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2548eae
1ad80f6
 
 
 
 
 
2548eae
1ec2d7e
2548eae
 
 
1ad80f6
 
 
2548eae
1ad80f6
2548eae
 
 
1ec2d7e
 
 
 
1ad80f6
 
 
2548eae
1ad80f6
 
 
 
 
 
2548eae
1ad80f6
2548eae
 
 
1ad80f6
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
import librosa
import numpy as np
import torch


def singmos_warmup():
    predictor = torch.hub.load(
        "South-Twilight/SingMOS:v0.2.0", "singing_ssl_mos", trust_repo=True
    )
    return predictor


def singmos_evaluation(predictor, wav_info, fs):
    wav_mos = librosa.resample(wav_info, orig_sr=fs, target_sr=16000)
    wav_mos = torch.from_numpy(wav_mos).unsqueeze(0)
    len_mos = torch.tensor([wav_mos.shape[1]])
    score = predictor(wav_mos, len_mos)
    return score


def initialize_audiobox_predictor():
    from audiobox_aesthetics.infer import initialize_predictor
    predictor = initialize_predictor()
    return predictor


def audiobox_aesthetics_evaluation(predictor, audio_path):
    score = predictor.forward([{"path": str(audio_path)}])
    return score


def score_extract_warmpup():
    from basic_pitch.inference import predict

    return predict


def score_metric_evaluation(score_extractor, audio_path):
    model_output, midi_data, note_events = score_extractor(audio_path)
    metrics = {}
    assert (
        len(midi_data.instruments) == 1
    ), f"Detected {len(midi_data.instruments)} instruments for {audio_path}"
    midi_notes = midi_data.instruments[0].notes
    melody = [note.pitch for note in midi_notes]
    if len(melody) == 0:
        print(f"No notes detected in {audio_path}")
        return {}
    intervals = [abs(melody[i + 1] - melody[i]) for i in range(len(melody) - 1)]
    metrics["pitch_range"] = max(melody) - min(melody)
    if len(intervals) > 0:
        metrics["interval_mean"] = np.mean(intervals)
        metrics["interval_std"] = np.std(intervals)
        metrics["interval_large_jump_ratio"] = np.mean([i > 5 for i in intervals])
        metrics["dissonance_rate"] = compute_dissonance_rate(intervals)
    return metrics


def compute_dissonance_rate(intervals, dissonant_intervals={1, 2, 6, 10, 11}):
    dissonant = [i % 12 in dissonant_intervals for i in intervals]
    return np.mean(dissonant) if intervals else np.nan


if __name__ == "__main__":
    import argparse
    from pathlib import Path

    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--wav_path",
        type=Path,
        help="Path to the wav file",
    )
    parser.add_argument(
        "--results_csv",
        type=Path,
        help="csv file to save the results",
    )

    args = parser.parse_args()

    args.results_csv.parent.mkdir(parents=True, exist_ok=True)

    y, fs = librosa.load(args.wav_path, sr=None)

    # warmup
    predictor = singmos_warmup()
    score_extractor = score_extract_warmpup()
    aesthetic_predictor = initialize_audiobox_predictor()

    # evaluate the audio
    metrics = {}

    # singmos evaluation
    score = singmos_evaluation(predictor, y, fs)
    metrics["singmos"] = score
    
    # score metric evaluation
    score_results = score_metric_evaluation(score_extractor, args.wav_path)
    metrics.update(score_results)
    
    # audiobox aesthetics evaluation
    score_results = audiobox_aesthetics_evaluation(aesthetic_predictor, args.wav_path)
    metrics.update(score_results[0])

    # save results
    with open(args.results_csv, "a") as f:
        header = "file," + ",".join(metrics.keys()) + "\n"
        if f.tell() == 0:
            f.write(header)
        else:
            with open(args.results_csv, "r") as f2:
                file_header = f2.readline()
            if file_header != header:
                raise ValueError(f"Header mismatch: {file_header} vs {header}")

        line = (
            ",".join([str(args.wav_path)] + [str(v) for v in metrics.values()]) + "\n"
        )
        f.write(line)