Update orpheus-tts/engine_class.py
Browse files
orpheus-tts/engine_class.py
CHANGED
@@ -5,7 +5,7 @@ from vllm import AsyncLLMEngine, AsyncEngineArgs, SamplingParams
|
|
5 |
from transformers import AutoTokenizer
|
6 |
import threading
|
7 |
import queue
|
8 |
-
from
|
9 |
|
10 |
class OrpheusModel:
|
11 |
def __init__(self, model_name, dtype=torch.bfloat16, tokenizer=None, **engine_kwargs):
|
@@ -63,7 +63,7 @@ class OrpheusModel:
|
|
63 |
if (model_name in unsupported_models):
|
64 |
raise ValueError(f"Model {model_name} is not supported. Only medium-3b is supported, small, micro and nano models will be released very soon")
|
65 |
elif model_name in model_map:
|
66 |
-
return
|
67 |
else:
|
68 |
return model_name
|
69 |
|
|
|
5 |
from transformers import AutoTokenizer
|
6 |
import threading
|
7 |
import queue
|
8 |
+
from decoder import tokens_decoder_sync
|
9 |
|
10 |
class OrpheusModel:
|
11 |
def __init__(self, model_name, dtype=torch.bfloat16, tokenizer=None, **engine_kwargs):
|
|
|
63 |
if (model_name in unsupported_models):
|
64 |
raise ValueError(f"Model {model_name} is not supported. Only medium-3b is supported, small, micro and nano models will be released very soon")
|
65 |
elif model_name in model_map:
|
66 |
+
return model_map[model_name]["repo_id"]
|
67 |
else:
|
68 |
return model_name
|
69 |
|