Spaces:
Sleeping
Sleeping
Refactor CLI to support multiple query audio inputs
Browse files
cli.py
CHANGED
@@ -12,11 +12,12 @@ logger = getLogger(__name__)
|
|
12 |
|
13 |
def get_parser():
|
14 |
parser = ArgumentParser()
|
15 |
-
parser.add_argument("--
|
16 |
parser.add_argument(
|
17 |
"--config_path", type=Path, default="config/cli/yaoyin_default.yaml"
|
18 |
)
|
19 |
-
parser.add_argument("--
|
|
|
20 |
return parser
|
21 |
|
22 |
|
@@ -36,16 +37,36 @@ def main():
|
|
36 |
character_name = config["prompt_template_character"]
|
37 |
character = get_character(character_name)
|
38 |
prompt_template = character.prompt
|
39 |
-
|
40 |
-
|
41 |
-
|
42 |
-
|
43 |
-
|
44 |
-
|
45 |
-
|
46 |
-
|
47 |
-
|
48 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
49 |
|
50 |
|
51 |
if __name__ == "__main__":
|
|
|
12 |
|
13 |
def get_parser():
|
14 |
parser = ArgumentParser()
|
15 |
+
parser.add_argument("--query_audios", nargs="+", type=Path, required=True)
|
16 |
parser.add_argument(
|
17 |
"--config_path", type=Path, default="config/cli/yaoyin_default.yaml"
|
18 |
)
|
19 |
+
parser.add_argument("--output_audio_folder", type=Path, required=True)
|
20 |
+
parser.add_argument("--eval_results_csv", type=Path, required=True)
|
21 |
return parser
|
22 |
|
23 |
|
|
|
37 |
character_name = config["prompt_template_character"]
|
38 |
character = get_character(character_name)
|
39 |
prompt_template = character.prompt
|
40 |
+
args.output_audio_folder.mkdir(parents=True, exist_ok=True)
|
41 |
+
args.eval_results_csv.parent.mkdir(parents=True, exist_ok=True)
|
42 |
+
with open(args.eval_results_csv, "a") as f:
|
43 |
+
f.write(
|
44 |
+
f"query_audio,asr_model,llm_model,svs_model,melody_source,language,speaker,output_audio,asr_text,llm_text,metrics\n"
|
45 |
+
)
|
46 |
+
try:
|
47 |
+
for query_audio in args.query_audios:
|
48 |
+
output_audio = args.output_audio_folder / f"{query_audio.stem}_response.wav"
|
49 |
+
results = pipeline.run(
|
50 |
+
query_audio,
|
51 |
+
language,
|
52 |
+
prompt_template,
|
53 |
+
speaker,
|
54 |
+
output_audio_path=output_audio,
|
55 |
+
)
|
56 |
+
metrics = pipeline.evaluate(output_audio, **results)
|
57 |
+
metrics.update(results.get("metrics", {}))
|
58 |
+
metrics_str = ",".join([f"{metrics[k]}" for k in sorted(metrics.keys())])
|
59 |
+
logger.info(
|
60 |
+
f"Input: {query_audio}, Output: {output_audio}, ASR results: {results['asr_text']}, LLM results: {results['llm_text']}"
|
61 |
+
)
|
62 |
+
with open(args.eval_results_csv, "a") as f:
|
63 |
+
f.write(
|
64 |
+
f"{query_audio},{config['asr_model']},{config['llm_model']},{config['svs_model']},{config['melody_source']},{config['language']},{config['speaker']},{output_audio},{results['asr_text']},{results['llm_text']},{metrics_str}\n"
|
65 |
+
)
|
66 |
+
except Exception as e:
|
67 |
+
logger.error(f"Error in main: {e}")
|
68 |
+
breakpoint()
|
69 |
+
raise e
|
70 |
|
71 |
|
72 |
if __name__ == "__main__":
|