Update app.py
Browse files
app.py
CHANGED
@@ -9,15 +9,16 @@ import numpy as np
|
|
9 |
import onnxruntime
|
10 |
import torch
|
11 |
import librosa
|
12 |
-
from transformers import AutoProcessor, AutoModelForSpeechSeq2Seq, AutoTokenizer, pipeline
|
13 |
from scipy.io.wavfile import write as write_wav
|
14 |
import os
|
15 |
import re
|
16 |
from huggingface_hub import login
|
|
|
17 |
|
18 |
# --- Login to Hugging Face using secret ---
|
19 |
# Make sure HF_TOKEN is set in your Hugging Face Space > Settings > Repository secrets
|
20 |
-
hf_token = os.environ.get("hugface")
|
21 |
if not hf_token:
|
22 |
raise ValueError("HF_TOKEN not found. Please set it in Hugging Face Space repository secrets.")
|
23 |
login(token=hf_token)
|
@@ -25,7 +26,7 @@ print("Successfully logged into Hugging Face Hub!")
|
|
25 |
|
26 |
# --- Configuration ---
|
27 |
STT_MODEL_ID = "EYEDOL/SALAMA_C3"
|
28 |
-
LLM_MODEL_ID = "google/gemma-
|
29 |
TTS_TOKENIZER_ID = "facebook/mms-tts-swh"
|
30 |
TTS_ONNX_MODEL_PATH = "swahili_tts.onnx"
|
31 |
|
@@ -62,10 +63,13 @@ class WeeboAssistant:
|
|
62 |
|
63 |
# LLM
|
64 |
print(f"Loading LLM: {LLM_MODEL_ID}")
|
|
|
|
|
65 |
self.llm_pipeline = pipeline(
|
66 |
"text-generation",
|
67 |
model=LLM_MODEL_ID,
|
68 |
model_kwargs={"torch_dtype": self.torch_dtype},
|
|
|
69 |
device=self.device,
|
70 |
)
|
71 |
print("LLM pipeline loaded successfully.")
|
@@ -118,6 +122,7 @@ class WeeboAssistant:
|
|
118 |
messages.append({'role': 'user', 'content': turn[0]})
|
119 |
if turn[1] is not None:
|
120 |
messages.append({'role': 'assistant', 'content': turn[1]})
|
|
|
121 |
prompt = self.llm_pipeline.tokenizer.apply_chat_template(
|
122 |
messages, tokenize=False, add_generation_prompt=True
|
123 |
)
|
@@ -125,17 +130,27 @@ class WeeboAssistant:
|
|
125 |
self.llm_pipeline.tokenizer.eos_token_id,
|
126 |
self.llm_pipeline.tokenizer.convert_tokens_to_ids("<|eot_id|>")
|
127 |
]
|
128 |
-
|
129 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
130 |
max_new_tokens=512,
|
131 |
eos_token_id=terminators,
|
132 |
do_sample=True,
|
133 |
temperature=0.6,
|
134 |
top_p=0.9,
|
135 |
-
streamer=gr.TextIterator(),
|
136 |
)
|
|
|
|
|
|
|
|
|
|
|
137 |
return streamer
|
138 |
-
|
139 |
|
140 |
assistant = WeeboAssistant()
|
141 |
|
@@ -146,31 +161,35 @@ def s2s_pipeline(audio_input, chat_history):
|
|
146 |
chat_history.append((user_text or "(No valid speech detected)", None))
|
147 |
yield chat_history, None, "Please record your voice again."
|
148 |
return
|
149 |
-
|
150 |
-
|
|
|
|
|
151 |
response_stream = assistant.get_llm_response(chat_history)
|
152 |
llm_response_text = ""
|
153 |
for text_chunk in response_stream:
|
154 |
-
llm_response_text
|
155 |
chat_history[-1] = (user_text, llm_response_text)
|
156 |
yield chat_history, None, llm_response_text
|
|
|
157 |
final_audio_path = assistant.generate_speech(llm_response_text)
|
158 |
yield chat_history, final_audio_path, llm_response_text
|
159 |
|
160 |
|
161 |
def t2t_pipeline(text_input, chat_history):
|
162 |
-
chat_history.append((text_input,
|
163 |
-
yield chat_history
|
|
|
164 |
response_stream = assistant.get_llm_response(chat_history)
|
165 |
llm_response_text = ""
|
166 |
for text_chunk in response_stream:
|
167 |
-
llm_response_text
|
168 |
chat_history[-1] = (text_input, llm_response_text)
|
169 |
-
yield chat_history
|
170 |
|
171 |
|
172 |
def clear_textbox():
|
173 |
-
return ""
|
174 |
|
175 |
|
176 |
with gr.Blocks(theme=gr.themes.Soft(), title="Msaidizi wa Kiswahili") as demo:
|
@@ -191,14 +210,14 @@ with gr.Blocks(theme=gr.themes.Soft(), title="Msaidizi wa Kiswahili") as demo:
|
|
191 |
with gr.TabItem("⌨️ Maandishi-kwa-Maandishi (Text-to-Text)"):
|
192 |
t2t_chatbot = gr.Chatbot(label="Mazungumzo (Conversation)", bubble_full_width=False, height=500)
|
193 |
with gr.Row():
|
194 |
-
t2t_text_in = gr.Textbox(
|
195 |
t2t_submit_btn = gr.Button("Tuma (Submit)", variant="primary", scale=1)
|
196 |
|
197 |
with gr.TabItem("🛠️ Zana (Tools)"):
|
198 |
with gr.Row():
|
199 |
with gr.Column():
|
200 |
gr.Markdown("### Unukuzi wa Sauti (Speech Transcription)")
|
201 |
-
tool_s2t_audio_in = gr.Audio(sources=["microphone"], type="numpy", label="Sauti ya Kuingiza (Input Audio)")
|
202 |
tool_s2t_text_out = gr.Textbox(label="Maandishi Yaliyonukuliwa (Transcribed Text)", interactive=False)
|
203 |
tool_s2t_btn = gr.Button("Nukuu (Transcribe)")
|
204 |
with gr.Column():
|
@@ -212,12 +231,28 @@ with gr.Blocks(theme=gr.themes.Soft(), title="Msaidizi wa Kiswahili") as demo:
|
|
212 |
inputs=[s2s_audio_in, s2s_chatbot],
|
213 |
outputs=[s2s_chatbot, s2s_audio_out, s2s_text_out],
|
214 |
queue=True
|
|
|
|
|
|
|
|
|
215 |
)
|
216 |
|
217 |
t2t_submit_btn.click(
|
218 |
fn=t2t_pipeline,
|
219 |
inputs=[t2t_text_in, t2t_chatbot],
|
220 |
-
outputs=[t2t_chatbot,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
221 |
queue=True
|
222 |
).then(
|
223 |
fn=clear_textbox,
|
@@ -225,15 +260,18 @@ with gr.Blocks(theme=gr.themes.Soft(), title="Msaidizi wa Kiswahili") as demo:
|
|
225 |
outputs=t2t_text_in
|
226 |
)
|
227 |
|
|
|
228 |
tool_s2t_btn.click(
|
229 |
fn=assistant.transcribe_audio,
|
230 |
inputs=tool_s2t_audio_in,
|
231 |
-
outputs=tool_s2t_text_out
|
|
|
232 |
)
|
233 |
tool_t2s_btn.click(
|
234 |
fn=assistant.generate_speech,
|
235 |
inputs=tool_t2s_text_in,
|
236 |
-
outputs=tool_t2s_audio_out
|
|
|
237 |
)
|
238 |
|
239 |
-
demo.queue().launch(debug=True)
|
|
|
9 |
import onnxruntime
|
10 |
import torch
|
11 |
import librosa
|
12 |
+
from transformers import AutoProcessor, AutoModelForSpeechSeq2Seq, AutoTokenizer, pipeline, TextIteratorStreamer
|
13 |
from scipy.io.wavfile import write as write_wav
|
14 |
import os
|
15 |
import re
|
16 |
from huggingface_hub import login
|
17 |
+
import threading # <-- FIX: Added threading import
|
18 |
|
19 |
# --- Login to Hugging Face using secret ---
|
20 |
# Make sure HF_TOKEN is set in your Hugging Face Space > Settings > Repository secrets
|
21 |
+
hf_token = os.environ.get("hugface") #
|
22 |
if not hf_token:
|
23 |
raise ValueError("HF_TOKEN not found. Please set it in Hugging Face Space repository secrets.")
|
24 |
login(token=hf_token)
|
|
|
26 |
|
27 |
# --- Configuration ---
|
28 |
STT_MODEL_ID = "EYEDOL/SALAMA_C3"
|
29 |
+
LLM_MODEL_ID = "google/gemma-1.1-2b-it"
|
30 |
TTS_TOKENIZER_ID = "facebook/mms-tts-swh"
|
31 |
TTS_ONNX_MODEL_PATH = "swahili_tts.onnx"
|
32 |
|
|
|
63 |
|
64 |
# LLM
|
65 |
print(f"Loading LLM: {LLM_MODEL_ID}")
|
66 |
+
# <-- FIX: Initialize tokenizer separately to use it with the streamer
|
67 |
+
self.llm_tokenizer = AutoTokenizer.from_pretrained(LLM_MODEL_ID)
|
68 |
self.llm_pipeline = pipeline(
|
69 |
"text-generation",
|
70 |
model=LLM_MODEL_ID,
|
71 |
model_kwargs={"torch_dtype": self.torch_dtype},
|
72 |
+
tokenizer=self.llm_tokenizer, # Pass the tokenizer here
|
73 |
device=self.device,
|
74 |
)
|
75 |
print("LLM pipeline loaded successfully.")
|
|
|
122 |
messages.append({'role': 'user', 'content': turn[0]})
|
123 |
if turn[1] is not None:
|
124 |
messages.append({'role': 'assistant', 'content': turn[1]})
|
125 |
+
|
126 |
prompt = self.llm_pipeline.tokenizer.apply_chat_template(
|
127 |
messages, tokenize=False, add_generation_prompt=True
|
128 |
)
|
|
|
130 |
self.llm_pipeline.tokenizer.eos_token_id,
|
131 |
self.llm_pipeline.tokenizer.convert_tokens_to_ids("<|eot_id|>")
|
132 |
]
|
133 |
+
|
134 |
+
# <-- START OF FIX: Use TextIteratorStreamer instead of gr.TextIterator -->
|
135 |
+
streamer = TextIteratorStreamer(
|
136 |
+
self.llm_pipeline.tokenizer, skip_prompt=True, skip_special_tokens=True
|
137 |
+
)
|
138 |
+
|
139 |
+
generation_kwargs = dict(
|
140 |
+
streamer=streamer,
|
141 |
max_new_tokens=512,
|
142 |
eos_token_id=terminators,
|
143 |
do_sample=True,
|
144 |
temperature=0.6,
|
145 |
top_p=0.9,
|
|
|
146 |
)
|
147 |
+
|
148 |
+
# Run the pipeline in a separate thread to enable streaming
|
149 |
+
thread = threading.Thread(target=self.llm_pipeline, args=[prompt], kwargs=generation_kwargs)
|
150 |
+
thread.start()
|
151 |
+
|
152 |
return streamer
|
153 |
+
# <-- END OF FIX -->
|
154 |
|
155 |
assistant = WeeboAssistant()
|
156 |
|
|
|
161 |
chat_history.append((user_text or "(No valid speech detected)", None))
|
162 |
yield chat_history, None, "Please record your voice again."
|
163 |
return
|
164 |
+
|
165 |
+
chat_history.append((user_text, ""))
|
166 |
+
yield chat_history, None, "..." # Show thinking indicator
|
167 |
+
|
168 |
response_stream = assistant.get_llm_response(chat_history)
|
169 |
llm_response_text = ""
|
170 |
for text_chunk in response_stream:
|
171 |
+
llm_response_text += text_chunk # <-- FIX: Append chunk to full response
|
172 |
chat_history[-1] = (user_text, llm_response_text)
|
173 |
yield chat_history, None, llm_response_text
|
174 |
+
|
175 |
final_audio_path = assistant.generate_speech(llm_response_text)
|
176 |
yield chat_history, final_audio_path, llm_response_text
|
177 |
|
178 |
|
179 |
def t2t_pipeline(text_input, chat_history):
|
180 |
+
chat_history.append((text_input, ""))
|
181 |
+
yield chat_history
|
182 |
+
|
183 |
response_stream = assistant.get_llm_response(chat_history)
|
184 |
llm_response_text = ""
|
185 |
for text_chunk in response_stream:
|
186 |
+
llm_response_text += text_chunk # <-- FIX: Append chunk to full response
|
187 |
chat_history[-1] = (text_input, llm_response_text)
|
188 |
+
yield chat_history
|
189 |
|
190 |
|
191 |
def clear_textbox():
|
192 |
+
return gr.Textbox(value="")
|
193 |
|
194 |
|
195 |
with gr.Blocks(theme=gr.themes.Soft(), title="Msaidizi wa Kiswahili") as demo:
|
|
|
210 |
with gr.TabItem("⌨️ Maandishi-kwa-Maandishi (Text-to-Text)"):
|
211 |
t2t_chatbot = gr.Chatbot(label="Mazungumzo (Conversation)", bubble_full_width=False, height=500)
|
212 |
with gr.Row():
|
213 |
+
t2t_text_in = gr.Textbox(show_label=False, placeholder="Habari yako...", scale=4, container=False)
|
214 |
t2t_submit_btn = gr.Button("Tuma (Submit)", variant="primary", scale=1)
|
215 |
|
216 |
with gr.TabItem("🛠️ Zana (Tools)"):
|
217 |
with gr.Row():
|
218 |
with gr.Column():
|
219 |
gr.Markdown("### Unukuzi wa Sauti (Speech Transcription)")
|
220 |
+
tool_s2t_audio_in = gr.Audio(sources=["microphone", "upload"], type="numpy", label="Sauti ya Kuingiza (Input Audio)")
|
221 |
tool_s2t_text_out = gr.Textbox(label="Maandishi Yaliyonukuliwa (Transcribed Text)", interactive=False)
|
222 |
tool_s2t_btn = gr.Button("Nukuu (Transcribe)")
|
223 |
with gr.Column():
|
|
|
231 |
inputs=[s2s_audio_in, s2s_chatbot],
|
232 |
outputs=[s2s_chatbot, s2s_audio_out, s2s_text_out],
|
233 |
queue=True
|
234 |
+
).then(
|
235 |
+
fn=lambda: gr.Audio(value=None), # Clear audio input after submit
|
236 |
+
inputs=None,
|
237 |
+
outputs=s2s_audio_in
|
238 |
)
|
239 |
|
240 |
t2t_submit_btn.click(
|
241 |
fn=t2t_pipeline,
|
242 |
inputs=[t2t_text_in, t2t_chatbot],
|
243 |
+
outputs=[t2t_chatbot], # <-- FIX: Only output to the chatbot
|
244 |
+
queue=True
|
245 |
+
).then(
|
246 |
+
fn=clear_textbox,
|
247 |
+
inputs=None,
|
248 |
+
outputs=t2t_text_in
|
249 |
+
)
|
250 |
+
|
251 |
+
# Also allow Enter key to submit text
|
252 |
+
t2t_text_in.submit(
|
253 |
+
fn=t2t_pipeline,
|
254 |
+
inputs=[t2t_text_in, t2t_chatbot],
|
255 |
+
outputs=[t2t_chatbot],
|
256 |
queue=True
|
257 |
).then(
|
258 |
fn=clear_textbox,
|
|
|
260 |
outputs=t2t_text_in
|
261 |
)
|
262 |
|
263 |
+
|
264 |
tool_s2t_btn.click(
|
265 |
fn=assistant.transcribe_audio,
|
266 |
inputs=tool_s2t_audio_in,
|
267 |
+
outputs=tool_s2t_text_out,
|
268 |
+
queue=True
|
269 |
)
|
270 |
tool_t2s_btn.click(
|
271 |
fn=assistant.generate_speech,
|
272 |
inputs=tool_t2s_text_in,
|
273 |
+
outputs=tool_t2s_audio_out,
|
274 |
+
queue=True
|
275 |
)
|
276 |
|
277 |
+
demo.queue().launch(debug=True)
|