Update orpheus-tts/engine_class.py
Browse files- orpheus-tts/engine_class.py +61 -27
orpheus-tts/engine_class.py
CHANGED
@@ -14,9 +14,11 @@ class OrpheusModel:
|
|
14 |
self.dtype = dtype
|
15 |
self.engine_kwargs = engine_kwargs # vLLM engine kwargs
|
16 |
self.engine = self._setup_engine()
|
17 |
-
# Available voices
|
18 |
-
if "
|
19 |
self.available_voices = ["Jakob", "Anton", "Julian", "Sophie", "Marie", "Mia"]
|
|
|
|
|
20 |
else:
|
21 |
# Original English voices as fallback
|
22 |
self.available_voices = ["zoe", "zac", "jess", "leo", "mia", "julia", "leah", "tara"]
|
@@ -57,7 +59,7 @@ class OrpheusModel:
|
|
57 |
# "repo_id": "canopylabs/orpheus-tts-0.1-finetune-prod",
|
58 |
# },
|
59 |
"medium-3b":{
|
60 |
-
"repo_id": "canopylabs/
|
61 |
},
|
62 |
}
|
63 |
unsupported_models = ["nano-150m", "micro-400m", "small-1b"]
|
@@ -88,31 +90,60 @@ class OrpheusModel:
|
|
88 |
raise ValueError(f"Voice {voice} is not available for model {self.model_name}")
|
89 |
|
90 |
def _format_prompt(self, prompt, voice="Jakob", model_type="larger"):
|
91 |
-
#
|
92 |
-
|
93 |
-
|
94 |
-
|
95 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
96 |
|
97 |
-
|
98 |
-
|
99 |
-
|
100 |
-
|
101 |
-
|
102 |
-
|
103 |
-
|
104 |
-
|
105 |
-
|
106 |
-
|
107 |
-
|
108 |
-
|
109 |
-
|
110 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
111 |
|
112 |
|
113 |
def generate_tokens_sync(self, prompt, voice=None, request_id="req-001", temperature=0.6, top_p=0.95, max_tokens=4000, stop_token_ids = [128258], repetition_penalty=1.1):
|
114 |
-
prompt_string = self._format_prompt(prompt, voice)
|
115 |
print(f"DEBUG: Original prompt: {prompt}")
|
|
|
|
|
|
|
116 |
print(f"DEBUG: Formatted prompt: {prompt_string}")
|
117 |
|
118 |
sampling_params = SamplingParams(
|
@@ -171,12 +202,15 @@ class OrpheusModel:
|
|
171 |
token_generator = self.generate_tokens_sync(**kwargs)
|
172 |
print("DEBUG: Token generator created successfully")
|
173 |
|
174 |
-
# Verwende Kartoffel-Decoder für
|
175 |
-
if "
|
176 |
-
print("DEBUG: Using Kartoffel decoder for
|
177 |
audio_generator = tokens_decoder_kartoffel_sync(token_generator, self.tokenizer)
|
|
|
|
|
|
|
178 |
else:
|
179 |
-
print("DEBUG: Using original decoder")
|
180 |
audio_generator = tokens_decoder_sync(token_generator)
|
181 |
|
182 |
print("DEBUG: Audio decoder called successfully")
|
|
|
14 |
self.dtype = dtype
|
15 |
self.engine_kwargs = engine_kwargs # vLLM engine kwargs
|
16 |
self.engine = self._setup_engine()
|
17 |
+
# Available voices based on model type
|
18 |
+
if "kartoffel" in model_name.lower():
|
19 |
self.available_voices = ["Jakob", "Anton", "Julian", "Sophie", "Marie", "Mia"]
|
20 |
+
elif "3b-de-ft" in model_name.lower():
|
21 |
+
self.available_voices = ["jana", "thomas", "max"]
|
22 |
else:
|
23 |
# Original English voices as fallback
|
24 |
self.available_voices = ["zoe", "zac", "jess", "leo", "mia", "julia", "leah", "tara"]
|
|
|
59 |
# "repo_id": "canopylabs/orpheus-tts-0.1-finetune-prod",
|
60 |
# },
|
61 |
"medium-3b":{
|
62 |
+
"repo_id": "canopylabs/3b-de-ft-research_release",
|
63 |
},
|
64 |
}
|
65 |
unsupported_models = ["nano-150m", "micro-400m", "small-1b"]
|
|
|
90 |
raise ValueError(f"Voice {voice} is not available for model {self.model_name}")
|
91 |
|
92 |
def _format_prompt(self, prompt, voice="Jakob", model_type="larger"):
|
93 |
+
# Unterschiedliche Formate für verschiedene Modelle
|
94 |
+
print(f"DEBUG: Model name for format check: {self.model_name}")
|
95 |
+
if "kartoffel" in self.model_name.lower():
|
96 |
+
print("DEBUG: Using Kartoffel format")
|
97 |
+
# Kartoffel model format
|
98 |
+
if voice:
|
99 |
+
full_prompt = f"{voice}: {prompt}"
|
100 |
+
else:
|
101 |
+
full_prompt = prompt
|
102 |
+
|
103 |
+
# Kartoffel model token format - direkt die Token-IDs einfügen
|
104 |
+
start_token_id = 128259
|
105 |
+
end_token_ids = [128009, 128260]
|
106 |
|
107 |
+
# Text tokenisieren
|
108 |
+
input_ids = self.tokenizer(full_prompt, return_tensors="pt").input_ids[0].tolist()
|
109 |
+
|
110 |
+
# Token-IDs zusammenfügen
|
111 |
+
all_token_ids = [start_token_id] + input_ids + end_token_ids
|
112 |
+
|
113 |
+
# Zurück zu String dekodieren
|
114 |
+
prompt_string = self.tokenizer.decode(all_token_ids, skip_special_tokens=False)
|
115 |
+
|
116 |
+
return prompt_string
|
117 |
+
else:
|
118 |
+
# Original Orpheus format (für Canopy-Deutsch und English)
|
119 |
+
if model_type == "smaller":
|
120 |
+
if voice:
|
121 |
+
return f"<custom_token_3>{prompt}[{voice}]<custom_token_4><custom_token_5>"
|
122 |
+
else:
|
123 |
+
return f"<custom_token_3>{prompt}<custom_token_4><custom_token_5>"
|
124 |
+
else:
|
125 |
+
if voice:
|
126 |
+
adapted_prompt = f"{voice}: {prompt}"
|
127 |
+
prompt_tokens = self.tokenizer(adapted_prompt, return_tensors="pt")
|
128 |
+
start_token = torch.tensor([[ 128259]], dtype=torch.int64)
|
129 |
+
end_tokens = torch.tensor([[128009, 128260, 128261, 128257]], dtype=torch.int64)
|
130 |
+
all_input_ids = torch.cat([start_token, prompt_tokens.input_ids, end_tokens], dim=1)
|
131 |
+
prompt_string = self.tokenizer.decode(all_input_ids[0])
|
132 |
+
return prompt_string
|
133 |
+
else:
|
134 |
+
prompt_tokens = self.tokenizer(prompt, return_tensors="pt")
|
135 |
+
start_token = torch.tensor([[ 128259]], dtype=torch.int64)
|
136 |
+
end_tokens = torch.tensor([[128009, 128260, 128261, 128257]], dtype=torch.int64)
|
137 |
+
all_input_ids = torch.cat([start_token, prompt_tokens.input_ids, end_tokens], dim=1)
|
138 |
+
prompt_string = self.tokenizer.decode(all_input_ids[0])
|
139 |
+
return prompt_string
|
140 |
|
141 |
|
142 |
def generate_tokens_sync(self, prompt, voice=None, request_id="req-001", temperature=0.6, top_p=0.95, max_tokens=4000, stop_token_ids = [128258], repetition_penalty=1.1):
|
|
|
143 |
print(f"DEBUG: Original prompt: {prompt}")
|
144 |
+
print(f"DEBUG: Voice: {voice}")
|
145 |
+
print(f"DEBUG: Model name: {self.model_name}")
|
146 |
+
prompt_string = self._format_prompt(prompt, voice)
|
147 |
print(f"DEBUG: Formatted prompt: {prompt_string}")
|
148 |
|
149 |
sampling_params = SamplingParams(
|
|
|
202 |
token_generator = self.generate_tokens_sync(**kwargs)
|
203 |
print("DEBUG: Token generator created successfully")
|
204 |
|
205 |
+
# Verwende Kartoffel-Decoder nur für Kartoffel-Modell, Original-Decoder für Canopy-Deutsch
|
206 |
+
if "kartoffel" in self.model_name.lower():
|
207 |
+
print("DEBUG: Using Kartoffel decoder for Kartoffel model")
|
208 |
audio_generator = tokens_decoder_kartoffel_sync(token_generator, self.tokenizer)
|
209 |
+
elif "3b-de-ft" in self.model_name.lower() or "german" in self.model_name.lower():
|
210 |
+
print("DEBUG: Using original decoder for Canopy German model")
|
211 |
+
audio_generator = tokens_decoder_sync(token_generator)
|
212 |
else:
|
213 |
+
print("DEBUG: Using original decoder for English model")
|
214 |
audio_generator = tokens_decoder_sync(token_generator)
|
215 |
|
216 |
print("DEBUG: Audio decoder called successfully")
|