Update orpheus-tts/engine_class.py
Browse files- orpheus-tts/engine_class.py +15 -14
orpheus-tts/engine_class.py
CHANGED
@@ -100,25 +100,26 @@ class OrpheusModel:
|
|
100 |
else:
|
101 |
full_prompt = prompt
|
102 |
|
103 |
-
# Kartoffel model
|
104 |
-
|
105 |
-
start_token_id = 128259 # Für Prompt-Start
|
106 |
-
end_token_ids = [128009, 128260] # Für Prompt-Ende
|
107 |
|
108 |
-
|
109 |
-
|
110 |
-
|
|
|
|
|
|
|
111 |
|
112 |
# Token-IDs zusammenfügen
|
113 |
-
|
114 |
-
print(f"DEBUG KARTOFFEL:
|
|
|
115 |
|
116 |
-
#
|
117 |
-
|
118 |
-
|
119 |
-
print(f"DEBUG KARTOFFEL: token_string: {token_string}")
|
120 |
|
121 |
-
return
|
122 |
else:
|
123 |
# Original Orpheus format (für Canopy-Deutsch und English)
|
124 |
if model_type == "smaller":
|
|
|
100 |
else:
|
101 |
full_prompt = prompt
|
102 |
|
103 |
+
# Kartoffel model format - exakt wie in der Referenz-Implementierung
|
104 |
+
import torch
|
|
|
|
|
105 |
|
106 |
+
start_token = torch.tensor([[128259]], dtype=torch.int64)
|
107 |
+
end_tokens = torch.tensor([[128009, 128260]], dtype=torch.int64)
|
108 |
+
input_ids = self.tokenizer(full_prompt, return_tensors="pt").input_ids
|
109 |
+
|
110 |
+
print(f"DEBUG KARTOFFEL: Original prompt: '{full_prompt}'")
|
111 |
+
print(f"DEBUG KARTOFFEL: input_ids shape: {input_ids.shape}")
|
112 |
|
113 |
# Token-IDs zusammenfügen
|
114 |
+
modified_input_ids = torch.cat([start_token, input_ids, end_tokens], dim=1)
|
115 |
+
print(f"DEBUG KARTOFFEL: modified_input_ids shape: {modified_input_ids.shape}")
|
116 |
+
print(f"DEBUG KARTOFFEL: modified_input_ids: {modified_input_ids[0].tolist()}")
|
117 |
|
118 |
+
# Zurück zu Text dekodieren - EXAKT wie in der Referenz
|
119 |
+
decoded_text = self.tokenizer.decode(modified_input_ids[0], skip_special_tokens=False)
|
120 |
+
print(f"DEBUG KARTOFFEL: Final decoded prompt: '{decoded_text}'")
|
|
|
121 |
|
122 |
+
return decoded_text
|
123 |
else:
|
124 |
# Original Orpheus format (für Canopy-Deutsch und English)
|
125 |
if model_type == "smaller":
|