Tomtom84 commited on
Commit
ce68a3b
·
verified ·
1 Parent(s): 9369755

Update orpheus-tts/engine_class.py

Browse files
Files changed (1) hide show
  1. orpheus-tts/engine_class.py +61 -27
orpheus-tts/engine_class.py CHANGED
@@ -14,9 +14,11 @@ class OrpheusModel:
14
  self.dtype = dtype
15
  self.engine_kwargs = engine_kwargs # vLLM engine kwargs
16
  self.engine = self._setup_engine()
17
- # Available voices for German Kartoffel model
18
- if "german" in model_name.lower() or "kartoffel" in model_name.lower():
19
  self.available_voices = ["Jakob", "Anton", "Julian", "Sophie", "Marie", "Mia"]
 
 
20
  else:
21
  # Original English voices as fallback
22
  self.available_voices = ["zoe", "zac", "jess", "leo", "mia", "julia", "leah", "tara"]
@@ -57,7 +59,7 @@ class OrpheusModel:
57
  # "repo_id": "canopylabs/orpheus-tts-0.1-finetune-prod",
58
  # },
59
  "medium-3b":{
60
- "repo_id": "canopylabs/orpheus-tts-0.1-finetune-prod",
61
  },
62
  }
63
  unsupported_models = ["nano-150m", "micro-400m", "small-1b"]
@@ -88,31 +90,60 @@ class OrpheusModel:
88
  raise ValueError(f"Voice {voice} is not available for model {self.model_name}")
89
 
90
  def _format_prompt(self, prompt, voice="Jakob", model_type="larger"):
91
- # Use Kartoffel model format based on documentation
92
- if voice:
93
- full_prompt = f"{voice}: {prompt}"
94
- else:
95
- full_prompt = prompt
 
 
 
 
 
 
 
 
96
 
97
- # Kartoffel model token format - direkt die Token-IDs einfügen
98
- start_token_id = 128259
99
- end_token_ids = [128009, 128260]
100
-
101
- # Text tokenisieren
102
- input_ids = self.tokenizer(full_prompt, return_tensors="pt").input_ids[0].tolist()
103
-
104
- # Token-IDs zusammenfügen
105
- all_token_ids = [start_token_id] + input_ids + end_token_ids
106
-
107
- # Zurück zu String dekodieren - aber die speziellen Token-IDs bleiben erhalten
108
- prompt_string = self.tokenizer.decode(all_token_ids, skip_special_tokens=False)
109
-
110
- return prompt_string
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
111
 
112
 
113
  def generate_tokens_sync(self, prompt, voice=None, request_id="req-001", temperature=0.6, top_p=0.95, max_tokens=4000, stop_token_ids = [128258], repetition_penalty=1.1):
114
- prompt_string = self._format_prompt(prompt, voice)
115
  print(f"DEBUG: Original prompt: {prompt}")
 
 
 
116
  print(f"DEBUG: Formatted prompt: {prompt_string}")
117
 
118
  sampling_params = SamplingParams(
@@ -171,12 +202,15 @@ class OrpheusModel:
171
  token_generator = self.generate_tokens_sync(**kwargs)
172
  print("DEBUG: Token generator created successfully")
173
 
174
- # Verwende Kartoffel-Decoder für deutsche Modelle
175
- if "german" in self.model_name.lower() or "kartoffel" in self.model_name.lower():
176
- print("DEBUG: Using Kartoffel decoder for German model")
177
  audio_generator = tokens_decoder_kartoffel_sync(token_generator, self.tokenizer)
 
 
 
178
  else:
179
- print("DEBUG: Using original decoder")
180
  audio_generator = tokens_decoder_sync(token_generator)
181
 
182
  print("DEBUG: Audio decoder called successfully")
 
14
  self.dtype = dtype
15
  self.engine_kwargs = engine_kwargs # vLLM engine kwargs
16
  self.engine = self._setup_engine()
17
+ # Available voices based on model type
18
+ if "kartoffel" in model_name.lower():
19
  self.available_voices = ["Jakob", "Anton", "Julian", "Sophie", "Marie", "Mia"]
20
+ elif "3b-de-ft" in model_name.lower():
21
+ self.available_voices = ["jana", "thomas", "max"]
22
  else:
23
  # Original English voices as fallback
24
  self.available_voices = ["zoe", "zac", "jess", "leo", "mia", "julia", "leah", "tara"]
 
59
  # "repo_id": "canopylabs/orpheus-tts-0.1-finetune-prod",
60
  # },
61
  "medium-3b":{
62
+ "repo_id": "canopylabs/3b-de-ft-research_release",
63
  },
64
  }
65
  unsupported_models = ["nano-150m", "micro-400m", "small-1b"]
 
90
  raise ValueError(f"Voice {voice} is not available for model {self.model_name}")
91
 
92
  def _format_prompt(self, prompt, voice="Jakob", model_type="larger"):
93
+ # Unterschiedliche Formate für verschiedene Modelle
94
+ print(f"DEBUG: Model name for format check: {self.model_name}")
95
+ if "kartoffel" in self.model_name.lower():
96
+ print("DEBUG: Using Kartoffel format")
97
+ # Kartoffel model format
98
+ if voice:
99
+ full_prompt = f"{voice}: {prompt}"
100
+ else:
101
+ full_prompt = prompt
102
+
103
+ # Kartoffel model token format - direkt die Token-IDs einfügen
104
+ start_token_id = 128259
105
+ end_token_ids = [128009, 128260]
106
 
107
+ # Text tokenisieren
108
+ input_ids = self.tokenizer(full_prompt, return_tensors="pt").input_ids[0].tolist()
109
+
110
+ # Token-IDs zusammenfügen
111
+ all_token_ids = [start_token_id] + input_ids + end_token_ids
112
+
113
+ # Zurück zu String dekodieren
114
+ prompt_string = self.tokenizer.decode(all_token_ids, skip_special_tokens=False)
115
+
116
+ return prompt_string
117
+ else:
118
+ # Original Orpheus format (für Canopy-Deutsch und English)
119
+ if model_type == "smaller":
120
+ if voice:
121
+ return f"<custom_token_3>{prompt}[{voice}]<custom_token_4><custom_token_5>"
122
+ else:
123
+ return f"<custom_token_3>{prompt}<custom_token_4><custom_token_5>"
124
+ else:
125
+ if voice:
126
+ adapted_prompt = f"{voice}: {prompt}"
127
+ prompt_tokens = self.tokenizer(adapted_prompt, return_tensors="pt")
128
+ start_token = torch.tensor([[ 128259]], dtype=torch.int64)
129
+ end_tokens = torch.tensor([[128009, 128260, 128261, 128257]], dtype=torch.int64)
130
+ all_input_ids = torch.cat([start_token, prompt_tokens.input_ids, end_tokens], dim=1)
131
+ prompt_string = self.tokenizer.decode(all_input_ids[0])
132
+ return prompt_string
133
+ else:
134
+ prompt_tokens = self.tokenizer(prompt, return_tensors="pt")
135
+ start_token = torch.tensor([[ 128259]], dtype=torch.int64)
136
+ end_tokens = torch.tensor([[128009, 128260, 128261, 128257]], dtype=torch.int64)
137
+ all_input_ids = torch.cat([start_token, prompt_tokens.input_ids, end_tokens], dim=1)
138
+ prompt_string = self.tokenizer.decode(all_input_ids[0])
139
+ return prompt_string
140
 
141
 
142
  def generate_tokens_sync(self, prompt, voice=None, request_id="req-001", temperature=0.6, top_p=0.95, max_tokens=4000, stop_token_ids = [128258], repetition_penalty=1.1):
 
143
  print(f"DEBUG: Original prompt: {prompt}")
144
+ print(f"DEBUG: Voice: {voice}")
145
+ print(f"DEBUG: Model name: {self.model_name}")
146
+ prompt_string = self._format_prompt(prompt, voice)
147
  print(f"DEBUG: Formatted prompt: {prompt_string}")
148
 
149
  sampling_params = SamplingParams(
 
202
  token_generator = self.generate_tokens_sync(**kwargs)
203
  print("DEBUG: Token generator created successfully")
204
 
205
+ # Verwende Kartoffel-Decoder nur für Kartoffel-Modell, Original-Decoder für Canopy-Deutsch
206
+ if "kartoffel" in self.model_name.lower():
207
+ print("DEBUG: Using Kartoffel decoder for Kartoffel model")
208
  audio_generator = tokens_decoder_kartoffel_sync(token_generator, self.tokenizer)
209
+ elif "3b-de-ft" in self.model_name.lower() or "german" in self.model_name.lower():
210
+ print("DEBUG: Using original decoder for Canopy German model")
211
+ audio_generator = tokens_decoder_sync(token_generator)
212
  else:
213
+ print("DEBUG: Using original decoder for English model")
214
  audio_generator = tokens_decoder_sync(token_generator)
215
 
216
  print("DEBUG: Audio decoder called successfully")