Tomtom84 commited on
Commit
63ebe58
·
verified ·
1 Parent(s): c3e502f

Create orpheus-tts/engine_class.py

Browse files
Files changed (1) hide show
  1. orpheus-tts/engine_class.py +146 -0
orpheus-tts/engine_class.py ADDED
@@ -0,0 +1,146 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import asyncio
2
+ import torch
3
+ import os
4
+ from vllm import AsyncLLMEngine, AsyncEngineArgs, SamplingParams
5
+ from transformers import AutoTokenizer
6
+ import threading
7
+ import queue
8
+ from .decoder import tokens_decoder_sync
9
+
10
+ class OrpheusModel:
11
+ def __init__(self, model_name, dtype=torch.bfloat16, tokenizer=None, **engine_kwargs):
12
+ self.model_name = self._map_model_params(model_name)
13
+ self.dtype = dtype
14
+ self.engine_kwargs = engine_kwargs # vLLM engine kwargs
15
+ self.engine = self._setup_engine()
16
+ # Available voices for German Kartoffel model
17
+ if "german" in model_name.lower() or "kartoffel" in model_name.lower():
18
+ self.available_voices = ["Jakob", "Anton", "Julian", "Sophie", "Marie", "Mia"]
19
+ else:
20
+ # Original English voices as fallback
21
+ self.available_voices = ["zoe", "zac", "jess", "leo", "mia", "julia", "leah", "tara"]
22
+
23
+ # Use provided tokenizer path or default to model_name
24
+ # For German models, try the model itself first, then fallback to original tokenizer
25
+ if tokenizer:
26
+ tokenizer_path = tokenizer
27
+ elif "german" in model_name.lower() or "kartoffel" in model_name.lower():
28
+ tokenizer_path = model_name # Try using the same model as tokenizer
29
+ else:
30
+ tokenizer_path = 'canopylabs/orpheus-3b-0.1-pretrained' # Original fallback
31
+
32
+ self.tokenizer = self._load_tokenizer(tokenizer_path)
33
+
34
+ def _load_tokenizer(self, tokenizer_path):
35
+ """Load tokenizer from local path or HuggingFace hub"""
36
+ try:
37
+ # Check if tokenizer_path is a local directory
38
+ if os.path.isdir(tokenizer_path):
39
+ return AutoTokenizer.from_pretrained(tokenizer_path, local_files_only=True)
40
+ else:
41
+ return AutoTokenizer.from_pretrained(tokenizer_path)
42
+ except Exception as e:
43
+ print(f"Error loading tokenizer: {e}")
44
+ print(f"Falling back to default tokenizer")
45
+ return AutoTokenizer.from_pretrained("gpt2")
46
+
47
+ def _map_model_params(self, model_name):
48
+ model_map = {
49
+ # "nano-150m":{
50
+ # "repo_id": "canopylabs/orpheus-tts-0.1-finetune-prod",
51
+ # },
52
+ # "micro-400m":{
53
+ # "repo_id": "canopylabs/orpheus-tts-0.1-finetune-prod",
54
+ # },
55
+ # "small-1b":{
56
+ # "repo_id": "canopylabs/orpheus-tts-0.1-finetune-prod",
57
+ # },
58
+ "medium-3b":{
59
+ "repo_id": "canopylabs/orpheus-tts-0.1-finetune-prod",
60
+ },
61
+ }
62
+ unsupported_models = ["nano-150m", "micro-400m", "small-1b"]
63
+ if (model_name in unsupported_models):
64
+ raise ValueError(f"Model {model_name} is not supported. Only medium-3b is supported, small, micro and nano models will be released very soon")
65
+ elif model_name in model_map:
66
+ return model_name[model_name]["repo_id"]
67
+ else:
68
+ return model_name
69
+
70
+ def _setup_engine(self):
71
+ engine_args = AsyncEngineArgs(
72
+ model=self.model_name,
73
+ dtype=self.dtype,
74
+ **self.engine_kwargs
75
+ )
76
+
77
+ return AsyncLLMEngine.from_engine_args(engine_args)
78
+
79
+ def validate_voice(self, voice):
80
+ if voice:
81
+ if voice not in self.engine.available_voices:
82
+ raise ValueError(f"Voice {voice} is not available for model {self.model_name}")
83
+
84
+ def _format_prompt(self, prompt, voice="tara", model_type="larger"):
85
+ if model_type == "smaller":
86
+ if voice:
87
+ return f"<custom_token_3>{prompt}[{voice}]<custom_token_4><custom_token_5>"
88
+ else:
89
+ return f"<custom_token_3>{prompt}<custom_token_4><custom_token_5>"
90
+ else:
91
+ if voice:
92
+ adapted_prompt = f"{voice}: {prompt}"
93
+ prompt_tokens = self.tokenizer(adapted_prompt, return_tensors="pt")
94
+ start_token = torch.tensor([[ 128259]], dtype=torch.int64)
95
+ end_tokens = torch.tensor([[128009, 128260, 128261, 128257]], dtype=torch.int64)
96
+ all_input_ids = torch.cat([start_token, prompt_tokens.input_ids, end_tokens], dim=1)
97
+ prompt_string = self.tokenizer.decode(all_input_ids[0])
98
+ return prompt_string
99
+ else:
100
+ prompt_tokens = self.tokenizer(prompt, return_tensors="pt")
101
+ start_token = torch.tensor([[ 128259]], dtype=torch.int64)
102
+ end_tokens = torch.tensor([[128009, 128260, 128261, 128257]], dtype=torch.int64)
103
+ all_input_ids = torch.cat([start_token, prompt_tokens.input_ids, end_tokens], dim=1)
104
+ prompt_string = self.tokenizer.decode(all_input_ids[0])
105
+ return prompt_string
106
+
107
+
108
+
109
+
110
+ def generate_tokens_sync(self, prompt, voice=None, request_id="req-001", temperature=0.6, top_p=0.8, max_tokens=1200, stop_token_ids = [49158], repetition_penalty=1.3):
111
+ prompt_string = self._format_prompt(prompt, voice)
112
+ print(prompt)
113
+ sampling_params = SamplingParams(
114
+ temperature=temperature,
115
+ top_p=top_p,
116
+ max_tokens=max_tokens, # Adjust max_tokens as needed.
117
+ stop_token_ids = stop_token_ids,
118
+ repetition_penalty=repetition_penalty,
119
+ )
120
+
121
+ token_queue = queue.Queue()
122
+
123
+ async def async_producer():
124
+ async for result in self.engine.generate(prompt=prompt_string, sampling_params=sampling_params, request_id=request_id):
125
+ # Place each token text into the queue.
126
+ token_queue.put(result.outputs[0].text)
127
+ token_queue.put(None) # Sentinel to indicate completion.
128
+
129
+ def run_async():
130
+ asyncio.run(async_producer())
131
+
132
+ thread = threading.Thread(target=run_async)
133
+ thread.start()
134
+
135
+ while True:
136
+ token = token_queue.get()
137
+ if token is None:
138
+ break
139
+ yield token
140
+
141
+ thread.join()
142
+
143
+ def generate_speech(self, **kwargs):
144
+ return tokens_decoder_sync(self.generate_tokens_sync(**kwargs))
145
+
146
+