Update orpheus-tts/engine_class.py
Browse files- orpheus-tts/engine_class.py +16 -24
orpheus-tts/engine_class.py
CHANGED
@@ -86,33 +86,25 @@ class OrpheusModel:
|
|
86 |
if voice not in self.engine.available_voices:
|
87 |
raise ValueError(f"Voice {voice} is not available for model {self.model_name}")
|
88 |
|
89 |
-
def _format_prompt(self, prompt, voice="
|
90 |
-
|
91 |
-
|
92 |
-
|
93 |
-
else:
|
94 |
-
return f"<custom_token_3>{prompt}<custom_token_4><custom_token_5>"
|
95 |
else:
|
96 |
-
|
97 |
-
|
98 |
-
|
99 |
-
|
100 |
-
|
101 |
-
|
102 |
-
|
103 |
-
|
104 |
-
|
105 |
-
|
106 |
-
|
107 |
-
end_tokens = torch.tensor([[128009, 128260, 128261, 128257]], dtype=torch.int64)
|
108 |
-
all_input_ids = torch.cat([start_token, prompt_tokens.input_ids, end_tokens], dim=1)
|
109 |
-
prompt_string = self.tokenizer.decode(all_input_ids[0])
|
110 |
-
return prompt_string
|
111 |
|
112 |
|
113 |
-
|
114 |
-
|
115 |
-
def generate_tokens_sync(self, prompt, voice=None, request_id="req-001", temperature=0.6, top_p=0.8, max_tokens=1200, stop_token_ids = [128258], repetition_penalty=1.3):
|
116 |
prompt_string = self._format_prompt(prompt, voice)
|
117 |
print(f"DEBUG: Original prompt: {prompt}")
|
118 |
print(f"DEBUG: Formatted prompt: {prompt_string}")
|
|
|
86 |
if voice not in self.engine.available_voices:
|
87 |
raise ValueError(f"Voice {voice} is not available for model {self.model_name}")
|
88 |
|
89 |
+
def _format_prompt(self, prompt, voice="Sophie", model_type="larger"):
|
90 |
+
# Use Kartoffel model format based on documentation
|
91 |
+
if voice:
|
92 |
+
full_prompt = f"{voice}: {prompt}"
|
|
|
|
|
93 |
else:
|
94 |
+
full_prompt = prompt
|
95 |
+
|
96 |
+
# Kartoffel model token format
|
97 |
+
start_token = torch.tensor([[128259]], dtype=torch.int64)
|
98 |
+
end_tokens = torch.tensor([[128009, 128260]], dtype=torch.int64)
|
99 |
+
|
100 |
+
input_ids = self.tokenizer(full_prompt, return_tensors="pt").input_ids
|
101 |
+
modified_input_ids = torch.cat([start_token, input_ids, end_tokens], dim=1)
|
102 |
+
|
103 |
+
prompt_string = self.tokenizer.decode(modified_input_ids[0])
|
104 |
+
return prompt_string
|
|
|
|
|
|
|
|
|
105 |
|
106 |
|
107 |
+
def generate_tokens_sync(self, prompt, voice=None, request_id="req-001", temperature=0.6, top_p=0.95, max_tokens=4000, stop_token_ids = [128258], repetition_penalty=1.1):
|
|
|
|
|
108 |
prompt_string = self._format_prompt(prompt, voice)
|
109 |
print(f"DEBUG: Original prompt: {prompt}")
|
110 |
print(f"DEBUG: Formatted prompt: {prompt_string}")
|