Update orpheus-tts/engine_class.py
Browse files- orpheus-tts/engine_class.py +24 -9
orpheus-tts/engine_class.py
CHANGED
@@ -100,7 +100,8 @@ class OrpheusModel:
|
|
100 |
input_ids = self.tokenizer(full_prompt, return_tensors="pt").input_ids
|
101 |
modified_input_ids = torch.cat([start_token, input_ids, end_tokens], dim=1)
|
102 |
|
103 |
-
|
|
|
104 |
return prompt_string
|
105 |
|
106 |
|
@@ -122,14 +123,28 @@ class OrpheusModel:
|
|
122 |
|
123 |
async def async_producer():
|
124 |
nonlocal token_count
|
125 |
-
|
126 |
-
|
127 |
-
|
128 |
-
|
129 |
-
|
130 |
-
|
131 |
-
|
132 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
133 |
|
134 |
def run_async():
|
135 |
asyncio.run(async_producer())
|
|
|
100 |
input_ids = self.tokenizer(full_prompt, return_tensors="pt").input_ids
|
101 |
modified_input_ids = torch.cat([start_token, input_ids, end_tokens], dim=1)
|
102 |
|
103 |
+
# Decode back to string for vLLM
|
104 |
+
prompt_string = self.tokenizer.decode(modified_input_ids[0], skip_special_tokens=False)
|
105 |
return prompt_string
|
106 |
|
107 |
|
|
|
123 |
|
124 |
async def async_producer():
|
125 |
nonlocal token_count
|
126 |
+
print(f"DEBUG: Starting vLLM generation with prompt: {repr(prompt_string[:100])}...")
|
127 |
+
print(f"DEBUG: Sampling params: temp={sampling_params.temperature}, top_p={sampling_params.top_p}, max_tokens={sampling_params.max_tokens}")
|
128 |
+
|
129 |
+
try:
|
130 |
+
async for result in self.engine.generate(prompt=prompt_string, sampling_params=sampling_params, request_id=request_id):
|
131 |
+
# Place each token text into the queue.
|
132 |
+
token_text = result.outputs[0].text
|
133 |
+
print(f"DEBUG: Generated token {token_count}: {repr(token_text)}")
|
134 |
+
token_queue.put(token_text)
|
135 |
+
token_count += 1
|
136 |
+
|
137 |
+
# Show progress every 10 tokens
|
138 |
+
if token_count % 10 == 0:
|
139 |
+
print(f"DEBUG: Generated {token_count} tokens so far...")
|
140 |
+
|
141 |
+
print(f"DEBUG: Generation completed. Total tokens: {token_count}")
|
142 |
+
except Exception as e:
|
143 |
+
print(f"DEBUG: Error during generation: {e}")
|
144 |
+
import traceback
|
145 |
+
traceback.print_exc()
|
146 |
+
finally:
|
147 |
+
token_queue.put(None) # Sentinel to indicate completion.
|
148 |
|
149 |
def run_async():
|
150 |
asyncio.run(async_producer())
|