jhansss commited on
Commit
2ce9d86
·
1 Parent(s): 11e246d

Refactor CLI to support multiple query audio inputs

Browse files
Files changed (1) hide show
  1. cli.py +33 -12
cli.py CHANGED
@@ -12,11 +12,12 @@ logger = getLogger(__name__)
12
 
13
  def get_parser():
14
  parser = ArgumentParser()
15
- parser.add_argument("--query_audio", type=Path, required=True)
16
  parser.add_argument(
17
  "--config_path", type=Path, default="config/cli/yaoyin_default.yaml"
18
  )
19
- parser.add_argument("--output_audio", type=Path, required=True)
 
20
  return parser
21
 
22
 
@@ -36,16 +37,36 @@ def main():
36
  character_name = config["prompt_template_character"]
37
  character = get_character(character_name)
38
  prompt_template = character.prompt
39
- results = pipeline.run(
40
- args.query_audio,
41
- language,
42
- prompt_template,
43
- speaker,
44
- output_audio_path=args.output_audio,
45
- )
46
- logger.info(
47
- f"Input: {args.query_audio}, Output: {args.output_audio}, ASR results: {results['asr_text']}, LLM results: {results['llm_text']}"
48
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
49
 
50
 
51
  if __name__ == "__main__":
 
12
 
13
  def get_parser():
14
  parser = ArgumentParser()
15
+ parser.add_argument("--query_audios", nargs="+", type=Path, required=True)
16
  parser.add_argument(
17
  "--config_path", type=Path, default="config/cli/yaoyin_default.yaml"
18
  )
19
+ parser.add_argument("--output_audio_folder", type=Path, required=True)
20
+ parser.add_argument("--eval_results_csv", type=Path, required=True)
21
  return parser
22
 
23
 
 
37
  character_name = config["prompt_template_character"]
38
  character = get_character(character_name)
39
  prompt_template = character.prompt
40
+ args.output_audio_folder.mkdir(parents=True, exist_ok=True)
41
+ args.eval_results_csv.parent.mkdir(parents=True, exist_ok=True)
42
+ with open(args.eval_results_csv, "a") as f:
43
+ f.write(
44
+ f"query_audio,asr_model,llm_model,svs_model,melody_source,language,speaker,output_audio,asr_text,llm_text,metrics\n"
45
+ )
46
+ try:
47
+ for query_audio in args.query_audios:
48
+ output_audio = args.output_audio_folder / f"{query_audio.stem}_response.wav"
49
+ results = pipeline.run(
50
+ query_audio,
51
+ language,
52
+ prompt_template,
53
+ speaker,
54
+ output_audio_path=output_audio,
55
+ )
56
+ metrics = pipeline.evaluate(output_audio, **results)
57
+ metrics.update(results.get("metrics", {}))
58
+ metrics_str = ",".join([f"{metrics[k]}" for k in sorted(metrics.keys())])
59
+ logger.info(
60
+ f"Input: {query_audio}, Output: {output_audio}, ASR results: {results['asr_text']}, LLM results: {results['llm_text']}"
61
+ )
62
+ with open(args.eval_results_csv, "a") as f:
63
+ f.write(
64
+ f"{query_audio},{config['asr_model']},{config['llm_model']},{config['svs_model']},{config['melody_source']},{config['language']},{config['speaker']},{output_audio},{results['asr_text']},{results['llm_text']},{metrics_str}\n"
65
+ )
66
+ except Exception as e:
67
+ logger.error(f"Error in main: {e}")
68
+ breakpoint()
69
+ raise e
70
 
71
 
72
  if __name__ == "__main__":