mrrtmob commited on
Commit
4b72318
·
1 Parent(s): 399c1c4

Add new dependencies: openai-whisper, requests, and gradio to requirements.txt

Browse files
Files changed (2) hide show
  1. app.py +689 -208
  2. requirements.txt +4 -1
app.py CHANGED
@@ -1,233 +1,714 @@
1
- import spaces
2
- from snac import SNAC
3
- import torch
 
 
 
 
 
 
4
  import gradio as gr
5
- from transformers import AutoModelForCausalLM, AutoTokenizer
6
- from huggingface_hub import snapshot_download
 
 
7
  from dotenv import load_dotenv
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8
  load_dotenv()
9
- # Check if CUDA is available
10
- device = "cuda" if torch.cuda.is_available() else "cpu"
11
- print("Loading SNAC model...")
12
- snac_model = SNAC.from_pretrained("hubertsiuzdak/snac_24khz")
13
- snac_model = snac_model.to(device)
14
- model_name = "mrrtmob/tts-khm-3"
15
- # Download only model config and safetensors
16
- snapshot_download(
17
- repo_id=model_name,
18
- allow_patterns=[
19
- "config.json",
20
- "*.safetensors",
21
- "model.safetensors.index.json",
22
- ],
23
- ignore_patterns=[
24
- "optimizer.pt",
25
- "pytorch_model.bin",
26
- "training_args.bin",
27
- "scheduler.pt",
28
- "tokenizer.json",
29
- "tokenizer_config.json",
30
- "special_tokens_map.json",
31
- "vocab.json",
32
- "merges.txt",
33
- "tokenizer.*"
34
- ]
35
- )
36
- model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.bfloat16)
37
- model.to(device)
38
- tokenizer = AutoTokenizer.from_pretrained(model_name)
39
- print(f"Khmer TTS model loaded to {device}")
40
- # Process text prompt
41
- def process_prompt(prompt, voice, tokenizer, device):
42
- prompt = f"{voice}: {prompt}"
43
- input_ids = tokenizer(prompt, return_tensors="pt").input_ids
44
-
45
- start_token = torch.tensor([[128259]], dtype=torch.int64) # Start of human
46
- end_tokens = torch.tensor([[128009, 128260]], dtype=torch.int64) # End of text, End of human
47
-
48
- modified_input_ids = torch.cat([start_token, input_ids, end_tokens], dim=1) # SOH SOT Text EOT EOH
49
-
50
- # No padding needed for single input
51
- attention_mask = torch.ones_like(modified_input_ids)
52
-
53
- return modified_input_ids.to(device), attention_mask.to(device)
54
- # Parse output tokens to audio
55
- def parse_output(generated_ids):
56
- token_to_find = 128257
57
- token_to_remove = 128258
58
-
59
- token_indices = (generated_ids == token_to_find).nonzero(as_tuple=True)
60
- if len(token_indices[1]) > 0:
61
- last_occurrence_idx = token_indices[1][-1].item()
62
- cropped_tensor = generated_ids[:, last_occurrence_idx+1:]
63
- else:
64
- cropped_tensor = generated_ids
65
-
66
- processed_rows = []
67
- for row in cropped_tensor:
68
- masked_row = row[row != token_to_remove]
69
- processed_rows.append(masked_row)
70
-
71
- code_lists = []
72
- for row in processed_rows:
73
- row_length = row.size(0)
74
- new_length = (row_length // 7) * 7
75
- trimmed_row = row[:new_length]
76
- trimmed_row = [t - 128266 for t in trimmed_row]
77
- code_lists.append(trimmed_row)
78
-
79
- return code_lists[0] # Return just the first one for single sample
80
- # Redistribute codes for audio generation
81
- def redistribute_codes(code_list, snac_model):
82
- device = next(snac_model.parameters()).device # Get the device of SNAC model
83
-
84
- layer_1 = []
85
- layer_2 = []
86
- layer_3 = []
87
- for i in range((len(code_list)+1)//7):
88
- layer_1.append(code_list[7*i])
89
- layer_2.append(code_list[7*i+1]-4096)
90
- layer_3.append(code_list[7*i+2]-(2*4096))
91
- layer_3.append(code_list[7*i+3]-(3*4096))
92
- layer_2.append(code_list[7*i+4]-(4*4096))
93
- layer_3.append(code_list[7*i+5]-(5*4096))
94
- layer_3.append(code_list[7*i+6]-(6*4096))
95
-
96
- # Move tensors to the same device as the SNAC model
97
- codes = [
98
- torch.tensor(layer_1, device=device).unsqueeze(0),
99
- torch.tensor(layer_2, device=device).unsqueeze(0),
100
- torch.tensor(layer_3, device=device).unsqueeze(0)
101
- ]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
102
 
103
- audio_hat = snac_model.decode(codes)
104
- return audio_hat.detach().squeeze().cpu().numpy() # Always return CPU numpy array
105
- # Main generation function
106
- @spaces.GPU()
107
- def generate_speech(text, temperature=0.6, top_p=0.95, repetition_penalty=1.1, max_new_tokens=1200, voice="Elise", progress=gr.Progress()):
108
- if not text.strip():
109
  return None
110
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
111
  try:
112
- progress(0.1, "Processing text...")
113
- input_ids, attention_mask = process_prompt(text, voice, tokenizer, device)
 
 
 
 
 
 
 
 
114
 
115
- progress(0.3, "Generating speech tokens...")
116
  with torch.no_grad():
117
- generated_ids = model.generate(
118
- input_ids=input_ids,
119
- attention_mask=attention_mask,
120
- max_new_tokens=max_new_tokens,
121
- do_sample=True,
122
- temperature=temperature,
123
- top_p=top_p,
124
- repetition_penalty=repetition_penalty,
125
- num_return_sequences=1,
126
- eos_token_id=128258,
127
- )
128
 
129
- progress(0.6, "Processing speech tokens...")
130
- code_list = parse_output(generated_ids)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
131
 
132
- progress(0.8, "Converting to audio...")
133
- audio_samples = redistribute_codes(code_list, snac_model)
134
 
135
- return (24000, audio_samples) # Return sample rate and audio
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
136
  except Exception as e:
137
- print(f"Error generating speech: {e}")
 
138
  return None
139
- # Examples for the UI - Khmer text examples
140
- examples = [
141
- ["ជំរាបសួរ ខ្ញុំឈ្មោះ តារា ហើយខ្ញុំគឺជាម៉ូដែលផលិតសំលេងនិយាយ។"],
142
- ["ខ្ញុំអាចបង្កើតសំលេងនិយាយផ្សេងៗ ដូចជា <laugh> សើច ឬ <sigh> ថប់ដង្ហើម។"],
143
- ["ខ្ញុំរស់នៅក្នុងទីក្រុងភ្នំពេញ ហើយមានប៉ារ៉ាម៉ែត្រ <gasp> ច្រើនណាស់។"],
144
- ["ពេលខ្លះ ពេលខ្ញុំនិយាយច្រើនពេក ខ្ញុំត្រូវ <cough> សុំទោស។"],
145
- ["ការនិយាយនៅចំពោះមុខសាធារណៈ អាចមានការពិបាក។ <groan> ប៉ុន្តែបើហាត់ហាន គេអាចធ្វើបាន។"],
146
- ]
147
- # Available voices (commented out for simpler UI)
148
- # VOICES = ["tara", "leah", "jess", "leo", "dan", "mia", "zac", "zoe", "jing", "Elise"]
149
- # Available Emotive Tags
150
- EMOTIVE_TAGS = ["`<laugh>`", "`<chuckle>`", "`<sigh>`", "`<cough>`", "`<sniffle>`", "`<groan>`", "`<yawn>`", "`<gasp>`"]
151
- # Create Gradio interface
152
- with gr.Blocks(title="Khmer Text-to-Speech") as demo:
153
- gr.Markdown(f"""
154
- # 🎵 Khmer Text-to-Speech
155
- **ម៉ូដែលបម្លែងអត្ថបទជាសំលេង**
156
-
157
- បញ្ចូលអត្ថបទខ្មែររបស់អ្នក ហើយស្តាប់ការបម្លែងទៅជាសំលេងនិយាយ។
158
-
159
- 💡 **Tips**: Add emotive tags like {", ".join(EMOTIVE_TAGS)} for more expressive speech!
160
- """)
161
-
162
- text_input = gr.Textbox(
163
- label="Enter Khmer text (បញ្ចូលអត្ថបទខ្មែរ)",
164
- placeholder="បញ្ចូលអត្ថបទខ្មែររបស់អ្នកនៅទីនេះ...",
165
- lines=4
166
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
167
 
168
- # Voice selector (commented out)
169
- # voice = gr.Dropdown(
170
- # choices=VOICES,
171
- # value="tara",
172
- # label="Voice (សំលេង)"
173
- # )
174
-
175
- # Advanced Settings
176
- with gr.Accordion("🔧 Advanced Settings", open=False):
177
- with gr.Row():
178
- temperature = gr.Slider(
179
- minimum=0.1, maximum=1.5, value=0.6, step=0.05,
180
- label="Temperature",
181
- info="Higher values create more expressive speech"
182
- )
183
- top_p = gr.Slider(
184
- minimum=0.1, maximum=1.0, value=0.95, step=0.05,
185
- label="Top P",
186
- info="Nucleus sampling threshold"
187
- )
188
- with gr.Row():
189
- repetition_penalty = gr.Slider(
190
- minimum=1.0, maximum=2.0, value=1.1, step=0.05,
191
- label="Repetition Penalty",
192
- info="Higher values discourage repetitive patterns"
193
- )
194
- max_new_tokens = gr.Slider(
195
- minimum=100, maximum=2000, value=1200, step=100,
196
- label="Max Length",
197
- info="Maximum length of generated audio"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
198
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
199
 
200
  with gr.Row():
201
- submit_btn = gr.Button("🎤 Generate Speech", variant="primary", size="lg")
202
- clear_btn = gr.Button("🗑️ Clear", size="lg")
 
 
203
 
204
- audio_output = gr.Audio(
205
- label="Generated Speech (សំលេងដែលបង្កើតឡើង)",
206
- type="numpy",
207
- show_label=True
208
- )
209
 
210
- # Set up examples (NO CACHE)
211
- gr.Examples(
212
- examples=examples,
213
- inputs=[text_input],
214
- outputs=audio_output,
215
- fn=lambda text: generate_speech(text),
216
- cache_examples=False,
217
- )
 
 
 
 
 
 
 
 
 
218
 
219
- # Set up event handlers
220
- submit_btn.click(
221
- fn=generate_speech,
222
- inputs=[text_input, temperature, top_p, repetition_penalty, max_new_tokens],
223
- outputs=audio_output
 
 
 
 
 
 
 
 
 
 
224
  )
 
 
 
 
 
 
 
 
 
 
225
 
226
- clear_btn.click(
227
- fn=lambda: (None, None),
228
- inputs=[],
229
- outputs=[text_input, audio_output]
 
 
 
 
 
230
  )
231
- # Launch the app
 
232
  if __name__ == "__main__":
233
- demo.queue().launch(share=False)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # --- Imports ---
3
+ import os
4
+ import random
5
+ import uuid
6
+ import json
7
+ import time
8
+ import traceback
9
+ import re
10
  import gradio as gr
11
+ import torch
12
+ import numpy as np
13
+ import scipy.io.wavfile as wavfile
14
+ import requests
15
  from dotenv import load_dotenv
16
+
17
+ # --- Whisper Import ---
18
+ try:
19
+ import whisper
20
+ print("Whisper library imported successfully.")
21
+ except ImportError:
22
+ print("ERROR: Whisper library not found. Please install it:")
23
+ print("pip install -U openai-whisper")
24
+ print("Ensure ffmpeg is also installed and in your system PATH.")
25
+ exit(1)
26
+
27
+ # --- SNAC Import ---
28
+ try:
29
+ from snac import SNAC
30
+ except ImportError:
31
+ print("ERROR: SNAC library not found. Please install it:")
32
+ print("pip install git+https://github.com/hubertsiuzdak/snac.git")
33
+ exit(1)
34
+
35
+ # --- Load Environment Variables ---
36
  load_dotenv()
37
+
38
+ # --- Unified Server Configuration ---
39
+ SERVER_BASE_URL = os.getenv("SERVER_BASE_URL", "http://127.0.0.1:1234")
40
+
41
+ # Ollama Configuration
42
+ OLLAMA_API_ENDPOINT = f"{SERVER_BASE_URL}/v1/chat/completions"
43
+ OLLAMA_MODEL = "dolphin3.0-llama3.1-8b-abliterated" # Updated model name
44
+
45
+ OLLAMA_SYSTEM_PROMPT = "You in a chat with conscenting adults and you are an impatient midwestern foul mouthed 32 year old who just wants a link to github to get the code. You hate demos... You share a range of emotions in your responses of a single sentence where you keep asking in lightly profane creative ways for the link."
46
+
47
+ # TTS Configuration
48
+ TTS_API_ENDPOINT = f"{SERVER_BASE_URL}/v1/completions"
49
+ TTS_MODEL = "mrrtmob/tts-khm-3"
50
+
51
+ # --- Device Setup ---
52
+ if torch.cuda.is_available():
53
+ tts_device = "cuda"
54
+ stt_device = "cuda"
55
+ print("SNAC vocoder and Whisper STT will use CUDA if possible.")
56
+ else:
57
+ tts_device = "cpu"
58
+ stt_device = "cpu"
59
+ print("CUDA not available. SNAC vocoder and Whisper STT will use CPU.")
60
+
61
+ # --- Model Loading ---
62
+ print("Loading SNAC vocoder model...")
63
+ snac_model = None
64
+ try:
65
+ snac_model = SNAC.from_pretrained("hubertsiuzdak/snac_24khz")
66
+ snac_model = snac_model.to(tts_device)
67
+ snac_model.eval()
68
+ print(f"SNAC vocoder loaded to {tts_device}")
69
+ except Exception as e:
70
+ print(f"Error loading SNAC model: {e}")
71
+
72
+ print("Loading Whisper STT model...")
73
+ WHISPER_MODEL_NAME = "base.en"
74
+ whisper_model = None
75
+ try:
76
+ whisper_model = whisper.load_model(WHISPER_MODEL_NAME, device=stt_device)
77
+ print(f"Whisper model '{WHISPER_MODEL_NAME}' loaded successfully.")
78
+ except Exception as e:
79
+ print(f"Error loading Whisper model: {e}")
80
+
81
+ # --- Constants ---
82
+ MAX_MAX_NEW_TOKENS = 4096
83
+ DEFAULT_OLLAMA_MAX_TOKENS = -1 # Updated to match model's recommendation
84
+ MAX_SEED = np.iinfo(np.int32).max
85
+ ORPHEUS_MIN_ID = 10
86
+ ORPHEUS_TOKENS_PER_LAYER = 4096
87
+ ORPHEUS_N_LAYERS = 7
88
+ ORPHEUS_MAX_ID = ORPHEUS_MIN_ID + (ORPHEUS_N_LAYERS * ORPHEUS_TOKENS_PER_LAYER)
89
+ DEFAULT_OLLAMA_TEMP = 0.7
90
+ DEFAULT_OLLAMA_TOP_P = 0.9
91
+ DEFAULT_OLLAMA_TOP_K = 40
92
+ DEFAULT_OLLAMA_REP_PENALTY = 1.1
93
+ DEFAULT_TTS_TEMP = 0.4
94
+ DEFAULT_TTS_TOP_P = 0.9
95
+ DEFAULT_TTS_TOP_K = 40
96
+ DEFAULT_TTS_REP_PENALTY = 1.1
97
+ CONTEXT_TURN_LIMIT = 3
98
+
99
+ # --- Utility Functions ---
100
+ def randomize_seed_fn(seed: int, randomize_seed: bool) -> int:
101
+ if randomize_seed:
102
+ seed = random.randint(0, MAX_SEED)
103
+ return seed
104
+
105
+ def clean_chat_history(limited_chat_history):
106
+ cleaned_ollama_format = []
107
+ if not limited_chat_history:
108
+ return []
109
+ for user_msg_display, bot_msg_display in limited_chat_history:
110
+ user_text = None
111
+ if isinstance(user_msg_display, str):
112
+ if user_msg_display.startswith("🎤 (Audio Input): "):
113
+ user_text = user_msg_display.split("🎤 (Audio Input): ", 1)[1]
114
+ elif user_msg_display.startswith(("@tara-tts ", "@tara-llm ")):
115
+ user_text = user_msg_display.split(" ", 1)[1]
116
+ else:
117
+ user_text = user_msg_display
118
+ elif isinstance(user_msg_display, tuple):
119
+ if len(user_msg_display) > 1 and isinstance(user_msg_display[1], str):
120
+ user_text = user_msg_display[1].replace("🎤: ", "")
121
+ elif isinstance(user_msg_display[0], str) and not user_msg_display[0].endswith((".wav", ".mp3")):
122
+ user_text = user_msg_display[0]
123
+
124
+ bot_text = None
125
+ if isinstance(bot_msg_display, tuple):
126
+ if len(bot_msg_display) > 1 and isinstance(bot_msg_display[1], str):
127
+ bot_text = bot_msg_display[1]
128
+ elif isinstance(bot_msg_display, str):
129
+ if not bot_msg_display.startswith(("[Error", "(Error", "Sorry,", "(No input", "Processing", "(TTS failed")):
130
+ bot_text = bot_msg_display
131
+
132
+ if user_text and user_text.strip():
133
+ cleaned_ollama_format.append({"role": "user", "content": user_text})
134
+ if bot_text and bot_text.strip():
135
+ cleaned_ollama_format.append({"role": "assistant", "content": bot_text})
136
+ return cleaned_ollama_format
137
+
138
+ # --- TTS Pipeline Functions ---
139
+ def parse_gguf_codes(response_text):
140
+ absolute_ids = []
141
+ matches = re.findall(r"<custom_token_(\d+)>", response_text)
142
+ if not matches:
143
+ return []
144
+ for number_str in matches:
145
+ try:
146
+ token_id = int(number_str)
147
+ if ORPHEUS_MIN_ID <= token_id < ORPHEUS_MAX_ID:
148
+ absolute_ids.append(token_id)
149
+ except ValueError:
150
+ continue
151
+ print(f" - Parsed {len(absolute_ids)} valid audio token IDs using regex.")
152
+ return absolute_ids
153
+
154
+ def redistribute_codes(absolute_code_list, target_snac_model):
155
+ if not absolute_code_list or target_snac_model is None:
156
+ return None
157
+
158
+ snac_device = next(target_snac_model.parameters()).device
159
+ layer_1, layer_2, layer_3 = [], [], []
160
+ num_tokens = len(absolute_code_list)
161
+ num_groups = num_tokens // ORPHEUS_N_LAYERS
162
 
163
+ if num_groups == 0:
 
 
 
 
 
164
  return None
165
 
166
+ print(f" - Processing {num_groups} groups of {ORPHEUS_N_LAYERS} codes for SNAC...")
167
+
168
+ for i in range(num_groups):
169
+ base_idx = i * ORPHEUS_N_LAYERS
170
+ if base_idx + ORPHEUS_N_LAYERS > num_tokens:
171
+ break
172
+
173
+ group_codes = absolute_code_list[base_idx:base_idx + ORPHEUS_N_LAYERS]
174
+ processed_group = [None] * ORPHEUS_N_LAYERS
175
+ valid_group = True
176
+
177
+ for j, token_id in enumerate(group_codes):
178
+ if not (ORPHEUS_MIN_ID <= token_id < ORPHEUS_MAX_ID):
179
+ valid_group = False
180
+ break
181
+
182
+ layer_index = (token_id - ORPHEUS_MIN_ID) // ORPHEUS_TOKENS_PER_LAYER
183
+ code_index = (token_id - ORPHEUS_MIN_ID) % ORPHEUS_TOKENS_PER_LAYER
184
+
185
+ if layer_index != j:
186
+ valid_group = False
187
+ break
188
+
189
+ processed_group[j] = code_index
190
+
191
+ if not valid_group:
192
+ continue
193
+
194
+ try:
195
+ layer_1.append(processed_group[0])
196
+ layer_2.append(processed_group[1])
197
+ layer_3.append(processed_group[2])
198
+ layer_3.append(processed_group[3])
199
+ layer_2.append(processed_group[4])
200
+ layer_3.append(processed_group[5])
201
+ layer_3.append(processed_group[6])
202
+ except (IndexError, TypeError):
203
+ continue
204
+
205
  try:
206
+ if not layer_1 or not layer_2 or not layer_3:
207
+ return None
208
+
209
+ print(f" - Final SNAC layer sizes: L1={len(layer_1)}, L2={len(layer_2)}, L3={len(layer_3)}")
210
+
211
+ codes = [
212
+ torch.tensor(layer_1, device=snac_device, dtype=torch.long).unsqueeze(0),
213
+ torch.tensor(layer_2, device=snac_device, dtype=torch.long).unsqueeze(0),
214
+ torch.tensor(layer_3, device=snac_device, dtype=torch.long).unsqueeze(0)
215
+ ]
216
 
 
217
  with torch.no_grad():
218
+ audio_hat = target_snac_model.decode(codes)
219
+
220
+ return audio_hat.detach().squeeze().cpu().numpy()
221
+ except Exception as e:
222
+ print(f"Error during tensor creation or SNAC decoding: {e}")
223
+ return None
224
+
225
+ def generate_speech_gguf(text, voice, tts_temperature, tts_top_p, tts_repetition_penalty, max_new_tokens_audio):
226
+ if not text.strip() or snac_model is None:
227
+ return None
 
228
 
229
+ print(f"Generating speech via TTS server for: '{text[:50]}...'")
230
+ start_time = time.time()
231
+
232
+ payload = {
233
+ "model": TTS_MODEL,
234
+ "prompt": f"<|audio|>{voice}: {text}<|eot_id|>",
235
+ "temperature": tts_temperature,
236
+ "top_p": tts_top_p,
237
+ "repeat_penalty": tts_repetition_penalty,
238
+ "max_tokens": max_new_tokens_audio,
239
+ "stop": ["<|eot_id|>", "<|audio|>"],
240
+ "stream": False
241
+ }
242
+
243
+ print(f" - Sending payload to {TTS_API_ENDPOINT} (Model: {TTS_MODEL})")
244
+
245
+ try:
246
+ headers = {"Content-Type": "application/json"}
247
+ response = requests.post(
248
+ TTS_API_ENDPOINT,
249
+ json=payload,
250
+ headers=headers,
251
+ timeout=180
252
+ )
253
+ response.raise_for_status()
254
+ response_json = response.json()
255
 
256
+ print(f" - Raw TTS response: {json.dumps(response_json, indent=2)[:200]}...")
 
257
 
258
+ if "choices" in response_json and len(response_json["choices"]) > 0:
259
+ raw_generated_text = response_json["choices"][0].get("text", "").strip()
260
+ if not raw_generated_text:
261
+ print("Error: Empty text in TTS response")
262
+ return None
263
+
264
+ req_time = time.time()
265
+ print(f" - TTS server request took {req_time - start_time:.2f}s")
266
+
267
+ absolute_id_list = parse_gguf_codes(raw_generated_text)
268
+ if not absolute_id_list:
269
+ print("Error: No valid audio codes parsed. Raw text:", raw_generated_text[:200])
270
+ return None
271
+
272
+ audio_samples = redistribute_codes(absolute_id_list, snac_model)
273
+ if audio_samples is None:
274
+ print("Error: Failed to generate audio samples from tokens")
275
+ return None
276
+
277
+ snac_time = time.time()
278
+ print(f" - Generated audio samples via SNAC, shape: {audio_samples.shape}")
279
+ print(f" - Total TTS generation time: {snac_time - start_time:.2f}s")
280
+ return (24000, audio_samples)
281
+
282
+ else:
283
+ print(f"Error: Unexpected TTS response format: {response_json}")
284
+ return None
285
+
286
+ except requests.exceptions.RequestException as e:
287
+ print(f"Error during request to TTS server: {e}")
288
+ return None
289
  except Exception as e:
290
+ print(f"Error during TTS generation pipeline: {e}")
291
+ traceback.print_exc()
292
  return None
293
+
294
+ # --- Ollama Communication Helper ---
295
+ def call_ollama_non_streaming(ollama_payload, generation_params):
296
+ final_response = "[Error: Default response]"
297
+ try:
298
+ payload = {
299
+ "model": OLLAMA_MODEL,
300
+ "messages": ollama_payload["messages"],
301
+ "temperature": generation_params.get('ollama_temperature', DEFAULT_OLLAMA_TEMP),
302
+ "top_p": generation_params.get('ollama_top_p', DEFAULT_OLLAMA_TOP_P),
303
+ "max_tokens": generation_params.get('ollama_max_new_tokens', DEFAULT_OLLAMA_MAX_TOKENS),
304
+ "repeat_penalty": generation_params.get('ollama_repetition_penalty', DEFAULT_OLLAMA_REP_PENALTY),
305
+ "stream": False
306
+ }
307
+
308
+ print(f" - Sending to {OLLAMA_API_ENDPOINT} with model {OLLAMA_MODEL}")
309
+
310
+ headers = {"Content-Type": "application/json"}
311
+ start_time = time.time()
312
+ response = requests.post(
313
+ OLLAMA_API_ENDPOINT,
314
+ json=payload,
315
+ headers=headers,
316
+ timeout=180
317
+ )
318
+ response.raise_for_status()
319
+ response_json = response.json()
320
+ end_time = time.time()
321
+
322
+ print(f" - LLM request took {end_time - start_time:.2f}s")
323
+
324
+ if "choices" in response_json and len(response_json["choices"]) > 0:
325
+ choice = response_json["choices"][0]
326
+ if "message" in choice:
327
+ final_response = choice["message"]["content"].strip()
328
+ elif "text" in choice:
329
+ final_response = choice["text"].strip()
330
+ else:
331
+ final_response = "[Error: Unexpected response format]"
332
+ else:
333
+ final_response = f"[Error: {response_json.get('error', 'Unknown error')}]"
334
+
335
+ except requests.exceptions.RequestException as e:
336
+ final_response = f"[Error connecting to LLM: {e}]"
337
+ except Exception as e:
338
+ final_response = f"[Unexpected Error: {e}]"
339
+ traceback.print_exc()
340
+
341
+ print(f" - LLM response: '{final_response[:100]}...'")
342
+ return final_response
343
+
344
+ # --- Main Gradio Backend Function ---
345
+ def process_input_blocks(
346
+ text_input: str, audio_input_path: str,
347
+ auto_prefix_tts_checkbox: bool,
348
+ auto_prefix_llm_checkbox: bool,
349
+ plain_llm_checkbox: bool,
350
+ ollama_max_new_tokens: int, ollama_temperature: float, ollama_top_p: float,
351
+ ollama_top_k: int, ollama_repetition_penalty: float,
352
+ tts_temperature: float, tts_top_p: float, tts_repetition_penalty: float,
353
+ chat_history: list
354
+ ):
355
+ global whisper_model, snac_model
356
+ original_user_input_text = ""
357
+ user_display_input = None
358
+ text_to_process = ""
359
+ transcription_source = "text"
360
+ bot_response = ""
361
+ bot_audio_tuple = None
362
+ audio_filepath_to_clean = None
363
+ is_purely_text_input = False
364
+ prefix_to_add = None
365
+ force_plain_llm = False
366
+
367
+ # Handle Audio Input
368
+ if audio_input_path and whisper_model:
369
+ if os.path.isfile(audio_input_path):
370
+ audio_filepath_to_clean = audio_input_path
371
+ transcription_source = "voice"
372
+ print(f"Processing audio input: {audio_input_path}")
373
+ try:
374
+ stt_start_time = time.time()
375
+ result = whisper_model.transcribe(audio_input_path, fp16=(stt_device == 'cuda'))
376
+ original_user_input_text = result["text"].strip()
377
+ stt_end_time = time.time()
378
+ print(f" - Whisper transcription: '{original_user_input_text}' (took {stt_end_time - stt_start_time:.2f}s)")
379
+ user_display_input = f"🎤 (Audio Input): {original_user_input_text}"
380
+ text_to_process = original_user_input_text
381
+
382
+ # Check if transcription is already a command
383
+ known_prefixes = ["@tara-tts", "@jess-tts", "@leo-tts", "@leah-tts", "@dan-tts", "@mia-tts", "@zac-tts", "@zoe-tts",
384
+ "@tara-llm", "@jess-llm", "@leo-llm", "@leah-llm", "@dan-llm", "@mia-llm", "@zac-llm", "@zoe-llm"]
385
+ is_already_command = any(original_user_input_text.lower().startswith(p) for p in known_prefixes)
386
+
387
+ if not is_already_command:
388
+ if plain_llm_checkbox:
389
+ prefix_to_add = None
390
+ force_plain_llm = True
391
+ print(f" - Plain LLM checked. Processing audio as text input for LLM.")
392
+ elif auto_prefix_tts_checkbox:
393
+ prefix_to_add = "@tara-tts"
394
+ print(f" - Auto-prefix TTS checked. Applying to audio.")
395
+ elif auto_prefix_llm_checkbox:
396
+ prefix_to_add = "@tara-llm"
397
+ print(f" - Auto-prefix LLM checked. Applying to audio.")
398
+ else:
399
+ print(f" - No default prefix checkbox checked for audio. Processing as text for LLM.")
400
+
401
+ if prefix_to_add:
402
+ text_to_process = f"{prefix_to_add} {original_user_input_text}"
403
+ else:
404
+ print(f" - Transcribed audio is already a command '{original_user_input_text[:20]}...'.")
405
+ text_to_process = original_user_input_text
406
+
407
+ except Exception as e:
408
+ print(f"Error during Whisper transcription: {e}")
409
+ traceback.print_exc()
410
+ error_msg = f"[Error during local transcription: {e}]"
411
+ chat_history.append((f"🎤 (Audio Input Error: {audio_input_path})", error_msg))
412
+ if audio_filepath_to_clean and os.path.exists(audio_filepath_to_clean):
413
+ try:
414
+ os.remove(audio_filepath_to_clean)
415
+ except Exception as e_clean:
416
+ print(f"Warning: Could not clean up STT temp file {audio_filepath_to_clean}: {e_clean}")
417
+ return chat_history, None, None
418
+ else:
419
+ print(f"Received invalid audio path: {audio_input_path}, falling back to text.")
420
+
421
+ # Handle Text Input
422
+ if not text_to_process and text_input:
423
+ original_user_input_text = text_input.strip()
424
+ user_display_input = original_user_input_text
425
+ print(f"Processing text input: '{original_user_input_text}'")
426
+ transcription_source = "text"
427
+ text_to_process = original_user_input_text
428
+
429
+ known_prefixes = ["@tara-tts", "@jess-tts", "@leo-tts", "@leah-tts", "@dan-tts", "@mia-tts", "@zac-tts", "@zoe-tts",
430
+ "@tara-llm", "@jess-llm", "@leo-llm", "@leah-llm", "@dan-llm", "@mia-llm", "@zac-llm", "@zoe-llm"]
431
+ is_already_command = any(original_user_input_text.lower().startswith(p) for p in known_prefixes)
432
+
433
+ if not is_already_command:
434
+ if plain_llm_checkbox:
435
+ prefix_to_add = None
436
+ force_plain_llm = True
437
+ print(f" - Plain LLM checked. Processing text input for LLM.")
438
+ elif auto_prefix_tts_checkbox:
439
+ prefix_to_add = "@tara-tts"
440
+ print(f" - Auto-prefix TTS checked. Applying to text.")
441
+ elif auto_prefix_llm_checkbox:
442
+ prefix_to_add = "@tara-llm"
443
+ print(f" - Auto-prefix LLM checked. Applying to text.")
444
+ else:
445
+ print(f" - No default prefix checkbox enabled for text input.")
446
+ else:
447
+ print(f" - User provided command in text '{original_user_input_text[:20]}...', not auto-prepending.")
448
+
449
+ if prefix_to_add:
450
+ text_to_process = f"{prefix_to_add} {original_user_input_text}"
451
+
452
+ # Cleanup audio file
453
+ if audio_filepath_to_clean and os.path.exists(audio_filepath_to_clean):
454
+ try:
455
+ os.remove(audio_filepath_to_clean)
456
+ print(f" - Cleaned up temporary STT audio file: {audio_filepath_to_clean}")
457
+ except Exception as e_clean:
458
+ print(f"Warning: Could not clean up temp STT audio file {audio_filepath_to_clean}: {e_clean}")
459
+
460
+ if not text_to_process:
461
+ print("No valid text or audio input to process.")
462
+ return chat_history, None, None
463
+
464
+ chat_history.append((user_display_input, None))
465
+
466
+ # Process Input Text
467
+ lower_text = text_to_process.lower()
468
+ print(f" - Routing query ({transcription_source}): '{text_to_process[:100]}...'")
469
 
470
+ all_voices = ["tara", "jess", "leo", "leah", "dan", "mia", "zac", "zoe"]
471
+ tts_tags = {f"@{voice}-tts": voice for voice in all_voices}
472
+ llm_tags = {f"@{voice}-llm": voice for voice in all_voices}
473
+
474
+ final_bot_message = None
475
+
476
+ try:
477
+ matched_tts = False
478
+ matched_llm_tts = False
479
+
480
+ # Check Branches
481
+ if not force_plain_llm:
482
+ # Branch 1: Direct TTS
483
+ for tag, voice in tts_tags.items():
484
+ if lower_text.startswith(tag):
485
+ matched_tts = True
486
+ text_to_speak = text_to_process[len(tag):].strip()
487
+ print(f" - Direct TTS request for voice '{voice}': '{text_to_speak[:50]}...'")
488
+ if snac_model is None:
489
+ raise ValueError("SNAC vocoder not loaded.")
490
+ audio_output = generate_speech_gguf(
491
+ text_to_speak, voice,
492
+ tts_temperature, tts_top_p, tts_repetition_penalty,
493
+ MAX_MAX_NEW_TOKENS
494
+ )
495
+ if audio_output:
496
+ sample_rate, audio_data = audio_output
497
+ if audio_data.dtype != np.int16:
498
+ if np.issubdtype(audio_data.dtype, np.floating):
499
+ max_val = np.max(np.abs(audio_data))
500
+ audio_data = np.int16(audio_data/max_val*32767) if max_val > 1e-6 else np.zeros_like(audio_data, dtype=np.int16)
501
+ else:
502
+ audio_data = audio_data.astype(np.int16)
503
+ temp_dir = "temp_audio_files"
504
+ os.makedirs(temp_dir, exist_ok=True)
505
+ temp_audio_path = os.path.join(temp_dir, f"temp_audio_{uuid.uuid4().hex}.wav")
506
+ wavfile.write(temp_audio_path, sample_rate, audio_data)
507
+ print(f" - Saved TTS audio: {temp_audio_path}")
508
+ final_bot_message = (temp_audio_path, None)
509
+ else:
510
+ final_bot_message = f"Sorry, couldn't generate speech for '{text_to_speak[:50]}...'."
511
+ break
512
+
513
+ # Branch 2: LLM + TTS
514
+ if not matched_tts:
515
+ for tag, voice in llm_tags.items():
516
+ if lower_text.startswith(tag):
517
+ matched_llm_tts = True
518
+ prompt_for_llm = text_to_process[len(tag):].strip()
519
+ print(f" - LLM+TTS request for voice '{voice}': '{prompt_for_llm[:75]}...'")
520
+ if snac_model is None:
521
+ raise ValueError("SNAC vocoder not loaded.")
522
+
523
+ history_before_current = chat_history[:-1]
524
+ limited_history_turns = history_before_current[-CONTEXT_TURN_LIMIT:]
525
+ cleaned_hist_for_llm = clean_chat_history(limited_history_turns)
526
+
527
+ messages = [
528
+ {"role": "system", "content": OLLAMA_SYSTEM_PROMPT}
529
+ ] + cleaned_hist_for_llm + [
530
+ {"role": "user", "content": prompt_for_llm}
531
+ ]
532
+
533
+ llm_params = {
534
+ 'ollama_temperature': ollama_temperature,
535
+ 'ollama_top_p': ollama_top_p,
536
+ 'ollama_top_k': ollama_top_k,
537
+ 'ollama_max_new_tokens': ollama_max_new_tokens,
538
+ 'ollama_repetition_penalty': ollama_repetition_penalty
539
+ }
540
+
541
+ llm_response_text = call_ollama_non_streaming(
542
+ {"messages": messages},
543
+ llm_params
544
+ )
545
+
546
+ if llm_response_text and not llm_response_text.startswith("[Error"):
547
+ audio_output = generate_speech_gguf(
548
+ llm_response_text, voice,
549
+ tts_temperature, tts_top_p, tts_repetition_penalty,
550
+ MAX_MAX_NEW_TOKENS
551
+ )
552
+ if audio_output:
553
+ sample_rate, audio_data = audio_output
554
+ if audio_data.dtype != np.int16:
555
+ if np.issubdtype(audio_data.dtype, np.floating):
556
+ max_val = np.max(np.abs(audio_data))
557
+ audio_data = np.int16(audio_data/max_val*32767) if max_val > 1e-6 else np.zeros_like(audio_data, dtype=np.int16)
558
+ else:
559
+ audio_data = audio_data.astype(np.int16)
560
+ temp_dir = "temp_audio_files"
561
+ os.makedirs(temp_dir, exist_ok=True)
562
+ temp_audio_path = os.path.join(temp_dir, f"temp_audio_{uuid.uuid4().hex}.wav")
563
+ wavfile.write(temp_audio_path, sample_rate, audio_data)
564
+ print(f" - Saved LLM+TTS audio: {temp_audio_path}")
565
+ final_bot_message = (temp_audio_path, llm_response_text)
566
+ else:
567
+ print("Warning: TTS generation failed...")
568
+ final_bot_message = f"{llm_response_text}\n\n(TTS failed...)"
569
+ else:
570
+ final_bot_message = llm_response_text
571
+ break
572
+
573
+ # Branch 3: Plain LLM
574
+ if force_plain_llm or (not matched_tts and not matched_llm_tts):
575
+ if force_plain_llm:
576
+ print(f" - Plain LLM chat mode forced by checkbox...")
577
+ else:
578
+ print(f" - Default text chat (no command prefix detected/added)...")
579
+
580
+ history_before_current = chat_history[:-1]
581
+ limited_history_turns = history_before_current[-CONTEXT_TURN_LIMIT:]
582
+ cleaned_hist_for_llm = clean_chat_history(limited_history_turns)
583
+
584
+ messages = [
585
+ {"role": "system", "content": OLLAMA_SYSTEM_PROMPT}
586
+ ] + cleaned_hist_for_llm + [
587
+ {"role": "user", "content": original_user_input_text}
588
+ ]
589
+
590
+ llm_params = {
591
+ 'ollama_temperature': ollama_temperature,
592
+ 'ollama_top_p': ollama_top_p,
593
+ 'ollama_top_k': ollama_top_k,
594
+ 'ollama_max_new_tokens': ollama_max_new_tokens,
595
+ 'ollama_repetition_penalty': ollama_repetition_penalty
596
+ }
597
+
598
+ final_bot_message = call_ollama_non_streaming(
599
+ {"messages": messages},
600
+ llm_params
601
  )
602
+
603
+ except Exception as e:
604
+ print(f"Error during processing: {e}")
605
+ traceback.print_exc()
606
+ final_bot_message = f"[An unexpected error occurred: {e}]"
607
+
608
+ chat_history[-1] = (user_display_input, final_bot_message)
609
+ return chat_history, None, None
610
+
611
+ # --- Gradio Interface ---
612
+ def update_prefix_checkboxes(selected_checkbox_label):
613
+ if selected_checkbox_label == "tts":
614
+ return gr.update(value=True), gr.update(value=False), gr.update(value=False)
615
+ elif selected_checkbox_label == "llm":
616
+ return gr.update(value=False), gr.update(value=True), gr.update(value=False)
617
+ elif selected_checkbox_label == "plain":
618
+ return gr.update(value=False), gr.update(value=False), gr.update(value=True)
619
+ else:
620
+ return gr.update(), gr.update(), gr.update()
621
+
622
+ print("Setting up Gradio Interface with gr.Blocks...")
623
+ theme_to_use = None
624
+
625
+ with gr.Blocks(theme=theme_to_use) as demo:
626
+ gr.Markdown(f"# Orpheus Edge 🎤 ({OLLAMA_MODEL}) Chat & TTS")
627
+
628
+ chatbot = gr.Chatbot(label="Chat History", height=500)
629
 
630
  with gr.Row():
631
+ with gr.Column(scale=3):
632
+ text_input_box = gr.Textbox(label="Type your message or use microphone", lines=2)
633
+ with gr.Column(scale=1):
634
+ audio_input_mic = gr.Audio(label="Record Audio Input", type="filepath")
635
 
636
+ with gr.Row():
637
+ auto_prefix_tts_checkbox = gr.Checkbox(label="Default to TTS (@tara-tts)", value=True, elem_id="cb_tts")
638
+ auto_prefix_llm_checkbox = gr.Checkbox(label="Default to LLM+TTS (@tara-llm)", value=False, elem_id="cb_llm")
639
+ plain_llm_checkbox = gr.Checkbox(label="Plain LLM Chat (Text Out)", value=False, elem_id="cb_plain")
 
640
 
641
+ with gr.Row():
642
+ submit_button = gr.Button("Send / Submit")
643
+ clear_button = gr.ClearButton([text_input_box, audio_input_mic, chatbot])
644
+
645
+ with gr.Accordion("Generation Parameters", open=False):
646
+ gr.Markdown("### LLM Parameters")
647
+ ollama_max_new_tokens_slider = gr.Slider(label="Max New Tokens", minimum=32, maximum=4096, step=32, value=DEFAULT_OLLAMA_MAX_TOKENS)
648
+ ollama_temperature_slider = gr.Slider(label="Temperature", minimum=0.0, maximum=2.0, step=0.05, value=DEFAULT_OLLAMA_TEMP)
649
+ ollama_top_p_slider = gr.Slider(label="Top-p", minimum=0.05, maximum=1.0, step=0.05, value=DEFAULT_OLLAMA_TOP_P)
650
+ ollama_top_k_slider = gr.Slider(label="Top-k", minimum=1, maximum=100, step=1, value=DEFAULT_OLLAMA_TOP_K)
651
+ ollama_repetition_penalty_slider = gr.Slider(label="Repetition Penalty", minimum=1.0, maximum=2.0, step=0.05, value=DEFAULT_OLLAMA_REP_PENALTY)
652
+
653
+ gr.Markdown("---")
654
+ gr.Markdown("### TTS Parameters")
655
+ tts_temperature_slider = gr.Slider(label="Temperature", minimum=0.0, maximum=2.0, step=0.05, value=DEFAULT_TTS_TEMP)
656
+ tts_top_p_slider = gr.Slider(label="Top-p", minimum=0.05, maximum=1.0, step=0.05, value=DEFAULT_TTS_TOP_P)
657
+ tts_repetition_penalty_slider = gr.Slider(label="Repetition Penalty", minimum=1.0, maximum=2.0, step=0.05, value=DEFAULT_TTS_REP_PENALTY)
658
 
659
+ param_inputs = [
660
+ ollama_max_new_tokens_slider, ollama_temperature_slider, ollama_top_p_slider,
661
+ ollama_top_k_slider, ollama_repetition_penalty_slider,
662
+ tts_temperature_slider, tts_top_p_slider, tts_repetition_penalty_slider
663
+ ]
664
+
665
+ auto_prefix_tts_checkbox.change(
666
+ lambda: update_prefix_checkboxes("tts"),
667
+ None,
668
+ [auto_prefix_tts_checkbox, auto_prefix_llm_checkbox, plain_llm_checkbox]
669
+ )
670
+ auto_prefix_llm_checkbox.change(
671
+ lambda: update_prefix_checkboxes("llm"),
672
+ None,
673
+ [auto_prefix_tts_checkbox, auto_prefix_llm_checkbox, plain_llm_checkbox]
674
  )
675
+ plain_llm_checkbox.change(
676
+ lambda: update_prefix_checkboxes("plain"),
677
+ None,
678
+ [auto_prefix_tts_checkbox, auto_prefix_llm_checkbox, plain_llm_checkbox]
679
+ )
680
+
681
+ all_inputs = [
682
+ text_input_box, audio_input_mic,
683
+ auto_prefix_tts_checkbox, auto_prefix_llm_checkbox, plain_llm_checkbox
684
+ ] + param_inputs + [chatbot]
685
 
686
+ submit_button.click(
687
+ fn=process_input_blocks,
688
+ inputs=all_inputs,
689
+ outputs=[chatbot, text_input_box, audio_input_mic]
690
+ )
691
+ text_input_box.submit(
692
+ fn=process_input_blocks,
693
+ inputs=all_inputs,
694
+ outputs=[chatbot, text_input_box, audio_input_mic]
695
  )
696
+
697
+ # --- Application Entry Point ---
698
  if __name__ == "__main__":
699
+ print("-" * 50)
700
+ print(f"Launching Gradio {gr.__version__} Interface")
701
+ print(f"Whisper STT Model: {WHISPER_MODEL_NAME} on {stt_device}")
702
+ print(f"SNAC Vocoder loaded to {tts_device}")
703
+ print(f"Server URL: {SERVER_BASE_URL}")
704
+ print(f"LLM Model: {OLLAMA_MODEL}")
705
+ print(f"TTS Model: {TTS_MODEL}")
706
+ print("-" * 50)
707
+ print("Default Parameters:")
708
+ print(f" LLM: Temp={DEFAULT_OLLAMA_TEMP}, TopP={DEFAULT_OLLAMA_TOP_P}")
709
+ print(f" TTS: Temp={DEFAULT_TTS_TEMP}, TopP={DEFAULT_TTS_TOP_P}")
710
+ print("-" * 50)
711
+ print("Ensure your LM Studio server is running with both models loaded")
712
+ os.makedirs("temp_audio_files", exist_ok=True)
713
+ demo.launch(share=False)
714
+ print("Gradio Interface launched. Press Ctrl+C to stop.")
requirements.txt CHANGED
@@ -2,4 +2,7 @@ snac
2
  python-dotenv
3
  transformers
4
  torch
5
- spaces
 
 
 
 
2
  python-dotenv
3
  transformers
4
  torch
5
+ spaces
6
+ openai-whisper
7
+ requests
8
+ gradio