SingingSDS / evaluation /svs_eval.py
jhansss's picture
fix aesthetic eval
0ff7232
raw
history blame
4.4 kB
import librosa
import soundfile as sf
import numpy as np
import torch
import uuid
from pathlib import Path
# ----------- Initialization -----------
def init_singmos():
print("[Init] Loading SingMOS...")
return torch.hub.load(
"South-Twilight/SingMOS:v0.2.0", "singing_ssl_mos", trust_repo=True
)
def init_basic_pitch():
print("[Init] Loading BasicPitch...")
from basic_pitch.inference import predict
return predict
def init_per():
return None # TODO: implement PER evaluation
def init_audiobox_aesthetics():
print("[Init] Loading AudioboxAesthetics...")
from audiobox_aesthetics.infer import initialize_predictor
predictor = initialize_predictor()
return predictor
# ----------- Evaluation -----------
def eval_singmos(audio_path, predictor):
audio_array, sr = librosa.load(audio_path, sr=44100)
wav = librosa.resample(audio_array, orig_sr=sr, target_sr=16000)
wav_tensor = torch.from_numpy(wav).unsqueeze(0)
length_tensor = torch.tensor([wav_tensor.shape[1]])
score = predictor(wav_tensor, length_tensor)
return {"singmos": float(score)}
def eval_melody_metrics(audio_path, pitch_extractor):
model_output, midi_data, note_events = pitch_extractor(audio_path)
metrics = {}
assert (
len(midi_data.instruments) == 1
), f"Detected {len(midi_data.instruments)} instruments for {audio_path}"
midi_notes = midi_data.instruments[0].notes
melody = [note.pitch for note in midi_notes]
if len(melody) == 0:
print(f"No notes detected in {audio_path}")
return {}
intervals = [abs(melody[i + 1] - melody[i]) for i in range(len(melody) - 1)]
metrics["pitch_range"] = max(melody) - min(melody)
if len(intervals) > 0:
metrics["interval_mean"] = np.mean(intervals)
metrics["interval_std"] = np.std(intervals)
metrics["interval_large_jump_ratio"] = np.mean([i > 5 for i in intervals])
metrics["dissonance_rate"] = compute_dissonance_rate(intervals)
return metrics
def compute_dissonance_rate(intervals, dissonant_intervals={1, 2, 6, 10, 11}):
dissonant = [i % 12 in dissonant_intervals for i in intervals]
return np.mean(dissonant) if intervals else np.nan
def eval_per(audio_path, model=None):
audio_array, sr = librosa.load(audio_path, sr=16000)
# TODO: implement PER evaluation
return {}
def eval_aesthetic(audio_path, predictor):
score = predictor.forward([{"path": str(audio_path)}])
return score
# ----------- Main Function -----------
def load_evaluators(config):
loaded = {}
if "singmos" in config:
loaded["singmos"] = init_singmos()
if "melody" in config:
loaded["melody"] = init_basic_pitch()
if "per" in config:
loaded["per"] = init_per()
if "aesthetic" in config:
loaded["aesthetic"] = init_audiobox_aesthetics()
return loaded
def run_evaluation(audio_path, evaluators):
results = {}
if "singmos" in evaluators:
results.update(eval_singmos(audio_path, evaluators["singmos"]))
if "per" in evaluators:
results.update(eval_per(audio_path, evaluators["per"]))
if "melody" in evaluators:
results.update(eval_melody_metrics(audio_path, evaluators["melody"]))
if "aesthetic" in evaluators:
results.update(eval_aesthetic(audio_path, evaluators["aesthetic"])[0])
return results
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser()
parser.add_argument("--wav_path", type=str, required=True)
parser.add_argument("--results_csv", type=str, required=True)
parser.add_argument("--evaluators", type=str, default="singmos,melody,aesthetic")
args = parser.parse_args()
evaluators = load_evaluators(args.evaluators.split(","))
results = run_evaluation(args.wav_path, evaluators)
print(results)
with open(args.results_csv, "a") as f:
header = "file," + ",".join(results.keys()) + "\n"
if f.tell() == 0:
f.write(header)
else:
with open(args.results_csv, "r") as f2:
file_header = f2.readline()
if file_header != header:
raise ValueError(f"Header mismatch: {file_header} vs {header}")
line = (
",".join([str(args.wav_path)] + [str(v) for v in results.values()]) + "\n"
)
f.write(line)