Tomtom84 commited on
Commit
9369755
·
verified ·
1 Parent(s): d4e3b98

Update orpheus-tts/engine_class.py

Browse files
Files changed (1) hide show
  1. orpheus-tts/engine_class.py +11 -7
orpheus-tts/engine_class.py CHANGED
@@ -94,15 +94,19 @@ class OrpheusModel:
94
  else:
95
  full_prompt = prompt
96
 
97
- # Kartoffel model token format
98
- start_token = torch.tensor([[128259]], dtype=torch.int64)
99
- end_tokens = torch.tensor([[128009, 128260]], dtype=torch.int64)
100
 
101
- input_ids = self.tokenizer(full_prompt, return_tensors="pt").input_ids
102
- modified_input_ids = torch.cat([start_token, input_ids, end_tokens], dim=1)
 
 
 
 
 
 
103
 
104
- # Decode back to string for vLLM
105
- prompt_string = self.tokenizer.decode(modified_input_ids[0], skip_special_tokens=False)
106
  return prompt_string
107
 
108
 
 
94
  else:
95
  full_prompt = prompt
96
 
97
+ # Kartoffel model token format - direkt die Token-IDs einfügen
98
+ start_token_id = 128259
99
+ end_token_ids = [128009, 128260]
100
 
101
+ # Text tokenisieren
102
+ input_ids = self.tokenizer(full_prompt, return_tensors="pt").input_ids[0].tolist()
103
+
104
+ # Token-IDs zusammenfügen
105
+ all_token_ids = [start_token_id] + input_ids + end_token_ids
106
+
107
+ # Zurück zu String dekodieren - aber die speziellen Token-IDs bleiben erhalten
108
+ prompt_string = self.tokenizer.decode(all_token_ids, skip_special_tokens=False)
109
 
 
 
110
  return prompt_string
111
 
112