Tomtom84 commited on
Commit
90d77aa
·
verified ·
1 Parent(s): 3fee54e

Update orpheus-tts/engine_class.py

Browse files
Files changed (1) hide show
  1. orpheus-tts/engine_class.py +24 -9
orpheus-tts/engine_class.py CHANGED
@@ -100,7 +100,8 @@ class OrpheusModel:
100
  input_ids = self.tokenizer(full_prompt, return_tensors="pt").input_ids
101
  modified_input_ids = torch.cat([start_token, input_ids, end_tokens], dim=1)
102
 
103
- prompt_string = self.tokenizer.decode(modified_input_ids[0])
 
104
  return prompt_string
105
 
106
 
@@ -122,14 +123,28 @@ class OrpheusModel:
122
 
123
  async def async_producer():
124
  nonlocal token_count
125
- async for result in self.engine.generate(prompt=prompt_string, sampling_params=sampling_params, request_id=request_id):
126
- # Place each token text into the queue.
127
- token_text = result.outputs[0].text
128
- print(f"DEBUG: Generated token {token_count}: {repr(token_text)}")
129
- token_queue.put(token_text)
130
- token_count += 1
131
- print(f"DEBUG: Generation completed. Total tokens: {token_count}")
132
- token_queue.put(None) # Sentinel to indicate completion.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
133
 
134
  def run_async():
135
  asyncio.run(async_producer())
 
100
  input_ids = self.tokenizer(full_prompt, return_tensors="pt").input_ids
101
  modified_input_ids = torch.cat([start_token, input_ids, end_tokens], dim=1)
102
 
103
+ # Decode back to string for vLLM
104
+ prompt_string = self.tokenizer.decode(modified_input_ids[0], skip_special_tokens=False)
105
  return prompt_string
106
 
107
 
 
123
 
124
  async def async_producer():
125
  nonlocal token_count
126
+ print(f"DEBUG: Starting vLLM generation with prompt: {repr(prompt_string[:100])}...")
127
+ print(f"DEBUG: Sampling params: temp={sampling_params.temperature}, top_p={sampling_params.top_p}, max_tokens={sampling_params.max_tokens}")
128
+
129
+ try:
130
+ async for result in self.engine.generate(prompt=prompt_string, sampling_params=sampling_params, request_id=request_id):
131
+ # Place each token text into the queue.
132
+ token_text = result.outputs[0].text
133
+ print(f"DEBUG: Generated token {token_count}: {repr(token_text)}")
134
+ token_queue.put(token_text)
135
+ token_count += 1
136
+
137
+ # Show progress every 10 tokens
138
+ if token_count % 10 == 0:
139
+ print(f"DEBUG: Generated {token_count} tokens so far...")
140
+
141
+ print(f"DEBUG: Generation completed. Total tokens: {token_count}")
142
+ except Exception as e:
143
+ print(f"DEBUG: Error during generation: {e}")
144
+ import traceback
145
+ traceback.print_exc()
146
+ finally:
147
+ token_queue.put(None) # Sentinel to indicate completion.
148
 
149
  def run_async():
150
  asyncio.run(async_producer())