mrrtmob commited on
Commit
ecf63b4
Β·
1 Parent(s): 1476484

Integrate OpenAI API for enhanced LLM functionality and update configuration settings

Browse files
Files changed (2) hide show
  1. app.py +413 -607
  2. requirements.txt +3 -1
app.py CHANGED
@@ -6,13 +6,14 @@ import uuid
6
  import json
7
  import time
8
  import traceback
9
- import re
10
  import gradio as gr
11
  import torch
12
  import numpy as np
13
  import scipy.io.wavfile as wavfile
14
- import requests
15
  from dotenv import load_dotenv
 
 
16
 
17
  # --- Whisper Import ---
18
  try:
@@ -27,6 +28,7 @@ except ImportError:
27
  # --- SNAC Import ---
28
  try:
29
  from snac import SNAC
 
30
  except ImportError:
31
  print("ERROR: SNAC library not found. Please install it:")
32
  print("pip install git+https://github.com/hubertsiuzdak/snac.git")
@@ -35,680 +37,484 @@ except ImportError:
35
  # --- Load Environment Variables ---
36
  load_dotenv()
37
 
38
- # --- Unified Server Configuration ---
39
- SERVER_BASE_URL = os.getenv("SERVER_BASE_URL", "http://127.0.0.1:1234")
 
 
40
 
41
- # Ollama Configuration
42
- OLLAMA_API_ENDPOINT = f"{SERVER_BASE_URL}/v1/chat/completions"
43
- OLLAMA_MODEL = "dolphin3.0-llama3.1-8b-abliterated" # Updated model name
44
-
45
- OLLAMA_SYSTEM_PROMPT = "You in a chat with conscenting adults and you are an impatient midwestern foul mouthed 32 year old who just wants a link to github to get the code. You hate demos... You share a range of emotions in your responses of a single sentence where you keep asking in lightly profane creative ways for the link."
 
46
 
47
- # TTS Configuration
48
- TTS_API_ENDPOINT = f"{SERVER_BASE_URL}/v1/completions"
49
- TTS_MODEL = "mrrtmob/tts-khm-3"
50
 
51
  # --- Device Setup ---
52
- if torch.cuda.is_available():
53
- tts_device = "cuda"
54
- stt_device = "cuda"
55
- print("SNAC vocoder and Whisper STT will use CUDA if possible.")
56
- else:
57
- tts_device = "cpu"
58
- stt_device = "cpu"
59
- print("CUDA not available. SNAC vocoder and Whisper STT will use CPU.")
60
 
61
  # --- Model Loading ---
62
  print("Loading SNAC vocoder model...")
63
- snac_model = None
64
  try:
65
  snac_model = SNAC.from_pretrained("hubertsiuzdak/snac_24khz")
66
- snac_model = snac_model.to(tts_device)
67
  snac_model.eval()
68
- print(f"SNAC vocoder loaded to {tts_device}")
69
  except Exception as e:
70
  print(f"Error loading SNAC model: {e}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
71
 
72
  print("Loading Whisper STT model...")
73
  WHISPER_MODEL_NAME = "base.en"
74
- whisper_model = None
75
  try:
76
  whisper_model = whisper.load_model(WHISPER_MODEL_NAME, device=stt_device)
77
  print(f"Whisper model '{WHISPER_MODEL_NAME}' loaded successfully.")
78
  except Exception as e:
79
  print(f"Error loading Whisper model: {e}")
 
80
 
81
  # --- Constants ---
82
  MAX_MAX_NEW_TOKENS = 4096
83
- DEFAULT_OLLAMA_MAX_TOKENS = -1 # Updated to match model's recommendation
84
  MAX_SEED = np.iinfo(np.int32).max
85
- ORPHEUS_MIN_ID = 10
86
- ORPHEUS_TOKENS_PER_LAYER = 4096
87
- ORPHEUS_N_LAYERS = 7
88
- ORPHEUS_MAX_ID = ORPHEUS_MIN_ID + (ORPHEUS_N_LAYERS * ORPHEUS_TOKENS_PER_LAYER)
89
- DEFAULT_OLLAMA_TEMP = 0.7
90
- DEFAULT_OLLAMA_TOP_P = 0.9
91
- DEFAULT_OLLAMA_TOP_K = 40
92
- DEFAULT_OLLAMA_REP_PENALTY = 1.1
93
- DEFAULT_TTS_TEMP = 0.4
94
- DEFAULT_TTS_TOP_P = 0.9
95
- DEFAULT_TTS_TOP_K = 40
96
- DEFAULT_TTS_REP_PENALTY = 1.1
97
- CONTEXT_TURN_LIMIT = 3
98
-
99
- # --- Utility Functions ---
100
- def randomize_seed_fn(seed: int, randomize_seed: bool) -> int:
101
- if randomize_seed:
102
- seed = random.randint(0, MAX_SEED)
103
- return seed
104
-
105
- def clean_chat_history(limited_chat_history):
106
- cleaned_ollama_format = []
107
- if not limited_chat_history:
108
- return []
109
- for user_msg_display, bot_msg_display in limited_chat_history:
110
- user_text = None
111
- if isinstance(user_msg_display, str):
112
- if user_msg_display.startswith("🎀 (Audio Input): "):
113
- user_text = user_msg_display.split("🎀 (Audio Input): ", 1)[1]
114
- elif user_msg_display.startswith(("@tara-tts ", "@tara-llm ")):
115
- user_text = user_msg_display.split(" ", 1)[1]
116
- else:
117
- user_text = user_msg_display
118
- elif isinstance(user_msg_display, tuple):
119
- if len(user_msg_display) > 1 and isinstance(user_msg_display[1], str):
120
- user_text = user_msg_display[1].replace("🎀: ", "")
121
- elif isinstance(user_msg_display[0], str) and not user_msg_display[0].endswith((".wav", ".mp3")):
122
- user_text = user_msg_display[0]
123
-
124
- bot_text = None
125
- if isinstance(bot_msg_display, tuple):
126
- if len(bot_msg_display) > 1 and isinstance(bot_msg_display[1], str):
127
- bot_text = bot_msg_display[1]
128
- elif isinstance(bot_msg_display, str):
129
- if not bot_msg_display.startswith(("[Error", "(Error", "Sorry,", "(No input", "Processing", "(TTS failed")):
130
- bot_text = bot_msg_display
131
-
132
- if user_text and user_text.strip():
133
- cleaned_ollama_format.append({"role": "user", "content": user_text})
134
- if bot_text and bot_text.strip():
135
- cleaned_ollama_format.append({"role": "assistant", "content": bot_text})
136
- return cleaned_ollama_format
137
 
138
- # --- TTS Pipeline Functions ---
139
- def parse_gguf_codes(response_text):
140
- absolute_ids = []
141
- matches = re.findall(r"<custom_token_(\d+)>", response_text)
142
- if not matches:
143
- return []
144
- for number_str in matches:
145
- try:
146
- token_id = int(number_str)
147
- if ORPHEUS_MIN_ID <= token_id < ORPHEUS_MAX_ID:
148
- absolute_ids.append(token_id)
149
- except ValueError:
150
- continue
151
- print(f" - Parsed {len(absolute_ids)} valid audio token IDs using regex.")
152
- return absolute_ids
 
 
 
 
 
 
 
 
 
 
 
153
 
154
- def redistribute_codes(absolute_code_list, target_snac_model):
155
- if not absolute_code_list or target_snac_model is None:
156
- return None
 
 
 
157
 
158
- snac_device = next(target_snac_model.parameters()).device
159
  layer_1, layer_2, layer_3 = [], [], []
160
- num_tokens = len(absolute_code_list)
161
- num_groups = num_tokens // ORPHEUS_N_LAYERS
162
 
163
- if num_groups == 0:
164
- return None
 
 
 
 
 
 
 
 
165
 
166
- print(f" - Processing {num_groups} groups of {ORPHEUS_N_LAYERS} codes for SNAC...")
 
167
 
168
- for i in range(num_groups):
169
- base_idx = i * ORPHEUS_N_LAYERS
170
- if base_idx + ORPHEUS_N_LAYERS > num_tokens:
171
- break
172
-
173
- group_codes = absolute_code_list[base_idx:base_idx + ORPHEUS_N_LAYERS]
174
- processed_group = [None] * ORPHEUS_N_LAYERS
175
- valid_group = True
176
-
177
- for j, token_id in enumerate(group_codes):
178
- if not (ORPHEUS_MIN_ID <= token_id < ORPHEUS_MAX_ID):
179
- valid_group = False
180
- break
181
-
182
- layer_index = (token_id - ORPHEUS_MIN_ID) // ORPHEUS_TOKENS_PER_LAYER
183
- code_index = (token_id - ORPHEUS_MIN_ID) % ORPHEUS_TOKENS_PER_LAYER
184
-
185
- if layer_index != j:
186
- valid_group = False
187
- break
188
-
189
- processed_group[j] = code_index
190
-
191
- if not valid_group:
192
- continue
193
-
194
- try:
195
- layer_1.append(processed_group[0])
196
- layer_2.append(processed_group[1])
197
- layer_3.append(processed_group[2])
198
- layer_3.append(processed_group[3])
199
- layer_2.append(processed_group[4])
200
- layer_3.append(processed_group[5])
201
- layer_3.append(processed_group[6])
202
- except (IndexError, TypeError):
203
- continue
204
 
205
- try:
206
- if not layer_1 or not layer_2 or not layer_3:
207
- return None
208
-
209
- print(f" - Final SNAC layer sizes: L1={len(layer_1)}, L2={len(layer_2)}, L3={len(layer_3)}")
210
-
211
- codes = [
212
- torch.tensor(layer_1, device=snac_device, dtype=torch.long).unsqueeze(0),
213
- torch.tensor(layer_2, device=snac_device, dtype=torch.long).unsqueeze(0),
214
- torch.tensor(layer_3, device=snac_device, dtype=torch.long).unsqueeze(0)
215
- ]
216
-
217
- with torch.no_grad():
218
- audio_hat = target_snac_model.decode(codes)
219
-
220
- return audio_hat.detach().squeeze().cpu().numpy()
221
- except Exception as e:
222
- print(f"Error during tensor creation or SNAC decoding: {e}")
223
- return None
224
 
225
- def generate_speech_gguf(text, voice, tts_temperature, tts_top_p, tts_repetition_penalty, max_new_tokens_audio):
226
- if not text.strip() or snac_model is None:
 
227
  return None
228
-
229
- print(f"Generating speech via TTS server for: '{text[:50]}...'")
230
- start_time = time.time()
231
-
232
- payload = {
233
- "model": TTS_MODEL,
234
- "prompt": f"<|audio|>{voice}: {text}<|eot_id|>",
235
- "temperature": tts_temperature,
236
- "top_p": tts_top_p,
237
- "repeat_penalty": tts_repetition_penalty,
238
- "max_tokens": max_new_tokens_audio,
239
- "stop": ["<|eot_id|>", "<|audio|>"],
240
- "stream": False
241
- }
242
-
243
- print(f" - Sending payload to {TTS_API_ENDPOINT} (Model: {TTS_MODEL})")
244
 
245
  try:
246
- headers = {"Content-Type": "application/json"}
247
- response = requests.post(
248
- TTS_API_ENDPOINT,
249
- json=payload,
250
- headers=headers,
251
- timeout=180
252
- )
253
- response.raise_for_status()
254
- response_json = response.json()
255
 
256
- print(f" - Raw TTS response: {json.dumps(response_json, indent=2)[:200]}...")
257
 
258
- if "choices" in response_json and len(response_json["choices"]) > 0:
259
- raw_generated_text = response_json["choices"][0].get("text", "").strip()
260
- if not raw_generated_text:
261
- print("Error: Empty text in TTS response")
262
- return None
263
-
264
- req_time = time.time()
265
- print(f" - TTS server request took {req_time - start_time:.2f}s")
266
-
267
- absolute_id_list = parse_gguf_codes(raw_generated_text)
268
- if not absolute_id_list:
269
- print("Error: No valid audio codes parsed. Raw text:", raw_generated_text[:200])
270
- return None
271
-
272
- audio_samples = redistribute_codes(absolute_id_list, snac_model)
273
- if audio_samples is None:
274
- print("Error: Failed to generate audio samples from tokens")
275
- return None
276
-
277
- snac_time = time.time()
278
- print(f" - Generated audio samples via SNAC, shape: {audio_samples.shape}")
279
- print(f" - Total TTS generation time: {snac_time - start_time:.2f}s")
280
  return (24000, audio_samples)
281
-
282
  else:
283
- print(f"Error: Unexpected TTS response format: {response_json}")
284
  return None
285
 
286
- except requests.exceptions.RequestException as e:
287
- print(f"Error during request to TTS server: {e}")
288
- return None
289
  except Exception as e:
290
- print(f"Error during TTS generation pipeline: {e}")
291
  traceback.print_exc()
292
  return None
293
 
294
- # --- Ollama Communication Helper ---
295
- def call_ollama_non_streaming(ollama_payload, generation_params):
296
- final_response = "[Error: Default response]"
 
 
 
297
  try:
298
- payload = {
299
- "model": OLLAMA_MODEL,
300
- "messages": ollama_payload["messages"],
301
- "temperature": generation_params.get('ollama_temperature', DEFAULT_OLLAMA_TEMP),
302
- "top_p": generation_params.get('ollama_top_p', DEFAULT_OLLAMA_TOP_P),
303
- "max_tokens": generation_params.get('ollama_max_new_tokens', DEFAULT_OLLAMA_MAX_TOKENS),
304
- "repeat_penalty": generation_params.get('ollama_repetition_penalty', DEFAULT_OLLAMA_REP_PENALTY),
305
- "stream": False
306
- }
307
-
308
- print(f" - Sending to {OLLAMA_API_ENDPOINT} with model {OLLAMA_MODEL}")
 
 
 
309
 
310
- headers = {"Content-Type": "application/json"}
311
  start_time = time.time()
312
- response = requests.post(
313
- OLLAMA_API_ENDPOINT,
314
- json=payload,
315
- headers=headers,
316
- timeout=180
 
 
 
317
  )
318
- response.raise_for_status()
319
- response_json = response.json()
320
- end_time = time.time()
321
 
322
- print(f" - LLM request took {end_time - start_time:.2f}s")
 
323
 
324
- if "choices" in response_json and len(response_json["choices"]) > 0:
325
- choice = response_json["choices"][0]
326
- if "message" in choice:
327
- final_response = choice["message"]["content"].strip()
328
- elif "text" in choice:
329
- final_response = choice["text"].strip()
330
- else:
331
- final_response = "[Error: Unexpected response format]"
332
  else:
333
- final_response = f"[Error: {response_json.get('error', 'Unknown error')}]"
334
 
335
- except requests.exceptions.RequestException as e:
336
- final_response = f"[Error connecting to LLM: {e}]"
337
  except Exception as e:
338
- final_response = f"[Unexpected Error: {e}]"
339
  traceback.print_exc()
340
-
341
- print(f" - LLM response: '{final_response[:100]}...'")
342
- return final_response
343
 
344
- # --- Main Gradio Backend Function ---
345
- def process_input_blocks(
346
- text_input: str, audio_input_path: str,
347
- auto_prefix_tts_checkbox: bool,
348
- auto_prefix_llm_checkbox: bool,
349
- plain_llm_checkbox: bool,
350
- ollama_max_new_tokens: int, ollama_temperature: float, ollama_top_p: float,
351
- ollama_top_k: int, ollama_repetition_penalty: float,
352
- tts_temperature: float, tts_top_p: float, tts_repetition_penalty: float,
353
- chat_history: list
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
354
  ):
355
- global whisper_model, snac_model
356
- original_user_input_text = ""
357
- user_display_input = None
358
- text_to_process = ""
359
- transcription_source = "text"
360
- bot_response = ""
361
- bot_audio_tuple = None
362
- audio_filepath_to_clean = None
363
- is_purely_text_input = False
364
- prefix_to_add = None
365
- force_plain_llm = False
366
-
367
- # Handle Audio Input
368
- if audio_input_path and whisper_model:
369
- if os.path.isfile(audio_input_path):
370
- audio_filepath_to_clean = audio_input_path
371
- transcription_source = "voice"
372
- print(f"Processing audio input: {audio_input_path}")
373
- try:
374
- stt_start_time = time.time()
375
- result = whisper_model.transcribe(audio_input_path, fp16=(stt_device == 'cuda'))
376
- original_user_input_text = result["text"].strip()
377
- stt_end_time = time.time()
378
- print(f" - Whisper transcription: '{original_user_input_text}' (took {stt_end_time - stt_start_time:.2f}s)")
379
- user_display_input = f"🎀 (Audio Input): {original_user_input_text}"
380
- text_to_process = original_user_input_text
381
-
382
- # Check if transcription is already a command
383
- known_prefixes = ["@tara-tts", "@jess-tts", "@leo-tts", "@leah-tts", "@dan-tts", "@mia-tts", "@zac-tts", "@zoe-tts",
384
- "@tara-llm", "@jess-llm", "@leo-llm", "@leah-llm", "@dan-llm", "@mia-llm", "@zac-llm", "@zoe-llm"]
385
- is_already_command = any(original_user_input_text.lower().startswith(p) for p in known_prefixes)
386
-
387
- if not is_already_command:
388
- if plain_llm_checkbox:
389
- prefix_to_add = None
390
- force_plain_llm = True
391
- print(f" - Plain LLM checked. Processing audio as text input for LLM.")
392
- elif auto_prefix_tts_checkbox:
393
- prefix_to_add = "@tara-tts"
394
- print(f" - Auto-prefix TTS checked. Applying to audio.")
395
- elif auto_prefix_llm_checkbox:
396
- prefix_to_add = "@tara-llm"
397
- print(f" - Auto-prefix LLM checked. Applying to audio.")
398
- else:
399
- print(f" - No default prefix checkbox checked for audio. Processing as text for LLM.")
400
-
401
- if prefix_to_add:
402
- text_to_process = f"{prefix_to_add} {original_user_input_text}"
403
- else:
404
- print(f" - Transcribed audio is already a command '{original_user_input_text[:20]}...'.")
405
- text_to_process = original_user_input_text
406
-
407
- except Exception as e:
408
- print(f"Error during Whisper transcription: {e}")
409
- traceback.print_exc()
410
- error_msg = f"[Error during local transcription: {e}]"
411
- chat_history.append((f"🎀 (Audio Input Error: {audio_input_path})", error_msg))
412
- if audio_filepath_to_clean and os.path.exists(audio_filepath_to_clean):
413
- try:
414
- os.remove(audio_filepath_to_clean)
415
- except Exception as e_clean:
416
- print(f"Warning: Could not clean up STT temp file {audio_filepath_to_clean}: {e_clean}")
417
- return chat_history, None, None
418
  else:
419
- print(f"Received invalid audio path: {audio_input_path}, falling back to text.")
420
-
421
- # Handle Text Input
422
- if not text_to_process and text_input:
423
- original_user_input_text = text_input.strip()
424
- user_display_input = original_user_input_text
425
- print(f"Processing text input: '{original_user_input_text}'")
426
- transcription_source = "text"
427
- text_to_process = original_user_input_text
428
-
429
- known_prefixes = ["@tara-tts", "@jess-tts", "@leo-tts", "@leah-tts", "@dan-tts", "@mia-tts", "@zac-tts", "@zoe-tts",
430
- "@tara-llm", "@jess-llm", "@leo-llm", "@leah-llm", "@dan-llm", "@mia-llm", "@zac-llm", "@zoe-llm"]
431
- is_already_command = any(original_user_input_text.lower().startswith(p) for p in known_prefixes)
432
-
433
- if not is_already_command:
434
- if plain_llm_checkbox:
435
- prefix_to_add = None
436
- force_plain_llm = True
437
- print(f" - Plain LLM checked. Processing text input for LLM.")
438
- elif auto_prefix_tts_checkbox:
439
- prefix_to_add = "@tara-tts"
440
- print(f" - Auto-prefix TTS checked. Applying to text.")
441
- elif auto_prefix_llm_checkbox:
442
- prefix_to_add = "@tara-llm"
443
- print(f" - Auto-prefix LLM checked. Applying to text.")
444
- else:
445
- print(f" - No default prefix checkbox enabled for text input.")
446
  else:
447
- print(f" - User provided command in text '{original_user_input_text[:20]}...', not auto-prepending.")
448
-
449
- if prefix_to_add:
450
- text_to_process = f"{prefix_to_add} {original_user_input_text}"
451
-
452
- # Cleanup audio file
453
- if audio_filepath_to_clean and os.path.exists(audio_filepath_to_clean):
454
- try:
455
- os.remove(audio_filepath_to_clean)
456
- print(f" - Cleaned up temporary STT audio file: {audio_filepath_to_clean}")
457
- except Exception as e_clean:
458
- print(f"Warning: Could not clean up temp STT audio file {audio_filepath_to_clean}: {e_clean}")
459
-
460
- if not text_to_process:
461
- print("No valid text or audio input to process.")
462
- return chat_history, None, None
463
-
464
- chat_history.append((user_display_input, None))
465
-
466
- # Process Input Text
467
- lower_text = text_to_process.lower()
468
- print(f" - Routing query ({transcription_source}): '{text_to_process[:100]}...'")
469
 
470
- all_voices = ["tara", "jess", "leo", "leah", "dan", "mia", "zac", "zoe"]
471
- tts_tags = {f"@{voice}-tts": voice for voice in all_voices}
472
- llm_tags = {f"@{voice}-llm": voice for voice in all_voices}
 
 
473
 
474
- final_bot_message = None
475
-
476
  try:
477
- matched_tts = False
478
- matched_llm_tts = False
479
-
480
- # Check Branches
481
- if not force_plain_llm:
482
- # Branch 1: Direct TTS
483
- for tag, voice in tts_tags.items():
484
- if lower_text.startswith(tag):
485
- matched_tts = True
486
- text_to_speak = text_to_process[len(tag):].strip()
487
- print(f" - Direct TTS request for voice '{voice}': '{text_to_speak[:50]}...'")
488
- if snac_model is None:
489
- raise ValueError("SNAC vocoder not loaded.")
490
- audio_output = generate_speech_gguf(
491
- text_to_speak, voice,
492
- tts_temperature, tts_top_p, tts_repetition_penalty,
493
- MAX_MAX_NEW_TOKENS
494
- )
495
- if audio_output:
496
- sample_rate, audio_data = audio_output
497
- if audio_data.dtype != np.int16:
498
- if np.issubdtype(audio_data.dtype, np.floating):
499
- max_val = np.max(np.abs(audio_data))
500
- audio_data = np.int16(audio_data/max_val*32767) if max_val > 1e-6 else np.zeros_like(audio_data, dtype=np.int16)
501
- else:
502
- audio_data = audio_data.astype(np.int16)
503
- temp_dir = "temp_audio_files"
504
- os.makedirs(temp_dir, exist_ok=True)
505
- temp_audio_path = os.path.join(temp_dir, f"temp_audio_{uuid.uuid4().hex}.wav")
506
- wavfile.write(temp_audio_path, sample_rate, audio_data)
507
- print(f" - Saved TTS audio: {temp_audio_path}")
508
- final_bot_message = (temp_audio_path, None)
509
- else:
510
- final_bot_message = f"Sorry, couldn't generate speech for '{text_to_speak[:50]}...'."
511
- break
512
-
513
- # Branch 2: LLM + TTS
514
- if not matched_tts:
515
- for tag, voice in llm_tags.items():
516
- if lower_text.startswith(tag):
517
- matched_llm_tts = True
518
- prompt_for_llm = text_to_process[len(tag):].strip()
519
- print(f" - LLM+TTS request for voice '{voice}': '{prompt_for_llm[:75]}...'")
520
- if snac_model is None:
521
- raise ValueError("SNAC vocoder not loaded.")
522
-
523
- history_before_current = chat_history[:-1]
524
- limited_history_turns = history_before_current[-CONTEXT_TURN_LIMIT:]
525
- cleaned_hist_for_llm = clean_chat_history(limited_history_turns)
526
-
527
- messages = [
528
- {"role": "system", "content": OLLAMA_SYSTEM_PROMPT}
529
- ] + cleaned_hist_for_llm + [
530
- {"role": "user", "content": prompt_for_llm}
531
- ]
532
-
533
- llm_params = {
534
- 'ollama_temperature': ollama_temperature,
535
- 'ollama_top_p': ollama_top_p,
536
- 'ollama_top_k': ollama_top_k,
537
- 'ollama_max_new_tokens': ollama_max_new_tokens,
538
- 'ollama_repetition_penalty': ollama_repetition_penalty
539
- }
540
-
541
- llm_response_text = call_ollama_non_streaming(
542
- {"messages": messages},
543
- llm_params
544
- )
545
-
546
- if llm_response_text and not llm_response_text.startswith("[Error"):
547
- audio_output = generate_speech_gguf(
548
- llm_response_text, voice,
549
- tts_temperature, tts_top_p, tts_repetition_penalty,
550
- MAX_MAX_NEW_TOKENS
551
- )
552
- if audio_output:
553
- sample_rate, audio_data = audio_output
554
- if audio_data.dtype != np.int16:
555
- if np.issubdtype(audio_data.dtype, np.floating):
556
- max_val = np.max(np.abs(audio_data))
557
- audio_data = np.int16(audio_data/max_val*32767) if max_val > 1e-6 else np.zeros_like(audio_data, dtype=np.int16)
558
- else:
559
- audio_data = audio_data.astype(np.int16)
560
- temp_dir = "temp_audio_files"
561
- os.makedirs(temp_dir, exist_ok=True)
562
- temp_audio_path = os.path.join(temp_dir, f"temp_audio_{uuid.uuid4().hex}.wav")
563
- wavfile.write(temp_audio_path, sample_rate, audio_data)
564
- print(f" - Saved LLM+TTS audio: {temp_audio_path}")
565
- final_bot_message = (temp_audio_path, llm_response_text)
566
- else:
567
- print("Warning: TTS generation failed...")
568
- final_bot_message = f"{llm_response_text}\n\n(TTS failed...)"
569
- else:
570
- final_bot_message = llm_response_text
571
- break
572
-
573
- # Branch 3: Plain LLM
574
- if force_plain_llm or (not matched_tts and not matched_llm_tts):
575
- if force_plain_llm:
576
- print(f" - Plain LLM chat mode forced by checkbox...")
577
- else:
578
- print(f" - Default text chat (no command prefix detected/added)...")
579
-
580
- history_before_current = chat_history[:-1]
581
- limited_history_turns = history_before_current[-CONTEXT_TURN_LIMIT:]
582
- cleaned_hist_for_llm = clean_chat_history(limited_history_turns)
583
-
584
- messages = [
585
- {"role": "system", "content": OLLAMA_SYSTEM_PROMPT}
586
- ] + cleaned_hist_for_llm + [
587
- {"role": "user", "content": original_user_input_text}
588
- ]
589
-
590
- llm_params = {
591
- 'ollama_temperature': ollama_temperature,
592
- 'ollama_top_p': ollama_top_p,
593
- 'ollama_top_k': ollama_top_k,
594
- 'ollama_max_new_tokens': ollama_max_new_tokens,
595
- 'ollama_repetition_penalty': ollama_repetition_penalty
596
- }
597
-
598
- final_bot_message = call_ollama_non_streaming(
599
- {"messages": messages},
600
- llm_params
601
  )
602
-
 
 
 
 
 
 
 
 
 
 
603
  except Exception as e:
604
- print(f"Error during processing: {e}")
 
605
  traceback.print_exc()
606
- final_bot_message = f"[An unexpected error occurred: {e}]"
607
-
608
- chat_history[-1] = (user_display_input, final_bot_message)
609
- return chat_history, None, None
610
 
611
  # --- Gradio Interface ---
612
- def update_prefix_checkboxes(selected_checkbox_label):
613
- if selected_checkbox_label == "tts":
614
- return gr.update(value=True), gr.update(value=False), gr.update(value=False)
615
- elif selected_checkbox_label == "llm":
616
- return gr.update(value=False), gr.update(value=True), gr.update(value=False)
617
- elif selected_checkbox_label == "plain":
618
- return gr.update(value=False), gr.update(value=False), gr.update(value=True)
619
- else:
620
- return gr.update(), gr.update(), gr.update()
621
-
622
- print("Setting up Gradio Interface with gr.Blocks...")
623
- theme_to_use = None
624
-
625
- with gr.Blocks(theme=theme_to_use) as demo:
626
- gr.Markdown(f"# Orpheus Edge 🎀 ({OLLAMA_MODEL}) Chat & TTS")
627
-
628
- chatbot = gr.Chatbot(label="Chat History", height=500)
629
-
630
- with gr.Row():
631
- with gr.Column(scale=3):
632
- text_input_box = gr.Textbox(label="Type your message or use microphone", lines=2)
633
- with gr.Column(scale=1):
634
- audio_input_mic = gr.Audio(label="Record Audio Input", type="filepath")
635
-
636
- with gr.Row():
637
- auto_prefix_tts_checkbox = gr.Checkbox(label="Default to TTS (@tara-tts)", value=True, elem_id="cb_tts")
638
- auto_prefix_llm_checkbox = gr.Checkbox(label="Default to LLM+TTS (@tara-llm)", value=False, elem_id="cb_llm")
639
- plain_llm_checkbox = gr.Checkbox(label="Plain LLM Chat (Text Out)", value=False, elem_id="cb_plain")
640
-
641
- with gr.Row():
642
- submit_button = gr.Button("Send / Submit")
643
- clear_button = gr.ClearButton([text_input_box, audio_input_mic, chatbot])
644
-
645
- with gr.Accordion("Generation Parameters", open=False):
646
- gr.Markdown("### LLM Parameters")
647
- ollama_max_new_tokens_slider = gr.Slider(label="Max New Tokens", minimum=32, maximum=4096, step=32, value=DEFAULT_OLLAMA_MAX_TOKENS)
648
- ollama_temperature_slider = gr.Slider(label="Temperature", minimum=0.0, maximum=2.0, step=0.05, value=DEFAULT_OLLAMA_TEMP)
649
- ollama_top_p_slider = gr.Slider(label="Top-p", minimum=0.05, maximum=1.0, step=0.05, value=DEFAULT_OLLAMA_TOP_P)
650
- ollama_top_k_slider = gr.Slider(label="Top-k", minimum=1, maximum=100, step=1, value=DEFAULT_OLLAMA_TOP_K)
651
- ollama_repetition_penalty_slider = gr.Slider(label="Repetition Penalty", minimum=1.0, maximum=2.0, step=0.05, value=DEFAULT_OLLAMA_REP_PENALTY)
652
 
653
- gr.Markdown("---")
654
- gr.Markdown("### TTS Parameters")
655
- tts_temperature_slider = gr.Slider(label="Temperature", minimum=0.0, maximum=2.0, step=0.05, value=DEFAULT_TTS_TEMP)
656
- tts_top_p_slider = gr.Slider(label="Top-p", minimum=0.05, maximum=1.0, step=0.05, value=DEFAULT_TTS_TOP_P)
657
- tts_repetition_penalty_slider = gr.Slider(label="Repetition Penalty", minimum=1.0, maximum=2.0, step=0.05, value=DEFAULT_TTS_REP_PENALTY)
658
-
659
- param_inputs = [
660
- ollama_max_new_tokens_slider, ollama_temperature_slider, ollama_top_p_slider,
661
- ollama_top_k_slider, ollama_repetition_penalty_slider,
662
- tts_temperature_slider, tts_top_p_slider, tts_repetition_penalty_slider
663
- ]
664
-
665
- auto_prefix_tts_checkbox.change(
666
- lambda: update_prefix_checkboxes("tts"),
667
- None,
668
- [auto_prefix_tts_checkbox, auto_prefix_llm_checkbox, plain_llm_checkbox]
669
- )
670
- auto_prefix_llm_checkbox.change(
671
- lambda: update_prefix_checkboxes("llm"),
672
- None,
673
- [auto_prefix_tts_checkbox, auto_prefix_llm_checkbox, plain_llm_checkbox]
674
- )
675
- plain_llm_checkbox.change(
676
- lambda: update_prefix_checkboxes("plain"),
677
- None,
678
- [auto_prefix_tts_checkbox, auto_prefix_llm_checkbox, plain_llm_checkbox]
679
- )
680
-
681
- all_inputs = [
682
- text_input_box, audio_input_mic,
683
- auto_prefix_tts_checkbox, auto_prefix_llm_checkbox, plain_llm_checkbox
684
- ] + param_inputs + [chatbot]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
685
 
686
- submit_button.click(
687
- fn=process_input_blocks,
688
- inputs=all_inputs,
689
- outputs=[chatbot, text_input_box, audio_input_mic]
690
- )
691
- text_input_box.submit(
692
- fn=process_input_blocks,
693
- inputs=all_inputs,
694
- outputs=[chatbot, text_input_box, audio_input_mic]
695
- )
696
 
697
- # --- Application Entry Point ---
698
  if __name__ == "__main__":
699
- print("-" * 50)
700
- print(f"Launching Gradio {gr.__version__} Interface")
701
- print(f"Whisper STT Model: {WHISPER_MODEL_NAME} on {stt_device}")
702
- print(f"SNAC Vocoder loaded to {tts_device}")
703
- print(f"Server URL: {SERVER_BASE_URL}")
704
- print(f"LLM Model: {OLLAMA_MODEL}")
705
- print(f"TTS Model: {TTS_MODEL}")
706
- print("-" * 50)
707
- print("Default Parameters:")
708
- print(f" LLM: Temp={DEFAULT_OLLAMA_TEMP}, TopP={DEFAULT_OLLAMA_TOP_P}")
709
- print(f" TTS: Temp={DEFAULT_TTS_TEMP}, TopP={DEFAULT_TTS_TOP_P}")
710
- print("-" * 50)
711
- print("Ensure your LM Studio server is running with both models loaded")
712
- os.makedirs("temp_audio_files", exist_ok=True)
713
- demo.launch(share=False)
714
- print("Gradio Interface launched. Press Ctrl+C to stop.")
 
6
  import json
7
  import time
8
  import traceback
 
9
  import gradio as gr
10
  import torch
11
  import numpy as np
12
  import scipy.io.wavfile as wavfile
13
+ import openai
14
  from dotenv import load_dotenv
15
+ from huggingface_hub import snapshot_download
16
+ from transformers import AutoModelForCausalLM, AutoTokenizer
17
 
18
  # --- Whisper Import ---
19
  try:
 
28
  # --- SNAC Import ---
29
  try:
30
  from snac import SNAC
31
+ print("SNAC library imported successfully.")
32
  except ImportError:
33
  print("ERROR: SNAC library not found. Please install it:")
34
  print("pip install git+https://github.com/hubertsiuzdak/snac.git")
 
37
  # --- Load Environment Variables ---
38
  load_dotenv()
39
 
40
+ # --- OpenAI Configuration ---
41
+ OPENAI_API_KEY = os.getenv("OPENAI_API_KEY","sk-proj-B97S8pFXA6YhSSIFileVT3BlbkFJWQvPI0PON1KZReYYRGge")
42
+ OPENAI_API_BASE = os.getenv("OPENAI_API_BASE", "https://api.openai.com/v1")
43
+ OPENAI_MODEL = os.getenv("OPENAI_MODEL", "gpt-4o")
44
 
45
+ # Initialize OpenAI Client
46
+ if OPENAI_API_KEY:
47
+ print("OpenAI API key found.")
48
+ else:
49
+ print("ERROR: OPENAI_API_KEY environment variable not set.")
50
+ exit(1)
51
 
52
+ SYSTEM_PROMPT = "You are a helpful AI assistant. Respond in a conversational and friendly manner."
 
 
53
 
54
  # --- Device Setup ---
55
+ device = "cuda" if torch.cuda.is_available() else "cpu"
56
+ stt_device = device
57
+ print(f"Using device: {device}")
 
 
 
 
 
58
 
59
  # --- Model Loading ---
60
  print("Loading SNAC vocoder model...")
 
61
  try:
62
  snac_model = SNAC.from_pretrained("hubertsiuzdak/snac_24khz")
63
+ snac_model = snac_model.to(device)
64
  snac_model.eval()
65
+ print(f"SNAC vocoder loaded to {device}")
66
  except Exception as e:
67
  print(f"Error loading SNAC model: {e}")
68
+ exit(1)
69
+
70
+ print("Loading TTS model...")
71
+ tts_model_name = "mrrtmob/tts-khm-3"
72
+ try:
73
+ snapshot_download(
74
+ repo_id=tts_model_name,
75
+ allow_patterns=["config.json", "*.safetensors", "model.safetensors.index.json"],
76
+ ignore_patterns=["optimizer.pt", "pytorch_model.bin", "training_args.bin", "scheduler.pt", "tokenizer.json", "tokenizer_config.json", "special_tokens_map.json", "vocab.json", "merges.txt", "tokenizer.*"]
77
+ )
78
+
79
+ tts_model = AutoModelForCausalLM.from_pretrained(tts_model_name, torch_dtype=torch.bfloat16)
80
+ tts_model.to(device)
81
+ tts_tokenizer = AutoTokenizer.from_pretrained(tts_model_name)
82
+ print(f"TTS model '{tts_model_name}' loaded to {device}")
83
+ except Exception as e:
84
+ print(f"Error loading TTS model: {e}")
85
+ exit(1)
86
 
87
  print("Loading Whisper STT model...")
88
  WHISPER_MODEL_NAME = "base.en"
 
89
  try:
90
  whisper_model = whisper.load_model(WHISPER_MODEL_NAME, device=stt_device)
91
  print(f"Whisper model '{WHISPER_MODEL_NAME}' loaded successfully.")
92
  except Exception as e:
93
  print(f"Error loading Whisper model: {e}")
94
+ exit(1)
95
 
96
  # --- Constants ---
97
  MAX_MAX_NEW_TOKENS = 4096
98
+ DEFAULT_OPENAI_MAX_TOKENS = 512
99
  MAX_SEED = np.iinfo(np.int32).max
100
+ DEFAULT_OPENAI_TEMP = 0.7
101
+ DEFAULT_OPENAI_TOP_P = 0.9
102
+ DEFAULT_TTS_TEMP = 0.6
103
+ DEFAULT_TTS_TOP_P = 0.95
104
+ CONTEXT_TURN_LIMIT = 5
105
+
106
+ # --- TTS Functions ---
107
+ def process_prompt(prompt, voice, tokenizer, device):
108
+ """Process text prompt for TTS model"""
109
+ prompt = f"{voice}: {prompt}"
110
+ input_ids = tokenizer(prompt, return_tensors="pt").input_ids
111
+
112
+ start_token = torch.tensor([[128259]], dtype=torch.int64)
113
+ end_tokens = torch.tensor([[128009, 128260]], dtype=torch.int64)
114
+
115
+ modified_input_ids = torch.cat([start_token, input_ids, end_tokens], dim=1)
116
+ attention_mask = torch.ones_like(modified_input_ids)
117
+
118
+ return modified_input_ids.to(device), attention_mask.to(device)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
119
 
120
+ def parse_output(generated_ids):
121
+ """Parse TTS model output tokens"""
122
+ token_to_find = 128257
123
+ token_to_remove = 128258
124
+
125
+ token_indices = (generated_ids == token_to_find).nonzero(as_tuple=True)
126
+ if len(token_indices[1]) > 0:
127
+ last_occurrence_idx = token_indices[1][-1].item()
128
+ cropped_tensor = generated_ids[:, last_occurrence_idx+1:]
129
+ else:
130
+ cropped_tensor = generated_ids
131
+
132
+ processed_rows = []
133
+ for row in cropped_tensor:
134
+ masked_row = row[row != token_to_remove]
135
+ processed_rows.append(masked_row)
136
+
137
+ code_lists = []
138
+ for row in processed_rows:
139
+ row_length = row.size(0)
140
+ new_length = (row_length // 7) * 7
141
+ trimmed_row = row[:new_length]
142
+ trimmed_row = [t - 128266 for t in trimmed_row]
143
+ code_lists.append(trimmed_row)
144
+
145
+ return code_lists[0] if code_lists else []
146
 
147
+ def redistribute_codes(code_list, snac_model):
148
+ """Convert codes to audio using SNAC"""
149
+ if not code_list:
150
+ return np.array([])
151
+
152
+ snac_device = next(snac_model.parameters()).device
153
 
 
154
  layer_1, layer_2, layer_3 = [], [], []
 
 
155
 
156
+ for i in range(len(code_list) // 7):
157
+ base = 7 * i
158
+ if base + 6 < len(code_list):
159
+ layer_1.append(code_list[base])
160
+ layer_2.append(code_list[base+1] - 4096)
161
+ layer_3.append(code_list[base+2] - (2*4096))
162
+ layer_3.append(code_list[base+3] - (3*4096))
163
+ layer_2.append(code_list[base+4] - (4*4096))
164
+ layer_3.append(code_list[base+5] - (5*4096))
165
+ layer_3.append(code_list[base+6] - (6*4096))
166
 
167
+ if not layer_1:
168
+ return np.array([])
169
 
170
+ codes = [
171
+ torch.tensor(layer_1, device=snac_device).unsqueeze(0),
172
+ torch.tensor(layer_2, device=snac_device).unsqueeze(0),
173
+ torch.tensor(layer_3, device=snac_device).unsqueeze(0)
174
+ ]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
175
 
176
+ with torch.no_grad():
177
+ audio_hat = snac_model.decode(codes)
178
+
179
+ return audio_hat.detach().squeeze().cpu().numpy()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
180
 
181
+ def generate_speech(text, voice="tara", temperature=0.6, top_p=0.95, max_new_tokens=1200):
182
+ """Generate speech from text"""
183
+ if not text.strip():
184
  return None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
185
 
186
  try:
187
+ print(f"Generating speech for: '{text[:50]}...'")
 
 
 
 
 
 
 
 
188
 
189
+ input_ids, attention_mask = process_prompt(text, voice, tts_tokenizer, device)
190
 
191
+ with torch.no_grad():
192
+ generated_ids = tts_model.generate(
193
+ input_ids=input_ids,
194
+ attention_mask=attention_mask,
195
+ max_new_tokens=max_new_tokens,
196
+ do_sample=True,
197
+ temperature=temperature,
198
+ top_p=top_p,
199
+ repetition_penalty=1.1,
200
+ num_return_sequences=1,
201
+ eos_token_id=128258,
202
+ )
203
+
204
+ code_list = parse_output(generated_ids)
205
+ audio_samples = redistribute_codes(code_list, snac_model)
206
+
207
+ if len(audio_samples) > 0:
208
+ print(f"Generated audio with shape: {audio_samples.shape}")
 
 
 
 
209
  return (24000, audio_samples)
 
210
  else:
211
+ print("No audio generated")
212
  return None
213
 
 
 
 
214
  except Exception as e:
215
+ print(f"Error during speech generation: {e}")
216
  traceback.print_exc()
217
  return None
218
 
219
+ # --- STT Function ---
220
+ def transcribe_audio(audio_path):
221
+ """Transcribe audio to text using Whisper"""
222
+ if not audio_path:
223
+ return ""
224
+
225
  try:
226
+ print(f"Transcribing audio: {audio_path}")
227
+ result = whisper_model.transcribe(audio_path, fp16=(stt_device == 'cuda'))
228
+ transcribed_text = result["text"].strip()
229
+ print(f"Transcribed: '{transcribed_text}'")
230
+ return transcribed_text
231
+ except Exception as e:
232
+ print(f"Error during transcription: {e}")
233
+ return f"[Transcription Error: {e}]"
234
+
235
+ # --- OpenAI LLM Function ---
236
+ def call_openai_llm(messages, temperature=0.7, top_p=0.9, max_tokens=512):
237
+ """Call OpenAI API for text generation"""
238
+ try:
239
+ client = openai.OpenAI(api_key=OPENAI_API_KEY, base_url=OPENAI_API_BASE)
240
 
241
+ print(f"Sending to OpenAI with model {OPENAI_MODEL}")
242
  start_time = time.time()
243
+
244
+ response = client.chat.completions.create(
245
+ model=OPENAI_MODEL,
246
+ messages=messages,
247
+ temperature=temperature,
248
+ top_p=top_p,
249
+ max_tokens=max_tokens,
250
+ stream=False
251
  )
 
 
 
252
 
253
+ end_time = time.time()
254
+ print(f"LLM request took {end_time - start_time:.2f}s")
255
 
256
+ if response.choices:
257
+ response_text = response.choices[0].message.content.strip()
258
+ print(f"LLM response: '{response_text[:100]}...'")
259
+ return response_text
 
 
 
 
260
  else:
261
+ return "[Error: No response from model]"
262
 
263
+ except openai.APIError as e:
264
+ return f"[OpenAI API Error: {e}]"
265
  except Exception as e:
 
266
  traceback.print_exc()
267
+ return f"[Unexpected Error: {e}]"
 
 
268
 
269
+ # --- Utility Functions ---
270
+ def clean_chat_history(chat_history, limit=CONTEXT_TURN_LIMIT):
271
+ """Clean and format chat history for OpenAI API"""
272
+ if not chat_history:
273
+ return []
274
+
275
+ messages = []
276
+ recent_history = chat_history[-limit:] if len(chat_history) > limit else chat_history
277
+
278
+ for user_msg, bot_msg in recent_history:
279
+ if user_msg and isinstance(user_msg, str):
280
+ # Handle audio input display format
281
+ if user_msg.startswith("🎀 Audio: "):
282
+ user_text = user_msg.replace("🎀 Audio: ", "")
283
+ else:
284
+ user_text = user_msg
285
+
286
+ if user_text.strip():
287
+ messages.append({"role": "user", "content": user_text.strip()})
288
+
289
+ if bot_msg and isinstance(bot_msg, str) and not bot_msg.startswith("[Error"):
290
+ messages.append({"role": "assistant", "content": bot_msg.strip()})
291
+
292
+ return messages
293
+
294
+ # --- Main Processing Function ---
295
+ def process_conversation(
296
+ text_input, audio_input,
297
+ mode, enable_tts,
298
+ openai_temp, openai_top_p, openai_max_tokens,
299
+ tts_temp, tts_top_p,
300
+ chat_history
301
  ):
302
+ """Main function to process user input and generate responses"""
303
+
304
+ user_text = ""
305
+ display_input = ""
306
+
307
+ # Handle audio input
308
+ if audio_input and mode in ["audio_only", "audio_text"]:
309
+ transcribed = transcribe_audio(audio_input)
310
+ if transcribed and not transcribed.startswith("["):
311
+ user_text = transcribed
312
+ display_input = f"🎀 Audio: {transcribed}"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
313
  else:
314
+ chat_history.append((f"🎀 Audio Input", transcribed))
315
+ return chat_history, "", None
316
+
317
+ # Handle text input
318
+ if text_input and mode in ["text_only", "audio_text"]:
319
+ if user_text: # If we already have audio transcription
320
+ user_text += f" {text_input}"
321
+ display_input = f"🎀 Audio: {transcribed} + Text: {text_input}"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
322
  else:
323
+ user_text = text_input
324
+ display_input = text_input
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
325
 
326
+ if not user_text.strip():
327
+ return chat_history, "", None
328
+
329
+ # Add user message to chat
330
+ chat_history.append((display_input, "Thinking..."))
331
 
 
 
332
  try:
333
+ # Prepare messages for OpenAI
334
+ cleaned_history = clean_chat_history(chat_history[:-1]) # Exclude current "Thinking..." entry
335
+ messages = [{"role": "system", "content": SYSTEM_PROMPT}] + cleaned_history
336
+ messages.append({"role": "user", "content": user_text})
337
+
338
+ # Get LLM response
339
+ llm_response = call_openai_llm(
340
+ messages=messages,
341
+ temperature=openai_temp,
342
+ top_p=openai_top_p,
343
+ max_tokens=openai_max_tokens
344
+ )
345
+
346
+ if llm_response.startswith("[Error"):
347
+ chat_history[-1] = (display_input, llm_response)
348
+ return chat_history, "", None
349
+
350
+ # Generate TTS if enabled
351
+ bot_response = llm_response
352
+ audio_output = None
353
+
354
+ if enable_tts and llm_response:
355
+ audio_result = generate_speech(
356
+ text=llm_response,
357
+ temperature=tts_temp,
358
+ top_p=tts_top_p
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
359
  )
360
+
361
+ if audio_result:
362
+ audio_output = audio_result
363
+ # For display purposes, we'll show text + indicate audio is available
364
+ bot_response = llm_response
365
+
366
+ # Update chat history
367
+ chat_history[-1] = (display_input, bot_response)
368
+
369
+ return chat_history, "", audio_output
370
+
371
  except Exception as e:
372
+ error_msg = f"[Processing Error: {e}]"
373
+ chat_history[-1] = (display_input, error_msg)
374
  traceback.print_exc()
375
+ return chat_history, "", None
 
 
 
376
 
377
  # --- Gradio Interface ---
378
+ def create_interface():
379
+ with gr.Blocks(title="STT + OpenAI + TTS Demo") as demo:
380
+ gr.Markdown("""
381
+ # 🎀 Complete AI Assistant: Speech-to-Text + OpenAI + Text-to-Speech
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
382
 
383
+ This demo combines:
384
+ - **Whisper STT**: Convert your speech to text
385
+ - **OpenAI LLM**: Generate intelligent responses
386
+ - **Local TTS**: Convert responses back to speech
387
+ """)
388
+
389
+ with gr.Row():
390
+ chatbot = gr.Chatbot(
391
+ label="Conversation",
392
+ height=400,
393
+ show_copy_button=True
394
+ )
395
+
396
+ with gr.Row():
397
+ with gr.Column(scale=3):
398
+ text_input = gr.Textbox(
399
+ label="Type your message",
400
+ placeholder="Enter your message here...",
401
+ lines=2
402
+ )
403
+ with gr.Column(scale=2):
404
+ audio_input = gr.Audio(
405
+ label="Record Audio",
406
+ type="filepath"
407
+ )
408
+
409
+ with gr.Row():
410
+ mode = gr.Radio(
411
+ choices=["text_only", "audio_only", "audio_text"],
412
+ value="text_only",
413
+ label="Input Mode",
414
+ info="How do you want to provide input?"
415
+ )
416
+ enable_tts = gr.Checkbox(
417
+ label="Enable Text-to-Speech Output",
418
+ value=True,
419
+ info="Convert responses to speech"
420
+ )
421
+
422
+ with gr.Row():
423
+ submit_btn = gr.Button("Send", variant="primary", size="lg")
424
+ clear_btn = gr.Button("Clear Conversation", size="lg")
425
+
426
+ audio_output = gr.Audio(
427
+ label="AI Response (Audio)",
428
+ type="numpy",
429
+ autoplay=False
430
+ )
431
+
432
+ with gr.Accordion("βš™οΈ Advanced Settings", open=False):
433
+ gr.Markdown("### OpenAI Settings")
434
+ with gr.Row():
435
+ openai_temp = gr.Slider(
436
+ minimum=0.1, maximum=2.0, value=DEFAULT_OPENAI_TEMP, step=0.1,
437
+ label="Temperature",
438
+ info="Higher = more creative"
439
+ )
440
+ openai_top_p = gr.Slider(
441
+ minimum=0.1, maximum=1.0, value=DEFAULT_OPENAI_TOP_P, step=0.05,
442
+ label="Top-p",
443
+ info="Nucleus sampling"
444
+ )
445
+ openai_max_tokens = gr.Slider(
446
+ minimum=50, maximum=2048, value=DEFAULT_OPENAI_MAX_TOKENS, step=50,
447
+ label="Max Tokens",
448
+ info="Maximum response length"
449
+ )
450
+
451
+ gr.Markdown("### TTS Settings")
452
+ with gr.Row():
453
+ tts_temp = gr.Slider(
454
+ minimum=0.1, maximum=1.5, value=DEFAULT_TTS_TEMP, step=0.05,
455
+ label="TTS Temperature",
456
+ info="Speech expressiveness"
457
+ )
458
+ tts_top_p = gr.Slider(
459
+ minimum=0.1, maximum=1.0, value=DEFAULT_TTS_TOP_P, step=0.05,
460
+ label="TTS Top-p",
461
+ info="Speech variation"
462
+ )
463
+
464
+ # Event handlers
465
+ inputs = [
466
+ text_input, audio_input,
467
+ mode, enable_tts,
468
+ openai_temp, openai_top_p, openai_max_tokens,
469
+ tts_temp, tts_top_p,
470
+ chatbot
471
+ ]
472
+
473
+ outputs = [chatbot, text_input, audio_output]
474
+
475
+ submit_btn.click(
476
+ fn=process_conversation,
477
+ inputs=inputs,
478
+ outputs=outputs
479
+ )
480
+
481
+ text_input.submit(
482
+ fn=process_conversation,
483
+ inputs=inputs,
484
+ outputs=outputs
485
+ )
486
+
487
+ clear_btn.click(
488
+ fn=lambda: ([], "", None),
489
+ outputs=[chatbot, text_input, audio_output]
490
+ )
491
+
492
+ # Examples
493
+ gr.Examples(
494
+ examples=[
495
+ ["Hello, how are you today?"],
496
+ ["Can you explain quantum computing in simple terms?"],
497
+ ["Tell me a short joke"],
498
+ ["What's the weather like?"]
499
+ ],
500
+ inputs=text_input,
501
+ )
502
 
503
+ return demo
 
 
 
 
 
 
 
 
 
504
 
505
+ # --- Launch Application ---
506
  if __name__ == "__main__":
507
+ print("-" * 60)
508
+ print("πŸš€ Initializing Complete AI Assistant")
509
+ print(f"πŸ“± Whisper STT: {WHISPER_MODEL_NAME} on {stt_device}")
510
+ print(f"πŸ€– OpenAI Model: {OPENAI_MODEL}")
511
+ print(f"πŸ”Š TTS Model: {tts_model_name}")
512
+ print(f"πŸ’Ύ SNAC Vocoder on {device}")
513
+ print("-" * 60)
514
+
515
+ demo = create_interface()
516
+ demo.launch(
517
+ share=False,
518
+ server_name="0.0.0.0",
519
+ server_port=7860
520
+ )
 
 
requirements.txt CHANGED
@@ -6,4 +6,6 @@ spaces
6
  openai-whisper
7
  requests
8
  gradio
9
- scipy
 
 
 
6
  openai-whisper
7
  requests
8
  gradio
9
+ scipy
10
+ openai
11
+ huggingface-hub