Stardust-minus commited on
Commit
b74cfae
·
verified ·
1 Parent(s): 8b8a08e

Delete generate_cli.py

Browse files
Files changed (1) hide show
  1. generate_cli.py +0 -227
generate_cli.py DELETED
@@ -1,227 +0,0 @@
1
- import os
2
- import json
3
- import queue
4
- from pathlib import Path
5
- from typing import Optional
6
-
7
- import click
8
- import torch
9
- import soundfile as sf
10
- from loguru import logger
11
-
12
- from fish_speech.models.text2semantic.inference import (
13
- CodebookSamplingParams,
14
- SamplingParams,
15
- generate_long,
16
- launch_thread_safe_queue,
17
- GenerateRequest,
18
- WrappedGenerateResponse,
19
- )
20
- from fish_speech.models.text2semantic.llama import BaseTransformer
21
- from fish_speech.models.dac.inference import load_model as load_decoder_model
22
- from fish_speech.text import clean_text
23
- from fish_speech.inference_engine.vq_manager import VQManager
24
- from tools.api import load_audio
25
-
26
-
27
- def load_llm_model(model_path: str, device: str, compile: bool = False):
28
- """加载LLM模型"""
29
- logger.info(f"Loading LLM model from {model_path}")
30
- model = BaseTransformer.from_pretrained(
31
- path=model_path,
32
- load_weights=True,
33
- )
34
- model = model.to(device=device, dtype=torch.bfloat16)
35
-
36
- if isinstance(model, model.__class__.__bases__[0].__subclasses__()[1]): # DualARTransformer
37
- from fish_speech.models.text2semantic.inference import decode_one_token_ar as decode_one_token
38
- logger.info("Using DualARTransformer")
39
- else:
40
- from fish_speech.models.text2semantic.inference import decode_one_token_naive as decode_one_token
41
- logger.info("Using NaiveTransformer")
42
-
43
- if compile:
44
- logger.info("Compiling decode function...")
45
- decode_one_token = torch.compile(
46
- decode_one_token,
47
- fullgraph=True,
48
- backend="inductor" if torch.cuda.is_available() else "aot_eager",
49
- mode="reduce-overhead" if torch.cuda.is_available() else None,
50
- )
51
-
52
- return model.eval(), decode_one_token
53
-
54
-
55
- def load_dac_model(config_name: str, checkpoint_path: str, device: str):
56
- """加载DAC模型"""
57
- logger.info(f"Loading DAC model from {checkpoint_path}")
58
- model = load_decoder_model(
59
- config_name=config_name,
60
- checkpoint_path=checkpoint_path,
61
- device=device,
62
- )
63
- return model
64
-
65
-
66
- @click.command()
67
- #@click.argument("text", type=str)
68
- @click.option("--llm-model-path", type=str, required=True, help="Path to the LLM model")
69
- @click.option("--dac-model-path", type=str, required=True, help="Path to the DAC model")
70
- @click.option("--dac-config-name", type=str, default="modded_dac_vq", help="DAC model config name")
71
- @click.option("--output-path", type=str, required=True, help="Path to save the output audio")
72
- @click.option("--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu", help="Device to use")
73
- @click.option("--max-new-tokens", type=int, default=4096, help="Maximum new tokens to generate")
74
- @click.option("--chunk-length", type=int, default=1000, help="Chunk length for synthesis")
75
- @click.option("--compile", is_flag=True, help="Whether to compile the model")
76
- @click.option("--iterative-prompt", is_flag=True, help="Whether to use iterative prompt")
77
- @click.option("--params-file", type=str, default="sampling_params_example.json", help="Path to JSON file containing sampling parameters")
78
- @click.option(
79
- "--ref-audio",
80
- type=click.Path(path_type=Path, exists=True),
81
- default="ref.wav",
82
- help="参考音频文件路径,默认ref.wav"
83
- )
84
- def main(
85
- #text: str,
86
- llm_model_path: str,
87
- dac_model_path: str,
88
- dac_config_name: str,
89
- output_path: str,
90
- device: str,
91
- max_new_tokens: int,
92
- chunk_length: int,
93
- compile: bool,
94
- iterative_prompt: bool,
95
- params_file: Optional[str],
96
- ref_audio: Path,
97
- ):
98
- """生成语音,包括LLM生成token和DAC生成音频两个步骤"""
99
-
100
- # 设置精度
101
- precision = torch.half if torch.cuda.is_available() else torch.bfloat16
102
-
103
- # 加载LLM模型(使用线程安全的队列)
104
- logger.info("Loading LLM model...")
105
- llama_queue = launch_thread_safe_queue(
106
- checkpoint_path=llm_model_path,
107
- device="cuda:0",
108
- precision=precision,
109
- compile=compile,
110
- )
111
- logger.info("LLM model loaded")
112
-
113
- # 加载DAC模型
114
- logger.info("Loading DAC model...")
115
- dac_model = load_decoder_model(
116
- config_name=dac_config_name,
117
- checkpoint_path=dac_model_path,
118
- device="cuda:1",
119
- )
120
- logger.info("DAC model loaded")
121
-
122
- # 加载采样参数
123
- if params_file:
124
- with open(params_file, "r", encoding="utf-8") as f:
125
- params_data = json.load(f)
126
- text = params_data.get("text", "")
127
-
128
- semantic_params = CodebookSamplingParams(**params_data.get("semantic", {}))
129
- codebook_params = [
130
- CodebookSamplingParams(**params) for params in params_data.get("codebooks", [])
131
- ]
132
- sampling_params = SamplingParams(
133
- semantic=semantic_params,
134
- codebooks=codebook_params,
135
- )
136
- else:
137
- sampling_params = SamplingParams()
138
-
139
- # 清理文本
140
- text = clean_text(text)
141
-
142
- # ���载参考音频
143
- if not ref_audio.exists():
144
- ref_audio_data, ref_sr = sf.read(ref_audio)
145
- logger.info(f"Loaded reference audio: {ref_audio}, shape={ref_audio_data.shape}, sr={ref_sr}")
146
- # 编码参考音频为prompt_tokens
147
- vq_manager = VQManager()
148
- vq_manager.decoder_model = dac_model
149
- vq_manager.load_audio = load_audio
150
- prompt_tokens = vq_manager.encode_reference(ref_audio, enable_reference_audio=True)
151
- logger.info(f"Encoded reference audio to prompt_tokens, shape={prompt_tokens.shape if prompt_tokens is not None else None}")
152
- else:
153
- prompt_tokens = []
154
- logger.warning(f"Reference audio {ref_audio} not found.")
155
-
156
- # 生成语音
157
- logger.info(f"Generating speech for text: {text}")
158
- logger.info(f"Using sampling parameters: {sampling_params}")
159
-
160
- output_path = Path(output_path)
161
- if not output_path.suffix:
162
- output_path = output_path.with_suffix('.wav')
163
- output_path.parent.mkdir(parents=True, exist_ok=True)
164
-
165
- # 创建响应队列
166
- response_queue = queue.Queue()
167
-
168
- # 准备请求
169
- request = dict(
170
- device=device,
171
- max_new_tokens=max_new_tokens,
172
- text=text,
173
- sampling_params=sampling_params,
174
- compile=compile,
175
- iterative_prompt=iterative_prompt,
176
- chunk_length=chunk_length,
177
- prompt_text=[],
178
- prompt_tokens=[prompt_tokens] if prompt_tokens is not None and len(prompt_tokens) else [],
179
- #prompt_text=["Through the dense morning fog that rolled across the peaceful valley, the distant church bells chimed their melodic song, echoing off ancient stone walls and mingling with the gentle rustling of maple leaves in the cool breeze. Inside the cozy lakeside cottage, fresh bread baked in the old clay oven filled every corner with its rich, comforting aroma, while steam rose lazily from ceramic mugs of fresh-brewed coffee on the handcrafted pine table. The persistent rain finally gave way to brilliant sunshine, transforming ordinary dewdrops into countless sparkling diamonds scattered across the vibrant garden flowers."],
180
- )
181
-
182
- # 发送请求到LLM模型
183
- llama_queue.put(GenerateRequest(request=request, response_queue=response_queue))
184
-
185
- # 收集生成的token
186
- all_tokens = []
187
- while True:
188
- wrapped_result: WrappedGenerateResponse = response_queue.get()
189
-
190
- if wrapped_result.status == "error":
191
- error = wrapped_result.response if isinstance(wrapped_result.response, Exception) else Exception("Unknown error")
192
- logger.error(f"Error during generation: {error}")
193
- break
194
-
195
- result = wrapped_result.response
196
- if result.action == "next":
197
- break
198
-
199
- all_tokens.append(result.codes)
200
- logger.info(f"Generated chunk {len(all_tokens)}")
201
-
202
- if not all_tokens:
203
- logger.error("No tokens generated")
204
- return
205
-
206
- # 合并所有token
207
- if len(all_tokens) > 1:
208
- tokens = torch.cat(all_tokens, dim=1)
209
- else:
210
- tokens = all_tokens[0]
211
-
212
- # 使用DAC模型生成音频
213
- logger.info("Converting tokens to audio...")
214
- feature_lengths = torch.tensor([tokens.shape[1]], device=device)
215
- audio, _ = dac_model.decode(
216
- indices=tokens[None].to("cuda:1"),
217
- feature_lengths=feature_lengths.to("cuda:1")
218
- )
219
-
220
- # 保存音频
221
- audio = audio[0, 0].detach().float().cpu().numpy()
222
- sf.write(output_path, audio, dac_model.sample_rate)
223
- logger.info(f"Saved audio to {output_path}")
224
-
225
-
226
- if __name__ == "__main__":
227
- main()