Tomtom84 commited on
Commit
d662a4e
·
verified ·
1 Parent(s): d660fca

Update orpheus-tts/engine_class.py

Browse files
Files changed (1) hide show
  1. orpheus-tts/engine_class.py +16 -24
orpheus-tts/engine_class.py CHANGED
@@ -86,33 +86,25 @@ class OrpheusModel:
86
  if voice not in self.engine.available_voices:
87
  raise ValueError(f"Voice {voice} is not available for model {self.model_name}")
88
 
89
- def _format_prompt(self, prompt, voice="tara", model_type="larger"):
90
- if model_type == "smaller":
91
- if voice:
92
- return f"<custom_token_3>{prompt}[{voice}]<custom_token_4><custom_token_5>"
93
- else:
94
- return f"<custom_token_3>{prompt}<custom_token_4><custom_token_5>"
95
  else:
96
- if voice:
97
- adapted_prompt = f"{voice}: {prompt}"
98
- prompt_tokens = self.tokenizer(adapted_prompt, return_tensors="pt")
99
- start_token = torch.tensor([[ 128259]], dtype=torch.int64)
100
- end_tokens = torch.tensor([[128009, 128260, 128261, 128257]], dtype=torch.int64)
101
- all_input_ids = torch.cat([start_token, prompt_tokens.input_ids, end_tokens], dim=1)
102
- prompt_string = self.tokenizer.decode(all_input_ids[0])
103
- return prompt_string
104
- else:
105
- prompt_tokens = self.tokenizer(prompt, return_tensors="pt")
106
- start_token = torch.tensor([[ 128259]], dtype=torch.int64)
107
- end_tokens = torch.tensor([[128009, 128260, 128261, 128257]], dtype=torch.int64)
108
- all_input_ids = torch.cat([start_token, prompt_tokens.input_ids, end_tokens], dim=1)
109
- prompt_string = self.tokenizer.decode(all_input_ids[0])
110
- return prompt_string
111
 
112
 
113
-
114
-
115
- def generate_tokens_sync(self, prompt, voice=None, request_id="req-001", temperature=0.6, top_p=0.8, max_tokens=1200, stop_token_ids = [128258], repetition_penalty=1.3):
116
  prompt_string = self._format_prompt(prompt, voice)
117
  print(f"DEBUG: Original prompt: {prompt}")
118
  print(f"DEBUG: Formatted prompt: {prompt_string}")
 
86
  if voice not in self.engine.available_voices:
87
  raise ValueError(f"Voice {voice} is not available for model {self.model_name}")
88
 
89
+ def _format_prompt(self, prompt, voice="Sophie", model_type="larger"):
90
+ # Use Kartoffel model format based on documentation
91
+ if voice:
92
+ full_prompt = f"{voice}: {prompt}"
 
 
93
  else:
94
+ full_prompt = prompt
95
+
96
+ # Kartoffel model token format
97
+ start_token = torch.tensor([[128259]], dtype=torch.int64)
98
+ end_tokens = torch.tensor([[128009, 128260]], dtype=torch.int64)
99
+
100
+ input_ids = self.tokenizer(full_prompt, return_tensors="pt").input_ids
101
+ modified_input_ids = torch.cat([start_token, input_ids, end_tokens], dim=1)
102
+
103
+ prompt_string = self.tokenizer.decode(modified_input_ids[0])
104
+ return prompt_string
 
 
 
 
105
 
106
 
107
+ def generate_tokens_sync(self, prompt, voice=None, request_id="req-001", temperature=0.6, top_p=0.95, max_tokens=4000, stop_token_ids = [128258], repetition_penalty=1.1):
 
 
108
  prompt_string = self._format_prompt(prompt, voice)
109
  print(f"DEBUG: Original prompt: {prompt}")
110
  print(f"DEBUG: Formatted prompt: {prompt_string}")