EYEDOL commited on
Commit
6fefd54
·
verified ·
1 Parent(s): 8454ce0

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +59 -21
app.py CHANGED
@@ -9,15 +9,16 @@ import numpy as np
9
  import onnxruntime
10
  import torch
11
  import librosa
12
- from transformers import AutoProcessor, AutoModelForSpeechSeq2Seq, AutoTokenizer, pipeline
13
  from scipy.io.wavfile import write as write_wav
14
  import os
15
  import re
16
  from huggingface_hub import login
 
17
 
18
  # --- Login to Hugging Face using secret ---
19
  # Make sure HF_TOKEN is set in your Hugging Face Space > Settings > Repository secrets
20
- hf_token = os.environ.get("hugface")
21
  if not hf_token:
22
  raise ValueError("HF_TOKEN not found. Please set it in Hugging Face Space repository secrets.")
23
  login(token=hf_token)
@@ -25,7 +26,7 @@ print("Successfully logged into Hugging Face Hub!")
25
 
26
  # --- Configuration ---
27
  STT_MODEL_ID = "EYEDOL/SALAMA_C3"
28
- LLM_MODEL_ID = "google/gemma-3-1b-it"
29
  TTS_TOKENIZER_ID = "facebook/mms-tts-swh"
30
  TTS_ONNX_MODEL_PATH = "swahili_tts.onnx"
31
 
@@ -62,10 +63,13 @@ class WeeboAssistant:
62
 
63
  # LLM
64
  print(f"Loading LLM: {LLM_MODEL_ID}")
 
 
65
  self.llm_pipeline = pipeline(
66
  "text-generation",
67
  model=LLM_MODEL_ID,
68
  model_kwargs={"torch_dtype": self.torch_dtype},
 
69
  device=self.device,
70
  )
71
  print("LLM pipeline loaded successfully.")
@@ -118,6 +122,7 @@ class WeeboAssistant:
118
  messages.append({'role': 'user', 'content': turn[0]})
119
  if turn[1] is not None:
120
  messages.append({'role': 'assistant', 'content': turn[1]})
 
121
  prompt = self.llm_pipeline.tokenizer.apply_chat_template(
122
  messages, tokenize=False, add_generation_prompt=True
123
  )
@@ -125,17 +130,27 @@ class WeeboAssistant:
125
  self.llm_pipeline.tokenizer.eos_token_id,
126
  self.llm_pipeline.tokenizer.convert_tokens_to_ids("<|eot_id|>")
127
  ]
128
- streamer = self.llm_pipeline(
129
- prompt,
 
 
 
 
 
 
130
  max_new_tokens=512,
131
  eos_token_id=terminators,
132
  do_sample=True,
133
  temperature=0.6,
134
  top_p=0.9,
135
- streamer=gr.TextIterator(),
136
  )
 
 
 
 
 
137
  return streamer
138
-
139
 
140
  assistant = WeeboAssistant()
141
 
@@ -146,31 +161,35 @@ def s2s_pipeline(audio_input, chat_history):
146
  chat_history.append((user_text or "(No valid speech detected)", None))
147
  yield chat_history, None, "Please record your voice again."
148
  return
149
- chat_history.append((user_text, None))
150
- yield chat_history, None, "..."
 
 
151
  response_stream = assistant.get_llm_response(chat_history)
152
  llm_response_text = ""
153
  for text_chunk in response_stream:
154
- llm_response_text = text_chunk
155
  chat_history[-1] = (user_text, llm_response_text)
156
  yield chat_history, None, llm_response_text
 
157
  final_audio_path = assistant.generate_speech(llm_response_text)
158
  yield chat_history, final_audio_path, llm_response_text
159
 
160
 
161
  def t2t_pipeline(text_input, chat_history):
162
- chat_history.append((text_input, None))
163
- yield chat_history, "..."
 
164
  response_stream = assistant.get_llm_response(chat_history)
165
  llm_response_text = ""
166
  for text_chunk in response_stream:
167
- llm_response_text = text_chunk
168
  chat_history[-1] = (text_input, llm_response_text)
169
- yield chat_history, llm_response_text
170
 
171
 
172
  def clear_textbox():
173
- return ""
174
 
175
 
176
  with gr.Blocks(theme=gr.themes.Soft(), title="Msaidizi wa Kiswahili") as demo:
@@ -191,14 +210,14 @@ with gr.Blocks(theme=gr.themes.Soft(), title="Msaidizi wa Kiswahili") as demo:
191
  with gr.TabItem("⌨️ Maandishi-kwa-Maandishi (Text-to-Text)"):
192
  t2t_chatbot = gr.Chatbot(label="Mazungumzo (Conversation)", bubble_full_width=False, height=500)
193
  with gr.Row():
194
- t2t_text_in = gr.Textbox(label="Andika Hapa (Write Here)", placeholder="Habari yako...", scale=4)
195
  t2t_submit_btn = gr.Button("Tuma (Submit)", variant="primary", scale=1)
196
 
197
  with gr.TabItem("🛠️ Zana (Tools)"):
198
  with gr.Row():
199
  with gr.Column():
200
  gr.Markdown("### Unukuzi wa Sauti (Speech Transcription)")
201
- tool_s2t_audio_in = gr.Audio(sources=["microphone"], type="numpy", label="Sauti ya Kuingiza (Input Audio)")
202
  tool_s2t_text_out = gr.Textbox(label="Maandishi Yaliyonukuliwa (Transcribed Text)", interactive=False)
203
  tool_s2t_btn = gr.Button("Nukuu (Transcribe)")
204
  with gr.Column():
@@ -212,12 +231,28 @@ with gr.Blocks(theme=gr.themes.Soft(), title="Msaidizi wa Kiswahili") as demo:
212
  inputs=[s2s_audio_in, s2s_chatbot],
213
  outputs=[s2s_chatbot, s2s_audio_out, s2s_text_out],
214
  queue=True
 
 
 
 
215
  )
216
 
217
  t2t_submit_btn.click(
218
  fn=t2t_pipeline,
219
  inputs=[t2t_text_in, t2t_chatbot],
220
- outputs=[t2t_chatbot, t2t_text_in],
 
 
 
 
 
 
 
 
 
 
 
 
221
  queue=True
222
  ).then(
223
  fn=clear_textbox,
@@ -225,15 +260,18 @@ with gr.Blocks(theme=gr.themes.Soft(), title="Msaidizi wa Kiswahili") as demo:
225
  outputs=t2t_text_in
226
  )
227
 
 
228
  tool_s2t_btn.click(
229
  fn=assistant.transcribe_audio,
230
  inputs=tool_s2t_audio_in,
231
- outputs=tool_s2t_text_out
 
232
  )
233
  tool_t2s_btn.click(
234
  fn=assistant.generate_speech,
235
  inputs=tool_t2s_text_in,
236
- outputs=tool_t2s_audio_out
 
237
  )
238
 
239
- demo.queue().launch(debug=True)
 
9
  import onnxruntime
10
  import torch
11
  import librosa
12
+ from transformers import AutoProcessor, AutoModelForSpeechSeq2Seq, AutoTokenizer, pipeline, TextIteratorStreamer
13
  from scipy.io.wavfile import write as write_wav
14
  import os
15
  import re
16
  from huggingface_hub import login
17
+ import threading # <-- FIX: Added threading import
18
 
19
  # --- Login to Hugging Face using secret ---
20
  # Make sure HF_TOKEN is set in your Hugging Face Space > Settings > Repository secrets
21
+ hf_token = os.environ.get("hugface") #
22
  if not hf_token:
23
  raise ValueError("HF_TOKEN not found. Please set it in Hugging Face Space repository secrets.")
24
  login(token=hf_token)
 
26
 
27
  # --- Configuration ---
28
  STT_MODEL_ID = "EYEDOL/SALAMA_C3"
29
+ LLM_MODEL_ID = "google/gemma-1.1-2b-it"
30
  TTS_TOKENIZER_ID = "facebook/mms-tts-swh"
31
  TTS_ONNX_MODEL_PATH = "swahili_tts.onnx"
32
 
 
63
 
64
  # LLM
65
  print(f"Loading LLM: {LLM_MODEL_ID}")
66
+ # <-- FIX: Initialize tokenizer separately to use it with the streamer
67
+ self.llm_tokenizer = AutoTokenizer.from_pretrained(LLM_MODEL_ID)
68
  self.llm_pipeline = pipeline(
69
  "text-generation",
70
  model=LLM_MODEL_ID,
71
  model_kwargs={"torch_dtype": self.torch_dtype},
72
+ tokenizer=self.llm_tokenizer, # Pass the tokenizer here
73
  device=self.device,
74
  )
75
  print("LLM pipeline loaded successfully.")
 
122
  messages.append({'role': 'user', 'content': turn[0]})
123
  if turn[1] is not None:
124
  messages.append({'role': 'assistant', 'content': turn[1]})
125
+
126
  prompt = self.llm_pipeline.tokenizer.apply_chat_template(
127
  messages, tokenize=False, add_generation_prompt=True
128
  )
 
130
  self.llm_pipeline.tokenizer.eos_token_id,
131
  self.llm_pipeline.tokenizer.convert_tokens_to_ids("<|eot_id|>")
132
  ]
133
+
134
+ # <-- START OF FIX: Use TextIteratorStreamer instead of gr.TextIterator -->
135
+ streamer = TextIteratorStreamer(
136
+ self.llm_pipeline.tokenizer, skip_prompt=True, skip_special_tokens=True
137
+ )
138
+
139
+ generation_kwargs = dict(
140
+ streamer=streamer,
141
  max_new_tokens=512,
142
  eos_token_id=terminators,
143
  do_sample=True,
144
  temperature=0.6,
145
  top_p=0.9,
 
146
  )
147
+
148
+ # Run the pipeline in a separate thread to enable streaming
149
+ thread = threading.Thread(target=self.llm_pipeline, args=[prompt], kwargs=generation_kwargs)
150
+ thread.start()
151
+
152
  return streamer
153
+ # <-- END OF FIX -->
154
 
155
  assistant = WeeboAssistant()
156
 
 
161
  chat_history.append((user_text or "(No valid speech detected)", None))
162
  yield chat_history, None, "Please record your voice again."
163
  return
164
+
165
+ chat_history.append((user_text, ""))
166
+ yield chat_history, None, "..." # Show thinking indicator
167
+
168
  response_stream = assistant.get_llm_response(chat_history)
169
  llm_response_text = ""
170
  for text_chunk in response_stream:
171
+ llm_response_text += text_chunk # <-- FIX: Append chunk to full response
172
  chat_history[-1] = (user_text, llm_response_text)
173
  yield chat_history, None, llm_response_text
174
+
175
  final_audio_path = assistant.generate_speech(llm_response_text)
176
  yield chat_history, final_audio_path, llm_response_text
177
 
178
 
179
  def t2t_pipeline(text_input, chat_history):
180
+ chat_history.append((text_input, ""))
181
+ yield chat_history
182
+
183
  response_stream = assistant.get_llm_response(chat_history)
184
  llm_response_text = ""
185
  for text_chunk in response_stream:
186
+ llm_response_text += text_chunk # <-- FIX: Append chunk to full response
187
  chat_history[-1] = (text_input, llm_response_text)
188
+ yield chat_history
189
 
190
 
191
  def clear_textbox():
192
+ return gr.Textbox(value="")
193
 
194
 
195
  with gr.Blocks(theme=gr.themes.Soft(), title="Msaidizi wa Kiswahili") as demo:
 
210
  with gr.TabItem("⌨️ Maandishi-kwa-Maandishi (Text-to-Text)"):
211
  t2t_chatbot = gr.Chatbot(label="Mazungumzo (Conversation)", bubble_full_width=False, height=500)
212
  with gr.Row():
213
+ t2t_text_in = gr.Textbox(show_label=False, placeholder="Habari yako...", scale=4, container=False)
214
  t2t_submit_btn = gr.Button("Tuma (Submit)", variant="primary", scale=1)
215
 
216
  with gr.TabItem("🛠️ Zana (Tools)"):
217
  with gr.Row():
218
  with gr.Column():
219
  gr.Markdown("### Unukuzi wa Sauti (Speech Transcription)")
220
+ tool_s2t_audio_in = gr.Audio(sources=["microphone", "upload"], type="numpy", label="Sauti ya Kuingiza (Input Audio)")
221
  tool_s2t_text_out = gr.Textbox(label="Maandishi Yaliyonukuliwa (Transcribed Text)", interactive=False)
222
  tool_s2t_btn = gr.Button("Nukuu (Transcribe)")
223
  with gr.Column():
 
231
  inputs=[s2s_audio_in, s2s_chatbot],
232
  outputs=[s2s_chatbot, s2s_audio_out, s2s_text_out],
233
  queue=True
234
+ ).then(
235
+ fn=lambda: gr.Audio(value=None), # Clear audio input after submit
236
+ inputs=None,
237
+ outputs=s2s_audio_in
238
  )
239
 
240
  t2t_submit_btn.click(
241
  fn=t2t_pipeline,
242
  inputs=[t2t_text_in, t2t_chatbot],
243
+ outputs=[t2t_chatbot], # <-- FIX: Only output to the chatbot
244
+ queue=True
245
+ ).then(
246
+ fn=clear_textbox,
247
+ inputs=None,
248
+ outputs=t2t_text_in
249
+ )
250
+
251
+ # Also allow Enter key to submit text
252
+ t2t_text_in.submit(
253
+ fn=t2t_pipeline,
254
+ inputs=[t2t_text_in, t2t_chatbot],
255
+ outputs=[t2t_chatbot],
256
  queue=True
257
  ).then(
258
  fn=clear_textbox,
 
260
  outputs=t2t_text_in
261
  )
262
 
263
+
264
  tool_s2t_btn.click(
265
  fn=assistant.transcribe_audio,
266
  inputs=tool_s2t_audio_in,
267
+ outputs=tool_s2t_text_out,
268
+ queue=True
269
  )
270
  tool_t2s_btn.click(
271
  fn=assistant.generate_speech,
272
  inputs=tool_t2s_text_in,
273
+ outputs=tool_t2s_audio_out,
274
+ queue=True
275
  )
276
 
277
+ demo.queue().launch(debug=True)