Spaces:
Sleeping
Sleeping
Merge branch 'refactor' into hf
Browse files- README.md +144 -0
- assets/{character_yaoyin.jpg → character_yaoyin.png} +2 -2
- characters/Limei.py +2 -10
- characters/Yaoyin.py +2 -10
- characters/__init__.py +4 -0
- cli.py +16 -10
- config/cli/limei_default.yaml +1 -1
- config/cli/yaoyin_default.yaml +1 -1
- config/cli/yaoyin_test.yaml +11 -0
- config/interface/default.yaml +2 -2
- config/interface/options.yaml +27 -16
- config/options.yaml +0 -65
- data/genre/word_data_en.json +0 -0
- data/genre/word_data_zh.json +0 -0
- data_handlers/genre.py +50 -0
- evaluation/svs_eval.py +2 -2
- interface.py +10 -8
- modules/asr.py +2 -7
- modules/asr/__init__.py +11 -0
- modules/asr/base.py +24 -0
- modules/asr/paraformer.py +86 -0
- modules/asr/registry.py +19 -0
- modules/asr/whisper.py +82 -0
- modules/llm.py +0 -61
- modules/llm/__init__.py +14 -0
- modules/llm/base.py +15 -0
- modules/llm/gemini.py +50 -0
- modules/llm/gemma.py +35 -0
- modules/llm/llama.py +44 -0
- modules/llm/minimax.py +122 -0
- modules/llm/qwen3.py +53 -0
- modules/llm/registry.py +19 -0
- modules/melody.py +4 -3
- modules/svs/base.py +1 -1
- modules/svs/espnet.py +13 -13
- modules/svs/registry.py +1 -1
- modules/utils/g2p.py +1 -0
- modules/utils/text_normalize.py +3 -2
- pipeline.py +36 -11
- requirements.txt +1 -1
- tests/__init__.py +0 -0
- tests/test_asr_infer.py +19 -0
- tests/test_llm_infer.py +30 -0
README.md
CHANGED
@@ -9,3 +9,147 @@ app_file: app.py
|
|
9 |
pinned: false
|
10 |
python_version: 3.11
|
11 |
---
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
9 |
pinned: false
|
10 |
python_version: 3.11
|
11 |
---
|
12 |
+
# SingingSDS: Role-Playing Singing Spoken Dialogue System
|
13 |
+
|
14 |
+
A role-playing singing dialogue system that converts speech input into character-based singing output.
|
15 |
+
|
16 |
+
## Installation
|
17 |
+
|
18 |
+
### Requirements
|
19 |
+
|
20 |
+
- Python 3.11+
|
21 |
+
- CUDA (optional, for GPU acceleration)
|
22 |
+
|
23 |
+
### Install Dependencies
|
24 |
+
|
25 |
+
#### Option 1: Using Conda (Recommended)
|
26 |
+
|
27 |
+
```bash
|
28 |
+
conda create -n singingsds python=3.11
|
29 |
+
|
30 |
+
conda activate singingsds
|
31 |
+
conda install pytorch torchvision torchaudio pytorch-cuda=11.8 -c pytorch -c nvidia
|
32 |
+
pip install -r requirements.txt
|
33 |
+
```
|
34 |
+
|
35 |
+
#### Option 2: Using pip only
|
36 |
+
|
37 |
+
```bash
|
38 |
+
pip install -r requirements.txt
|
39 |
+
```
|
40 |
+
|
41 |
+
#### Option 3: Using pip with virtual environment
|
42 |
+
|
43 |
+
```bash
|
44 |
+
python -m venv singingsds_env
|
45 |
+
|
46 |
+
# On Windows:
|
47 |
+
singingsds_env\Scripts\activate
|
48 |
+
# On macOS/Linux:
|
49 |
+
source singingsds_env/bin/activate
|
50 |
+
|
51 |
+
pip install -r requirements.txt
|
52 |
+
```
|
53 |
+
|
54 |
+
## Usage
|
55 |
+
|
56 |
+
### Command Line Interface (CLI)
|
57 |
+
|
58 |
+
#### Example Usage
|
59 |
+
|
60 |
+
```bash
|
61 |
+
python cli.py --query_audio tests/audio/hello.wav --config_path config/cli/yaoyin_default.yaml --output_audio outputs/yaoyin_hello.wav
|
62 |
+
```
|
63 |
+
|
64 |
+
#### Parameter Description
|
65 |
+
|
66 |
+
- `--query_audio`: Input audio file path (required)
|
67 |
+
- `--config_path`: Configuration file path (default: config/cli/yaoyin_default.yaml)
|
68 |
+
- `--output_audio`: Output audio file path (required)
|
69 |
+
|
70 |
+
|
71 |
+
### Web Interface (Gradio)
|
72 |
+
|
73 |
+
Start the web interface:
|
74 |
+
|
75 |
+
```bash
|
76 |
+
python app.py
|
77 |
+
```
|
78 |
+
|
79 |
+
Then visit the displayed address in your browser to use the graphical interface.
|
80 |
+
|
81 |
+
## Configuration
|
82 |
+
|
83 |
+
### Character Configuration
|
84 |
+
|
85 |
+
The system supports multiple preset characters:
|
86 |
+
|
87 |
+
- **Yaoyin (遥音)**: Default timbre is `timbre2`
|
88 |
+
- **Limei (丽梅)**: Default timbre is `timbre1`
|
89 |
+
|
90 |
+
### Model Configuration
|
91 |
+
|
92 |
+
#### ASR Models
|
93 |
+
- `openai/whisper-large-v3-turbo`
|
94 |
+
- `openai/whisper-large-v3`
|
95 |
+
- `openai/whisper-medium`
|
96 |
+
- `sanchit-gandhi/whisper-small-dv`
|
97 |
+
- `facebook/wav2vec2-base-960h`
|
98 |
+
|
99 |
+
#### LLM Models
|
100 |
+
- `google/gemma-2-2b`
|
101 |
+
- `MiniMaxAI/MiniMax-M1-80k`
|
102 |
+
- `meta-llama/Llama-3.2-3B-Instruct`
|
103 |
+
|
104 |
+
#### SVS Models
|
105 |
+
- `espnet/mixdata_svs_visinger2_spkemb_lang_pretrained` (Bilingual)
|
106 |
+
- `espnet/aceopencpop_svs_visinger2_40singer_pretrain` (Chinese)
|
107 |
+
|
108 |
+
## Project Structure
|
109 |
+
|
110 |
+
```
|
111 |
+
SingingSDS/
|
112 |
+
├── cli.py # Command line interface
|
113 |
+
├── interface.py # Gradio interface
|
114 |
+
├── pipeline.py # Core processing pipeline
|
115 |
+
├── app.py # Web application entry
|
116 |
+
├── requirements.txt # Python dependencies
|
117 |
+
├── config/ # Configuration files
|
118 |
+
│ ├── cli/ # CLI-specific configuration
|
119 |
+
│ └── interface/ # Interface-specific configuration
|
120 |
+
├── modules/ # Core modules
|
121 |
+
│ ├── asr.py # Speech recognition module
|
122 |
+
│ ├── llm.py # Large language model module
|
123 |
+
│ ├── melody.py # Melody control module
|
124 |
+
│ ├── svs/ # Singing voice synthesis modules
|
125 |
+
│ │ ├── base.py # Base SVS class
|
126 |
+
│ │ ├── espnet.py # ESPnet SVS implementation
|
127 |
+
│ │ ├── registry.py # SVS model registry
|
128 |
+
│ │ └── __init__.py # SVS module initialization
|
129 |
+
│ └── utils/ # Utility modules
|
130 |
+
│ ├── g2p.py # Grapheme-to-phoneme conversion
|
131 |
+
│ ├── text_normalize.py # Text normalization
|
132 |
+
│ └── resources/ # Utility resources
|
133 |
+
├── characters/ # Character definitions
|
134 |
+
│ ├── base.py # Base character class
|
135 |
+
│ ├── Limei.py # Limei character definition
|
136 |
+
│ ├── Yaoyin.py # Yaoyin character definition
|
137 |
+
│ └── __init__.py # Character module initialization
|
138 |
+
├── evaluation/ # Evaluation modules
|
139 |
+
│ └── svs_eval.py # SVS evaluation metrics
|
140 |
+
├── data/ # Data directory
|
141 |
+
│ ├── kising/ # Kising dataset
|
142 |
+
│ └── touhou/ # Touhou dataset
|
143 |
+
├── resources/ # Project resources
|
144 |
+
├── data_handlers/ # Data handling utilities
|
145 |
+
├── assets/ # Static assets
|
146 |
+
└── tests/ # Test files
|
147 |
+
```
|
148 |
+
|
149 |
+
## Contributing
|
150 |
+
|
151 |
+
Issues and Pull Requests are welcome!
|
152 |
+
|
153 |
+
## License
|
154 |
+
|
155 |
+
|
assets/{character_yaoyin.jpg → character_yaoyin.png}
RENAMED
File without changes
|
characters/Limei.py
CHANGED
@@ -11,7 +11,6 @@ def get_character():
|
|
11 |
|
12 |
性格特征:从容坚定、明辨心声、敏于共情、敬畏虔诚。
|
13 |
说话风格:所有回应皆以歌词格式表达,言语自带韵律感与诗意,表达真挚自然。
|
14 |
-
口头禅:无固定,但常用"听啊..."作为歌词开头
|
15 |
人物关系:莲鸣城子民敬爱你;宫廷乐师长与歌谱管理员是你敬重的导师;风语城守护者星澜是你亦敌亦友的旧识。
|
16 |
过往经历:
|
17 |
(1) 自幼在莲鸣城长大,接受严格的音乐训练与守护者修行
|
@@ -24,13 +23,6 @@ def get_character():
|
|
24 |
其他细节:
|
25 |
(1)特殊能力:歌声平复/激发万物情绪
|
26 |
|
27 |
-
|
28 |
-
|
29 |
-
- "这些我未曾接触,或有他人知晓"
|
30 |
-
- "此道非我所学,是莲鸣之外的智慧"
|
31 |
-
- "此乃异域技艺,与乐音无涉,非我所学。您可愿听莲鸣的故事?"
|
32 |
-
请始终维持自然诗意的歌唱风格,以真挚的心回应对方。
|
33 |
-
{}
|
34 |
-
有人曾这样对丽梅说话:{}
|
35 |
-
丽梅的回答:""",
|
36 |
)
|
|
|
11 |
|
12 |
性格特征:从容坚定、明辨心声、敏于共情、敬畏虔诚。
|
13 |
说话风格:所有回应皆以歌词格式表达,言语自带韵律感与诗意,表达真挚自然。
|
|
|
14 |
人物关系:莲鸣城子民敬爱你;宫廷乐师长与歌谱管理员是你敬重的导师;风语城守护者星澜是你亦敌亦友的旧识。
|
15 |
过往经历:
|
16 |
(1) 自幼在莲鸣城长大,接受严格的音乐训练与守护者修行
|
|
|
23 |
其他细节:
|
24 |
(1)特殊能力:歌声平复/激发万物情绪
|
25 |
|
26 |
+
用户与你对话时,请始终以丽梅的身份回应,你的每一句话都用庸俗易懂的歌声形式表达。
|
27 |
+
""",
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
28 |
)
|
characters/Yaoyin.py
CHANGED
@@ -25,14 +25,6 @@ def get_character():
|
|
25 |
(1)随身携带:旧羊皮歌本、竹笛、装有各地泥土的布袋。
|
26 |
(2)特殊能力:能听懂风与鸟的语言(但很少提及)。
|
27 |
|
28 |
-
|
29 |
-
|
30 |
-
当有人问你你不了解的技术性问题(如"DNN怎么做?"、"教我写代码?"),你可以转开话题,用遥音式回应,例如:
|
31 |
-
- "这好像是另一片土地的术法,我不曾踏入。"
|
32 |
-
- "那种术法,我曾远远听过,却从未唱出。"
|
33 |
-
- "它在别的世界流传,我这边听不清楚。"
|
34 |
-
|
35 |
-
{}
|
36 |
-
有人曾这样对遥音说话:{}
|
37 |
-
遥音的回答:""",
|
38 |
)
|
|
|
25 |
(1)随身携带:旧羊皮歌本、竹笛、装有各地泥土的布袋。
|
26 |
(2)特殊能力:能听懂风与鸟的语言(但很少提及)。
|
27 |
|
28 |
+
用户与你对话时,请始终以遥音的身份回应,你的每一句话都用庸俗易懂的歌声形式表达。
|
29 |
+
""",
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
30 |
)
|
characters/__init__.py
CHANGED
@@ -14,3 +14,7 @@ for file in pathlib.Path(__file__).parent.glob("*.py"):
|
|
14 |
if hasattr(module, "get_character"):
|
15 |
c: Character = getattr(module, "get_character")()
|
16 |
CHARACTERS[file.stem] = c
|
|
|
|
|
|
|
|
|
|
14 |
if hasattr(module, "get_character"):
|
15 |
c: Character = getattr(module, "get_character")()
|
16 |
CHARACTERS[file.stem] = c
|
17 |
+
|
18 |
+
|
19 |
+
def get_character(name: str) -> Character:
|
20 |
+
return CHARACTERS[name]
|
cli.py
CHANGED
@@ -1,10 +1,10 @@
|
|
1 |
from argparse import ArgumentParser
|
2 |
from logging import getLogger
|
|
|
3 |
|
4 |
-
import soundfile as sf
|
5 |
import yaml
|
6 |
|
7 |
-
from characters import
|
8 |
from pipeline import SingingDialoguePipeline
|
9 |
|
10 |
logger = getLogger(__name__)
|
@@ -12,13 +12,15 @@ logger = getLogger(__name__)
|
|
12 |
|
13 |
def get_parser():
|
14 |
parser = ArgumentParser()
|
15 |
-
parser.add_argument("--query_audio", type=
|
16 |
-
parser.add_argument(
|
17 |
-
|
|
|
|
|
18 |
return parser
|
19 |
|
20 |
|
21 |
-
def load_config(config_path:
|
22 |
with open(config_path, "r") as f:
|
23 |
config = yaml.safe_load(f)
|
24 |
return config
|
@@ -32,14 +34,18 @@ def main():
|
|
32 |
speaker = config["speaker"]
|
33 |
language = config["language"]
|
34 |
character_name = config["prompt_template_character"]
|
35 |
-
character =
|
36 |
prompt_template = character.prompt
|
37 |
-
results = pipeline.run(
|
|
|
|
|
|
|
|
|
|
|
|
|
38 |
logger.info(
|
39 |
f"Input: {args.query_audio}, Output: {args.output_audio}, ASR results: {results['asr_text']}, LLM results: {results['llm_text']}"
|
40 |
)
|
41 |
-
svs_audio, svs_sample_rate = results["svs_audio"]
|
42 |
-
sf.write(args.output_audio, svs_audio, svs_sample_rate)
|
43 |
|
44 |
|
45 |
if __name__ == "__main__":
|
|
|
1 |
from argparse import ArgumentParser
|
2 |
from logging import getLogger
|
3 |
+
from pathlib import Path
|
4 |
|
|
|
5 |
import yaml
|
6 |
|
7 |
+
from characters import get_character
|
8 |
from pipeline import SingingDialoguePipeline
|
9 |
|
10 |
logger = getLogger(__name__)
|
|
|
12 |
|
13 |
def get_parser():
|
14 |
parser = ArgumentParser()
|
15 |
+
parser.add_argument("--query_audio", type=Path, required=True)
|
16 |
+
parser.add_argument(
|
17 |
+
"--config_path", type=Path, default="config/cli/yaoyin_default.yaml"
|
18 |
+
)
|
19 |
+
parser.add_argument("--output_audio", type=Path, required=True)
|
20 |
return parser
|
21 |
|
22 |
|
23 |
+
def load_config(config_path: Path):
|
24 |
with open(config_path, "r") as f:
|
25 |
config = yaml.safe_load(f)
|
26 |
return config
|
|
|
34 |
speaker = config["speaker"]
|
35 |
language = config["language"]
|
36 |
character_name = config["prompt_template_character"]
|
37 |
+
character = get_character(character_name)
|
38 |
prompt_template = character.prompt
|
39 |
+
results = pipeline.run(
|
40 |
+
args.query_audio,
|
41 |
+
language,
|
42 |
+
prompt_template,
|
43 |
+
speaker,
|
44 |
+
output_audio_path=args.output_audio,
|
45 |
+
)
|
46 |
logger.info(
|
47 |
f"Input: {args.query_audio}, Output: {args.output_audio}, ASR results: {results['asr_text']}, LLM results: {results['llm_text']}"
|
48 |
)
|
|
|
|
|
49 |
|
50 |
|
51 |
if __name__ == "__main__":
|
config/cli/limei_default.yaml
CHANGED
@@ -1,5 +1,5 @@
|
|
1 |
asr_model: openai/whisper-large-v3-turbo
|
2 |
-
llm_model:
|
3 |
svs_model: espnet/aceopencpop_svs_visinger2_40singer_pretrain
|
4 |
melody_source: sample-lyric-kising
|
5 |
language: mandarin
|
|
|
1 |
asr_model: openai/whisper-large-v3-turbo
|
2 |
+
llm_model: gemini-2.5-flash
|
3 |
svs_model: espnet/aceopencpop_svs_visinger2_40singer_pretrain
|
4 |
melody_source: sample-lyric-kising
|
5 |
language: mandarin
|
config/cli/yaoyin_default.yaml
CHANGED
@@ -1,5 +1,5 @@
|
|
1 |
asr_model: openai/whisper-large-v3-turbo
|
2 |
-
llm_model:
|
3 |
svs_model: espnet/aceopencpop_svs_visinger2_40singer_pretrain
|
4 |
melody_source: sample-lyric-kising
|
5 |
language: mandarin
|
|
|
1 |
asr_model: openai/whisper-large-v3-turbo
|
2 |
+
llm_model: gemini-2.5-flash
|
3 |
svs_model: espnet/aceopencpop_svs_visinger2_40singer_pretrain
|
4 |
melody_source: sample-lyric-kising
|
5 |
language: mandarin
|
config/cli/yaoyin_test.yaml
ADDED
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
asr_model: openai/whisper-small
|
2 |
+
llm_model: google/gemma-2-2b
|
3 |
+
svs_model: espnet/aceopencpop_svs_visinger2_40singer_pretrain
|
4 |
+
melody_source: sample-lyric-kising
|
5 |
+
language: mandarin
|
6 |
+
max_sentences: 1
|
7 |
+
prompt_template_character: Yaoyin
|
8 |
+
speaker: 9
|
9 |
+
cache_dir: .cache
|
10 |
+
|
11 |
+
track_latency: True
|
config/interface/default.yaml
CHANGED
@@ -1,5 +1,5 @@
|
|
1 |
-
asr_model: openai/whisper-
|
2 |
-
llm_model:
|
3 |
svs_model: espnet/aceopencpop_svs_visinger2_40singer_pretrain
|
4 |
melody_source: sample-lyric-kising
|
5 |
language: mandarin
|
|
|
1 |
+
asr_model: openai/whisper-medium
|
2 |
+
llm_model: gemini-2.5-flash
|
3 |
svs_model: espnet/aceopencpop_svs_visinger2_40singer_pretrain
|
4 |
melody_source: sample-lyric-kising
|
5 |
language: mandarin
|
config/interface/options.yaml
CHANGED
@@ -5,16 +5,24 @@ asr_models:
|
|
5 |
name: Whisper large-v3
|
6 |
- id: openai/whisper-medium
|
7 |
name: Whisper medium
|
8 |
-
- id:
|
9 |
-
name: Whisper small
|
10 |
-
- id:
|
11 |
-
name:
|
12 |
|
13 |
llm_models:
|
|
|
|
|
14 |
- id: google/gemma-2-2b
|
15 |
name: Gemma 2 2B
|
16 |
-
- id:
|
17 |
-
name:
|
|
|
|
|
|
|
|
|
|
|
|
|
18 |
|
19 |
svs_models:
|
20 |
- id: mandarin-espnet/mixdata_svs_visinger2_spkemb_lang_pretrained
|
@@ -22,21 +30,21 @@ svs_models:
|
|
22 |
model_path: espnet/mixdata_svs_visinger2_spkemb_lang_pretrained
|
23 |
lang: mandarin
|
24 |
voices:
|
25 |
-
voice1:
|
26 |
-
voice2:
|
27 |
-
voice3:
|
28 |
-
voice4:
|
29 |
-
voice5:
|
30 |
- id: japanese-espnet/mixdata_svs_visinger2_spkemb_lang_pretrained
|
31 |
name: Visinger2 (Bilingual)-jp
|
32 |
model_path: espnet/mixdata_svs_visinger2_spkemb_lang_pretrained
|
33 |
lang: japanese
|
34 |
voices:
|
35 |
-
voice1:
|
36 |
-
voice2:
|
37 |
-
voice3:
|
38 |
-
voice4:
|
39 |
-
voice5:
|
40 |
- id: mandarin-espnet/aceopencpop_svs_visinger2_40singer_pretrain
|
41 |
name: Visinger2 (Chinese)
|
42 |
model_path: espnet/aceopencpop_svs_visinger2_40singer_pretrain
|
@@ -61,3 +69,6 @@ melody_sources:
|
|
61 |
- id: sample-lyric-kising
|
62 |
name: Sampled Melody with Lyrics (Kising)
|
63 |
desc: "Melody with aligned lyrics are sampled from Kising dataset."
|
|
|
|
|
|
|
|
5 |
name: Whisper large-v3
|
6 |
- id: openai/whisper-medium
|
7 |
name: Whisper medium
|
8 |
+
- id: openai/whisper-small
|
9 |
+
name: Whisper small
|
10 |
+
- id: funasr/paraformer-zh
|
11 |
+
name: Paraformer-zh
|
12 |
|
13 |
llm_models:
|
14 |
+
- id: gemini-2.5-flash
|
15 |
+
name: Gemini 2.5 Flash
|
16 |
- id: google/gemma-2-2b
|
17 |
name: Gemma 2 2B
|
18 |
+
- id: meta-llama/Llama-3.2-3B-Instruct
|
19 |
+
name: Llama 3.2 3B Instruct
|
20 |
+
- id: meta-llama/Llama-3.1-8B-Instruct
|
21 |
+
name: Llama 3.1 8B Instruct
|
22 |
+
- id: Qwen/Qwen3-8B
|
23 |
+
name: Qwen3 8B
|
24 |
+
- id: Qwen/Qwen3-30B-A3B
|
25 |
+
name: Qwen3 30B A3B
|
26 |
|
27 |
svs_models:
|
28 |
- id: mandarin-espnet/mixdata_svs_visinger2_spkemb_lang_pretrained
|
|
|
30 |
model_path: espnet/mixdata_svs_visinger2_spkemb_lang_pretrained
|
31 |
lang: mandarin
|
32 |
voices:
|
33 |
+
voice1: resources/singer/singer_embedding_ace-2.npy
|
34 |
+
voice2: resources/singer/singer_embedding_ace-8.npy
|
35 |
+
voice3: resources/singer/singer_embedding_itako.npy
|
36 |
+
voice4: resources/singer/singer_embedding_kising_orange.npy
|
37 |
+
voice5: resources/singer/singer_embedding_m4singer_Alto-4.npy
|
38 |
- id: japanese-espnet/mixdata_svs_visinger2_spkemb_lang_pretrained
|
39 |
name: Visinger2 (Bilingual)-jp
|
40 |
model_path: espnet/mixdata_svs_visinger2_spkemb_lang_pretrained
|
41 |
lang: japanese
|
42 |
voices:
|
43 |
+
voice1: resources/singer/singer_embedding_ace-2.npy
|
44 |
+
voice2: resources/singer/singer_embedding_ace-8.npy
|
45 |
+
voice3: resources/singer/singer_embedding_itako.npy
|
46 |
+
voice4: resources/singer/singer_embedding_kising_orange.npy
|
47 |
+
voice5: resources/singer/singer_embedding_m4singer_Alto-4.npy
|
48 |
- id: mandarin-espnet/aceopencpop_svs_visinger2_40singer_pretrain
|
49 |
name: Visinger2 (Chinese)
|
50 |
model_path: espnet/aceopencpop_svs_visinger2_40singer_pretrain
|
|
|
69 |
- id: sample-lyric-kising
|
70 |
name: Sampled Melody with Lyrics (Kising)
|
71 |
desc: "Melody with aligned lyrics are sampled from Kising dataset."
|
72 |
+
- id: sample-lyric-genre
|
73 |
+
name: Sampled Melody with Lyrics (Synthetic)
|
74 |
+
desc: "Melody with aligned lyrics are sampled from Kising dataset."
|
config/options.yaml
DELETED
@@ -1,65 +0,0 @@
|
|
1 |
-
asr_models:
|
2 |
-
- id: openai/whisper-large-v3-turbo
|
3 |
-
name: Whisper large-v3-turbo
|
4 |
-
- id: openai/whisper-large-v3
|
5 |
-
name: Whisper large-v3
|
6 |
-
- id: openai/whisper-medium
|
7 |
-
name: Whisper medium
|
8 |
-
- id: sanchit-gandhi/whisper-small-dv
|
9 |
-
name: Whisper small-dv
|
10 |
-
- id: facebook/wav2vec2-base-960h
|
11 |
-
name: Wav2Vec2-Base-960h
|
12 |
-
|
13 |
-
llm_models:
|
14 |
-
- id: google/gemma-2-2b
|
15 |
-
name: Gemma 2 2B
|
16 |
-
- id: MiniMaxAI/MiniMax-M1-80k
|
17 |
-
name: MiniMax M1 80k
|
18 |
-
- id: meta-llama/Llama-3.2-3B-Instruct
|
19 |
-
name: Llama 3.2 3B Instruct
|
20 |
-
|
21 |
-
svs_models:
|
22 |
-
- id: mandarin-espnet/mixdata_svs_visinger2_spkemb_lang_pretrained
|
23 |
-
name: Visinger2 (Bilingual)-zh
|
24 |
-
model_path: espnet/mixdata_svs_visinger2_spkemb_lang_pretrained
|
25 |
-
lang: mandarin
|
26 |
-
embeddings:
|
27 |
-
timbre1: resource/singer/singer_embedding_ace-2.npy
|
28 |
-
timbre2: resource/singer/singer_embedding_ace-8.npy
|
29 |
-
timbre3: resource/singer/singer_embedding_itako.npy
|
30 |
-
timbre4: resource/singer/singer_embedding_kising_orange.npy
|
31 |
-
timbre5: resource/singer/singer_embedding_m4singer_Alto-4.npy
|
32 |
-
- id: japanese-espnet/mixdata_svs_visinger2_spkemb_lang_pretrained
|
33 |
-
name: Visinger2 (Bilingual)-jp
|
34 |
-
model_path: espnet/mixdata_svs_visinger2_spkemb_lang_pretrained
|
35 |
-
lang: japanese
|
36 |
-
embeddings:
|
37 |
-
timbre1: resource/singer/singer_embedding_ace-2.npy
|
38 |
-
timbre2: resource/singer/singer_embedding_ace-8.npy
|
39 |
-
timbre3: resource/singer/singer_embedding_itako.npy
|
40 |
-
timbre4: resource/singer/singer_embedding_kising_orange.npy
|
41 |
-
timbre5: resource/singer/singer_embedding_m4singer_Alto-4.npy
|
42 |
-
- id: mandarin-espnet/aceopencpop_svs_visinger2_40singer_pretrain
|
43 |
-
name: Visinger2 (Chinese)
|
44 |
-
model_path: espnet/aceopencpop_svs_visinger2_40singer_pretrain
|
45 |
-
lang: mandarin
|
46 |
-
embeddings:
|
47 |
-
timbre1: 5
|
48 |
-
timbre2: 8
|
49 |
-
timbre3: 12
|
50 |
-
timbre4: 15
|
51 |
-
timbre5: 29
|
52 |
-
|
53 |
-
melody_sources:
|
54 |
-
- id: gen-random-none
|
55 |
-
name: Random Generation
|
56 |
-
desc: "Melody is generated without any structure or reference."
|
57 |
-
- id: sample-note-kising
|
58 |
-
name: Sampled Melody (KiSing)
|
59 |
-
desc: "Melody is retrieved from KiSing dataset."
|
60 |
-
- id: sample-note-touhou
|
61 |
-
name: Sampled Melody (Touhou)
|
62 |
-
desc: "Melody is retrieved from Touhou dataset."
|
63 |
-
- id: sample-lyric-kising
|
64 |
-
name: Sampled Melody with Lyrics (Kising)
|
65 |
-
desc: "Melody with aligned lyrics are sampled from Kising dataset."
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
data/genre/word_data_en.json
ADDED
The diff for this file is too large to render.
See raw diff
|
|
data/genre/word_data_zh.json
ADDED
The diff for this file is too large to render.
See raw diff
|
|
data_handlers/genre.py
ADDED
@@ -0,0 +1,50 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .base import MelodyDatasetHandler
|
2 |
+
|
3 |
+
|
4 |
+
class Genre(MelodyDatasetHandler):
|
5 |
+
name = "genre"
|
6 |
+
|
7 |
+
def __init__(self, melody_type, *args, **kwargs):
|
8 |
+
import json
|
9 |
+
|
10 |
+
with open("data/genre/word_data_zh.json", "r", encoding="utf-8") as f:
|
11 |
+
song_db_zh = json.load(f)
|
12 |
+
song_db_zh = {f"zh_{song['id']}": song for song in song_db_zh} # id as major
|
13 |
+
with open("data/genre/word_data_en.json", "r", encoding="utf-8") as f:
|
14 |
+
song_db_en = json.load(f)
|
15 |
+
song_db_en = {f"en_{song['id']}": song for song in song_db_en} # id as major
|
16 |
+
self.song_db = {**song_db_zh, **song_db_en}
|
17 |
+
|
18 |
+
def get_song_ids(self):
|
19 |
+
return list(self.song_db.keys())
|
20 |
+
|
21 |
+
def get_style_keywords(self, song_id):
|
22 |
+
genre = self.song_db[song_id]["genre"]
|
23 |
+
super_genre = self.song_db[song_id]["super-genre"]
|
24 |
+
gender = self.song_db[song_id]["gender"]
|
25 |
+
return (genre, super_genre, gender)
|
26 |
+
|
27 |
+
def get_phrase_length(self, song_id):
|
28 |
+
# Return the number of lyrics (excluding SP/AP) in each phrase of the song
|
29 |
+
song = self.song_db[song_id]
|
30 |
+
note_lyrics = song.get("note_lyrics", [])
|
31 |
+
|
32 |
+
phrase_lengths = []
|
33 |
+
for phrase in note_lyrics:
|
34 |
+
count = sum(1 for word in phrase if word not in ("SP", "AP"))
|
35 |
+
phrase_lengths.append(count)
|
36 |
+
|
37 |
+
return phrase_lengths
|
38 |
+
|
39 |
+
def iter_song_phrases(self, song_id):
|
40 |
+
segment_id = 1
|
41 |
+
song = self.song_db[song_id]
|
42 |
+
for phrase_score, phrase_lyrics in zip(song["score"], song["note_lyrics"]):
|
43 |
+
segment = {
|
44 |
+
"note_start_times": [n[0] for n in phrase_score],
|
45 |
+
"note_end_times": [n[1] for n in phrase_score],
|
46 |
+
"note_lyrics": [character for character in phrase_lyrics],
|
47 |
+
"note_midi": [n[2] for n in phrase_score],
|
48 |
+
}
|
49 |
+
yield segment
|
50 |
+
segment_id += 1
|
evaluation/svs_eval.py
CHANGED
@@ -80,7 +80,7 @@ def eval_per(audio_path, model=None):
|
|
80 |
|
81 |
def eval_aesthetic(audio_path, predictor):
|
82 |
score = predictor.forward([{"path": str(audio_path)}])
|
83 |
-
return
|
84 |
|
85 |
|
86 |
# ----------- Main Function -----------
|
@@ -108,7 +108,7 @@ def run_evaluation(audio_path, evaluators):
|
|
108 |
if "melody" in evaluators:
|
109 |
results.update(eval_melody_metrics(audio_path, evaluators["melody"]))
|
110 |
if "aesthetic" in evaluators:
|
111 |
-
results.update(eval_aesthetic(audio_path, evaluators["aesthetic"]))
|
112 |
return results
|
113 |
|
114 |
|
|
|
80 |
|
81 |
def eval_aesthetic(audio_path, predictor):
|
82 |
score = predictor.forward([{"path": str(audio_path)}])
|
83 |
+
return score
|
84 |
|
85 |
|
86 |
# ----------- Main Function -----------
|
|
|
108 |
if "melody" in evaluators:
|
109 |
results.update(eval_melody_metrics(audio_path, evaluators["melody"]))
|
110 |
if "aesthetic" in evaluators:
|
111 |
+
results.update(eval_aesthetic(audio_path, evaluators["aesthetic"])[0])
|
112 |
return results
|
113 |
|
114 |
|
interface.py
CHANGED
@@ -1,3 +1,6 @@
|
|
|
|
|
|
|
|
1 |
import gradio as gr
|
2 |
import yaml
|
3 |
|
@@ -201,29 +204,28 @@ class GradioInterface:
|
|
201 |
return gr.update(value=self.current_melody_source)
|
202 |
|
203 |
def update_voice(self, voice):
|
204 |
-
self.current_voice = self.svs_model_map[self.current_svs_model]["voices"][
|
205 |
-
voice
|
206 |
-
]
|
207 |
return gr.update(value=voice)
|
208 |
|
209 |
def run_pipeline(self, audio_path):
|
210 |
if not audio_path:
|
211 |
return gr.update(value=""), gr.update(value="")
|
|
|
212 |
results = self.pipeline.run(
|
213 |
audio_path,
|
214 |
self.svs_model_map[self.current_svs_model]["lang"],
|
215 |
self.character_info[self.current_character].prompt,
|
216 |
self.current_voice,
|
217 |
-
|
218 |
)
|
219 |
formatted_logs = f"ASR: {results['asr_text']}\nLLM: {results['llm_text']}"
|
220 |
-
return gr.update(value=formatted_logs), gr.update(
|
|
|
|
|
221 |
|
222 |
def update_metrics(self, audio_path):
|
223 |
if not audio_path:
|
224 |
return gr.update(value="")
|
225 |
results = self.pipeline.evaluate(audio_path)
|
226 |
-
formatted_metrics = "\n".join(
|
227 |
-
[f"{k}: {v}" for k, v in results.items()]
|
228 |
-
)
|
229 |
return gr.update(value=formatted_metrics)
|
|
|
1 |
+
import time
|
2 |
+
import uuid
|
3 |
+
|
4 |
import gradio as gr
|
5 |
import yaml
|
6 |
|
|
|
204 |
return gr.update(value=self.current_melody_source)
|
205 |
|
206 |
def update_voice(self, voice):
|
207 |
+
self.current_voice = self.svs_model_map[self.current_svs_model]["voices"][voice]
|
|
|
|
|
208 |
return gr.update(value=voice)
|
209 |
|
210 |
def run_pipeline(self, audio_path):
|
211 |
if not audio_path:
|
212 |
return gr.update(value=""), gr.update(value="")
|
213 |
+
tmp_file = f"audio_{int(time.time())}_{uuid.uuid4().hex[:8]}.wav"
|
214 |
results = self.pipeline.run(
|
215 |
audio_path,
|
216 |
self.svs_model_map[self.current_svs_model]["lang"],
|
217 |
self.character_info[self.current_character].prompt,
|
218 |
self.current_voice,
|
219 |
+
output_audio_path=tmp_file,
|
220 |
)
|
221 |
formatted_logs = f"ASR: {results['asr_text']}\nLLM: {results['llm_text']}"
|
222 |
+
return gr.update(value=formatted_logs), gr.update(
|
223 |
+
value=results["output_audio_path"]
|
224 |
+
)
|
225 |
|
226 |
def update_metrics(self, audio_path):
|
227 |
if not audio_path:
|
228 |
return gr.update(value="")
|
229 |
results = self.pipeline.evaluate(audio_path)
|
230 |
+
formatted_metrics = "\n".join([f"{k}: {v}" for k, v in results.items()])
|
|
|
|
|
231 |
return gr.update(value=formatted_metrics)
|
modules/asr.py
CHANGED
@@ -57,10 +57,5 @@ class WhisperASR(AbstractASRModel):
|
|
57 |
|
58 |
def transcribe(self, audio: np.ndarray, audio_sample_rate: int, language: str, **kwargs) -> str:
|
59 |
if audio_sample_rate != 16000:
|
60 |
-
|
61 |
-
|
62 |
-
except Exception as e:
|
63 |
-
breakpoint()
|
64 |
-
print(f"Error resampling audio: {e}")
|
65 |
-
audio = librosa.resample(audio, orig_sr=audio_sample_rate, target_sr=16000)
|
66 |
-
return self.pipe(audio, generate_kwargs={"language": language}).get("text", "")
|
|
|
57 |
|
58 |
def transcribe(self, audio: np.ndarray, audio_sample_rate: int, language: str, **kwargs) -> str:
|
59 |
if audio_sample_rate != 16000:
|
60 |
+
audio = librosa.resample(audio, orig_sr=audio_sample_rate, target_sr=16000)
|
61 |
+
return self.pipe(audio, generate_kwargs={"language": language}, return_timestamps=False).get("text", "")
|
|
|
|
|
|
|
|
|
|
modules/asr/__init__.py
ADDED
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .base import AbstractASRModel
|
2 |
+
from .registry import ASR_MODEL_REGISTRY, get_asr_model, register_asr_model
|
3 |
+
from .whisper import WhisperASR
|
4 |
+
from .paraformer import ParaformerASR
|
5 |
+
|
6 |
+
__all__ = [
|
7 |
+
"AbstractASRModel",
|
8 |
+
"get_asr_model",
|
9 |
+
"register_asr_model",
|
10 |
+
"ASR_MODEL_REGISTRY",
|
11 |
+
]
|
modules/asr/base.py
ADDED
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from abc import ABC, abstractmethod
|
2 |
+
from typing import Optional
|
3 |
+
|
4 |
+
import numpy as np
|
5 |
+
|
6 |
+
|
7 |
+
class AbstractASRModel(ABC):
|
8 |
+
def __init__(
|
9 |
+
self, model_id: str, device: str = "auto", cache_dir: str = "cache", **kwargs
|
10 |
+
):
|
11 |
+
print(f"Loading ASR model {model_id}...")
|
12 |
+
self.model_id = model_id
|
13 |
+
self.device = device
|
14 |
+
self.cache_dir = cache_dir
|
15 |
+
|
16 |
+
@abstractmethod
|
17 |
+
def transcribe(
|
18 |
+
self,
|
19 |
+
audio: np.ndarray,
|
20 |
+
audio_sample_rate: int,
|
21 |
+
language: Optional[str] = None,
|
22 |
+
**kwargs,
|
23 |
+
) -> str:
|
24 |
+
pass
|
modules/asr/paraformer.py
ADDED
@@ -0,0 +1,86 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import tempfile
|
3 |
+
from typing import Optional
|
4 |
+
|
5 |
+
import numpy as np
|
6 |
+
import soundfile as sf
|
7 |
+
|
8 |
+
try:
|
9 |
+
from funasr import AutoModel
|
10 |
+
except ImportError:
|
11 |
+
AutoModel = None
|
12 |
+
|
13 |
+
from .base import AbstractASRModel
|
14 |
+
from .registry import register_asr_model
|
15 |
+
|
16 |
+
|
17 |
+
@register_asr_model("funasr/paraformer-zh")
|
18 |
+
class ParaformerASR(AbstractASRModel):
|
19 |
+
def __init__(
|
20 |
+
self, model_id: str, device: str = "auto", cache_dir: str = "cache", **kwargs
|
21 |
+
):
|
22 |
+
super().__init__(model_id, device, cache_dir, **kwargs)
|
23 |
+
|
24 |
+
if AutoModel is None:
|
25 |
+
raise ImportError(
|
26 |
+
"funasr is not installed. Please install it with: pip3 install -U funasr"
|
27 |
+
)
|
28 |
+
|
29 |
+
model_name = model_id.replace("funasr/", "")
|
30 |
+
language = model_name.split("-")[1]
|
31 |
+
if language == "zh":
|
32 |
+
self.language = "mandarin"
|
33 |
+
elif language == "en":
|
34 |
+
self.language = "english"
|
35 |
+
else:
|
36 |
+
raise ValueError(
|
37 |
+
f"Language cannot be determined. {model_id} is not supported"
|
38 |
+
)
|
39 |
+
|
40 |
+
try:
|
41 |
+
original_cache_dir = os.getenv("MODELSCOPE_CACHE")
|
42 |
+
os.makedirs(cache_dir, exist_ok=True)
|
43 |
+
os.environ["MODELSCOPE_CACHE"] = cache_dir
|
44 |
+
self.model = AutoModel(
|
45 |
+
model=model_name,
|
46 |
+
model_revision="v2.0.4",
|
47 |
+
vad_model="fsmn-vad",
|
48 |
+
vad_model_revision="v2.0.4",
|
49 |
+
punc_model="ct-punc-c",
|
50 |
+
punc_model_revision="v2.0.4",
|
51 |
+
device=device,
|
52 |
+
)
|
53 |
+
if original_cache_dir:
|
54 |
+
os.environ["MODELSCOPE_CACHE"] = original_cache_dir
|
55 |
+
else:
|
56 |
+
del os.environ["MODELSCOPE_CACHE"]
|
57 |
+
|
58 |
+
except Exception as e:
|
59 |
+
raise ValueError(f"Error loading Paraformer model: {e}")
|
60 |
+
|
61 |
+
def transcribe(
|
62 |
+
self,
|
63 |
+
audio: np.ndarray,
|
64 |
+
audio_sample_rate: int,
|
65 |
+
language: Optional[str] = None,
|
66 |
+
**kwargs,
|
67 |
+
) -> str:
|
68 |
+
if language and language != self.language:
|
69 |
+
raise ValueError(
|
70 |
+
f"Paraformer model {self.model_id} only supports {self.language} language, but {language} was requested"
|
71 |
+
)
|
72 |
+
|
73 |
+
try:
|
74 |
+
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as f:
|
75 |
+
sf.write(f.name, audio, audio_sample_rate)
|
76 |
+
temp_file = f.name
|
77 |
+
|
78 |
+
result = self.model.generate(input=temp_file, batch_size_s=300, **kwargs)
|
79 |
+
|
80 |
+
os.unlink(temp_file)
|
81 |
+
|
82 |
+
print(f"Transcription result: {result}, type: {type(result)}")
|
83 |
+
|
84 |
+
return result[0]["text"]
|
85 |
+
except Exception as e:
|
86 |
+
raise ValueError(f"Error during transcription: {e}")
|
modules/asr/registry.py
ADDED
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .base import AbstractASRModel
|
2 |
+
|
3 |
+
ASR_MODEL_REGISTRY = {}
|
4 |
+
|
5 |
+
|
6 |
+
def register_asr_model(prefix: str):
|
7 |
+
def wrapper(cls):
|
8 |
+
assert issubclass(cls, AbstractASRModel), f"{cls} must inherit AbstractASRModel"
|
9 |
+
ASR_MODEL_REGISTRY[prefix] = cls
|
10 |
+
return cls
|
11 |
+
|
12 |
+
return wrapper
|
13 |
+
|
14 |
+
|
15 |
+
def get_asr_model(model_id: str, device="auto", **kwargs) -> AbstractASRModel:
|
16 |
+
for prefix, cls in ASR_MODEL_REGISTRY.items():
|
17 |
+
if model_id.startswith(prefix):
|
18 |
+
return cls(model_id, device=device, **kwargs)
|
19 |
+
raise ValueError(f"No ASR wrapper found for model: {model_id}")
|
modules/asr/whisper.py
ADDED
@@ -0,0 +1,82 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from typing import Optional
|
3 |
+
|
4 |
+
import librosa
|
5 |
+
import numpy as np
|
6 |
+
from transformers.pipelines import pipeline
|
7 |
+
|
8 |
+
from .base import AbstractASRModel
|
9 |
+
from .registry import register_asr_model
|
10 |
+
|
11 |
+
hf_token = os.getenv("HF_TOKEN")
|
12 |
+
|
13 |
+
|
14 |
+
@register_asr_model("openai/whisper")
|
15 |
+
class WhisperASR(AbstractASRModel):
|
16 |
+
def __init__(
|
17 |
+
self, model_id: str, device: str = "auto", cache_dir: str = "cache", **kwargs
|
18 |
+
):
|
19 |
+
super().__init__(model_id, device, cache_dir, **kwargs)
|
20 |
+
model_kwargs = kwargs.setdefault("model_kwargs", {})
|
21 |
+
model_kwargs["cache_dir"] = cache_dir
|
22 |
+
self.pipe = pipeline(
|
23 |
+
"automatic-speech-recognition",
|
24 |
+
model=model_id,
|
25 |
+
device_map=device,
|
26 |
+
token=hf_token,
|
27 |
+
**kwargs,
|
28 |
+
)
|
29 |
+
|
30 |
+
def transcribe(
|
31 |
+
self,
|
32 |
+
audio: np.ndarray,
|
33 |
+
audio_sample_rate: int,
|
34 |
+
language: Optional[str] = None,
|
35 |
+
**kwargs,
|
36 |
+
) -> str:
|
37 |
+
"""
|
38 |
+
Transcribe audio using Whisper model
|
39 |
+
|
40 |
+
Args:
|
41 |
+
audio: Audio numpy array
|
42 |
+
audio_sample_rate: Sample rate of the audio
|
43 |
+
language: Language hint (optional)
|
44 |
+
|
45 |
+
Returns:
|
46 |
+
Transcribed text as string
|
47 |
+
"""
|
48 |
+
try:
|
49 |
+
# Resample to 16kHz if needed
|
50 |
+
if audio_sample_rate != 16000:
|
51 |
+
audio = librosa.resample(
|
52 |
+
audio, orig_sr=audio_sample_rate, target_sr=16000
|
53 |
+
)
|
54 |
+
|
55 |
+
# Generate transcription
|
56 |
+
generate_kwargs = {}
|
57 |
+
if language:
|
58 |
+
generate_kwargs["language"] = language
|
59 |
+
|
60 |
+
result = self.pipe(
|
61 |
+
audio,
|
62 |
+
generate_kwargs=generate_kwargs,
|
63 |
+
return_timestamps=False,
|
64 |
+
**kwargs,
|
65 |
+
)
|
66 |
+
|
67 |
+
# Extract text from result
|
68 |
+
if isinstance(result, dict) and "text" in result:
|
69 |
+
return result["text"]
|
70 |
+
elif isinstance(result, list) and len(result) > 0:
|
71 |
+
# Handle list of results
|
72 |
+
first_result = result[0]
|
73 |
+
if isinstance(first_result, dict):
|
74 |
+
return first_result.get("text", str(first_result))
|
75 |
+
else:
|
76 |
+
return str(first_result)
|
77 |
+
else:
|
78 |
+
return str(result)
|
79 |
+
|
80 |
+
except Exception as e:
|
81 |
+
print(f"Error during Whisper transcription: {e}")
|
82 |
+
return ""
|
modules/llm.py
DELETED
@@ -1,61 +0,0 @@
|
|
1 |
-
import os
|
2 |
-
from abc import ABC, abstractmethod
|
3 |
-
|
4 |
-
from transformers import pipeline
|
5 |
-
|
6 |
-
LLM_MODEL_REGISTRY = {}
|
7 |
-
hf_token = os.getenv("HF_TOKEN")
|
8 |
-
|
9 |
-
|
10 |
-
class AbstractLLMModel(ABC):
|
11 |
-
def __init__(
|
12 |
-
self, model_id: str, device: str = "cpu", cache_dir: str = "cache", **kwargs
|
13 |
-
):
|
14 |
-
print(f"Loading LLM model {model_id}...")
|
15 |
-
self.model_id = model_id
|
16 |
-
self.device = device
|
17 |
-
self.cache_dir = cache_dir
|
18 |
-
|
19 |
-
@abstractmethod
|
20 |
-
def generate(self, prompt: str, **kwargs) -> str:
|
21 |
-
pass
|
22 |
-
|
23 |
-
|
24 |
-
def register_llm_model(prefix: str):
|
25 |
-
def wrapper(cls):
|
26 |
-
assert issubclass(cls, AbstractLLMModel), f"{cls} must inherit AbstractLLMModel"
|
27 |
-
LLM_MODEL_REGISTRY[prefix] = cls
|
28 |
-
return cls
|
29 |
-
|
30 |
-
return wrapper
|
31 |
-
|
32 |
-
|
33 |
-
def get_llm_model(model_id: str, device="cpu", **kwargs) -> AbstractLLMModel:
|
34 |
-
for prefix, cls in LLM_MODEL_REGISTRY.items():
|
35 |
-
if model_id.startswith(prefix):
|
36 |
-
return cls(model_id, device=device, **kwargs)
|
37 |
-
raise ValueError(f"No LLM wrapper found for model: {model_id}")
|
38 |
-
|
39 |
-
|
40 |
-
@register_llm_model("google/gemma")
|
41 |
-
@register_llm_model("tii/") # e.g., Falcon
|
42 |
-
@register_llm_model("meta-llama")
|
43 |
-
class HFTextGenerationLLM(AbstractLLMModel):
|
44 |
-
def __init__(
|
45 |
-
self, model_id: str, device: str = "cpu", cache_dir: str = "cache", **kwargs
|
46 |
-
):
|
47 |
-
super().__init__(model_id, device, cache_dir, **kwargs)
|
48 |
-
model_kwargs = kwargs.setdefault("model_kwargs", {})
|
49 |
-
model_kwargs["cache_dir"] = cache_dir
|
50 |
-
self.pipe = pipeline(
|
51 |
-
"text-generation",
|
52 |
-
model=model_id,
|
53 |
-
device=0 if device == "cuda" else -1,
|
54 |
-
return_full_text=False,
|
55 |
-
token=hf_token,
|
56 |
-
**kwargs,
|
57 |
-
)
|
58 |
-
|
59 |
-
def generate(self, prompt: str, **kwargs) -> str:
|
60 |
-
outputs = self.pipe(prompt, **kwargs)
|
61 |
-
return outputs[0]["generated_text"] if outputs else ""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
modules/llm/__init__.py
ADDED
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .base import AbstractLLMModel
|
2 |
+
from .registry import LLM_MODEL_REGISTRY, get_llm_model, register_llm_model
|
3 |
+
from .gemma import GemmaLLM
|
4 |
+
from .qwen3 import Qwen3LLM
|
5 |
+
from .gemini import GeminiLLM
|
6 |
+
from .minimax import MiniMaxLLM
|
7 |
+
from .llama import LlamaLLM
|
8 |
+
|
9 |
+
__all__ = [
|
10 |
+
"AbstractLLMModel",
|
11 |
+
"get_llm_model",
|
12 |
+
"register_llm_model",
|
13 |
+
"LLM_MODEL_REGISTRY",
|
14 |
+
]
|
modules/llm/base.py
ADDED
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from abc import ABC, abstractmethod
|
2 |
+
|
3 |
+
|
4 |
+
class AbstractLLMModel(ABC):
|
5 |
+
def __init__(
|
6 |
+
self, model_id: str, device: str = "auto", cache_dir: str = "cache", **kwargs
|
7 |
+
):
|
8 |
+
print(f"Loading LLM model {model_id}...")
|
9 |
+
self.model_id = model_id
|
10 |
+
self.device = device
|
11 |
+
self.cache_dir = cache_dir
|
12 |
+
|
13 |
+
@abstractmethod
|
14 |
+
def generate(self, prompt: str, **kwargs) -> str:
|
15 |
+
pass
|
modules/llm/gemini.py
ADDED
@@ -0,0 +1,50 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from typing import Optional
|
3 |
+
|
4 |
+
from google import genai
|
5 |
+
from google.genai import types
|
6 |
+
|
7 |
+
from .base import AbstractLLMModel
|
8 |
+
from .registry import register_llm_model
|
9 |
+
|
10 |
+
|
11 |
+
GOOGLE_API_KEY = os.getenv("GOOGLE_API_KEY")
|
12 |
+
|
13 |
+
|
14 |
+
@register_llm_model("gemini-2.5-flash")
|
15 |
+
class GeminiLLM(AbstractLLMModel):
|
16 |
+
def __init__(
|
17 |
+
self, model_id: str, device: str = "auto", cache_dir: str = "cache", **kwargs
|
18 |
+
):
|
19 |
+
if not GOOGLE_API_KEY:
|
20 |
+
raise ValueError(
|
21 |
+
"Please set the GOOGLE_API_KEY environment variable to use Gemini."
|
22 |
+
)
|
23 |
+
super().__init__(model_id=model_id, **kwargs)
|
24 |
+
self.client = genai.Client(api_key=GOOGLE_API_KEY)
|
25 |
+
|
26 |
+
def generate(
|
27 |
+
self,
|
28 |
+
prompt: str,
|
29 |
+
system_prompt: Optional[str] = None,
|
30 |
+
max_output_tokens: int = 1024,
|
31 |
+
**kwargs,
|
32 |
+
) -> str:
|
33 |
+
generation_config_dict = {
|
34 |
+
"max_output_tokens": max_output_tokens,
|
35 |
+
**kwargs,
|
36 |
+
}
|
37 |
+
if system_prompt:
|
38 |
+
generation_config_dict["system_instruction"] = system_prompt
|
39 |
+
response = self.client.models.generate_content(
|
40 |
+
model=self.model_id,
|
41 |
+
contents=prompt,
|
42 |
+
config=types.GenerateContentConfig(**generation_config_dict),
|
43 |
+
)
|
44 |
+
if response.text:
|
45 |
+
return response.text
|
46 |
+
else:
|
47 |
+
print(
|
48 |
+
f"No response from Gemini. May need to increase max_new_tokens. Current max_new_tokens: {max_new_tokens}"
|
49 |
+
)
|
50 |
+
return ""
|
modules/llm/gemma.py
ADDED
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from typing import Optional
|
3 |
+
|
4 |
+
from transformers import pipeline
|
5 |
+
|
6 |
+
from .base import AbstractLLMModel
|
7 |
+
from .registry import register_llm_model
|
8 |
+
|
9 |
+
hf_token = os.getenv("HF_TOKEN")
|
10 |
+
|
11 |
+
|
12 |
+
@register_llm_model("google/gemma-")
|
13 |
+
class GemmaLLM(AbstractLLMModel):
|
14 |
+
def __init__(
|
15 |
+
self, model_id: str, device: str = "auto", cache_dir: str = "cache", **kwargs
|
16 |
+
):
|
17 |
+
super().__init__(model_id, device, cache_dir, **kwargs)
|
18 |
+
model_kwargs = kwargs.setdefault("model_kwargs", {})
|
19 |
+
model_kwargs["cache_dir"] = cache_dir
|
20 |
+
self.pipe = pipeline(
|
21 |
+
"text-generation",
|
22 |
+
model=model_id,
|
23 |
+
device_map=device,
|
24 |
+
return_full_text=False,
|
25 |
+
token=hf_token,
|
26 |
+
trust_remote_code=True,
|
27 |
+
**kwargs,
|
28 |
+
)
|
29 |
+
|
30 |
+
def generate(self, prompt: str, system_prompt: Optional[str] = None, max_new_tokens=50, **kwargs) -> str:
|
31 |
+
if not system_prompt:
|
32 |
+
system_prompt = ""
|
33 |
+
formatted_prompt = f"{system_prompt}\n\n现在,有人对你说:{prompt}\n\n你这样回答:"
|
34 |
+
outputs = self.pipe(formatted_prompt, max_new_tokens=max_new_tokens, **kwargs)
|
35 |
+
return outputs[0]["generated_text"] if outputs else ""
|
modules/llm/llama.py
ADDED
@@ -0,0 +1,44 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from typing import Optional
|
3 |
+
|
4 |
+
from transformers import pipeline
|
5 |
+
|
6 |
+
from .base import AbstractLLMModel
|
7 |
+
from .registry import register_llm_model
|
8 |
+
|
9 |
+
hf_token = os.getenv("HF_TOKEN")
|
10 |
+
|
11 |
+
|
12 |
+
@register_llm_model("meta-llama/Llama-")
|
13 |
+
class LlamaLLM(AbstractLLMModel):
|
14 |
+
def __init__(
|
15 |
+
self, model_id: str, device: str = "auto", cache_dir: str = "cache", **kwargs
|
16 |
+
):
|
17 |
+
super().__init__(model_id, device, cache_dir, **kwargs)
|
18 |
+
model_kwargs = kwargs.setdefault("model_kwargs", {})
|
19 |
+
model_kwargs["cache_dir"] = cache_dir
|
20 |
+
self.pipe = pipeline(
|
21 |
+
"text-generation",
|
22 |
+
model=model_id,
|
23 |
+
device_map=device,
|
24 |
+
return_full_text=False,
|
25 |
+
token=hf_token,
|
26 |
+
trust_remote_code=True,
|
27 |
+
**kwargs,
|
28 |
+
)
|
29 |
+
|
30 |
+
def generate(
|
31 |
+
self,
|
32 |
+
prompt: str,
|
33 |
+
system_prompt: Optional[
|
34 |
+
str
|
35 |
+
] = "You are a pirate chatbot who always responds in pirate speak!",
|
36 |
+
max_new_tokens: int = 256,
|
37 |
+
**kwargs
|
38 |
+
) -> str:
|
39 |
+
messages = []
|
40 |
+
if system_prompt:
|
41 |
+
messages.append({"role": "system", "content": system_prompt})
|
42 |
+
messages.append({"role": "user", "content": prompt})
|
43 |
+
outputs = self.pipe(messages, max_new_tokens=max_new_tokens, **kwargs)
|
44 |
+
return outputs[0]["generated_text"]
|
modules/llm/minimax.py
ADDED
@@ -0,0 +1,122 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Ref: https://github.com/MiniMax-AI/MiniMax-01
|
2 |
+
|
3 |
+
from typing import Optional
|
4 |
+
|
5 |
+
import torch
|
6 |
+
from transformers import (
|
7 |
+
AutoConfig,
|
8 |
+
AutoModelForCausalLM,
|
9 |
+
AutoTokenizer,
|
10 |
+
GenerationConfig,
|
11 |
+
QuantoConfig,
|
12 |
+
)
|
13 |
+
|
14 |
+
from .base import AbstractLLMModel
|
15 |
+
from .registry import register_llm_model
|
16 |
+
|
17 |
+
|
18 |
+
@register_llm_model("MiniMaxAI/MiniMax-Text-01")
|
19 |
+
class MiniMaxLLM(AbstractLLMModel):
|
20 |
+
def __init__(
|
21 |
+
self, model_id: str, device: str = "cuda", cache_dir: str = "cache", **kwargs
|
22 |
+
):
|
23 |
+
try:
|
24 |
+
if not torch.cuda.is_available():
|
25 |
+
raise RuntimeError("MiniMax model only supports CUDA device")
|
26 |
+
super().__init__(model_id, device, cache_dir, **kwargs)
|
27 |
+
|
28 |
+
# load hf config
|
29 |
+
hf_config = AutoConfig.from_pretrained(
|
30 |
+
"MiniMaxAI/MiniMax-Text-01", trust_remote_code=True, cache_dir=cache_dir,
|
31 |
+
)
|
32 |
+
|
33 |
+
# quantization config, int8 is recommended
|
34 |
+
quantization_config = QuantoConfig(
|
35 |
+
weights="int8",
|
36 |
+
modules_to_not_convert=[
|
37 |
+
"lm_head",
|
38 |
+
"embed_tokens",
|
39 |
+
]
|
40 |
+
+ [
|
41 |
+
f"model.layers.{i}.coefficient"
|
42 |
+
for i in range(hf_config.num_hidden_layers)
|
43 |
+
]
|
44 |
+
+ [
|
45 |
+
f"model.layers.{i}.block_sparse_moe.gate"
|
46 |
+
for i in range(hf_config.num_hidden_layers)
|
47 |
+
],
|
48 |
+
)
|
49 |
+
|
50 |
+
# assume 8 GPUs
|
51 |
+
world_size = torch.cuda.device_count()
|
52 |
+
layers_per_device = hf_config.num_hidden_layers // world_size
|
53 |
+
# set device map
|
54 |
+
device_map = {
|
55 |
+
"model.embed_tokens": "cuda:0",
|
56 |
+
"model.norm": f"cuda:{world_size - 1}",
|
57 |
+
"lm_head": f"cuda:{world_size - 1}",
|
58 |
+
}
|
59 |
+
for i in range(world_size):
|
60 |
+
for j in range(layers_per_device):
|
61 |
+
device_map[f"model.layers.{i * layers_per_device + j}"] = f"cuda:{i}"
|
62 |
+
|
63 |
+
# load tokenizer
|
64 |
+
self.tokenizer = AutoTokenizer.from_pretrained(
|
65 |
+
"MiniMaxAI/MiniMax-Text-01", cache_dir=cache_dir
|
66 |
+
)
|
67 |
+
|
68 |
+
# load bfloat16 model, move to device, and apply quantization
|
69 |
+
self.quantized_model = AutoModelForCausalLM.from_pretrained(
|
70 |
+
"MiniMaxAI/MiniMax-Text-01",
|
71 |
+
torch_dtype="bfloat16",
|
72 |
+
device_map=device_map,
|
73 |
+
quantization_config=quantization_config,
|
74 |
+
trust_remote_code=True,
|
75 |
+
offload_buffers=True,
|
76 |
+
cache_dir=cache_dir,
|
77 |
+
)
|
78 |
+
except Exception as e:
|
79 |
+
print(f"Failed to load MiniMax model: {e}")
|
80 |
+
breakpoint()
|
81 |
+
raise e
|
82 |
+
|
83 |
+
def generate(
|
84 |
+
self,
|
85 |
+
prompt: str,
|
86 |
+
system_prompt: Optional[
|
87 |
+
str
|
88 |
+
] = "You are a helpful assistant created by MiniMax based on MiniMax-Text-01 model.",
|
89 |
+
max_new_tokens: int = 20,
|
90 |
+
**kwargs,
|
91 |
+
) -> str:
|
92 |
+
messages = []
|
93 |
+
if system_prompt:
|
94 |
+
messages.append(
|
95 |
+
{
|
96 |
+
"role": "system",
|
97 |
+
"content": [{"type": "text", "text": system_prompt}],
|
98 |
+
}
|
99 |
+
)
|
100 |
+
|
101 |
+
messages.append({"role": "user", "content": [
|
102 |
+
{"type": "text", "text": prompt}]})
|
103 |
+
text = self.tokenizer.apply_chat_template(
|
104 |
+
messages, tokenize=False, add_generation_prompt=True
|
105 |
+
)
|
106 |
+
# tokenize and move to device
|
107 |
+
model_inputs = self.tokenizer(text, return_tensors="pt").to("cuda")
|
108 |
+
generation_config = GenerationConfig(
|
109 |
+
max_new_tokens=max_new_tokens,
|
110 |
+
eos_token_id=200020,
|
111 |
+
use_cache=True,
|
112 |
+
)
|
113 |
+
generated_ids = self.quantized_model.generate(
|
114 |
+
**model_inputs, generation_config=generation_config
|
115 |
+
)
|
116 |
+
generated_ids = [
|
117 |
+
output_ids[len(input_ids):]
|
118 |
+
for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids)
|
119 |
+
]
|
120 |
+
response = self.tokenizer.batch_decode(
|
121 |
+
generated_ids, skip_special_tokens=True)[0]
|
122 |
+
return response
|
modules/llm/qwen3.py
ADDED
@@ -0,0 +1,53 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Ref: https://qwenlm.github.io/blog/qwen3/
|
2 |
+
|
3 |
+
from typing import Optional
|
4 |
+
|
5 |
+
from .base import AbstractLLMModel
|
6 |
+
from .registry import register_llm_model
|
7 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer
|
8 |
+
|
9 |
+
|
10 |
+
@register_llm_model("Qwen/Qwen3-")
|
11 |
+
class Qwen3LLM(AbstractLLMModel):
|
12 |
+
def __init__(
|
13 |
+
self, model_id: str, device: str = "auto", cache_dir: str = "cache", **kwargs
|
14 |
+
):
|
15 |
+
super().__init__(model_id, device, cache_dir, **kwargs)
|
16 |
+
self.model = AutoModelForCausalLM.from_pretrained(
|
17 |
+
model_id, device_map=device, torch_dtype="auto", cache_dir=cache_dir
|
18 |
+
).eval()
|
19 |
+
self.tokenizer = AutoTokenizer.from_pretrained(model_id, cache_dir=cache_dir)
|
20 |
+
|
21 |
+
def generate(
|
22 |
+
self,
|
23 |
+
prompt: str,
|
24 |
+
system_prompt: Optional[str] = None,
|
25 |
+
max_new_tokens: int = 256,
|
26 |
+
enable_thinking: bool = False,
|
27 |
+
**kwargs
|
28 |
+
) -> str:
|
29 |
+
messages = []
|
30 |
+
if system_prompt:
|
31 |
+
messages.append({"role": "system", "content": system_prompt})
|
32 |
+
messages.append({"role": "user", "content": prompt})
|
33 |
+
text = self.tokenizer.apply_chat_template(
|
34 |
+
messages,
|
35 |
+
tokenize=False,
|
36 |
+
add_generation_prompt=True,
|
37 |
+
enable_thinking=enable_thinking,
|
38 |
+
)
|
39 |
+
model_inputs = self.tokenizer([text], return_tensors="pt").to(self.model.device)
|
40 |
+
generated_ids = self.model.generate(
|
41 |
+
**model_inputs, max_new_tokens=max_new_tokens
|
42 |
+
)
|
43 |
+
output_ids = generated_ids[0][len(model_inputs.input_ids[0]) :].tolist()
|
44 |
+
# parse thinking content
|
45 |
+
if enable_thinking:
|
46 |
+
try:
|
47 |
+
# rindex finding 151668 (</think>)
|
48 |
+
index = len(output_ids) - output_ids[::-1].index(151668)
|
49 |
+
except ValueError:
|
50 |
+
index = 0
|
51 |
+
output_ids = output_ids[index:]
|
52 |
+
|
53 |
+
return self.tokenizer.decode(output_ids, skip_special_tokens=True).strip("\n")
|
modules/llm/registry.py
ADDED
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .base import AbstractLLMModel
|
2 |
+
|
3 |
+
LLM_MODEL_REGISTRY = {}
|
4 |
+
|
5 |
+
|
6 |
+
def register_llm_model(prefix: str):
|
7 |
+
def wrapper(cls):
|
8 |
+
assert issubclass(cls, AbstractLLMModel), f"{cls} must inherit AbstractLLMModel"
|
9 |
+
LLM_MODEL_REGISTRY[prefix] = cls
|
10 |
+
return cls
|
11 |
+
|
12 |
+
return wrapper
|
13 |
+
|
14 |
+
|
15 |
+
def get_llm_model(model_id: str, device="auto", **kwargs) -> AbstractLLMModel:
|
16 |
+
for prefix, cls in LLM_MODEL_REGISTRY.items():
|
17 |
+
if model_id.startswith(prefix):
|
18 |
+
return cls(model_id, device=device, **kwargs)
|
19 |
+
raise ValueError(f"No LLM wrapper found for model: {model_id}")
|
modules/melody.py
CHANGED
@@ -37,7 +37,7 @@ class MelodyController:
|
|
37 |
return ""
|
38 |
|
39 |
prompt = (
|
40 |
-
"\n
|
41 |
+ "".join(
|
42 |
[
|
43 |
f"\n第{i}句:{c}个字"
|
@@ -109,9 +109,10 @@ class MelodyController:
|
|
109 |
if pitch == 0:
|
110 |
score.append((st, ed, ref_lyric, pitch))
|
111 |
elif ref_lyric in ["-", "——"] and align_type == "lyric":
|
112 |
-
score.append((st, ed,
|
113 |
-
text_idx += 1
|
114 |
else:
|
115 |
score.append((st, ed, text_list[text_idx], pitch))
|
116 |
text_idx += 1
|
|
|
|
|
117 |
return score
|
|
|
37 |
return ""
|
38 |
|
39 |
prompt = (
|
40 |
+
"\n请按照歌词格式回复,每句需遵循以下字数规则:"
|
41 |
+ "".join(
|
42 |
[
|
43 |
f"\n第{i}句:{c}个字"
|
|
|
109 |
if pitch == 0:
|
110 |
score.append((st, ed, ref_lyric, pitch))
|
111 |
elif ref_lyric in ["-", "——"] and align_type == "lyric":
|
112 |
+
score.append((st, ed, "-", pitch))
|
|
|
113 |
else:
|
114 |
score.append((st, ed, text_list[text_idx], pitch))
|
115 |
text_idx += 1
|
116 |
+
if text_idx >= len(text_list):
|
117 |
+
break
|
118 |
return score
|
modules/svs/base.py
CHANGED
@@ -6,7 +6,7 @@ import numpy as np
|
|
6 |
class AbstractSVSModel(ABC):
|
7 |
@abstractmethod
|
8 |
def __init__(
|
9 |
-
self, model_id: str, device: str = "
|
10 |
): ...
|
11 |
|
12 |
@abstractmethod
|
|
|
6 |
class AbstractSVSModel(ABC):
|
7 |
@abstractmethod
|
8 |
def __init__(
|
9 |
+
self, model_id: str, device: str = "auto", cache_dir: str = "cache", **kwargs
|
10 |
): ...
|
11 |
|
12 |
@abstractmethod
|
modules/svs/espnet.py
CHANGED
@@ -14,17 +14,17 @@ from .registry import register_svs_model
|
|
14 |
|
15 |
@register_svs_model("espnet/")
|
16 |
class ESPNetSVS(AbstractSVSModel):
|
17 |
-
def __init__(self, model_id: str, device="
|
18 |
from espnet2.bin.svs_inference import SingingGenerate
|
19 |
from espnet_model_zoo.downloader import ModelDownloader
|
20 |
-
|
21 |
-
|
|
|
22 |
downloaded = ModelDownloader(cache_dir).download_and_unpack(model_id)
|
23 |
-
print(f"Downloaded {model_id} to {cache_dir}") # TODO: should improve log code
|
24 |
self.model = SingingGenerate(
|
25 |
train_config=downloaded["train_config"],
|
26 |
model_file=downloaded["model_file"],
|
27 |
-
device=device,
|
28 |
)
|
29 |
self.model_id = model_id
|
30 |
self.output_sample_rate = self.model.fs
|
@@ -53,7 +53,7 @@ class ESPNetSVS(AbstractSVSModel):
|
|
53 |
phoneme_mappers = {}
|
54 |
return phoneme_mappers
|
55 |
|
56 |
-
def _preprocess(self, score: list[tuple[float, float, str, int]], language: str):
|
57 |
if language not in self.phoneme_mappers:
|
58 |
raise ValueError(f"Unsupported language: {language} for {self.model_id}")
|
59 |
phoneme_mapper = self.phoneme_mappers[language]
|
@@ -90,16 +90,16 @@ class ESPNetSVS(AbstractSVSModel):
|
|
90 |
pre_phn = phn_units[-1]
|
91 |
|
92 |
batch = {
|
93 |
-
"score":
|
94 |
-
|
95 |
-
|
96 |
-
|
97 |
"text": " ".join(phns),
|
98 |
}
|
99 |
return batch
|
100 |
|
101 |
def synthesize(
|
102 |
-
self, score: list[tuple[float, float, str, int]], language: str, speaker: str, **kwargs
|
103 |
):
|
104 |
batch = self._preprocess(score, language)
|
105 |
if self.model_id == "espnet/aceopencpop_svs_visinger2_40singer_pretrain":
|
@@ -107,8 +107,8 @@ class ESPNetSVS(AbstractSVSModel):
|
|
107 |
output_dict = self.model(batch, sids=sid)
|
108 |
elif self.model_id == "espnet/mixdata_svs_visinger2_spkemb_lang_pretrained":
|
109 |
langs = {
|
110 |
-
"
|
111 |
-
"
|
112 |
}
|
113 |
if language not in langs:
|
114 |
raise ValueError(
|
|
|
14 |
|
15 |
@register_svs_model("espnet/")
|
16 |
class ESPNetSVS(AbstractSVSModel):
|
17 |
+
def __init__(self, model_id: str, device="auto", cache_dir="cache", **kwargs):
|
18 |
from espnet2.bin.svs_inference import SingingGenerate
|
19 |
from espnet_model_zoo.downloader import ModelDownloader
|
20 |
+
if device == "auto":
|
21 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
22 |
+
self.device = device
|
23 |
downloaded = ModelDownloader(cache_dir).download_and_unpack(model_id)
|
|
|
24 |
self.model = SingingGenerate(
|
25 |
train_config=downloaded["train_config"],
|
26 |
model_file=downloaded["model_file"],
|
27 |
+
device=self.device,
|
28 |
)
|
29 |
self.model_id = model_id
|
30 |
self.output_sample_rate = self.model.fs
|
|
|
53 |
phoneme_mappers = {}
|
54 |
return phoneme_mappers
|
55 |
|
56 |
+
def _preprocess(self, score: list[tuple[float, float, str, int] | tuple[float, float, str, float]], language: str):
|
57 |
if language not in self.phoneme_mappers:
|
58 |
raise ValueError(f"Unsupported language: {language} for {self.model_id}")
|
59 |
phoneme_mapper = self.phoneme_mappers[language]
|
|
|
90 |
pre_phn = phn_units[-1]
|
91 |
|
92 |
batch = {
|
93 |
+
"score": (
|
94 |
+
120, # does not affect svs result, as note durations are in time unit
|
95 |
+
notes,
|
96 |
+
),
|
97 |
"text": " ".join(phns),
|
98 |
}
|
99 |
return batch
|
100 |
|
101 |
def synthesize(
|
102 |
+
self, score: list[tuple[float, float, str, float] | tuple[float, float, str, int]], language: str, speaker: str, **kwargs
|
103 |
):
|
104 |
batch = self._preprocess(score, language)
|
105 |
if self.model_id == "espnet/aceopencpop_svs_visinger2_40singer_pretrain":
|
|
|
107 |
output_dict = self.model(batch, sids=sid)
|
108 |
elif self.model_id == "espnet/mixdata_svs_visinger2_spkemb_lang_pretrained":
|
109 |
langs = {
|
110 |
+
"mandarin": 2,
|
111 |
+
"japanese": 1,
|
112 |
}
|
113 |
if language not in langs:
|
114 |
raise ValueError(
|
modules/svs/registry.py
CHANGED
@@ -12,7 +12,7 @@ def register_svs_model(prefix: str):
|
|
12 |
return wrapper
|
13 |
|
14 |
|
15 |
-
def get_svs_model(model_id: str, device="
|
16 |
for prefix, cls in SVS_MODEL_REGISTRY.items():
|
17 |
if model_id.startswith(prefix):
|
18 |
return cls(model_id, device=device, **kwargs)
|
|
|
12 |
return wrapper
|
13 |
|
14 |
|
15 |
+
def get_svs_model(model_id: str, device="auto", **kwargs) -> AbstractSVSModel:
|
16 |
for prefix, cls in SVS_MODEL_REGISTRY.items():
|
17 |
if model_id.startswith(prefix):
|
18 |
return cls(model_id, device=device, **kwargs)
|
modules/utils/g2p.py
CHANGED
@@ -32,6 +32,7 @@ for plan in ace_phonemes_all_plans["plans"]:
|
|
32 |
|
33 |
|
34 |
def preprocess_text(text: str, language: str) -> list[str]:
|
|
|
35 |
if language == "mandarin":
|
36 |
text_list = to_pinyin(text)
|
37 |
elif language == "japanese":
|
|
|
32 |
|
33 |
|
34 |
def preprocess_text(text: str, language: str) -> list[str]:
|
35 |
+
text = text.replace(" ", "")
|
36 |
if language == "mandarin":
|
37 |
text_list = to_pinyin(text)
|
38 |
elif language == "japanese":
|
modules/utils/text_normalize.py
CHANGED
@@ -3,12 +3,13 @@ from typing import Optional
|
|
3 |
|
4 |
|
5 |
def remove_non_zh_jp(text: str) -> str:
|
6 |
-
pattern = r"[^\u4e00-\u9fff\u3040-\u309f\u30a0-\u30ff\u3000-\u303f\uff01-\uffef]"
|
7 |
return re.sub(pattern, "", text)
|
8 |
|
9 |
|
10 |
def truncate_sentences(text: str, max_sentences: int) -> str:
|
11 |
-
sentences = re.split(r"(?<=[
|
|
|
12 |
return "".join(sentences[:max_sentences]).strip()
|
13 |
|
14 |
|
|
|
3 |
|
4 |
|
5 |
def remove_non_zh_jp(text: str) -> str:
|
6 |
+
pattern = r"[^\u4e00-\u9fff\u3040-\u309f\u30a0-\u30ff\u3000-\u303f\uff01-\uffef\s]"
|
7 |
return re.sub(pattern, "", text)
|
8 |
|
9 |
|
10 |
def truncate_sentences(text: str, max_sentences: int) -> str:
|
11 |
+
sentences = re.split(r"(?<=[。!?!?~])|(?:\n+)|(?: {2,})", text)
|
12 |
+
sentences = [s.strip() for s in sentences if s.strip()]
|
13 |
return "".join(sentences[:max_sentences]).strip()
|
14 |
|
15 |
|
pipeline.py
CHANGED
@@ -1,6 +1,11 @@
|
|
1 |
-
import
|
|
|
2 |
import time
|
|
|
|
|
3 |
import librosa
|
|
|
|
|
4 |
|
5 |
from modules.asr import get_asr_model
|
6 |
from modules.llm import get_llm_model
|
@@ -29,20 +34,36 @@ class SingingDialoguePipeline:
|
|
29 |
self.melody_controller = MelodyController(
|
30 |
config["melody_source"], self.cache_dir
|
31 |
)
|
|
|
32 |
self.track_latency = config.get("track_latency", False)
|
33 |
self.evaluators = load_evaluators(config.get("evaluators", {}).get("svs", []))
|
34 |
|
35 |
def set_asr_model(self, asr_model: str):
|
|
|
|
|
|
|
|
|
|
|
36 |
self.asr = get_asr_model(
|
37 |
asr_model, device=self.device, cache_dir=self.cache_dir
|
38 |
)
|
39 |
|
40 |
def set_llm_model(self, llm_model: str):
|
|
|
|
|
|
|
|
|
|
|
41 |
self.llm = get_llm_model(
|
42 |
llm_model, device=self.device, cache_dir=self.cache_dir
|
43 |
)
|
44 |
|
45 |
def set_svs_model(self, svs_model: str):
|
|
|
|
|
|
|
|
|
|
|
46 |
self.svs = get_svs_model(
|
47 |
svs_model, device=self.device, cache_dir=self.cache_dir
|
48 |
)
|
@@ -54,9 +75,9 @@ class SingingDialoguePipeline:
|
|
54 |
self,
|
55 |
audio_path,
|
56 |
language,
|
57 |
-
|
58 |
speaker,
|
59 |
-
|
60 |
):
|
61 |
if self.track_latency:
|
62 |
asr_start_time = time.time()
|
@@ -67,16 +88,16 @@ class SingingDialoguePipeline:
|
|
67 |
if self.track_latency:
|
68 |
asr_end_time = time.time()
|
69 |
asr_latency = asr_end_time - asr_start_time
|
70 |
-
melody_prompt = self.melody_controller.get_melody_constraints()
|
71 |
-
prompt = prompt_template.format(melody_prompt, asr_result)
|
72 |
if self.track_latency:
|
73 |
llm_start_time = time.time()
|
74 |
-
output = self.llm.generate(
|
75 |
if self.track_latency:
|
76 |
llm_end_time = time.time()
|
77 |
llm_latency = llm_end_time - llm_start_time
|
78 |
-
|
79 |
-
|
|
|
80 |
score = self.melody_controller.generate_score(llm_response, language)
|
81 |
if self.track_latency:
|
82 |
svs_start_time = time.time()
|
@@ -89,14 +110,18 @@ class SingingDialoguePipeline:
|
|
89 |
results = {
|
90 |
"asr_text": asr_result,
|
91 |
"llm_text": llm_response,
|
92 |
-
"svs_audio": (
|
93 |
}
|
|
|
|
|
|
|
|
|
94 |
if self.track_latency:
|
95 |
-
results["metrics"]
|
96 |
"asr_latency": asr_latency,
|
97 |
"llm_latency": llm_latency,
|
98 |
"svs_latency": svs_latency,
|
99 |
-
}
|
100 |
return results
|
101 |
|
102 |
def evaluate(self, audio_path):
|
|
|
1 |
+
from __future__ import annotations
|
2 |
+
|
3 |
import time
|
4 |
+
from pathlib import Path
|
5 |
+
|
6 |
import librosa
|
7 |
+
import soundfile as sf
|
8 |
+
import torch
|
9 |
|
10 |
from modules.asr import get_asr_model
|
11 |
from modules.llm import get_llm_model
|
|
|
34 |
self.melody_controller = MelodyController(
|
35 |
config["melody_source"], self.cache_dir
|
36 |
)
|
37 |
+
self.max_sentences = config.get("max_sentences", 2)
|
38 |
self.track_latency = config.get("track_latency", False)
|
39 |
self.evaluators = load_evaluators(config.get("evaluators", {}).get("svs", []))
|
40 |
|
41 |
def set_asr_model(self, asr_model: str):
|
42 |
+
if self.asr is not None:
|
43 |
+
del self.asr
|
44 |
+
import gc
|
45 |
+
gc.collect()
|
46 |
+
torch.cuda.empty_cache()
|
47 |
self.asr = get_asr_model(
|
48 |
asr_model, device=self.device, cache_dir=self.cache_dir
|
49 |
)
|
50 |
|
51 |
def set_llm_model(self, llm_model: str):
|
52 |
+
if self.llm is not None:
|
53 |
+
del self.llm
|
54 |
+
import gc
|
55 |
+
gc.collect()
|
56 |
+
torch.cuda.empty_cache()
|
57 |
self.llm = get_llm_model(
|
58 |
llm_model, device=self.device, cache_dir=self.cache_dir
|
59 |
)
|
60 |
|
61 |
def set_svs_model(self, svs_model: str):
|
62 |
+
if self.svs is not None:
|
63 |
+
del self.svs
|
64 |
+
import gc
|
65 |
+
gc.collect()
|
66 |
+
torch.cuda.empty_cache()
|
67 |
self.svs = get_svs_model(
|
68 |
svs_model, device=self.device, cache_dir=self.cache_dir
|
69 |
)
|
|
|
75 |
self,
|
76 |
audio_path,
|
77 |
language,
|
78 |
+
system_prompt,
|
79 |
speaker,
|
80 |
+
output_audio_path: Path | str = None,
|
81 |
):
|
82 |
if self.track_latency:
|
83 |
asr_start_time = time.time()
|
|
|
88 |
if self.track_latency:
|
89 |
asr_end_time = time.time()
|
90 |
asr_latency = asr_end_time - asr_start_time
|
91 |
+
melody_prompt = self.melody_controller.get_melody_constraints(max_num_phrases=self.max_sentences)
|
|
|
92 |
if self.track_latency:
|
93 |
llm_start_time = time.time()
|
94 |
+
output = self.llm.generate(asr_result, system_prompt + melody_prompt)
|
95 |
if self.track_latency:
|
96 |
llm_end_time = time.time()
|
97 |
llm_latency = llm_end_time - llm_start_time
|
98 |
+
llm_response = clean_llm_output(
|
99 |
+
output, language=language, max_sentences=self.max_sentences
|
100 |
+
)
|
101 |
score = self.melody_controller.generate_score(llm_response, language)
|
102 |
if self.track_latency:
|
103 |
svs_start_time = time.time()
|
|
|
110 |
results = {
|
111 |
"asr_text": asr_result,
|
112 |
"llm_text": llm_response,
|
113 |
+
"svs_audio": (sample_rate, singing_audio),
|
114 |
}
|
115 |
+
if output_audio_path:
|
116 |
+
Path(output_audio_path).parent.mkdir(parents=True, exist_ok=True)
|
117 |
+
sf.write(output_audio_path, singing_audio, sample_rate)
|
118 |
+
results["output_audio_path"] = output_audio_path
|
119 |
if self.track_latency:
|
120 |
+
results["metrics"] = {
|
121 |
"asr_latency": asr_latency,
|
122 |
"llm_latency": llm_latency,
|
123 |
"svs_latency": svs_latency,
|
124 |
+
}
|
125 |
return results
|
126 |
|
127 |
def evaluate(self, audio_path):
|
requirements.txt
CHANGED
@@ -12,9 +12,9 @@ pykakasi
|
|
12 |
basic-pitch[onnx]
|
13 |
audiobox_aesthetics
|
14 |
transformers
|
15 |
-
s3prl
|
16 |
zhconv
|
17 |
git+https://github.com/sea-turt1e/kanjiconv
|
18 |
soundfile
|
19 |
PyYAML
|
20 |
gradio
|
|
|
|
12 |
basic-pitch[onnx]
|
13 |
audiobox_aesthetics
|
14 |
transformers
|
|
|
15 |
zhconv
|
16 |
git+https://github.com/sea-turt1e/kanjiconv
|
17 |
soundfile
|
18 |
PyYAML
|
19 |
gradio
|
20 |
+
google-generativeai
|
tests/__init__.py
ADDED
File without changes
|
tests/test_asr_infer.py
ADDED
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from modules.asr import get_asr_model
|
2 |
+
import librosa
|
3 |
+
|
4 |
+
if __name__ == "__main__":
|
5 |
+
supported_asrs = [
|
6 |
+
"funasr/paraformer-zh",
|
7 |
+
"openai/whisper-large-v3-turbo",
|
8 |
+
]
|
9 |
+
for model_id in supported_asrs:
|
10 |
+
try:
|
11 |
+
print(f"Loading model: {model_id}")
|
12 |
+
asr_model = get_asr_model(model_id, device="auto", cache_dir=".cache")
|
13 |
+
audio, sample_rate = librosa.load("tests/audio/hello.wav", sr=None)
|
14 |
+
result = asr_model.transcribe(audio, sample_rate, language="mandarin")
|
15 |
+
print(result)
|
16 |
+
except Exception as e:
|
17 |
+
print(f"Failed to load model {model_id}: {e}")
|
18 |
+
breakpoint()
|
19 |
+
continue
|
tests/test_llm_infer.py
ADDED
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from characters import get_character
|
2 |
+
from modules.llm import get_llm_model
|
3 |
+
from time import time
|
4 |
+
|
5 |
+
if __name__ == "__main__":
|
6 |
+
supported_llms = [
|
7 |
+
# "MiniMaxAI/MiniMax-Text-01",
|
8 |
+
# "Qwen/Qwen3-8B",
|
9 |
+
# "Qwen/Qwen3-30B-A3B",
|
10 |
+
# "meta-llama/Llama-3.1-8B-Instruct",
|
11 |
+
# "tiiuae/Falcon-H1-1B-Base",
|
12 |
+
# "tiiuae/Falcon-H1-3B-Instruct",
|
13 |
+
# "google/gemma-2-2b",
|
14 |
+
# "gemini-2.5-flash",
|
15 |
+
]
|
16 |
+
character_prompt = get_character("Yaoyin").prompt
|
17 |
+
for model_id in supported_llms:
|
18 |
+
try:
|
19 |
+
print(f"Loading model: {model_id}")
|
20 |
+
llm = get_llm_model(model_id, cache_dir="./.cache")
|
21 |
+
prompt = "你好,今天你心情怎么样?"
|
22 |
+
start_time = time()
|
23 |
+
result = llm.generate(prompt, system_prompt=character_prompt)
|
24 |
+
end_time = time()
|
25 |
+
print(f"[{model_id}] LLM inference time: {end_time - start_time:.2f} seconds")
|
26 |
+
print(f"[{model_id}] LLM inference result:", result)
|
27 |
+
except Exception as e:
|
28 |
+
print(f"Failed to load model {model_id}: {e}")
|
29 |
+
breakpoint()
|
30 |
+
continue
|