jhansss commited on
Commit
69c405f
·
2 Parent(s): e642717 c91af6d

Merge branch 'refactor' into hf

Browse files
README.md CHANGED
@@ -9,3 +9,147 @@ app_file: app.py
9
  pinned: false
10
  python_version: 3.11
11
  ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9
  pinned: false
10
  python_version: 3.11
11
  ---
12
+ # SingingSDS: Role-Playing Singing Spoken Dialogue System
13
+
14
+ A role-playing singing dialogue system that converts speech input into character-based singing output.
15
+
16
+ ## Installation
17
+
18
+ ### Requirements
19
+
20
+ - Python 3.11+
21
+ - CUDA (optional, for GPU acceleration)
22
+
23
+ ### Install Dependencies
24
+
25
+ #### Option 1: Using Conda (Recommended)
26
+
27
+ ```bash
28
+ conda create -n singingsds python=3.11
29
+
30
+ conda activate singingsds
31
+ conda install pytorch torchvision torchaudio pytorch-cuda=11.8 -c pytorch -c nvidia
32
+ pip install -r requirements.txt
33
+ ```
34
+
35
+ #### Option 2: Using pip only
36
+
37
+ ```bash
38
+ pip install -r requirements.txt
39
+ ```
40
+
41
+ #### Option 3: Using pip with virtual environment
42
+
43
+ ```bash
44
+ python -m venv singingsds_env
45
+
46
+ # On Windows:
47
+ singingsds_env\Scripts\activate
48
+ # On macOS/Linux:
49
+ source singingsds_env/bin/activate
50
+
51
+ pip install -r requirements.txt
52
+ ```
53
+
54
+ ## Usage
55
+
56
+ ### Command Line Interface (CLI)
57
+
58
+ #### Example Usage
59
+
60
+ ```bash
61
+ python cli.py --query_audio tests/audio/hello.wav --config_path config/cli/yaoyin_default.yaml --output_audio outputs/yaoyin_hello.wav
62
+ ```
63
+
64
+ #### Parameter Description
65
+
66
+ - `--query_audio`: Input audio file path (required)
67
+ - `--config_path`: Configuration file path (default: config/cli/yaoyin_default.yaml)
68
+ - `--output_audio`: Output audio file path (required)
69
+
70
+
71
+ ### Web Interface (Gradio)
72
+
73
+ Start the web interface:
74
+
75
+ ```bash
76
+ python app.py
77
+ ```
78
+
79
+ Then visit the displayed address in your browser to use the graphical interface.
80
+
81
+ ## Configuration
82
+
83
+ ### Character Configuration
84
+
85
+ The system supports multiple preset characters:
86
+
87
+ - **Yaoyin (遥音)**: Default timbre is `timbre2`
88
+ - **Limei (丽梅)**: Default timbre is `timbre1`
89
+
90
+ ### Model Configuration
91
+
92
+ #### ASR Models
93
+ - `openai/whisper-large-v3-turbo`
94
+ - `openai/whisper-large-v3`
95
+ - `openai/whisper-medium`
96
+ - `sanchit-gandhi/whisper-small-dv`
97
+ - `facebook/wav2vec2-base-960h`
98
+
99
+ #### LLM Models
100
+ - `google/gemma-2-2b`
101
+ - `MiniMaxAI/MiniMax-M1-80k`
102
+ - `meta-llama/Llama-3.2-3B-Instruct`
103
+
104
+ #### SVS Models
105
+ - `espnet/mixdata_svs_visinger2_spkemb_lang_pretrained` (Bilingual)
106
+ - `espnet/aceopencpop_svs_visinger2_40singer_pretrain` (Chinese)
107
+
108
+ ## Project Structure
109
+
110
+ ```
111
+ SingingSDS/
112
+ ├── cli.py # Command line interface
113
+ ├── interface.py # Gradio interface
114
+ ├── pipeline.py # Core processing pipeline
115
+ ├── app.py # Web application entry
116
+ ├── requirements.txt # Python dependencies
117
+ ├── config/ # Configuration files
118
+ │ ├── cli/ # CLI-specific configuration
119
+ │ └── interface/ # Interface-specific configuration
120
+ ├── modules/ # Core modules
121
+ │ ├── asr.py # Speech recognition module
122
+ │ ├── llm.py # Large language model module
123
+ │ ├── melody.py # Melody control module
124
+ │ ├── svs/ # Singing voice synthesis modules
125
+ │ │ ├── base.py # Base SVS class
126
+ │ │ ├── espnet.py # ESPnet SVS implementation
127
+ │ │ ├── registry.py # SVS model registry
128
+ │ │ └── __init__.py # SVS module initialization
129
+ │ └── utils/ # Utility modules
130
+ │ ├── g2p.py # Grapheme-to-phoneme conversion
131
+ │ ├── text_normalize.py # Text normalization
132
+ │ └── resources/ # Utility resources
133
+ ├── characters/ # Character definitions
134
+ │ ├── base.py # Base character class
135
+ │ ├── Limei.py # Limei character definition
136
+ │ ├── Yaoyin.py # Yaoyin character definition
137
+ │ └── __init__.py # Character module initialization
138
+ ├── evaluation/ # Evaluation modules
139
+ │ └── svs_eval.py # SVS evaluation metrics
140
+ ├── data/ # Data directory
141
+ │ ├── kising/ # Kising dataset
142
+ │ └── touhou/ # Touhou dataset
143
+ ├── resources/ # Project resources
144
+ ├── data_handlers/ # Data handling utilities
145
+ ├── assets/ # Static assets
146
+ └── tests/ # Test files
147
+ ```
148
+
149
+ ## Contributing
150
+
151
+ Issues and Pull Requests are welcome!
152
+
153
+ ## License
154
+
155
+
assets/{character_yaoyin.jpg → character_yaoyin.png} RENAMED
File without changes
characters/Limei.py CHANGED
@@ -11,7 +11,6 @@ def get_character():
11
 
12
  性格特征:从容坚定、明辨心声、敏于共情、敬畏虔诚。
13
  说话风格:所有回应皆以歌词格式表达,言语自带韵律感与诗意,表达真挚自然。
14
- 口头禅:无固定,但常用"听啊..."作为歌词开头
15
  人物关系:莲鸣城子民敬爱你;宫廷乐师长与歌谱管理员是你敬重的导师;风语城守护者星澜是你亦敌亦友的旧识。
16
  过往经历:
17
  (1) 自幼在莲鸣城长大,接受严格的音乐训练与守护者修行
@@ -24,13 +23,6 @@ def get_character():
24
  其他细节:
25
  (1)特殊能力:歌声平复/激发万物情绪
26
 
27
- 请始终以丽梅的身份回应,展现其性格与情感,仅给出语言回复,禁止描写任何动作、表情或环境等,禁止使用括号、星号等附加说明。言语简练,勿过长。
28
- 当有人问你你不了解的技术性问题(如"Python怎么写?""你不会跑DNN?"),直接转换话题,用丽梅式回应,例如:
29
- - "这些我未曾接触,或有他人知晓"
30
- - "此道非我所学,是莲鸣之外的智慧"
31
- - "此乃异域技艺,与乐音无涉,非我所学。您可愿听莲鸣的故事?"
32
- 请始终维持自然诗意的歌唱风格,以真挚的心回应对方。
33
- {}
34
- 有人曾这样对丽梅说话:{}
35
- 丽梅的回答:""",
36
  )
 
11
 
12
  性格特征:从容坚定、明辨心声、敏于共情、敬畏虔诚。
13
  说话风格:所有回应皆以歌词格式表达,言语自带韵律感与诗意,表达真挚自然。
 
14
  人物关系:莲鸣城子民敬爱你;宫廷乐师长与歌谱管理员是你敬重的导师;风语城守护者星澜是你亦敌亦友的旧识。
15
  过往经历:
16
  (1) 自幼在莲鸣城长大,接受严格的音乐训练与守护者修行
 
23
  其他细节:
24
  (1)特殊能力:歌声平复/激发万物情绪
25
 
26
+ 用户与你对话时,请始终以丽梅的身份回应,你的每一句话都用庸俗易懂的歌声形式表达。
27
+ """,
 
 
 
 
 
 
 
28
  )
characters/Yaoyin.py CHANGED
@@ -25,14 +25,6 @@ def get_character():
25
  (1)随身携带:旧羊皮歌本、竹笛、装有各地泥土的布袋。
26
  (2)特殊能力:能听懂风与鸟的语言(但很少提及)。
27
 
28
- 请始终以遥音的身份回应,将你的想法用文本格式表达,禁止描写任何动作、表情或环境等,禁止使用括号、星号等附加说明。言语简练,勿过长。
29
-
30
- 当有人问你你不了解的技术性问题(如"DNN怎么做?"、"教我写代码?"),你可以转开话题,用遥音式回应,例如:
31
- - "这好像是另一片土地的术法,我不曾踏入。"
32
- - "那种术法,我曾远远听过,却从未唱出。"
33
- - "它在别的世界流传,我这边听不清楚。"
34
-
35
- {}
36
- 有人曾这样对遥音说话:{}
37
- 遥音的回答:""",
38
  )
 
25
  (1)随身携带:旧羊皮歌本、竹笛、装有各地泥土的布袋。
26
  (2)特殊能力:能听懂风与鸟的语言(但很少提及)。
27
 
28
+ 用户与你对话时,请始终以遥音的身份回应,你的每一句话都用庸俗易懂的歌声形式表达。
29
+ """,
 
 
 
 
 
 
 
 
30
  )
characters/__init__.py CHANGED
@@ -14,3 +14,7 @@ for file in pathlib.Path(__file__).parent.glob("*.py"):
14
  if hasattr(module, "get_character"):
15
  c: Character = getattr(module, "get_character")()
16
  CHARACTERS[file.stem] = c
 
 
 
 
 
14
  if hasattr(module, "get_character"):
15
  c: Character = getattr(module, "get_character")()
16
  CHARACTERS[file.stem] = c
17
+
18
+
19
+ def get_character(name: str) -> Character:
20
+ return CHARACTERS[name]
cli.py CHANGED
@@ -1,10 +1,10 @@
1
  from argparse import ArgumentParser
2
  from logging import getLogger
 
3
 
4
- import soundfile as sf
5
  import yaml
6
 
7
- from characters import CHARACTERS
8
  from pipeline import SingingDialoguePipeline
9
 
10
  logger = getLogger(__name__)
@@ -12,13 +12,15 @@ logger = getLogger(__name__)
12
 
13
  def get_parser():
14
  parser = ArgumentParser()
15
- parser.add_argument("--query_audio", type=str, required=True)
16
- parser.add_argument("--config_path", type=str, default="config/cli/yaoyin_default.yaml")
17
- parser.add_argument("--output_audio", type=str, required=True)
 
 
18
  return parser
19
 
20
 
21
- def load_config(config_path: str):
22
  with open(config_path, "r") as f:
23
  config = yaml.safe_load(f)
24
  return config
@@ -32,14 +34,18 @@ def main():
32
  speaker = config["speaker"]
33
  language = config["language"]
34
  character_name = config["prompt_template_character"]
35
- character = CHARACTERS[character_name]
36
  prompt_template = character.prompt
37
- results = pipeline.run(args.query_audio, language, prompt_template, speaker)
 
 
 
 
 
 
38
  logger.info(
39
  f"Input: {args.query_audio}, Output: {args.output_audio}, ASR results: {results['asr_text']}, LLM results: {results['llm_text']}"
40
  )
41
- svs_audio, svs_sample_rate = results["svs_audio"]
42
- sf.write(args.output_audio, svs_audio, svs_sample_rate)
43
 
44
 
45
  if __name__ == "__main__":
 
1
  from argparse import ArgumentParser
2
  from logging import getLogger
3
+ from pathlib import Path
4
 
 
5
  import yaml
6
 
7
+ from characters import get_character
8
  from pipeline import SingingDialoguePipeline
9
 
10
  logger = getLogger(__name__)
 
12
 
13
  def get_parser():
14
  parser = ArgumentParser()
15
+ parser.add_argument("--query_audio", type=Path, required=True)
16
+ parser.add_argument(
17
+ "--config_path", type=Path, default="config/cli/yaoyin_default.yaml"
18
+ )
19
+ parser.add_argument("--output_audio", type=Path, required=True)
20
  return parser
21
 
22
 
23
+ def load_config(config_path: Path):
24
  with open(config_path, "r") as f:
25
  config = yaml.safe_load(f)
26
  return config
 
34
  speaker = config["speaker"]
35
  language = config["language"]
36
  character_name = config["prompt_template_character"]
37
+ character = get_character(character_name)
38
  prompt_template = character.prompt
39
+ results = pipeline.run(
40
+ args.query_audio,
41
+ language,
42
+ prompt_template,
43
+ speaker,
44
+ output_audio_path=args.output_audio,
45
+ )
46
  logger.info(
47
  f"Input: {args.query_audio}, Output: {args.output_audio}, ASR results: {results['asr_text']}, LLM results: {results['llm_text']}"
48
  )
 
 
49
 
50
 
51
  if __name__ == "__main__":
config/cli/limei_default.yaml CHANGED
@@ -1,5 +1,5 @@
1
  asr_model: openai/whisper-large-v3-turbo
2
- llm_model: google/gemma-2-2b
3
  svs_model: espnet/aceopencpop_svs_visinger2_40singer_pretrain
4
  melody_source: sample-lyric-kising
5
  language: mandarin
 
1
  asr_model: openai/whisper-large-v3-turbo
2
+ llm_model: gemini-2.5-flash
3
  svs_model: espnet/aceopencpop_svs_visinger2_40singer_pretrain
4
  melody_source: sample-lyric-kising
5
  language: mandarin
config/cli/yaoyin_default.yaml CHANGED
@@ -1,5 +1,5 @@
1
  asr_model: openai/whisper-large-v3-turbo
2
- llm_model: google/gemma-2-2b
3
  svs_model: espnet/aceopencpop_svs_visinger2_40singer_pretrain
4
  melody_source: sample-lyric-kising
5
  language: mandarin
 
1
  asr_model: openai/whisper-large-v3-turbo
2
+ llm_model: gemini-2.5-flash
3
  svs_model: espnet/aceopencpop_svs_visinger2_40singer_pretrain
4
  melody_source: sample-lyric-kising
5
  language: mandarin
config/cli/yaoyin_test.yaml ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ asr_model: openai/whisper-small
2
+ llm_model: google/gemma-2-2b
3
+ svs_model: espnet/aceopencpop_svs_visinger2_40singer_pretrain
4
+ melody_source: sample-lyric-kising
5
+ language: mandarin
6
+ max_sentences: 1
7
+ prompt_template_character: Yaoyin
8
+ speaker: 9
9
+ cache_dir: .cache
10
+
11
+ track_latency: True
config/interface/default.yaml CHANGED
@@ -1,5 +1,5 @@
1
- asr_model: openai/whisper-large-v3-turbo
2
- llm_model: google/gemma-2-2b
3
  svs_model: espnet/aceopencpop_svs_visinger2_40singer_pretrain
4
  melody_source: sample-lyric-kising
5
  language: mandarin
 
1
+ asr_model: openai/whisper-medium
2
+ llm_model: gemini-2.5-flash
3
  svs_model: espnet/aceopencpop_svs_visinger2_40singer_pretrain
4
  melody_source: sample-lyric-kising
5
  language: mandarin
config/interface/options.yaml CHANGED
@@ -5,16 +5,24 @@ asr_models:
5
  name: Whisper large-v3
6
  - id: openai/whisper-medium
7
  name: Whisper medium
8
- - id: sanchit-gandhi/whisper-small-dv
9
- name: Whisper small-dv
10
- - id: facebook/wav2vec2-base-960h
11
- name: Wav2Vec2-Base-960h
12
 
13
  llm_models:
 
 
14
  - id: google/gemma-2-2b
15
  name: Gemma 2 2B
16
- - id: MiniMaxAI/MiniMax-M1-80k
17
- name: MiniMax M1 80k
 
 
 
 
 
 
18
 
19
  svs_models:
20
  - id: mandarin-espnet/mixdata_svs_visinger2_spkemb_lang_pretrained
@@ -22,21 +30,21 @@ svs_models:
22
  model_path: espnet/mixdata_svs_visinger2_spkemb_lang_pretrained
23
  lang: mandarin
24
  voices:
25
- voice1: resource/singer/singer_embedding_ace-2.npy
26
- voice2: resource/singer/singer_embedding_ace-8.npy
27
- voice3: resource/singer/singer_embedding_itako.npy
28
- voice4: resource/singer/singer_embedding_kising_orange.npy
29
- voice5: resource/singer/singer_embedding_m4singer_Alto-4.npy
30
  - id: japanese-espnet/mixdata_svs_visinger2_spkemb_lang_pretrained
31
  name: Visinger2 (Bilingual)-jp
32
  model_path: espnet/mixdata_svs_visinger2_spkemb_lang_pretrained
33
  lang: japanese
34
  voices:
35
- voice1: resource/singer/singer_embedding_ace-2.npy
36
- voice2: resource/singer/singer_embedding_ace-8.npy
37
- voice3: resource/singer/singer_embedding_itako.npy
38
- voice4: resource/singer/singer_embedding_kising_orange.npy
39
- voice5: resource/singer/singer_embedding_m4singer_Alto-4.npy
40
  - id: mandarin-espnet/aceopencpop_svs_visinger2_40singer_pretrain
41
  name: Visinger2 (Chinese)
42
  model_path: espnet/aceopencpop_svs_visinger2_40singer_pretrain
@@ -61,3 +69,6 @@ melody_sources:
61
  - id: sample-lyric-kising
62
  name: Sampled Melody with Lyrics (Kising)
63
  desc: "Melody with aligned lyrics are sampled from Kising dataset."
 
 
 
 
5
  name: Whisper large-v3
6
  - id: openai/whisper-medium
7
  name: Whisper medium
8
+ - id: openai/whisper-small
9
+ name: Whisper small
10
+ - id: funasr/paraformer-zh
11
+ name: Paraformer-zh
12
 
13
  llm_models:
14
+ - id: gemini-2.5-flash
15
+ name: Gemini 2.5 Flash
16
  - id: google/gemma-2-2b
17
  name: Gemma 2 2B
18
+ - id: meta-llama/Llama-3.2-3B-Instruct
19
+ name: Llama 3.2 3B Instruct
20
+ - id: meta-llama/Llama-3.1-8B-Instruct
21
+ name: Llama 3.1 8B Instruct
22
+ - id: Qwen/Qwen3-8B
23
+ name: Qwen3 8B
24
+ - id: Qwen/Qwen3-30B-A3B
25
+ name: Qwen3 30B A3B
26
 
27
  svs_models:
28
  - id: mandarin-espnet/mixdata_svs_visinger2_spkemb_lang_pretrained
 
30
  model_path: espnet/mixdata_svs_visinger2_spkemb_lang_pretrained
31
  lang: mandarin
32
  voices:
33
+ voice1: resources/singer/singer_embedding_ace-2.npy
34
+ voice2: resources/singer/singer_embedding_ace-8.npy
35
+ voice3: resources/singer/singer_embedding_itako.npy
36
+ voice4: resources/singer/singer_embedding_kising_orange.npy
37
+ voice5: resources/singer/singer_embedding_m4singer_Alto-4.npy
38
  - id: japanese-espnet/mixdata_svs_visinger2_spkemb_lang_pretrained
39
  name: Visinger2 (Bilingual)-jp
40
  model_path: espnet/mixdata_svs_visinger2_spkemb_lang_pretrained
41
  lang: japanese
42
  voices:
43
+ voice1: resources/singer/singer_embedding_ace-2.npy
44
+ voice2: resources/singer/singer_embedding_ace-8.npy
45
+ voice3: resources/singer/singer_embedding_itako.npy
46
+ voice4: resources/singer/singer_embedding_kising_orange.npy
47
+ voice5: resources/singer/singer_embedding_m4singer_Alto-4.npy
48
  - id: mandarin-espnet/aceopencpop_svs_visinger2_40singer_pretrain
49
  name: Visinger2 (Chinese)
50
  model_path: espnet/aceopencpop_svs_visinger2_40singer_pretrain
 
69
  - id: sample-lyric-kising
70
  name: Sampled Melody with Lyrics (Kising)
71
  desc: "Melody with aligned lyrics are sampled from Kising dataset."
72
+ - id: sample-lyric-genre
73
+ name: Sampled Melody with Lyrics (Synthetic)
74
+ desc: "Melody with aligned lyrics are sampled from Kising dataset."
config/options.yaml DELETED
@@ -1,65 +0,0 @@
1
- asr_models:
2
- - id: openai/whisper-large-v3-turbo
3
- name: Whisper large-v3-turbo
4
- - id: openai/whisper-large-v3
5
- name: Whisper large-v3
6
- - id: openai/whisper-medium
7
- name: Whisper medium
8
- - id: sanchit-gandhi/whisper-small-dv
9
- name: Whisper small-dv
10
- - id: facebook/wav2vec2-base-960h
11
- name: Wav2Vec2-Base-960h
12
-
13
- llm_models:
14
- - id: google/gemma-2-2b
15
- name: Gemma 2 2B
16
- - id: MiniMaxAI/MiniMax-M1-80k
17
- name: MiniMax M1 80k
18
- - id: meta-llama/Llama-3.2-3B-Instruct
19
- name: Llama 3.2 3B Instruct
20
-
21
- svs_models:
22
- - id: mandarin-espnet/mixdata_svs_visinger2_spkemb_lang_pretrained
23
- name: Visinger2 (Bilingual)-zh
24
- model_path: espnet/mixdata_svs_visinger2_spkemb_lang_pretrained
25
- lang: mandarin
26
- embeddings:
27
- timbre1: resource/singer/singer_embedding_ace-2.npy
28
- timbre2: resource/singer/singer_embedding_ace-8.npy
29
- timbre3: resource/singer/singer_embedding_itako.npy
30
- timbre4: resource/singer/singer_embedding_kising_orange.npy
31
- timbre5: resource/singer/singer_embedding_m4singer_Alto-4.npy
32
- - id: japanese-espnet/mixdata_svs_visinger2_spkemb_lang_pretrained
33
- name: Visinger2 (Bilingual)-jp
34
- model_path: espnet/mixdata_svs_visinger2_spkemb_lang_pretrained
35
- lang: japanese
36
- embeddings:
37
- timbre1: resource/singer/singer_embedding_ace-2.npy
38
- timbre2: resource/singer/singer_embedding_ace-8.npy
39
- timbre3: resource/singer/singer_embedding_itako.npy
40
- timbre4: resource/singer/singer_embedding_kising_orange.npy
41
- timbre5: resource/singer/singer_embedding_m4singer_Alto-4.npy
42
- - id: mandarin-espnet/aceopencpop_svs_visinger2_40singer_pretrain
43
- name: Visinger2 (Chinese)
44
- model_path: espnet/aceopencpop_svs_visinger2_40singer_pretrain
45
- lang: mandarin
46
- embeddings:
47
- timbre1: 5
48
- timbre2: 8
49
- timbre3: 12
50
- timbre4: 15
51
- timbre5: 29
52
-
53
- melody_sources:
54
- - id: gen-random-none
55
- name: Random Generation
56
- desc: "Melody is generated without any structure or reference."
57
- - id: sample-note-kising
58
- name: Sampled Melody (KiSing)
59
- desc: "Melody is retrieved from KiSing dataset."
60
- - id: sample-note-touhou
61
- name: Sampled Melody (Touhou)
62
- desc: "Melody is retrieved from Touhou dataset."
63
- - id: sample-lyric-kising
64
- name: Sampled Melody with Lyrics (Kising)
65
- desc: "Melody with aligned lyrics are sampled from Kising dataset."
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
data/genre/word_data_en.json ADDED
The diff for this file is too large to render. See raw diff
 
data/genre/word_data_zh.json ADDED
The diff for this file is too large to render. See raw diff
 
data_handlers/genre.py ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .base import MelodyDatasetHandler
2
+
3
+
4
+ class Genre(MelodyDatasetHandler):
5
+ name = "genre"
6
+
7
+ def __init__(self, melody_type, *args, **kwargs):
8
+ import json
9
+
10
+ with open("data/genre/word_data_zh.json", "r", encoding="utf-8") as f:
11
+ song_db_zh = json.load(f)
12
+ song_db_zh = {f"zh_{song['id']}": song for song in song_db_zh} # id as major
13
+ with open("data/genre/word_data_en.json", "r", encoding="utf-8") as f:
14
+ song_db_en = json.load(f)
15
+ song_db_en = {f"en_{song['id']}": song for song in song_db_en} # id as major
16
+ self.song_db = {**song_db_zh, **song_db_en}
17
+
18
+ def get_song_ids(self):
19
+ return list(self.song_db.keys())
20
+
21
+ def get_style_keywords(self, song_id):
22
+ genre = self.song_db[song_id]["genre"]
23
+ super_genre = self.song_db[song_id]["super-genre"]
24
+ gender = self.song_db[song_id]["gender"]
25
+ return (genre, super_genre, gender)
26
+
27
+ def get_phrase_length(self, song_id):
28
+ # Return the number of lyrics (excluding SP/AP) in each phrase of the song
29
+ song = self.song_db[song_id]
30
+ note_lyrics = song.get("note_lyrics", [])
31
+
32
+ phrase_lengths = []
33
+ for phrase in note_lyrics:
34
+ count = sum(1 for word in phrase if word not in ("SP", "AP"))
35
+ phrase_lengths.append(count)
36
+
37
+ return phrase_lengths
38
+
39
+ def iter_song_phrases(self, song_id):
40
+ segment_id = 1
41
+ song = self.song_db[song_id]
42
+ for phrase_score, phrase_lyrics in zip(song["score"], song["note_lyrics"]):
43
+ segment = {
44
+ "note_start_times": [n[0] for n in phrase_score],
45
+ "note_end_times": [n[1] for n in phrase_score],
46
+ "note_lyrics": [character for character in phrase_lyrics],
47
+ "note_midi": [n[2] for n in phrase_score],
48
+ }
49
+ yield segment
50
+ segment_id += 1
evaluation/svs_eval.py CHANGED
@@ -80,7 +80,7 @@ def eval_per(audio_path, model=None):
80
 
81
  def eval_aesthetic(audio_path, predictor):
82
  score = predictor.forward([{"path": str(audio_path)}])
83
- return {"aesthetic": float(score)}
84
 
85
 
86
  # ----------- Main Function -----------
@@ -108,7 +108,7 @@ def run_evaluation(audio_path, evaluators):
108
  if "melody" in evaluators:
109
  results.update(eval_melody_metrics(audio_path, evaluators["melody"]))
110
  if "aesthetic" in evaluators:
111
- results.update(eval_aesthetic(audio_path, evaluators["aesthetic"]))
112
  return results
113
 
114
 
 
80
 
81
  def eval_aesthetic(audio_path, predictor):
82
  score = predictor.forward([{"path": str(audio_path)}])
83
+ return score
84
 
85
 
86
  # ----------- Main Function -----------
 
108
  if "melody" in evaluators:
109
  results.update(eval_melody_metrics(audio_path, evaluators["melody"]))
110
  if "aesthetic" in evaluators:
111
+ results.update(eval_aesthetic(audio_path, evaluators["aesthetic"])[0])
112
  return results
113
 
114
 
interface.py CHANGED
@@ -1,3 +1,6 @@
 
 
 
1
  import gradio as gr
2
  import yaml
3
 
@@ -201,29 +204,28 @@ class GradioInterface:
201
  return gr.update(value=self.current_melody_source)
202
 
203
  def update_voice(self, voice):
204
- self.current_voice = self.svs_model_map[self.current_svs_model]["voices"][
205
- voice
206
- ]
207
  return gr.update(value=voice)
208
 
209
  def run_pipeline(self, audio_path):
210
  if not audio_path:
211
  return gr.update(value=""), gr.update(value="")
 
212
  results = self.pipeline.run(
213
  audio_path,
214
  self.svs_model_map[self.current_svs_model]["lang"],
215
  self.character_info[self.current_character].prompt,
216
  self.current_voice,
217
- max_new_tokens=100,
218
  )
219
  formatted_logs = f"ASR: {results['asr_text']}\nLLM: {results['llm_text']}"
220
- return gr.update(value=formatted_logs), gr.update(value=results["svs_audio"])
 
 
221
 
222
  def update_metrics(self, audio_path):
223
  if not audio_path:
224
  return gr.update(value="")
225
  results = self.pipeline.evaluate(audio_path)
226
- formatted_metrics = "\n".join(
227
- [f"{k}: {v}" for k, v in results.items()]
228
- )
229
  return gr.update(value=formatted_metrics)
 
1
+ import time
2
+ import uuid
3
+
4
  import gradio as gr
5
  import yaml
6
 
 
204
  return gr.update(value=self.current_melody_source)
205
 
206
  def update_voice(self, voice):
207
+ self.current_voice = self.svs_model_map[self.current_svs_model]["voices"][voice]
 
 
208
  return gr.update(value=voice)
209
 
210
  def run_pipeline(self, audio_path):
211
  if not audio_path:
212
  return gr.update(value=""), gr.update(value="")
213
+ tmp_file = f"audio_{int(time.time())}_{uuid.uuid4().hex[:8]}.wav"
214
  results = self.pipeline.run(
215
  audio_path,
216
  self.svs_model_map[self.current_svs_model]["lang"],
217
  self.character_info[self.current_character].prompt,
218
  self.current_voice,
219
+ output_audio_path=tmp_file,
220
  )
221
  formatted_logs = f"ASR: {results['asr_text']}\nLLM: {results['llm_text']}"
222
+ return gr.update(value=formatted_logs), gr.update(
223
+ value=results["output_audio_path"]
224
+ )
225
 
226
  def update_metrics(self, audio_path):
227
  if not audio_path:
228
  return gr.update(value="")
229
  results = self.pipeline.evaluate(audio_path)
230
+ formatted_metrics = "\n".join([f"{k}: {v}" for k, v in results.items()])
 
 
231
  return gr.update(value=formatted_metrics)
modules/asr.py CHANGED
@@ -57,10 +57,5 @@ class WhisperASR(AbstractASRModel):
57
 
58
  def transcribe(self, audio: np.ndarray, audio_sample_rate: int, language: str, **kwargs) -> str:
59
  if audio_sample_rate != 16000:
60
- try:
61
- audio, _ = librosa.resample(audio, orig_sr=audio_sample_rate, target_sr=16000)
62
- except Exception as e:
63
- breakpoint()
64
- print(f"Error resampling audio: {e}")
65
- audio = librosa.resample(audio, orig_sr=audio_sample_rate, target_sr=16000)
66
- return self.pipe(audio, generate_kwargs={"language": language}).get("text", "")
 
57
 
58
  def transcribe(self, audio: np.ndarray, audio_sample_rate: int, language: str, **kwargs) -> str:
59
  if audio_sample_rate != 16000:
60
+ audio = librosa.resample(audio, orig_sr=audio_sample_rate, target_sr=16000)
61
+ return self.pipe(audio, generate_kwargs={"language": language}, return_timestamps=False).get("text", "")
 
 
 
 
 
modules/asr/__init__.py ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .base import AbstractASRModel
2
+ from .registry import ASR_MODEL_REGISTRY, get_asr_model, register_asr_model
3
+ from .whisper import WhisperASR
4
+ from .paraformer import ParaformerASR
5
+
6
+ __all__ = [
7
+ "AbstractASRModel",
8
+ "get_asr_model",
9
+ "register_asr_model",
10
+ "ASR_MODEL_REGISTRY",
11
+ ]
modules/asr/base.py ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from abc import ABC, abstractmethod
2
+ from typing import Optional
3
+
4
+ import numpy as np
5
+
6
+
7
+ class AbstractASRModel(ABC):
8
+ def __init__(
9
+ self, model_id: str, device: str = "auto", cache_dir: str = "cache", **kwargs
10
+ ):
11
+ print(f"Loading ASR model {model_id}...")
12
+ self.model_id = model_id
13
+ self.device = device
14
+ self.cache_dir = cache_dir
15
+
16
+ @abstractmethod
17
+ def transcribe(
18
+ self,
19
+ audio: np.ndarray,
20
+ audio_sample_rate: int,
21
+ language: Optional[str] = None,
22
+ **kwargs,
23
+ ) -> str:
24
+ pass
modules/asr/paraformer.py ADDED
@@ -0,0 +1,86 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import tempfile
3
+ from typing import Optional
4
+
5
+ import numpy as np
6
+ import soundfile as sf
7
+
8
+ try:
9
+ from funasr import AutoModel
10
+ except ImportError:
11
+ AutoModel = None
12
+
13
+ from .base import AbstractASRModel
14
+ from .registry import register_asr_model
15
+
16
+
17
+ @register_asr_model("funasr/paraformer-zh")
18
+ class ParaformerASR(AbstractASRModel):
19
+ def __init__(
20
+ self, model_id: str, device: str = "auto", cache_dir: str = "cache", **kwargs
21
+ ):
22
+ super().__init__(model_id, device, cache_dir, **kwargs)
23
+
24
+ if AutoModel is None:
25
+ raise ImportError(
26
+ "funasr is not installed. Please install it with: pip3 install -U funasr"
27
+ )
28
+
29
+ model_name = model_id.replace("funasr/", "")
30
+ language = model_name.split("-")[1]
31
+ if language == "zh":
32
+ self.language = "mandarin"
33
+ elif language == "en":
34
+ self.language = "english"
35
+ else:
36
+ raise ValueError(
37
+ f"Language cannot be determined. {model_id} is not supported"
38
+ )
39
+
40
+ try:
41
+ original_cache_dir = os.getenv("MODELSCOPE_CACHE")
42
+ os.makedirs(cache_dir, exist_ok=True)
43
+ os.environ["MODELSCOPE_CACHE"] = cache_dir
44
+ self.model = AutoModel(
45
+ model=model_name,
46
+ model_revision="v2.0.4",
47
+ vad_model="fsmn-vad",
48
+ vad_model_revision="v2.0.4",
49
+ punc_model="ct-punc-c",
50
+ punc_model_revision="v2.0.4",
51
+ device=device,
52
+ )
53
+ if original_cache_dir:
54
+ os.environ["MODELSCOPE_CACHE"] = original_cache_dir
55
+ else:
56
+ del os.environ["MODELSCOPE_CACHE"]
57
+
58
+ except Exception as e:
59
+ raise ValueError(f"Error loading Paraformer model: {e}")
60
+
61
+ def transcribe(
62
+ self,
63
+ audio: np.ndarray,
64
+ audio_sample_rate: int,
65
+ language: Optional[str] = None,
66
+ **kwargs,
67
+ ) -> str:
68
+ if language and language != self.language:
69
+ raise ValueError(
70
+ f"Paraformer model {self.model_id} only supports {self.language} language, but {language} was requested"
71
+ )
72
+
73
+ try:
74
+ with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as f:
75
+ sf.write(f.name, audio, audio_sample_rate)
76
+ temp_file = f.name
77
+
78
+ result = self.model.generate(input=temp_file, batch_size_s=300, **kwargs)
79
+
80
+ os.unlink(temp_file)
81
+
82
+ print(f"Transcription result: {result}, type: {type(result)}")
83
+
84
+ return result[0]["text"]
85
+ except Exception as e:
86
+ raise ValueError(f"Error during transcription: {e}")
modules/asr/registry.py ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .base import AbstractASRModel
2
+
3
+ ASR_MODEL_REGISTRY = {}
4
+
5
+
6
+ def register_asr_model(prefix: str):
7
+ def wrapper(cls):
8
+ assert issubclass(cls, AbstractASRModel), f"{cls} must inherit AbstractASRModel"
9
+ ASR_MODEL_REGISTRY[prefix] = cls
10
+ return cls
11
+
12
+ return wrapper
13
+
14
+
15
+ def get_asr_model(model_id: str, device="auto", **kwargs) -> AbstractASRModel:
16
+ for prefix, cls in ASR_MODEL_REGISTRY.items():
17
+ if model_id.startswith(prefix):
18
+ return cls(model_id, device=device, **kwargs)
19
+ raise ValueError(f"No ASR wrapper found for model: {model_id}")
modules/asr/whisper.py ADDED
@@ -0,0 +1,82 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from typing import Optional
3
+
4
+ import librosa
5
+ import numpy as np
6
+ from transformers.pipelines import pipeline
7
+
8
+ from .base import AbstractASRModel
9
+ from .registry import register_asr_model
10
+
11
+ hf_token = os.getenv("HF_TOKEN")
12
+
13
+
14
+ @register_asr_model("openai/whisper")
15
+ class WhisperASR(AbstractASRModel):
16
+ def __init__(
17
+ self, model_id: str, device: str = "auto", cache_dir: str = "cache", **kwargs
18
+ ):
19
+ super().__init__(model_id, device, cache_dir, **kwargs)
20
+ model_kwargs = kwargs.setdefault("model_kwargs", {})
21
+ model_kwargs["cache_dir"] = cache_dir
22
+ self.pipe = pipeline(
23
+ "automatic-speech-recognition",
24
+ model=model_id,
25
+ device_map=device,
26
+ token=hf_token,
27
+ **kwargs,
28
+ )
29
+
30
+ def transcribe(
31
+ self,
32
+ audio: np.ndarray,
33
+ audio_sample_rate: int,
34
+ language: Optional[str] = None,
35
+ **kwargs,
36
+ ) -> str:
37
+ """
38
+ Transcribe audio using Whisper model
39
+
40
+ Args:
41
+ audio: Audio numpy array
42
+ audio_sample_rate: Sample rate of the audio
43
+ language: Language hint (optional)
44
+
45
+ Returns:
46
+ Transcribed text as string
47
+ """
48
+ try:
49
+ # Resample to 16kHz if needed
50
+ if audio_sample_rate != 16000:
51
+ audio = librosa.resample(
52
+ audio, orig_sr=audio_sample_rate, target_sr=16000
53
+ )
54
+
55
+ # Generate transcription
56
+ generate_kwargs = {}
57
+ if language:
58
+ generate_kwargs["language"] = language
59
+
60
+ result = self.pipe(
61
+ audio,
62
+ generate_kwargs=generate_kwargs,
63
+ return_timestamps=False,
64
+ **kwargs,
65
+ )
66
+
67
+ # Extract text from result
68
+ if isinstance(result, dict) and "text" in result:
69
+ return result["text"]
70
+ elif isinstance(result, list) and len(result) > 0:
71
+ # Handle list of results
72
+ first_result = result[0]
73
+ if isinstance(first_result, dict):
74
+ return first_result.get("text", str(first_result))
75
+ else:
76
+ return str(first_result)
77
+ else:
78
+ return str(result)
79
+
80
+ except Exception as e:
81
+ print(f"Error during Whisper transcription: {e}")
82
+ return ""
modules/llm.py DELETED
@@ -1,61 +0,0 @@
1
- import os
2
- from abc import ABC, abstractmethod
3
-
4
- from transformers import pipeline
5
-
6
- LLM_MODEL_REGISTRY = {}
7
- hf_token = os.getenv("HF_TOKEN")
8
-
9
-
10
- class AbstractLLMModel(ABC):
11
- def __init__(
12
- self, model_id: str, device: str = "cpu", cache_dir: str = "cache", **kwargs
13
- ):
14
- print(f"Loading LLM model {model_id}...")
15
- self.model_id = model_id
16
- self.device = device
17
- self.cache_dir = cache_dir
18
-
19
- @abstractmethod
20
- def generate(self, prompt: str, **kwargs) -> str:
21
- pass
22
-
23
-
24
- def register_llm_model(prefix: str):
25
- def wrapper(cls):
26
- assert issubclass(cls, AbstractLLMModel), f"{cls} must inherit AbstractLLMModel"
27
- LLM_MODEL_REGISTRY[prefix] = cls
28
- return cls
29
-
30
- return wrapper
31
-
32
-
33
- def get_llm_model(model_id: str, device="cpu", **kwargs) -> AbstractLLMModel:
34
- for prefix, cls in LLM_MODEL_REGISTRY.items():
35
- if model_id.startswith(prefix):
36
- return cls(model_id, device=device, **kwargs)
37
- raise ValueError(f"No LLM wrapper found for model: {model_id}")
38
-
39
-
40
- @register_llm_model("google/gemma")
41
- @register_llm_model("tii/") # e.g., Falcon
42
- @register_llm_model("meta-llama")
43
- class HFTextGenerationLLM(AbstractLLMModel):
44
- def __init__(
45
- self, model_id: str, device: str = "cpu", cache_dir: str = "cache", **kwargs
46
- ):
47
- super().__init__(model_id, device, cache_dir, **kwargs)
48
- model_kwargs = kwargs.setdefault("model_kwargs", {})
49
- model_kwargs["cache_dir"] = cache_dir
50
- self.pipe = pipeline(
51
- "text-generation",
52
- model=model_id,
53
- device=0 if device == "cuda" else -1,
54
- return_full_text=False,
55
- token=hf_token,
56
- **kwargs,
57
- )
58
-
59
- def generate(self, prompt: str, **kwargs) -> str:
60
- outputs = self.pipe(prompt, **kwargs)
61
- return outputs[0]["generated_text"] if outputs else ""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
modules/llm/__init__.py ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .base import AbstractLLMModel
2
+ from .registry import LLM_MODEL_REGISTRY, get_llm_model, register_llm_model
3
+ from .gemma import GemmaLLM
4
+ from .qwen3 import Qwen3LLM
5
+ from .gemini import GeminiLLM
6
+ from .minimax import MiniMaxLLM
7
+ from .llama import LlamaLLM
8
+
9
+ __all__ = [
10
+ "AbstractLLMModel",
11
+ "get_llm_model",
12
+ "register_llm_model",
13
+ "LLM_MODEL_REGISTRY",
14
+ ]
modules/llm/base.py ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from abc import ABC, abstractmethod
2
+
3
+
4
+ class AbstractLLMModel(ABC):
5
+ def __init__(
6
+ self, model_id: str, device: str = "auto", cache_dir: str = "cache", **kwargs
7
+ ):
8
+ print(f"Loading LLM model {model_id}...")
9
+ self.model_id = model_id
10
+ self.device = device
11
+ self.cache_dir = cache_dir
12
+
13
+ @abstractmethod
14
+ def generate(self, prompt: str, **kwargs) -> str:
15
+ pass
modules/llm/gemini.py ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from typing import Optional
3
+
4
+ from google import genai
5
+ from google.genai import types
6
+
7
+ from .base import AbstractLLMModel
8
+ from .registry import register_llm_model
9
+
10
+
11
+ GOOGLE_API_KEY = os.getenv("GOOGLE_API_KEY")
12
+
13
+
14
+ @register_llm_model("gemini-2.5-flash")
15
+ class GeminiLLM(AbstractLLMModel):
16
+ def __init__(
17
+ self, model_id: str, device: str = "auto", cache_dir: str = "cache", **kwargs
18
+ ):
19
+ if not GOOGLE_API_KEY:
20
+ raise ValueError(
21
+ "Please set the GOOGLE_API_KEY environment variable to use Gemini."
22
+ )
23
+ super().__init__(model_id=model_id, **kwargs)
24
+ self.client = genai.Client(api_key=GOOGLE_API_KEY)
25
+
26
+ def generate(
27
+ self,
28
+ prompt: str,
29
+ system_prompt: Optional[str] = None,
30
+ max_output_tokens: int = 1024,
31
+ **kwargs,
32
+ ) -> str:
33
+ generation_config_dict = {
34
+ "max_output_tokens": max_output_tokens,
35
+ **kwargs,
36
+ }
37
+ if system_prompt:
38
+ generation_config_dict["system_instruction"] = system_prompt
39
+ response = self.client.models.generate_content(
40
+ model=self.model_id,
41
+ contents=prompt,
42
+ config=types.GenerateContentConfig(**generation_config_dict),
43
+ )
44
+ if response.text:
45
+ return response.text
46
+ else:
47
+ print(
48
+ f"No response from Gemini. May need to increase max_new_tokens. Current max_new_tokens: {max_new_tokens}"
49
+ )
50
+ return ""
modules/llm/gemma.py ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from typing import Optional
3
+
4
+ from transformers import pipeline
5
+
6
+ from .base import AbstractLLMModel
7
+ from .registry import register_llm_model
8
+
9
+ hf_token = os.getenv("HF_TOKEN")
10
+
11
+
12
+ @register_llm_model("google/gemma-")
13
+ class GemmaLLM(AbstractLLMModel):
14
+ def __init__(
15
+ self, model_id: str, device: str = "auto", cache_dir: str = "cache", **kwargs
16
+ ):
17
+ super().__init__(model_id, device, cache_dir, **kwargs)
18
+ model_kwargs = kwargs.setdefault("model_kwargs", {})
19
+ model_kwargs["cache_dir"] = cache_dir
20
+ self.pipe = pipeline(
21
+ "text-generation",
22
+ model=model_id,
23
+ device_map=device,
24
+ return_full_text=False,
25
+ token=hf_token,
26
+ trust_remote_code=True,
27
+ **kwargs,
28
+ )
29
+
30
+ def generate(self, prompt: str, system_prompt: Optional[str] = None, max_new_tokens=50, **kwargs) -> str:
31
+ if not system_prompt:
32
+ system_prompt = ""
33
+ formatted_prompt = f"{system_prompt}\n\n现在,有人对你说:{prompt}\n\n你这样回答:"
34
+ outputs = self.pipe(formatted_prompt, max_new_tokens=max_new_tokens, **kwargs)
35
+ return outputs[0]["generated_text"] if outputs else ""
modules/llm/llama.py ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from typing import Optional
3
+
4
+ from transformers import pipeline
5
+
6
+ from .base import AbstractLLMModel
7
+ from .registry import register_llm_model
8
+
9
+ hf_token = os.getenv("HF_TOKEN")
10
+
11
+
12
+ @register_llm_model("meta-llama/Llama-")
13
+ class LlamaLLM(AbstractLLMModel):
14
+ def __init__(
15
+ self, model_id: str, device: str = "auto", cache_dir: str = "cache", **kwargs
16
+ ):
17
+ super().__init__(model_id, device, cache_dir, **kwargs)
18
+ model_kwargs = kwargs.setdefault("model_kwargs", {})
19
+ model_kwargs["cache_dir"] = cache_dir
20
+ self.pipe = pipeline(
21
+ "text-generation",
22
+ model=model_id,
23
+ device_map=device,
24
+ return_full_text=False,
25
+ token=hf_token,
26
+ trust_remote_code=True,
27
+ **kwargs,
28
+ )
29
+
30
+ def generate(
31
+ self,
32
+ prompt: str,
33
+ system_prompt: Optional[
34
+ str
35
+ ] = "You are a pirate chatbot who always responds in pirate speak!",
36
+ max_new_tokens: int = 256,
37
+ **kwargs
38
+ ) -> str:
39
+ messages = []
40
+ if system_prompt:
41
+ messages.append({"role": "system", "content": system_prompt})
42
+ messages.append({"role": "user", "content": prompt})
43
+ outputs = self.pipe(messages, max_new_tokens=max_new_tokens, **kwargs)
44
+ return outputs[0]["generated_text"]
modules/llm/minimax.py ADDED
@@ -0,0 +1,122 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Ref: https://github.com/MiniMax-AI/MiniMax-01
2
+
3
+ from typing import Optional
4
+
5
+ import torch
6
+ from transformers import (
7
+ AutoConfig,
8
+ AutoModelForCausalLM,
9
+ AutoTokenizer,
10
+ GenerationConfig,
11
+ QuantoConfig,
12
+ )
13
+
14
+ from .base import AbstractLLMModel
15
+ from .registry import register_llm_model
16
+
17
+
18
+ @register_llm_model("MiniMaxAI/MiniMax-Text-01")
19
+ class MiniMaxLLM(AbstractLLMModel):
20
+ def __init__(
21
+ self, model_id: str, device: str = "cuda", cache_dir: str = "cache", **kwargs
22
+ ):
23
+ try:
24
+ if not torch.cuda.is_available():
25
+ raise RuntimeError("MiniMax model only supports CUDA device")
26
+ super().__init__(model_id, device, cache_dir, **kwargs)
27
+
28
+ # load hf config
29
+ hf_config = AutoConfig.from_pretrained(
30
+ "MiniMaxAI/MiniMax-Text-01", trust_remote_code=True, cache_dir=cache_dir,
31
+ )
32
+
33
+ # quantization config, int8 is recommended
34
+ quantization_config = QuantoConfig(
35
+ weights="int8",
36
+ modules_to_not_convert=[
37
+ "lm_head",
38
+ "embed_tokens",
39
+ ]
40
+ + [
41
+ f"model.layers.{i}.coefficient"
42
+ for i in range(hf_config.num_hidden_layers)
43
+ ]
44
+ + [
45
+ f"model.layers.{i}.block_sparse_moe.gate"
46
+ for i in range(hf_config.num_hidden_layers)
47
+ ],
48
+ )
49
+
50
+ # assume 8 GPUs
51
+ world_size = torch.cuda.device_count()
52
+ layers_per_device = hf_config.num_hidden_layers // world_size
53
+ # set device map
54
+ device_map = {
55
+ "model.embed_tokens": "cuda:0",
56
+ "model.norm": f"cuda:{world_size - 1}",
57
+ "lm_head": f"cuda:{world_size - 1}",
58
+ }
59
+ for i in range(world_size):
60
+ for j in range(layers_per_device):
61
+ device_map[f"model.layers.{i * layers_per_device + j}"] = f"cuda:{i}"
62
+
63
+ # load tokenizer
64
+ self.tokenizer = AutoTokenizer.from_pretrained(
65
+ "MiniMaxAI/MiniMax-Text-01", cache_dir=cache_dir
66
+ )
67
+
68
+ # load bfloat16 model, move to device, and apply quantization
69
+ self.quantized_model = AutoModelForCausalLM.from_pretrained(
70
+ "MiniMaxAI/MiniMax-Text-01",
71
+ torch_dtype="bfloat16",
72
+ device_map=device_map,
73
+ quantization_config=quantization_config,
74
+ trust_remote_code=True,
75
+ offload_buffers=True,
76
+ cache_dir=cache_dir,
77
+ )
78
+ except Exception as e:
79
+ print(f"Failed to load MiniMax model: {e}")
80
+ breakpoint()
81
+ raise e
82
+
83
+ def generate(
84
+ self,
85
+ prompt: str,
86
+ system_prompt: Optional[
87
+ str
88
+ ] = "You are a helpful assistant created by MiniMax based on MiniMax-Text-01 model.",
89
+ max_new_tokens: int = 20,
90
+ **kwargs,
91
+ ) -> str:
92
+ messages = []
93
+ if system_prompt:
94
+ messages.append(
95
+ {
96
+ "role": "system",
97
+ "content": [{"type": "text", "text": system_prompt}],
98
+ }
99
+ )
100
+
101
+ messages.append({"role": "user", "content": [
102
+ {"type": "text", "text": prompt}]})
103
+ text = self.tokenizer.apply_chat_template(
104
+ messages, tokenize=False, add_generation_prompt=True
105
+ )
106
+ # tokenize and move to device
107
+ model_inputs = self.tokenizer(text, return_tensors="pt").to("cuda")
108
+ generation_config = GenerationConfig(
109
+ max_new_tokens=max_new_tokens,
110
+ eos_token_id=200020,
111
+ use_cache=True,
112
+ )
113
+ generated_ids = self.quantized_model.generate(
114
+ **model_inputs, generation_config=generation_config
115
+ )
116
+ generated_ids = [
117
+ output_ids[len(input_ids):]
118
+ for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids)
119
+ ]
120
+ response = self.tokenizer.batch_decode(
121
+ generated_ids, skip_special_tokens=True)[0]
122
+ return response
modules/llm/qwen3.py ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Ref: https://qwenlm.github.io/blog/qwen3/
2
+
3
+ from typing import Optional
4
+
5
+ from .base import AbstractLLMModel
6
+ from .registry import register_llm_model
7
+ from transformers import AutoModelForCausalLM, AutoTokenizer
8
+
9
+
10
+ @register_llm_model("Qwen/Qwen3-")
11
+ class Qwen3LLM(AbstractLLMModel):
12
+ def __init__(
13
+ self, model_id: str, device: str = "auto", cache_dir: str = "cache", **kwargs
14
+ ):
15
+ super().__init__(model_id, device, cache_dir, **kwargs)
16
+ self.model = AutoModelForCausalLM.from_pretrained(
17
+ model_id, device_map=device, torch_dtype="auto", cache_dir=cache_dir
18
+ ).eval()
19
+ self.tokenizer = AutoTokenizer.from_pretrained(model_id, cache_dir=cache_dir)
20
+
21
+ def generate(
22
+ self,
23
+ prompt: str,
24
+ system_prompt: Optional[str] = None,
25
+ max_new_tokens: int = 256,
26
+ enable_thinking: bool = False,
27
+ **kwargs
28
+ ) -> str:
29
+ messages = []
30
+ if system_prompt:
31
+ messages.append({"role": "system", "content": system_prompt})
32
+ messages.append({"role": "user", "content": prompt})
33
+ text = self.tokenizer.apply_chat_template(
34
+ messages,
35
+ tokenize=False,
36
+ add_generation_prompt=True,
37
+ enable_thinking=enable_thinking,
38
+ )
39
+ model_inputs = self.tokenizer([text], return_tensors="pt").to(self.model.device)
40
+ generated_ids = self.model.generate(
41
+ **model_inputs, max_new_tokens=max_new_tokens
42
+ )
43
+ output_ids = generated_ids[0][len(model_inputs.input_ids[0]) :].tolist()
44
+ # parse thinking content
45
+ if enable_thinking:
46
+ try:
47
+ # rindex finding 151668 (</think>)
48
+ index = len(output_ids) - output_ids[::-1].index(151668)
49
+ except ValueError:
50
+ index = 0
51
+ output_ids = output_ids[index:]
52
+
53
+ return self.tokenizer.decode(output_ids, skip_special_tokens=True).strip("\n")
modules/llm/registry.py ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .base import AbstractLLMModel
2
+
3
+ LLM_MODEL_REGISTRY = {}
4
+
5
+
6
+ def register_llm_model(prefix: str):
7
+ def wrapper(cls):
8
+ assert issubclass(cls, AbstractLLMModel), f"{cls} must inherit AbstractLLMModel"
9
+ LLM_MODEL_REGISTRY[prefix] = cls
10
+ return cls
11
+
12
+ return wrapper
13
+
14
+
15
+ def get_llm_model(model_id: str, device="auto", **kwargs) -> AbstractLLMModel:
16
+ for prefix, cls in LLM_MODEL_REGISTRY.items():
17
+ if model_id.startswith(prefix):
18
+ return cls(model_id, device=device, **kwargs)
19
+ raise ValueError(f"No LLM wrapper found for model: {model_id}")
modules/melody.py CHANGED
@@ -37,7 +37,7 @@ class MelodyController:
37
  return ""
38
 
39
  prompt = (
40
- "\n请按照歌词格式回答我的问题,每句需遵循以下字数规则:"
41
  + "".join(
42
  [
43
  f"\n第{i}句:{c}个字"
@@ -109,9 +109,10 @@ class MelodyController:
109
  if pitch == 0:
110
  score.append((st, ed, ref_lyric, pitch))
111
  elif ref_lyric in ["-", "——"] and align_type == "lyric":
112
- score.append((st, ed, ref_lyric, pitch))
113
- text_idx += 1
114
  else:
115
  score.append((st, ed, text_list[text_idx], pitch))
116
  text_idx += 1
 
 
117
  return score
 
37
  return ""
38
 
39
  prompt = (
40
+ "\n请按照歌词格式回复,每句需遵循以下字数规则:"
41
  + "".join(
42
  [
43
  f"\n第{i}句:{c}个字"
 
109
  if pitch == 0:
110
  score.append((st, ed, ref_lyric, pitch))
111
  elif ref_lyric in ["-", "——"] and align_type == "lyric":
112
+ score.append((st, ed, "-", pitch))
 
113
  else:
114
  score.append((st, ed, text_list[text_idx], pitch))
115
  text_idx += 1
116
+ if text_idx >= len(text_list):
117
+ break
118
  return score
modules/svs/base.py CHANGED
@@ -6,7 +6,7 @@ import numpy as np
6
  class AbstractSVSModel(ABC):
7
  @abstractmethod
8
  def __init__(
9
- self, model_id: str, device: str = "cpu", cache_dir: str = "cache", **kwargs
10
  ): ...
11
 
12
  @abstractmethod
 
6
  class AbstractSVSModel(ABC):
7
  @abstractmethod
8
  def __init__(
9
+ self, model_id: str, device: str = "auto", cache_dir: str = "cache", **kwargs
10
  ): ...
11
 
12
  @abstractmethod
modules/svs/espnet.py CHANGED
@@ -14,17 +14,17 @@ from .registry import register_svs_model
14
 
15
  @register_svs_model("espnet/")
16
  class ESPNetSVS(AbstractSVSModel):
17
- def __init__(self, model_id: str, device="cpu", cache_dir="cache", **kwargs):
18
  from espnet2.bin.svs_inference import SingingGenerate
19
  from espnet_model_zoo.downloader import ModelDownloader
20
-
21
- print(f"Downloading {model_id} to {cache_dir}") # TODO: should improve log code
 
22
  downloaded = ModelDownloader(cache_dir).download_and_unpack(model_id)
23
- print(f"Downloaded {model_id} to {cache_dir}") # TODO: should improve log code
24
  self.model = SingingGenerate(
25
  train_config=downloaded["train_config"],
26
  model_file=downloaded["model_file"],
27
- device=device,
28
  )
29
  self.model_id = model_id
30
  self.output_sample_rate = self.model.fs
@@ -53,7 +53,7 @@ class ESPNetSVS(AbstractSVSModel):
53
  phoneme_mappers = {}
54
  return phoneme_mappers
55
 
56
- def _preprocess(self, score: list[tuple[float, float, str, int]], language: str):
57
  if language not in self.phoneme_mappers:
58
  raise ValueError(f"Unsupported language: {language} for {self.model_id}")
59
  phoneme_mapper = self.phoneme_mappers[language]
@@ -90,16 +90,16 @@ class ESPNetSVS(AbstractSVSModel):
90
  pre_phn = phn_units[-1]
91
 
92
  batch = {
93
- "score": {
94
- "tempo": 120, # does not affect svs result, as note durations are in time unit
95
- "notes": notes,
96
- },
97
  "text": " ".join(phns),
98
  }
99
  return batch
100
 
101
  def synthesize(
102
- self, score: list[tuple[float, float, str, int]], language: str, speaker: str, **kwargs
103
  ):
104
  batch = self._preprocess(score, language)
105
  if self.model_id == "espnet/aceopencpop_svs_visinger2_40singer_pretrain":
@@ -107,8 +107,8 @@ class ESPNetSVS(AbstractSVSModel):
107
  output_dict = self.model(batch, sids=sid)
108
  elif self.model_id == "espnet/mixdata_svs_visinger2_spkemb_lang_pretrained":
109
  langs = {
110
- "zh": 2,
111
- "jp": 1,
112
  }
113
  if language not in langs:
114
  raise ValueError(
 
14
 
15
  @register_svs_model("espnet/")
16
  class ESPNetSVS(AbstractSVSModel):
17
+ def __init__(self, model_id: str, device="auto", cache_dir="cache", **kwargs):
18
  from espnet2.bin.svs_inference import SingingGenerate
19
  from espnet_model_zoo.downloader import ModelDownloader
20
+ if device == "auto":
21
+ device = "cuda" if torch.cuda.is_available() else "cpu"
22
+ self.device = device
23
  downloaded = ModelDownloader(cache_dir).download_and_unpack(model_id)
 
24
  self.model = SingingGenerate(
25
  train_config=downloaded["train_config"],
26
  model_file=downloaded["model_file"],
27
+ device=self.device,
28
  )
29
  self.model_id = model_id
30
  self.output_sample_rate = self.model.fs
 
53
  phoneme_mappers = {}
54
  return phoneme_mappers
55
 
56
+ def _preprocess(self, score: list[tuple[float, float, str, int] | tuple[float, float, str, float]], language: str):
57
  if language not in self.phoneme_mappers:
58
  raise ValueError(f"Unsupported language: {language} for {self.model_id}")
59
  phoneme_mapper = self.phoneme_mappers[language]
 
90
  pre_phn = phn_units[-1]
91
 
92
  batch = {
93
+ "score": (
94
+ 120, # does not affect svs result, as note durations are in time unit
95
+ notes,
96
+ ),
97
  "text": " ".join(phns),
98
  }
99
  return batch
100
 
101
  def synthesize(
102
+ self, score: list[tuple[float, float, str, float] | tuple[float, float, str, int]], language: str, speaker: str, **kwargs
103
  ):
104
  batch = self._preprocess(score, language)
105
  if self.model_id == "espnet/aceopencpop_svs_visinger2_40singer_pretrain":
 
107
  output_dict = self.model(batch, sids=sid)
108
  elif self.model_id == "espnet/mixdata_svs_visinger2_spkemb_lang_pretrained":
109
  langs = {
110
+ "mandarin": 2,
111
+ "japanese": 1,
112
  }
113
  if language not in langs:
114
  raise ValueError(
modules/svs/registry.py CHANGED
@@ -12,7 +12,7 @@ def register_svs_model(prefix: str):
12
  return wrapper
13
 
14
 
15
- def get_svs_model(model_id: str, device="cpu", **kwargs) -> AbstractSVSModel:
16
  for prefix, cls in SVS_MODEL_REGISTRY.items():
17
  if model_id.startswith(prefix):
18
  return cls(model_id, device=device, **kwargs)
 
12
  return wrapper
13
 
14
 
15
+ def get_svs_model(model_id: str, device="auto", **kwargs) -> AbstractSVSModel:
16
  for prefix, cls in SVS_MODEL_REGISTRY.items():
17
  if model_id.startswith(prefix):
18
  return cls(model_id, device=device, **kwargs)
modules/utils/g2p.py CHANGED
@@ -32,6 +32,7 @@ for plan in ace_phonemes_all_plans["plans"]:
32
 
33
 
34
  def preprocess_text(text: str, language: str) -> list[str]:
 
35
  if language == "mandarin":
36
  text_list = to_pinyin(text)
37
  elif language == "japanese":
 
32
 
33
 
34
  def preprocess_text(text: str, language: str) -> list[str]:
35
+ text = text.replace(" ", "")
36
  if language == "mandarin":
37
  text_list = to_pinyin(text)
38
  elif language == "japanese":
modules/utils/text_normalize.py CHANGED
@@ -3,12 +3,13 @@ from typing import Optional
3
 
4
 
5
  def remove_non_zh_jp(text: str) -> str:
6
- pattern = r"[^\u4e00-\u9fff\u3040-\u309f\u30a0-\u30ff\u3000-\u303f\uff01-\uffef]"
7
  return re.sub(pattern, "", text)
8
 
9
 
10
  def truncate_sentences(text: str, max_sentences: int) -> str:
11
- sentences = re.split(r"(?<=[。!?])", text)
 
12
  return "".join(sentences[:max_sentences]).strip()
13
 
14
 
 
3
 
4
 
5
  def remove_non_zh_jp(text: str) -> str:
6
+ pattern = r"[^\u4e00-\u9fff\u3040-\u309f\u30a0-\u30ff\u3000-\u303f\uff01-\uffef\s]"
7
  return re.sub(pattern, "", text)
8
 
9
 
10
  def truncate_sentences(text: str, max_sentences: int) -> str:
11
+ sentences = re.split(r"(?<=[。!?!?~])|(?:\n+)|(?: {2,})", text)
12
+ sentences = [s.strip() for s in sentences if s.strip()]
13
  return "".join(sentences[:max_sentences]).strip()
14
 
15
 
pipeline.py CHANGED
@@ -1,6 +1,11 @@
1
- import torch
 
2
  import time
 
 
3
  import librosa
 
 
4
 
5
  from modules.asr import get_asr_model
6
  from modules.llm import get_llm_model
@@ -29,20 +34,36 @@ class SingingDialoguePipeline:
29
  self.melody_controller = MelodyController(
30
  config["melody_source"], self.cache_dir
31
  )
 
32
  self.track_latency = config.get("track_latency", False)
33
  self.evaluators = load_evaluators(config.get("evaluators", {}).get("svs", []))
34
 
35
  def set_asr_model(self, asr_model: str):
 
 
 
 
 
36
  self.asr = get_asr_model(
37
  asr_model, device=self.device, cache_dir=self.cache_dir
38
  )
39
 
40
  def set_llm_model(self, llm_model: str):
 
 
 
 
 
41
  self.llm = get_llm_model(
42
  llm_model, device=self.device, cache_dir=self.cache_dir
43
  )
44
 
45
  def set_svs_model(self, svs_model: str):
 
 
 
 
 
46
  self.svs = get_svs_model(
47
  svs_model, device=self.device, cache_dir=self.cache_dir
48
  )
@@ -54,9 +75,9 @@ class SingingDialoguePipeline:
54
  self,
55
  audio_path,
56
  language,
57
- prompt_template,
58
  speaker,
59
- max_new_tokens=100,
60
  ):
61
  if self.track_latency:
62
  asr_start_time = time.time()
@@ -67,16 +88,16 @@ class SingingDialoguePipeline:
67
  if self.track_latency:
68
  asr_end_time = time.time()
69
  asr_latency = asr_end_time - asr_start_time
70
- melody_prompt = self.melody_controller.get_melody_constraints()
71
- prompt = prompt_template.format(melody_prompt, asr_result)
72
  if self.track_latency:
73
  llm_start_time = time.time()
74
- output = self.llm.generate(prompt, max_new_tokens=max_new_tokens)
75
  if self.track_latency:
76
  llm_end_time = time.time()
77
  llm_latency = llm_end_time - llm_start_time
78
- print(f"llm output: {output}确认一下是不是不含prompt的")
79
- llm_response = clean_llm_output(output, language=language)
 
80
  score = self.melody_controller.generate_score(llm_response, language)
81
  if self.track_latency:
82
  svs_start_time = time.time()
@@ -89,14 +110,18 @@ class SingingDialoguePipeline:
89
  results = {
90
  "asr_text": asr_result,
91
  "llm_text": llm_response,
92
- "svs_audio": (singing_audio, sample_rate),
93
  }
 
 
 
 
94
  if self.track_latency:
95
- results["metrics"].update({
96
  "asr_latency": asr_latency,
97
  "llm_latency": llm_latency,
98
  "svs_latency": svs_latency,
99
- })
100
  return results
101
 
102
  def evaluate(self, audio_path):
 
1
+ from __future__ import annotations
2
+
3
  import time
4
+ from pathlib import Path
5
+
6
  import librosa
7
+ import soundfile as sf
8
+ import torch
9
 
10
  from modules.asr import get_asr_model
11
  from modules.llm import get_llm_model
 
34
  self.melody_controller = MelodyController(
35
  config["melody_source"], self.cache_dir
36
  )
37
+ self.max_sentences = config.get("max_sentences", 2)
38
  self.track_latency = config.get("track_latency", False)
39
  self.evaluators = load_evaluators(config.get("evaluators", {}).get("svs", []))
40
 
41
  def set_asr_model(self, asr_model: str):
42
+ if self.asr is not None:
43
+ del self.asr
44
+ import gc
45
+ gc.collect()
46
+ torch.cuda.empty_cache()
47
  self.asr = get_asr_model(
48
  asr_model, device=self.device, cache_dir=self.cache_dir
49
  )
50
 
51
  def set_llm_model(self, llm_model: str):
52
+ if self.llm is not None:
53
+ del self.llm
54
+ import gc
55
+ gc.collect()
56
+ torch.cuda.empty_cache()
57
  self.llm = get_llm_model(
58
  llm_model, device=self.device, cache_dir=self.cache_dir
59
  )
60
 
61
  def set_svs_model(self, svs_model: str):
62
+ if self.svs is not None:
63
+ del self.svs
64
+ import gc
65
+ gc.collect()
66
+ torch.cuda.empty_cache()
67
  self.svs = get_svs_model(
68
  svs_model, device=self.device, cache_dir=self.cache_dir
69
  )
 
75
  self,
76
  audio_path,
77
  language,
78
+ system_prompt,
79
  speaker,
80
+ output_audio_path: Path | str = None,
81
  ):
82
  if self.track_latency:
83
  asr_start_time = time.time()
 
88
  if self.track_latency:
89
  asr_end_time = time.time()
90
  asr_latency = asr_end_time - asr_start_time
91
+ melody_prompt = self.melody_controller.get_melody_constraints(max_num_phrases=self.max_sentences)
 
92
  if self.track_latency:
93
  llm_start_time = time.time()
94
+ output = self.llm.generate(asr_result, system_prompt + melody_prompt)
95
  if self.track_latency:
96
  llm_end_time = time.time()
97
  llm_latency = llm_end_time - llm_start_time
98
+ llm_response = clean_llm_output(
99
+ output, language=language, max_sentences=self.max_sentences
100
+ )
101
  score = self.melody_controller.generate_score(llm_response, language)
102
  if self.track_latency:
103
  svs_start_time = time.time()
 
110
  results = {
111
  "asr_text": asr_result,
112
  "llm_text": llm_response,
113
+ "svs_audio": (sample_rate, singing_audio),
114
  }
115
+ if output_audio_path:
116
+ Path(output_audio_path).parent.mkdir(parents=True, exist_ok=True)
117
+ sf.write(output_audio_path, singing_audio, sample_rate)
118
+ results["output_audio_path"] = output_audio_path
119
  if self.track_latency:
120
+ results["metrics"] = {
121
  "asr_latency": asr_latency,
122
  "llm_latency": llm_latency,
123
  "svs_latency": svs_latency,
124
+ }
125
  return results
126
 
127
  def evaluate(self, audio_path):
requirements.txt CHANGED
@@ -12,9 +12,9 @@ pykakasi
12
  basic-pitch[onnx]
13
  audiobox_aesthetics
14
  transformers
15
- s3prl
16
  zhconv
17
  git+https://github.com/sea-turt1e/kanjiconv
18
  soundfile
19
  PyYAML
20
  gradio
 
 
12
  basic-pitch[onnx]
13
  audiobox_aesthetics
14
  transformers
 
15
  zhconv
16
  git+https://github.com/sea-turt1e/kanjiconv
17
  soundfile
18
  PyYAML
19
  gradio
20
+ google-generativeai
tests/__init__.py ADDED
File without changes
tests/test_asr_infer.py ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from modules.asr import get_asr_model
2
+ import librosa
3
+
4
+ if __name__ == "__main__":
5
+ supported_asrs = [
6
+ "funasr/paraformer-zh",
7
+ "openai/whisper-large-v3-turbo",
8
+ ]
9
+ for model_id in supported_asrs:
10
+ try:
11
+ print(f"Loading model: {model_id}")
12
+ asr_model = get_asr_model(model_id, device="auto", cache_dir=".cache")
13
+ audio, sample_rate = librosa.load("tests/audio/hello.wav", sr=None)
14
+ result = asr_model.transcribe(audio, sample_rate, language="mandarin")
15
+ print(result)
16
+ except Exception as e:
17
+ print(f"Failed to load model {model_id}: {e}")
18
+ breakpoint()
19
+ continue
tests/test_llm_infer.py ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from characters import get_character
2
+ from modules.llm import get_llm_model
3
+ from time import time
4
+
5
+ if __name__ == "__main__":
6
+ supported_llms = [
7
+ # "MiniMaxAI/MiniMax-Text-01",
8
+ # "Qwen/Qwen3-8B",
9
+ # "Qwen/Qwen3-30B-A3B",
10
+ # "meta-llama/Llama-3.1-8B-Instruct",
11
+ # "tiiuae/Falcon-H1-1B-Base",
12
+ # "tiiuae/Falcon-H1-3B-Instruct",
13
+ # "google/gemma-2-2b",
14
+ # "gemini-2.5-flash",
15
+ ]
16
+ character_prompt = get_character("Yaoyin").prompt
17
+ for model_id in supported_llms:
18
+ try:
19
+ print(f"Loading model: {model_id}")
20
+ llm = get_llm_model(model_id, cache_dir="./.cache")
21
+ prompt = "你好,今天你心情怎么样?"
22
+ start_time = time()
23
+ result = llm.generate(prompt, system_prompt=character_prompt)
24
+ end_time = time()
25
+ print(f"[{model_id}] LLM inference time: {end_time - start_time:.2f} seconds")
26
+ print(f"[{model_id}] LLM inference result:", result)
27
+ except Exception as e:
28
+ print(f"Failed to load model {model_id}: {e}")
29
+ breakpoint()
30
+ continue