Spaces:
Sleeping
Sleeping
support output file path in pipeline.py
Browse files- cli.py +7 -5
- interface.py +11 -8
- pipeline.py +16 -4
cli.py
CHANGED
@@ -2,7 +2,6 @@ from argparse import ArgumentParser
|
|
2 |
from logging import getLogger
|
3 |
from pathlib import Path
|
4 |
|
5 |
-
import soundfile as sf
|
6 |
import yaml
|
7 |
|
8 |
from characters import CHARACTERS
|
@@ -37,13 +36,16 @@ def main():
|
|
37 |
character_name = config["prompt_template_character"]
|
38 |
character = CHARACTERS[character_name]
|
39 |
prompt_template = character.prompt
|
40 |
-
results = pipeline.run(
|
|
|
|
|
|
|
|
|
|
|
|
|
41 |
logger.info(
|
42 |
f"Input: {args.query_audio}, Output: {args.output_audio}, ASR results: {results['asr_text']}, LLM results: {results['llm_text']}"
|
43 |
)
|
44 |
-
svs_audio, svs_sample_rate = results["svs_audio"]
|
45 |
-
args.output_audio.parent.mkdir(parents=True, exist_ok=True)
|
46 |
-
sf.write(args.output_audio, svs_audio, svs_sample_rate)
|
47 |
|
48 |
|
49 |
if __name__ == "__main__":
|
|
|
2 |
from logging import getLogger
|
3 |
from pathlib import Path
|
4 |
|
|
|
5 |
import yaml
|
6 |
|
7 |
from characters import CHARACTERS
|
|
|
36 |
character_name = config["prompt_template_character"]
|
37 |
character = CHARACTERS[character_name]
|
38 |
prompt_template = character.prompt
|
39 |
+
results = pipeline.run(
|
40 |
+
args.query_audio,
|
41 |
+
language,
|
42 |
+
prompt_template,
|
43 |
+
speaker,
|
44 |
+
output_audio_path=args.output_audio,
|
45 |
+
)
|
46 |
logger.info(
|
47 |
f"Input: {args.query_audio}, Output: {args.output_audio}, ASR results: {results['asr_text']}, LLM results: {results['llm_text']}"
|
48 |
)
|
|
|
|
|
|
|
49 |
|
50 |
|
51 |
if __name__ == "__main__":
|
interface.py
CHANGED
@@ -1,3 +1,6 @@
|
|
|
|
|
|
|
|
1 |
import gradio as gr
|
2 |
import yaml
|
3 |
|
@@ -201,29 +204,29 @@ class GradioInterface:
|
|
201 |
return gr.update(value=self.current_melody_source)
|
202 |
|
203 |
def update_voice(self, voice):
|
204 |
-
self.current_voice = self.svs_model_map[self.current_svs_model]["voices"][
|
205 |
-
voice
|
206 |
-
]
|
207 |
return gr.update(value=voice)
|
208 |
|
209 |
def run_pipeline(self, audio_path):
|
210 |
if not audio_path:
|
211 |
return gr.update(value=""), gr.update(value="")
|
|
|
212 |
results = self.pipeline.run(
|
213 |
audio_path,
|
214 |
self.svs_model_map[self.current_svs_model]["lang"],
|
215 |
self.character_info[self.current_character].prompt,
|
216 |
self.current_voice,
|
217 |
-
|
|
|
218 |
)
|
219 |
formatted_logs = f"ASR: {results['asr_text']}\nLLM: {results['llm_text']}"
|
220 |
-
return gr.update(value=formatted_logs), gr.update(
|
|
|
|
|
221 |
|
222 |
def update_metrics(self, audio_path):
|
223 |
if not audio_path:
|
224 |
return gr.update(value="")
|
225 |
results = self.pipeline.evaluate(audio_path)
|
226 |
-
formatted_metrics = "\n".join(
|
227 |
-
[f"{k}: {v}" for k, v in results.items()]
|
228 |
-
)
|
229 |
return gr.update(value=formatted_metrics)
|
|
|
1 |
+
import time
|
2 |
+
import uuid
|
3 |
+
|
4 |
import gradio as gr
|
5 |
import yaml
|
6 |
|
|
|
204 |
return gr.update(value=self.current_melody_source)
|
205 |
|
206 |
def update_voice(self, voice):
|
207 |
+
self.current_voice = self.svs_model_map[self.current_svs_model]["voices"][voice]
|
|
|
|
|
208 |
return gr.update(value=voice)
|
209 |
|
210 |
def run_pipeline(self, audio_path):
|
211 |
if not audio_path:
|
212 |
return gr.update(value=""), gr.update(value="")
|
213 |
+
tmp_file = f"audio_{int(time.time())}_{uuid.uuid4().hex[:8]}.wav"
|
214 |
results = self.pipeline.run(
|
215 |
audio_path,
|
216 |
self.svs_model_map[self.current_svs_model]["lang"],
|
217 |
self.character_info[self.current_character].prompt,
|
218 |
self.current_voice,
|
219 |
+
output_audio_path=tmp_file,
|
220 |
+
max_new_tokens=50,
|
221 |
)
|
222 |
formatted_logs = f"ASR: {results['asr_text']}\nLLM: {results['llm_text']}"
|
223 |
+
return gr.update(value=formatted_logs), gr.update(
|
224 |
+
value=results["output_audio_path"]
|
225 |
+
)
|
226 |
|
227 |
def update_metrics(self, audio_path):
|
228 |
if not audio_path:
|
229 |
return gr.update(value="")
|
230 |
results = self.pipeline.evaluate(audio_path)
|
231 |
+
formatted_metrics = "\n".join([f"{k}: {v}" for k, v in results.items()])
|
|
|
|
|
232 |
return gr.update(value=formatted_metrics)
|
pipeline.py
CHANGED
@@ -1,6 +1,11 @@
|
|
1 |
-
import
|
|
|
2 |
import time
|
|
|
|
|
3 |
import librosa
|
|
|
|
|
4 |
|
5 |
from modules.asr import get_asr_model
|
6 |
from modules.llm import get_llm_model
|
@@ -57,7 +62,8 @@ class SingingDialoguePipeline:
|
|
57 |
language,
|
58 |
prompt_template,
|
59 |
speaker,
|
60 |
-
|
|
|
61 |
):
|
62 |
if self.track_latency:
|
63 |
asr_start_time = time.time()
|
@@ -76,7 +82,9 @@ class SingingDialoguePipeline:
|
|
76 |
if self.track_latency:
|
77 |
llm_end_time = time.time()
|
78 |
llm_latency = llm_end_time - llm_start_time
|
79 |
-
llm_response = clean_llm_output(
|
|
|
|
|
80 |
score = self.melody_controller.generate_score(llm_response, language)
|
81 |
if self.track_latency:
|
82 |
svs_start_time = time.time()
|
@@ -89,8 +97,12 @@ class SingingDialoguePipeline:
|
|
89 |
results = {
|
90 |
"asr_text": asr_result,
|
91 |
"llm_text": llm_response,
|
92 |
-
"svs_audio": (
|
93 |
}
|
|
|
|
|
|
|
|
|
94 |
if self.track_latency:
|
95 |
results["metrics"] = {
|
96 |
"asr_latency": asr_latency,
|
|
|
1 |
+
from __future__ import annotations
|
2 |
+
|
3 |
import time
|
4 |
+
from pathlib import Path
|
5 |
+
|
6 |
import librosa
|
7 |
+
import soundfile as sf
|
8 |
+
import torch
|
9 |
|
10 |
from modules.asr import get_asr_model
|
11 |
from modules.llm import get_llm_model
|
|
|
62 |
language,
|
63 |
prompt_template,
|
64 |
speaker,
|
65 |
+
output_audio_path: Path | str = None,
|
66 |
+
max_new_tokens=50,
|
67 |
):
|
68 |
if self.track_latency:
|
69 |
asr_start_time = time.time()
|
|
|
82 |
if self.track_latency:
|
83 |
llm_end_time = time.time()
|
84 |
llm_latency = llm_end_time - llm_start_time
|
85 |
+
llm_response = clean_llm_output(
|
86 |
+
output, language=language, max_sentences=self.max_sentences
|
87 |
+
)
|
88 |
score = self.melody_controller.generate_score(llm_response, language)
|
89 |
if self.track_latency:
|
90 |
svs_start_time = time.time()
|
|
|
97 |
results = {
|
98 |
"asr_text": asr_result,
|
99 |
"llm_text": llm_response,
|
100 |
+
"svs_audio": (sample_rate, singing_audio),
|
101 |
}
|
102 |
+
if output_audio_path:
|
103 |
+
Path(output_audio_path).parent.mkdir(parents=True, exist_ok=True)
|
104 |
+
sf.write(output_audio_path, singing_audio, sample_rate)
|
105 |
+
results["output_audio_path"] = output_audio_path
|
106 |
if self.track_latency:
|
107 |
results["metrics"] = {
|
108 |
"asr_latency": asr_latency,
|