Tomtom84 commited on
Commit
d660fca
·
verified ·
1 Parent(s): e97cf4d

Update orpheus-tts/engine_class.py

Browse files
Files changed (1) hide show
  1. orpheus-tts/engine_class.py +13 -5
orpheus-tts/engine_class.py CHANGED
@@ -112,23 +112,31 @@ class OrpheusModel:
112
 
113
 
114
 
115
- def generate_tokens_sync(self, prompt, voice=None, request_id="req-001", temperature=0.6, top_p=0.8, max_tokens=1200, stop_token_ids = [49158], repetition_penalty=1.3):
116
  prompt_string = self._format_prompt(prompt, voice)
117
- print(prompt)
 
 
118
  sampling_params = SamplingParams(
119
  temperature=temperature,
120
  top_p=top_p,
121
  max_tokens=max_tokens, # Adjust max_tokens as needed.
122
- stop_token_ids = stop_token_ids,
123
- repetition_penalty=repetition_penalty,
124
  )
125
 
126
  token_queue = queue.Queue()
 
127
 
128
  async def async_producer():
129
  async for result in self.engine.generate(prompt=prompt_string, sampling_params=sampling_params, request_id=request_id):
130
  # Place each token text into the queue.
131
- token_queue.put(result.outputs[0].text)
 
 
 
 
 
132
  token_queue.put(None) # Sentinel to indicate completion.
133
 
134
  def run_async():
 
112
 
113
 
114
 
115
+ def generate_tokens_sync(self, prompt, voice=None, request_id="req-001", temperature=0.6, top_p=0.8, max_tokens=1200, stop_token_ids = [128258], repetition_penalty=1.3):
116
  prompt_string = self._format_prompt(prompt, voice)
117
+ print(f"DEBUG: Original prompt: {prompt}")
118
+ print(f"DEBUG: Formatted prompt: {prompt_string}")
119
+
120
  sampling_params = SamplingParams(
121
  temperature=temperature,
122
  top_p=top_p,
123
  max_tokens=max_tokens, # Adjust max_tokens as needed.
124
+ stop_token_ids = stop_token_ids,
125
+ repetition_penalty=repetition_penalty,
126
  )
127
 
128
  token_queue = queue.Queue()
129
+ token_count = 0
130
 
131
  async def async_producer():
132
  async for result in self.engine.generate(prompt=prompt_string, sampling_params=sampling_params, request_id=request_id):
133
  # Place each token text into the queue.
134
+ token_text = result.outputs[0].text
135
+ print(f"DEBUG: Generated token {token_count}: {repr(token_text)}")
136
+ token_queue.put(token_text)
137
+ nonlocal token_count
138
+ token_count += 1
139
+ print(f"DEBUG: Generation completed. Total tokens: {token_count}")
140
  token_queue.put(None) # Sentinel to indicate completion.
141
 
142
  def run_async():