EYEDOL commited on
Commit
e1a9f6f
·
verified ·
1 Parent(s): d507929

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +209 -79
app.py CHANGED
@@ -1,28 +1,44 @@
1
  # -*- coding: utf-8 -*-
2
  """
3
- This script implements a multi-modal Swahili assistant for Hugging Face Spaces.
4
- It uses Gradio for the user interface and loads models from the HF Hub.
 
 
 
 
 
 
 
 
 
 
 
 
 
5
  """
6
 
7
- import gradio as gr
 
 
8
  import numpy as np
9
- import onnxruntime
10
- import torch
11
  import librosa
12
- from transformers import AutoProcessor, AutoModelForSpeechSeq2Seq, AutoTokenizer, pipeline, TextIteratorStreamer
13
  from scipy.io.wavfile import write as write_wav
14
- import os
15
- import re
16
  from huggingface_hub import login
17
- import threading
18
 
19
- hf_token = os.environ.get("hugface") # Using "HF_TOKEN" is the standard on Spaces
20
- if not hf_token:
21
- raise ValueError("HF_TOKEN not found. Please set it in Hugging Face Space repository secrets.")
22
- login(token=hf_token)
23
- print("Successfully logged into Hugging Face Hub!")
 
 
 
 
24
 
25
- # --- Configuration ---
26
  STT_MODEL_ID = "EYEDOL/SALAMA_C3"
27
  LLM_MODEL_ID = "EYEDOL/Llama-3.2-1B_ON_ALPACA5"
28
  TTS_TOKENIZER_ID = "facebook/mms-tts-swh"
@@ -31,6 +47,18 @@ TTS_ONNX_MODEL_PATH = "swahili_tts.onnx"
31
  TEMP_DIR = "temp"
32
  os.makedirs(TEMP_DIR, exist_ok=True)
33
 
 
 
 
 
 
 
 
 
 
 
 
 
34
 
35
  class WeeboAssistant:
36
  def __init__(self):
@@ -48,125 +76,224 @@ class WeeboAssistant:
48
  self.torch_dtype = torch.bfloat16 if self.device == "cuda" else torch.float32
49
  print(f"Using device: {self.device}")
50
 
51
- # STT
52
  print(f"Loading STT model: {STT_MODEL_ID}")
53
  self.stt_processor = AutoProcessor.from_pretrained(STT_MODEL_ID)
 
54
  self.stt_model = AutoModelForSpeechSeq2Seq.from_pretrained(
55
- STT_MODEL_ID,
56
- torch_dtype=self.torch_dtype,
57
- low_cpu_mem_usage=True,
58
- use_safetensors=True
59
- ).to(self.device)
 
 
 
 
 
60
  print("STT model loaded successfully.")
61
 
62
- # LLM
63
  print(f"Loading LLM: {LLM_MODEL_ID}")
64
- self.llm_tokenizer = AutoTokenizer.from_pretrained(LLM_MODEL_ID)
65
- self.llm_pipeline = pipeline(
66
- "text-generation",
67
- model=LLM_MODEL_ID,
68
- model_kwargs={"torch_dtype": self.torch_dtype},
69
- tokenizer=self.llm_tokenizer,
70
- device=self.device,
71
- )
72
- print("LLM pipeline loaded successfully.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
73
 
74
- # TTS
75
  print(f"Loading TTS model: {TTS_ONNX_MODEL_PATH}")
76
- self.tts_session = onnxruntime.InferenceSession(
77
- TTS_ONNX_MODEL_PATH,
78
- providers=["CUDAExecutionProvider", "CPUExecutionProvider"]
79
- )
 
80
  self.tts_tokenizer = AutoTokenizer.from_pretrained(TTS_TOKENIZER_ID)
81
  print("TTS model and tokenizer loaded successfully.")
82
 
83
  print("-" * 30)
84
  print("All models initialized successfully! ✅")
85
 
 
86
  def transcribe_audio(self, audio_tuple):
 
87
  if audio_tuple is None:
88
  return ""
89
  sample_rate, audio_data = audio_tuple
 
90
  if audio_data.ndim > 1:
91
  audio_data = audio_data.mean(axis=1)
 
92
  if audio_data.dtype != np.float32:
93
- audio_data = audio_data.astype(np.float32) / np.iinfo(audio_data.dtype).max
 
 
 
 
 
 
94
  if sample_rate != self.STT_SAMPLE_RATE:
95
  audio_data = librosa.resample(y=audio_data, orig_sr=sample_rate, target_sr=self.STT_SAMPLE_RATE)
96
  if len(audio_data) < 1000:
97
  return "(Audio too short to transcribe)"
 
98
  inputs = self.stt_processor(audio_data, sampling_rate=self.STT_SAMPLE_RATE, return_tensors="pt")
99
- inputs = {k: v.to(self.device) for k, v in inputs.items()}
100
  with torch.no_grad():
101
  generated_ids = self.stt_model.generate(**inputs, max_new_tokens=128)
102
  transcription = self.stt_processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
103
  return transcription.strip()
104
 
105
  def generate_speech(self, text):
 
106
  if not text:
107
  return None
108
  text = text.strip()
 
109
  inputs = self.tts_tokenizer(text, return_tensors="np")
110
- ort_inputs = {self.tts_session.get_inputs()[0].name: inputs.input_ids}
 
111
  audio_waveform = self.tts_session.run(None, ort_inputs)[0].flatten()
 
 
 
 
 
 
 
 
 
 
112
  output_path = os.path.join(TEMP_DIR, f"{os.urandom(8).hex()}.wav")
113
- write_wav(output_path, self.TTS_SAMPLE_RATE, audio_waveform)
114
  return output_path
115
 
116
  def get_llm_response(self, chat_history):
117
- # <-- FIX: Reverted to using a 'system' role, which is correct for Llama 3 -->
118
- messages = [{'role': 'system', 'content': self.SYSTEM_PROMPT}]
 
 
 
 
 
119
  for user_msg, assistant_msg in chat_history:
120
- messages.append({"role": "user", "content": user_msg})
 
 
121
  if assistant_msg:
122
- messages.append({"role": "assistant", "content": assistant_msg})
123
-
124
- prompt = self.llm_pipeline.tokenizer.apply_chat_template(
125
- messages, tokenize=False, add_generation_prompt=True
126
- )
127
- terminators = [
128
- self.llm_pipeline.tokenizer.eos_token_id,
129
- self.llm_pipeline.tokenizer.convert_tokens_to_ids("<|eot_id|>")
130
- ]
131
-
132
- streamer = TextIteratorStreamer(
133
- self.llm_pipeline.tokenizer, skip_prompt=True, skip_special_tokens=True
134
- )
135
-
136
  generation_kwargs = dict(
137
- streamer=streamer,
 
138
  max_new_tokens=512,
139
- eos_token_id=terminators,
140
  do_sample=True,
141
  temperature=0.6,
142
  top_p=0.9,
 
 
143
  )
144
-
145
- thread = threading.Thread(target=self.llm_pipeline, args=[prompt], kwargs=generation_kwargs)
146
- thread.start()
147
-
 
148
  return streamer
149
 
 
 
150
  assistant = WeeboAssistant()
151
 
152
 
 
 
153
  def s2s_pipeline(audio_input, chat_history):
 
154
  user_text = assistant.transcribe_audio(audio_input)
155
  if not user_text or user_text.startswith("("):
156
  chat_history.append((user_text or "(No valid speech detected)", None))
157
  yield chat_history, None, "Please record your voice again."
158
  return
159
-
160
  chat_history.append((user_text, ""))
161
  yield chat_history, None, "..."
162
-
163
  response_stream = assistant.get_llm_response(chat_history)
164
  llm_response_text = ""
165
  for text_chunk in response_stream:
166
  llm_response_text += text_chunk
 
167
  chat_history[-1] = (user_text, llm_response_text)
168
  yield chat_history, None, llm_response_text
169
-
 
170
  final_audio_path = assistant.generate_speech(llm_response_text)
171
  yield chat_history, final_audio_path, llm_response_text
172
 
@@ -174,7 +301,7 @@ def s2s_pipeline(audio_input, chat_history):
174
  def t2t_pipeline(text_input, chat_history):
175
  chat_history.append((text_input, ""))
176
  yield chat_history
177
-
178
  response_stream = assistant.get_llm_response(chat_history)
179
  llm_response_text = ""
180
  for text_chunk in response_stream:
@@ -187,10 +314,11 @@ def clear_textbox():
187
  return gr.Textbox(value="")
188
 
189
 
 
190
  with gr.Blocks(theme=gr.themes.Soft(), title="Msaidizi wa Kiswahili") as demo:
191
  gr.Markdown("# 🤖 Msaidizi wa Sauti wa Kiswahili (Swahili Voice Assistant)")
192
  gr.Markdown("Ongea na msaidizi kwa Kiswahili. Toa sauti, andika maandishi, na upate majibu kwa sauti au maandishi.")
193
-
194
  with gr.Tabs():
195
  with gr.TabItem("🎙️ Sauti-kwa-Sauti (Speech-to-Speech)"):
196
  with gr.Row():
@@ -201,7 +329,7 @@ with gr.Blocks(theme=gr.themes.Soft(), title="Msaidizi wa Kiswahili") as demo:
201
  s2s_chatbot = gr.Chatbot(label="Mazungumzo (Conversation)", bubble_full_width=False, height=400)
202
  s2s_audio_out = gr.Audio(type="filepath", label="Jibu la Sauti (Audio Response)", autoplay=True)
203
  s2s_text_out = gr.Textbox(label="Jibu la Maandishi (Text Response)", interactive=False)
204
-
205
  with gr.TabItem("⌨️ Maandishi-kwa-Maandishi (Text-to-Text)"):
206
  t2t_chatbot = gr.Chatbot(label="Mazungumzo (Conversation)", bubble_full_width=False, height=500)
207
  with gr.Row():
@@ -225,46 +353,48 @@ with gr.Blocks(theme=gr.themes.Soft(), title="Msaidizi wa Kiswahili") as demo:
225
  fn=s2s_pipeline,
226
  inputs=[s2s_audio_in, s2s_chatbot],
227
  outputs=[s2s_chatbot, s2s_audio_out, s2s_text_out],
228
- queue=True
229
  ).then(
230
  fn=lambda: gr.Audio(value=None),
231
  inputs=None,
232
- outputs=s2s_audio_in
233
  )
234
 
235
  t2t_submit_btn.click(
236
  fn=t2t_pipeline,
237
  inputs=[t2t_text_in, t2t_chatbot],
238
  outputs=[t2t_chatbot],
239
- queue=True
240
  ).then(
241
  fn=clear_textbox,
242
  inputs=None,
243
- outputs=t2t_text_in
244
  )
245
-
246
  t2t_text_in.submit(
247
  fn=t2t_pipeline,
248
  inputs=[t2t_text_in, t2t_chatbot],
249
  outputs=[t2t_chatbot],
250
- queue=True
251
  ).then(
252
  fn=clear_textbox,
253
  inputs=None,
254
- outputs=t2t_text_in
255
  )
256
 
257
  tool_s2t_btn.click(
258
  fn=assistant.transcribe_audio,
259
  inputs=tool_s2t_audio_in,
260
  outputs=tool_s2t_text_out,
261
- queue=True
262
  )
 
263
  tool_t2s_btn.click(
264
  fn=assistant.generate_speech,
265
  inputs=tool_t2s_text_in,
266
  outputs=tool_t2s_audio_out,
267
- queue=True
268
  )
269
 
270
- demo.queue().launch(debug=True)
 
 
1
  # -*- coding: utf-8 -*-
2
  """
3
+ Fixed and self-contained Swahili multimodal assistant for Hugging Face Spaces.
4
+
5
+ Key fixes / improvements over original:
6
+ - Robust loading of an LLM repo that may lack `model_type` in config.json by
7
+ loading the model object directly and using `trust_remote_code=True` as a
8
+ fallback. Avoids `pipeline(... )` raising ValueError on AutoConfig.
9
+ - Correct handling of `pipeline(..., device=...)` which expects an int GPU
10
+ index or -1 for CPU (previously passed a string like "cpu").
11
+ - Streaming generation implemented by calling `model.generate(..., streamer=TextIteratorStreamer(...))`
12
+ in a background thread so the main thread can iterate over the streamer.
13
+ - Use standard HF env var `HF_TOKEN` and graceful error message if not set.
14
+ - Minor robustness improvements (resampling audio, handling mono/stereo, temp
15
+ filenames, etc.).
16
+
17
+ Drop this file into your Space and replace the old app.py contents.
18
  """
19
 
20
+ import os
21
+ import re
22
+ import threading
23
  import numpy as np
24
+ import gradio as gr
 
25
  import librosa
26
+ import torch
27
  from scipy.io.wavfile import write as write_wav
 
 
28
  from huggingface_hub import login
29
+ import onnxruntime
30
 
31
+ from transformers import (
32
+ AutoProcessor,
33
+ AutoModelForSpeechSeq2Seq,
34
+ AutoTokenizer,
35
+ AutoConfig,
36
+ AutoModelForCausalLM,
37
+ pipeline,
38
+ TextIteratorStreamer,
39
+ )
40
 
41
+ # -------------------- Configuration --------------------
42
  STT_MODEL_ID = "EYEDOL/SALAMA_C3"
43
  LLM_MODEL_ID = "EYEDOL/Llama-3.2-1B_ON_ALPACA5"
44
  TTS_TOKENIZER_ID = "facebook/mms-tts-swh"
 
47
  TEMP_DIR = "temp"
48
  os.makedirs(TEMP_DIR, exist_ok=True)
49
 
50
+ # Use the standard environment variable name used by Spaces
51
+ HF_TOKEN = os.environ.get("HF_TOKEN") or os.environ.get("hugface")
52
+ if not HF_TOKEN:
53
+ raise ValueError("HF_TOKEN not found. Please set it in Hugging Face Space repository secrets.")
54
+
55
+ # Attempt login to HF hub (Spaces typically already provides token, but this keeps parity)
56
+ try:
57
+ login(token=HF_TOKEN)
58
+ print("Successfully logged into Hugging Face Hub!")
59
+ except Exception as e:
60
+ print("Warning: could not call huggingface_hub.login(). Proceeding — ensure your token is valid in the environment. Error:", e)
61
+
62
 
63
  class WeeboAssistant:
64
  def __init__(self):
 
76
  self.torch_dtype = torch.bfloat16 if self.device == "cuda" else torch.float32
77
  print(f"Using device: {self.device}")
78
 
79
+ # ---------------- STT ----------------
80
  print(f"Loading STT model: {STT_MODEL_ID}")
81
  self.stt_processor = AutoProcessor.from_pretrained(STT_MODEL_ID)
82
+ # Speech seq2seq model (e.g. Whisper-like)
83
  self.stt_model = AutoModelForSpeechSeq2Seq.from_pretrained(
84
+ STT_MODEL_ID,
85
+ torch_dtype=self.torch_dtype,
86
+ low_cpu_mem_usage=True,
87
+ use_safetensors=True,
88
+ )
89
+ if self.device == "cuda":
90
+ try:
91
+ self.stt_model = self.stt_model.to("cuda")
92
+ except Exception:
93
+ pass
94
  print("STT model loaded successfully.")
95
 
96
+ # ---------------- LLM ----------------
97
  print(f"Loading LLM: {LLM_MODEL_ID}")
98
+ self.llm_tokenizer = AutoTokenizer.from_pretrained(LLM_MODEL_ID, use_fast=True)
99
+
100
+ # Attempt robust loading. If the repo lacks a model_type in config.json,
101
+ # try loading with trust_remote_code=True (this allows custom model code in repo).
102
+ try:
103
+ config = AutoConfig.from_pretrained(LLM_MODEL_ID)
104
+ # If config loaded but missing model_type, continue to try direct load
105
+ if not getattr(config, "model_type", None):
106
+ raise ValueError("config missing model_type - forcing trusted load")
107
+
108
+ # Try to load into a causal LM class (works for many standard model types)
109
+ self.llm_model = AutoModelForCausalLM.from_pretrained(
110
+ LLM_MODEL_ID,
111
+ config=config,
112
+ torch_dtype=self.torch_dtype,
113
+ low_cpu_mem_usage=True,
114
+ )
115
+ except Exception as first_err:
116
+ print("Standard AutoConfig/AutoModel load failed or model_type missing. Trying trust_remote_code=True. Error:", first_err)
117
+ # Try using trust_remote_code which will import repo-specific model code if present
118
+ try:
119
+ config = AutoConfig.from_pretrained(LLM_MODEL_ID, trust_remote_code=True)
120
+ self.llm_model = AutoModelForCausalLM.from_pretrained(
121
+ LLM_MODEL_ID,
122
+ config=config,
123
+ torch_dtype=self.torch_dtype,
124
+ trust_remote_code=True,
125
+ low_cpu_mem_usage=True,
126
+ device_map="auto" if torch.cuda.is_available() else None,
127
+ )
128
+ except Exception as second_err:
129
+ # Final fallback: try to load without special configs — may still fail for custom repos
130
+ print("Fallback load also failed:", second_err)
131
+ raise RuntimeError(
132
+ "Unable to load LLM model. Check the model repo, ensure config.json contains a model_type or that trust_remote_code is allowed."
133
+ )
134
+
135
+ # If device_map wasn't used and model is on CPU, ensure model is moved to CPU
136
+ if self.device == "cpu":
137
+ try:
138
+ # Many Hugging Face helpers use device_map; if not used, move model
139
+ self.llm_model = self.llm_model.to("cpu")
140
+ except Exception:
141
+ pass
142
+
143
+ # For convenience, create a pipeline for non-streaming quick calls (device expects int or -1)
144
+ device_index = 0 if torch.cuda.is_available() else -1
145
+ try:
146
+ self.llm_pipeline = pipeline(
147
+ "text-generation",
148
+ model=self.llm_model,
149
+ tokenizer=self.llm_tokenizer,
150
+ device=device_index,
151
+ model_kwargs={"torch_dtype": self.torch_dtype},
152
+ )
153
+ except Exception:
154
+ # pipeline is optional; if it fails we still support the streaming flow via model.generate
155
+ self.llm_pipeline = None
156
+
157
+ print("LLM loaded successfully.")
158
 
159
+ # ---------------- TTS ----------------
160
  print(f"Loading TTS model: {TTS_ONNX_MODEL_PATH}")
161
+ # ONNX runtime session; providers include CUDA if available
162
+ providers = ["CPUExecutionProvider"]
163
+ if torch.cuda.is_available():
164
+ providers = ["CUDAExecutionProvider", "CPUExecutionProvider"]
165
+ self.tts_session = onnxruntime.InferenceSession(TTS_ONNX_MODEL_PATH, providers=providers)
166
  self.tts_tokenizer = AutoTokenizer.from_pretrained(TTS_TOKENIZER_ID)
167
  print("TTS model and tokenizer loaded successfully.")
168
 
169
  print("-" * 30)
170
  print("All models initialized successfully! ✅")
171
 
172
+ # ---------------- Utility methods ----------------
173
  def transcribe_audio(self, audio_tuple):
174
+ """Take a Gradio audio tuple (sample_rate, np_audio) and return transcription string."""
175
  if audio_tuple is None:
176
  return ""
177
  sample_rate, audio_data = audio_tuple
178
+ # Convert to mono
179
  if audio_data.ndim > 1:
180
  audio_data = audio_data.mean(axis=1)
181
+ # Normalize to float32
182
  if audio_data.dtype != np.float32:
183
+ # handle common integer audio dtypes
184
+ if np.issubdtype(audio_data.dtype, np.integer):
185
+ max_val = np.iinfo(audio_data.dtype).max
186
+ audio_data = audio_data.astype(np.float32) / float(max_val)
187
+ else:
188
+ audio_data = audio_data.astype(np.float32)
189
+ # Resample if needed
190
  if sample_rate != self.STT_SAMPLE_RATE:
191
  audio_data = librosa.resample(y=audio_data, orig_sr=sample_rate, target_sr=self.STT_SAMPLE_RATE)
192
  if len(audio_data) < 1000:
193
  return "(Audio too short to transcribe)"
194
+
195
  inputs = self.stt_processor(audio_data, sampling_rate=self.STT_SAMPLE_RATE, return_tensors="pt")
196
+ inputs = {k: v.to(next(self.stt_model.parameters()).device) for k, v in inputs.items()}
197
  with torch.no_grad():
198
  generated_ids = self.stt_model.generate(**inputs, max_new_tokens=128)
199
  transcription = self.stt_processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
200
  return transcription.strip()
201
 
202
  def generate_speech(self, text):
203
+ """Synthesize speech using the ONNX TTS model and return a filepath to a WAV file."""
204
  if not text:
205
  return None
206
  text = text.strip()
207
+ # Tokenize with numpy arrays for ONNX
208
  inputs = self.tts_tokenizer(text, return_tensors="np")
209
+ input_name = self.tts_session.get_inputs()[0].name
210
+ ort_inputs = {input_name: inputs["input_ids"]}
211
  audio_waveform = self.tts_session.run(None, ort_inputs)[0].flatten()
212
+
213
+ # ONNX model might produce float audio in range [-1,1] or int16 depending on model. We'll safe-guard.
214
+ # Normalize to int16 WAV
215
+ if np.issubdtype(audio_waveform.dtype, np.floating):
216
+ # Clip and convert
217
+ audio_clip = np.clip(audio_waveform, -1.0, 1.0)
218
+ audio_int16 = (audio_clip * 32767).astype(np.int16)
219
+ else:
220
+ audio_int16 = audio_waveform.astype(np.int16)
221
+
222
  output_path = os.path.join(TEMP_DIR, f"{os.urandom(8).hex()}.wav")
223
+ write_wav(output_path, self.TTS_SAMPLE_RATE, audio_int16)
224
  return output_path
225
 
226
  def get_llm_response(self, chat_history):
227
+ """Return a TextIteratorStreamer that yields generated text pieces as the model produces them.
228
+
229
+ This implementation uses self.llm_model.generate(...) with a TextIteratorStreamer and
230
+ runs generate in a background thread so the caller can iterate over streamer.
231
+ """
232
+ # Build prompt from system + conversation. Adjust this template to match your LLM's preferred format.
233
+ prompt_lines = [self.SYSTEM_PROMPT.strip(), "\n"]
234
  for user_msg, assistant_msg in chat_history:
235
+ if user_msg:
236
+ # tag user messages clearly so model understands dialogue turns
237
+ prompt_lines.append("User: " + user_msg)
238
  if assistant_msg:
239
+ prompt_lines.append("Assistant: " + assistant_msg)
240
+ prompt_lines.append("Assistant: ")
241
+ prompt = "\n".join(prompt_lines)
242
+
243
+ # Tokenize and prepare inputs on the same device as the model
244
+ inputs = self.llm_tokenizer(prompt, return_tensors="pt")
245
+ try:
246
+ model_device = next(self.llm_model.parameters()).device
247
+ except StopIteration:
248
+ model_device = torch.device("cpu")
249
+ inputs = {k: v.to(model_device) for k, v in inputs.items()}
250
+
251
+ streamer = TextIteratorStreamer(self.llm_tokenizer, skip_prompt=True, skip_special_tokens=True)
252
+
253
  generation_kwargs = dict(
254
+ input_ids=inputs["input_ids"],
255
+ attention_mask=inputs.get("attention_mask", None),
256
  max_new_tokens=512,
 
257
  do_sample=True,
258
  temperature=0.6,
259
  top_p=0.9,
260
+ streamer=streamer,
261
+ eos_token_id=getattr(self.llm_tokenizer, "eos_token_id", None),
262
  )
263
+
264
+ # Launch generation in a thread so we can yield from the streamer in the main thread
265
+ gen_thread = threading.Thread(target=self.llm_model.generate, kwargs=generation_kwargs, daemon=True)
266
+ gen_thread.start()
267
+
268
  return streamer
269
 
270
+
271
+ # -------------------- Create assistant instance --------------------
272
  assistant = WeeboAssistant()
273
 
274
 
275
+ # -------------------- Gradio pipelines --------------------
276
+
277
  def s2s_pipeline(audio_input, chat_history):
278
+ # `chat_history` is expected to be a list of (user_text, assistant_text) tuples
279
  user_text = assistant.transcribe_audio(audio_input)
280
  if not user_text or user_text.startswith("("):
281
  chat_history.append((user_text or "(No valid speech detected)", None))
282
  yield chat_history, None, "Please record your voice again."
283
  return
284
+
285
  chat_history.append((user_text, ""))
286
  yield chat_history, None, "..."
287
+
288
  response_stream = assistant.get_llm_response(chat_history)
289
  llm_response_text = ""
290
  for text_chunk in response_stream:
291
  llm_response_text += text_chunk
292
+ # Update last turn in chat history
293
  chat_history[-1] = (user_text, llm_response_text)
294
  yield chat_history, None, llm_response_text
295
+
296
+ # Once finished, synthesize audio
297
  final_audio_path = assistant.generate_speech(llm_response_text)
298
  yield chat_history, final_audio_path, llm_response_text
299
 
 
301
  def t2t_pipeline(text_input, chat_history):
302
  chat_history.append((text_input, ""))
303
  yield chat_history
304
+
305
  response_stream = assistant.get_llm_response(chat_history)
306
  llm_response_text = ""
307
  for text_chunk in response_stream:
 
314
  return gr.Textbox(value="")
315
 
316
 
317
+ # -------------------- Gradio UI --------------------
318
  with gr.Blocks(theme=gr.themes.Soft(), title="Msaidizi wa Kiswahili") as demo:
319
  gr.Markdown("# 🤖 Msaidizi wa Sauti wa Kiswahili (Swahili Voice Assistant)")
320
  gr.Markdown("Ongea na msaidizi kwa Kiswahili. Toa sauti, andika maandishi, na upate majibu kwa sauti au maandishi.")
321
+
322
  with gr.Tabs():
323
  with gr.TabItem("🎙️ Sauti-kwa-Sauti (Speech-to-Speech)"):
324
  with gr.Row():
 
329
  s2s_chatbot = gr.Chatbot(label="Mazungumzo (Conversation)", bubble_full_width=False, height=400)
330
  s2s_audio_out = gr.Audio(type="filepath", label="Jibu la Sauti (Audio Response)", autoplay=True)
331
  s2s_text_out = gr.Textbox(label="Jibu la Maandishi (Text Response)", interactive=False)
332
+
333
  with gr.TabItem("⌨️ Maandishi-kwa-Maandishi (Text-to-Text)"):
334
  t2t_chatbot = gr.Chatbot(label="Mazungumzo (Conversation)", bubble_full_width=False, height=500)
335
  with gr.Row():
 
353
  fn=s2s_pipeline,
354
  inputs=[s2s_audio_in, s2s_chatbot],
355
  outputs=[s2s_chatbot, s2s_audio_out, s2s_text_out],
356
+ queue=True,
357
  ).then(
358
  fn=lambda: gr.Audio(value=None),
359
  inputs=None,
360
+ outputs=s2s_audio_in,
361
  )
362
 
363
  t2t_submit_btn.click(
364
  fn=t2t_pipeline,
365
  inputs=[t2t_text_in, t2t_chatbot],
366
  outputs=[t2t_chatbot],
367
+ queue=True,
368
  ).then(
369
  fn=clear_textbox,
370
  inputs=None,
371
+ outputs=t2t_text_in,
372
  )
373
+
374
  t2t_text_in.submit(
375
  fn=t2t_pipeline,
376
  inputs=[t2t_text_in, t2t_chatbot],
377
  outputs=[t2t_chatbot],
378
+ queue=True,
379
  ).then(
380
  fn=clear_textbox,
381
  inputs=None,
382
+ outputs=t2t_text_in,
383
  )
384
 
385
  tool_s2t_btn.click(
386
  fn=assistant.transcribe_audio,
387
  inputs=tool_s2t_audio_in,
388
  outputs=tool_s2t_text_out,
389
+ queue=True,
390
  )
391
+
392
  tool_t2s_btn.click(
393
  fn=assistant.generate_speech,
394
  inputs=tool_t2s_text_in,
395
  outputs=tool_t2s_audio_out,
396
+ queue=True,
397
  )
398
 
399
+
400
+ demo.queue().launch(debug=True)