jhansss commited on
Commit
b5e825c
·
1 Parent(s): 1a42cf5

support output file path in pipeline.py

Browse files
Files changed (3) hide show
  1. cli.py +7 -5
  2. interface.py +11 -8
  3. pipeline.py +16 -4
cli.py CHANGED
@@ -2,7 +2,6 @@ from argparse import ArgumentParser
2
  from logging import getLogger
3
  from pathlib import Path
4
 
5
- import soundfile as sf
6
  import yaml
7
 
8
  from characters import CHARACTERS
@@ -37,13 +36,16 @@ def main():
37
  character_name = config["prompt_template_character"]
38
  character = CHARACTERS[character_name]
39
  prompt_template = character.prompt
40
- results = pipeline.run(args.query_audio, language, prompt_template, speaker)
 
 
 
 
 
 
41
  logger.info(
42
  f"Input: {args.query_audio}, Output: {args.output_audio}, ASR results: {results['asr_text']}, LLM results: {results['llm_text']}"
43
  )
44
- svs_audio, svs_sample_rate = results["svs_audio"]
45
- args.output_audio.parent.mkdir(parents=True, exist_ok=True)
46
- sf.write(args.output_audio, svs_audio, svs_sample_rate)
47
 
48
 
49
  if __name__ == "__main__":
 
2
  from logging import getLogger
3
  from pathlib import Path
4
 
 
5
  import yaml
6
 
7
  from characters import CHARACTERS
 
36
  character_name = config["prompt_template_character"]
37
  character = CHARACTERS[character_name]
38
  prompt_template = character.prompt
39
+ results = pipeline.run(
40
+ args.query_audio,
41
+ language,
42
+ prompt_template,
43
+ speaker,
44
+ output_audio_path=args.output_audio,
45
+ )
46
  logger.info(
47
  f"Input: {args.query_audio}, Output: {args.output_audio}, ASR results: {results['asr_text']}, LLM results: {results['llm_text']}"
48
  )
 
 
 
49
 
50
 
51
  if __name__ == "__main__":
interface.py CHANGED
@@ -1,3 +1,6 @@
 
 
 
1
  import gradio as gr
2
  import yaml
3
 
@@ -201,29 +204,29 @@ class GradioInterface:
201
  return gr.update(value=self.current_melody_source)
202
 
203
  def update_voice(self, voice):
204
- self.current_voice = self.svs_model_map[self.current_svs_model]["voices"][
205
- voice
206
- ]
207
  return gr.update(value=voice)
208
 
209
  def run_pipeline(self, audio_path):
210
  if not audio_path:
211
  return gr.update(value=""), gr.update(value="")
 
212
  results = self.pipeline.run(
213
  audio_path,
214
  self.svs_model_map[self.current_svs_model]["lang"],
215
  self.character_info[self.current_character].prompt,
216
  self.current_voice,
217
- max_new_tokens=100,
 
218
  )
219
  formatted_logs = f"ASR: {results['asr_text']}\nLLM: {results['llm_text']}"
220
- return gr.update(value=formatted_logs), gr.update(value=results["svs_audio"])
 
 
221
 
222
  def update_metrics(self, audio_path):
223
  if not audio_path:
224
  return gr.update(value="")
225
  results = self.pipeline.evaluate(audio_path)
226
- formatted_metrics = "\n".join(
227
- [f"{k}: {v}" for k, v in results.items()]
228
- )
229
  return gr.update(value=formatted_metrics)
 
1
+ import time
2
+ import uuid
3
+
4
  import gradio as gr
5
  import yaml
6
 
 
204
  return gr.update(value=self.current_melody_source)
205
 
206
  def update_voice(self, voice):
207
+ self.current_voice = self.svs_model_map[self.current_svs_model]["voices"][voice]
 
 
208
  return gr.update(value=voice)
209
 
210
  def run_pipeline(self, audio_path):
211
  if not audio_path:
212
  return gr.update(value=""), gr.update(value="")
213
+ tmp_file = f"audio_{int(time.time())}_{uuid.uuid4().hex[:8]}.wav"
214
  results = self.pipeline.run(
215
  audio_path,
216
  self.svs_model_map[self.current_svs_model]["lang"],
217
  self.character_info[self.current_character].prompt,
218
  self.current_voice,
219
+ output_audio_path=tmp_file,
220
+ max_new_tokens=50,
221
  )
222
  formatted_logs = f"ASR: {results['asr_text']}\nLLM: {results['llm_text']}"
223
+ return gr.update(value=formatted_logs), gr.update(
224
+ value=results["output_audio_path"]
225
+ )
226
 
227
  def update_metrics(self, audio_path):
228
  if not audio_path:
229
  return gr.update(value="")
230
  results = self.pipeline.evaluate(audio_path)
231
+ formatted_metrics = "\n".join([f"{k}: {v}" for k, v in results.items()])
 
 
232
  return gr.update(value=formatted_metrics)
pipeline.py CHANGED
@@ -1,6 +1,11 @@
1
- import torch
 
2
  import time
 
 
3
  import librosa
 
 
4
 
5
  from modules.asr import get_asr_model
6
  from modules.llm import get_llm_model
@@ -57,7 +62,8 @@ class SingingDialoguePipeline:
57
  language,
58
  prompt_template,
59
  speaker,
60
- max_new_tokens=100,
 
61
  ):
62
  if self.track_latency:
63
  asr_start_time = time.time()
@@ -76,7 +82,9 @@ class SingingDialoguePipeline:
76
  if self.track_latency:
77
  llm_end_time = time.time()
78
  llm_latency = llm_end_time - llm_start_time
79
- llm_response = clean_llm_output(output, language=language, max_sentences=self.max_sentences)
 
 
80
  score = self.melody_controller.generate_score(llm_response, language)
81
  if self.track_latency:
82
  svs_start_time = time.time()
@@ -89,8 +97,12 @@ class SingingDialoguePipeline:
89
  results = {
90
  "asr_text": asr_result,
91
  "llm_text": llm_response,
92
- "svs_audio": (singing_audio, sample_rate),
93
  }
 
 
 
 
94
  if self.track_latency:
95
  results["metrics"] = {
96
  "asr_latency": asr_latency,
 
1
+ from __future__ import annotations
2
+
3
  import time
4
+ from pathlib import Path
5
+
6
  import librosa
7
+ import soundfile as sf
8
+ import torch
9
 
10
  from modules.asr import get_asr_model
11
  from modules.llm import get_llm_model
 
62
  language,
63
  prompt_template,
64
  speaker,
65
+ output_audio_path: Path | str = None,
66
+ max_new_tokens=50,
67
  ):
68
  if self.track_latency:
69
  asr_start_time = time.time()
 
82
  if self.track_latency:
83
  llm_end_time = time.time()
84
  llm_latency = llm_end_time - llm_start_time
85
+ llm_response = clean_llm_output(
86
+ output, language=language, max_sentences=self.max_sentences
87
+ )
88
  score = self.melody_controller.generate_score(llm_response, language)
89
  if self.track_latency:
90
  svs_start_time = time.time()
 
97
  results = {
98
  "asr_text": asr_result,
99
  "llm_text": llm_response,
100
+ "svs_audio": (sample_rate, singing_audio),
101
  }
102
+ if output_audio_path:
103
+ Path(output_audio_path).parent.mkdir(parents=True, exist_ok=True)
104
+ sf.write(output_audio_path, singing_audio, sample_rate)
105
+ results["output_audio_path"] = output_audio_path
106
  if self.track_latency:
107
  results["metrics"] = {
108
  "asr_latency": asr_latency,