Update orpheus-tts/engine_class.py
Browse files- orpheus-tts/engine_class.py +13 -5
orpheus-tts/engine_class.py
CHANGED
@@ -112,23 +112,31 @@ class OrpheusModel:
|
|
112 |
|
113 |
|
114 |
|
115 |
-
def generate_tokens_sync(self, prompt, voice=None, request_id="req-001", temperature=0.6, top_p=0.8, max_tokens=1200, stop_token_ids = [
|
116 |
prompt_string = self._format_prompt(prompt, voice)
|
117 |
-
print(prompt)
|
|
|
|
|
118 |
sampling_params = SamplingParams(
|
119 |
temperature=temperature,
|
120 |
top_p=top_p,
|
121 |
max_tokens=max_tokens, # Adjust max_tokens as needed.
|
122 |
-
stop_token_ids = stop_token_ids,
|
123 |
-
repetition_penalty=repetition_penalty,
|
124 |
)
|
125 |
|
126 |
token_queue = queue.Queue()
|
|
|
127 |
|
128 |
async def async_producer():
|
129 |
async for result in self.engine.generate(prompt=prompt_string, sampling_params=sampling_params, request_id=request_id):
|
130 |
# Place each token text into the queue.
|
131 |
-
|
|
|
|
|
|
|
|
|
|
|
132 |
token_queue.put(None) # Sentinel to indicate completion.
|
133 |
|
134 |
def run_async():
|
|
|
112 |
|
113 |
|
114 |
|
115 |
+
def generate_tokens_sync(self, prompt, voice=None, request_id="req-001", temperature=0.6, top_p=0.8, max_tokens=1200, stop_token_ids = [128258], repetition_penalty=1.3):
|
116 |
prompt_string = self._format_prompt(prompt, voice)
|
117 |
+
print(f"DEBUG: Original prompt: {prompt}")
|
118 |
+
print(f"DEBUG: Formatted prompt: {prompt_string}")
|
119 |
+
|
120 |
sampling_params = SamplingParams(
|
121 |
temperature=temperature,
|
122 |
top_p=top_p,
|
123 |
max_tokens=max_tokens, # Adjust max_tokens as needed.
|
124 |
+
stop_token_ids = stop_token_ids,
|
125 |
+
repetition_penalty=repetition_penalty,
|
126 |
)
|
127 |
|
128 |
token_queue = queue.Queue()
|
129 |
+
token_count = 0
|
130 |
|
131 |
async def async_producer():
|
132 |
async for result in self.engine.generate(prompt=prompt_string, sampling_params=sampling_params, request_id=request_id):
|
133 |
# Place each token text into the queue.
|
134 |
+
token_text = result.outputs[0].text
|
135 |
+
print(f"DEBUG: Generated token {token_count}: {repr(token_text)}")
|
136 |
+
token_queue.put(token_text)
|
137 |
+
nonlocal token_count
|
138 |
+
token_count += 1
|
139 |
+
print(f"DEBUG: Generation completed. Total tokens: {token_count}")
|
140 |
token_queue.put(None) # Sentinel to indicate completion.
|
141 |
|
142 |
def run_async():
|