File size: 4,239 Bytes
b5e825c
 
91394e0
b5e825c
 
91394e0
b5e825c
 
91394e0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7a23964
91394e0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f1b8d35
b5e825c
 
91394e0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b5e825c
 
 
91394e0
 
 
 
f1b8d35
91394e0
 
 
 
 
 
 
b5e825c
91394e0
b5e825c
 
 
 
91394e0
1a42cf5
91394e0
 
 
1a42cf5
91394e0
 
05779d3
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
from __future__ import annotations

import time
from pathlib import Path

import librosa
import soundfile as sf
import torch

from modules.asr import get_asr_model
from modules.llm import get_llm_model
from modules.svs import get_svs_model
from evaluation.svs_eval import load_evaluators, run_evaluation
from modules.melody import MelodyController
from modules.utils.text_normalize import clean_llm_output


class SingingDialoguePipeline:
    def __init__(self, config: dict):
        if "device" in config:
            self.device = config["device"]
        else:
            self.device = "cuda" if torch.cuda.is_available() else "cpu"
        self.cache_dir = config["cache_dir"]
        self.asr = get_asr_model(
            config["asr_model"], device=self.device, cache_dir=self.cache_dir
        )
        self.llm = get_llm_model(
            config["llm_model"], device=self.device, cache_dir=self.cache_dir
        )
        self.svs = get_svs_model(
            config["svs_model"], device=self.device, cache_dir=self.cache_dir
        )
        self.melody_controller = MelodyController(
            config["melody_source"], self.cache_dir
        )
        self.max_sentences = config.get("max_sentences", 2)
        self.track_latency = config.get("track_latency", False)
        self.evaluators = load_evaluators(config.get("evaluators", {}).get("svs", []))

    def set_asr_model(self, asr_model: str):
        self.asr = get_asr_model(
            asr_model, device=self.device, cache_dir=self.cache_dir
        )

    def set_llm_model(self, llm_model: str):
        self.llm = get_llm_model(
            llm_model, device=self.device, cache_dir=self.cache_dir
        )

    def set_svs_model(self, svs_model: str):
        self.svs = get_svs_model(
            svs_model, device=self.device, cache_dir=self.cache_dir
        )

    def set_melody_controller(self, melody_source: str):
        self.melody_controller = MelodyController(melody_source, self.cache_dir)

    def run(
        self,
        audio_path,
        language,
        prompt_template,
        speaker,
        output_audio_path: Path | str = None,
        max_new_tokens=50,
    ):
        if self.track_latency:
            asr_start_time = time.time()
        audio_array, audio_sample_rate = librosa.load(audio_path, sr=16000)
        asr_result = self.asr.transcribe(
            audio_array, audio_sample_rate=audio_sample_rate, language=language
        )
        if self.track_latency:
            asr_end_time = time.time()
            asr_latency = asr_end_time - asr_start_time
        melody_prompt = self.melody_controller.get_melody_constraints()
        prompt = prompt_template.format(melody_prompt, asr_result)
        if self.track_latency:
            llm_start_time = time.time()
        output = self.llm.generate(prompt, max_new_tokens=max_new_tokens)
        if self.track_latency:
            llm_end_time = time.time()
            llm_latency = llm_end_time - llm_start_time
        llm_response = clean_llm_output(
            output, language=language, max_sentences=self.max_sentences
        )
        score = self.melody_controller.generate_score(llm_response, language)
        if self.track_latency:
            svs_start_time = time.time()
        singing_audio, sample_rate = self.svs.synthesize(
            score, language=language, speaker=speaker
        )
        if self.track_latency:
            svs_end_time = time.time()
            svs_latency = svs_end_time - svs_start_time
        results = {
            "asr_text": asr_result,
            "llm_text": llm_response,
            "svs_audio": (sample_rate, singing_audio),
        }
        if output_audio_path:
            Path(output_audio_path).parent.mkdir(parents=True, exist_ok=True)
            sf.write(output_audio_path, singing_audio, sample_rate)
            results["output_audio_path"] = output_audio_path
        if self.track_latency:
            results["metrics"] = {
                "asr_latency": asr_latency,
                "llm_latency": llm_latency,
                "svs_latency": svs_latency,
            }
        return results

    def evaluate(self, audio_path):
        return run_evaluation(audio_path, self.evaluators)