EYEDOL commited on
Commit
a248e18
·
verified ·
1 Parent(s): da6f089

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +313 -0
app.py ADDED
@@ -0,0 +1,313 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ """
3
+ This script implements a multi-modal Swahili assistant for Hugging Face Spaces.
4
+ It uses Gradio for the user interface and loads models from the HF Hub.
5
+ """
6
+
7
+ import gradio as gr
8
+ import numpy as np
9
+ import onnxruntime
10
+ import torch
11
+ import librosa
12
+ from transformers import AutoProcessor, AutoModelForSpeechSeq2Seq, AutoTokenizer, pipeline
13
+ from scipy.io.wavfile import write as write_wav
14
+ import os
15
+ import re
16
+
17
+ # --- Configuration ---
18
+ # IMPORTANT: Replace these with your actual model IDs on the Hugging Face Hub.
19
+ # You must upload your fine-tuned ASR model to the Hub.
20
+ STT_MODEL_ID = "YOUR_USERNAME/YOUR_ASR_MODEL_ID" # e.g., "MickyMike/SALAMA_B3_ASR"
21
+
22
+ # You can use any powerful multilingual model that supports Swahili.
23
+ LLM_MODEL_ID = "google/gemma-2-9b-it"
24
+
25
+ # This is the tokenizer for your ONNX TTS model.
26
+ TTS_TOKENIZER_ID = "facebook/mms-tts-swh"
27
+ TTS_ONNX_MODEL_PATH = "swahili_tts.onnx" # Make sure this file is in your Space repo
28
+
29
+ # Ensure the temporary directory for audio files exists
30
+ TEMP_DIR = "temp"
31
+ os.makedirs(TEMP_DIR, exist_ok=True)
32
+
33
+
34
+ class WeeboAssistant:
35
+ def __init__(self):
36
+ # Audio settings
37
+ self.STT_SAMPLE_RATE = 16000
38
+ self.TTS_SAMPLE_RATE = 16000
39
+
40
+ # System prompt for the LLM
41
+ self.SYSTEM_PROMPT = "Wewe ni msaidizi mwenye akili, jibu swali lililoulizwa kwa UFUPI na kwa usahihi. Jibu kwa lugha ya Kiswahili pekee. Hakuna jibu refu."
42
+
43
+ self._init_models()
44
+
45
+ def _init_models(self):
46
+ """Initializes all models required for the pipeline."""
47
+ print("Initializing models...")
48
+ self.device = "cuda" if torch.cuda.is_available() else "cpu"
49
+ self.torch_dtype = torch.bfloat16 if self.device == "cuda" else torch.float32
50
+ print(f"Using device: {self.device}")
51
+
52
+ # --- 1. Initialize Swahili Speech-to-Text (STT/ASR) ---
53
+ print(f"Loading STT model: {STT_MODEL_ID}")
54
+ try:
55
+ self.stt_processor = AutoProcessor.from_pretrained(STT_MODEL_ID)
56
+ self.stt_model = AutoModelForSpeechSeq2Seq.from_pretrained(
57
+ STT_MODEL_ID,
58
+ torch_dtype=self.torch_dtype,
59
+ low_cpu_mem_usage=True,
60
+ use_safetensors=True
61
+ )
62
+ self.stt_model.to(self.device)
63
+ print("STT model loaded successfully.")
64
+ except Exception as e:
65
+ print(f"FATAL: Could not load STT model. Please check the model ID and ensure you have access. Error: {e}")
66
+ # In a real app, you might want to handle this more gracefully
67
+ raise
68
+
69
+ # --- 2. Initialize Language Model (LLM) ---
70
+ print(f"Loading LLM: {LLM_MODEL_ID}")
71
+ try:
72
+ # We don't need a separate tokenizer for the pipeline
73
+ self.llm_pipeline = pipeline(
74
+ "text-generation",
75
+ model=LLM_MODEL_ID,
76
+ model_kwargs={"torch_dtype": self.torch_dtype},
77
+ device=self.device,
78
+ )
79
+ print("LLM pipeline loaded successfully.")
80
+ except Exception as e:
81
+ print(f"FATAL: Could not load LLM. Error: {e}")
82
+ raise
83
+
84
+ # --- 3. Initialize Swahili Text-to-Speech (TTS) ---
85
+ print(f"Loading TTS model: {TTS_ONNX_MODEL_PATH}")
86
+ try:
87
+ # The ONNX model should be in the same repository as app.py
88
+ self.tts_session = onnxruntime.InferenceSession(
89
+ TTS_ONNX_MODEL_PATH,
90
+ providers=["CUDAExecutionProvider", "CPUExecutionProvider"]
91
+ )
92
+ self.tts_tokenizer = AutoTokenizer.from_pretrained(TTS_TOKENIZER_ID)
93
+ print("TTS model and tokenizer loaded successfully.")
94
+ except Exception as e:
95
+ print(f"FATAL: Could not load TTS model. Make sure '{TTS_ONNX_MODEL_PATH}' is in the repository. Error: {e}")
96
+ raise
97
+
98
+ print("-" * 30)
99
+ print("All models initialized successfully! ✅")
100
+
101
+ def transcribe_audio(self, audio_tuple: tuple) -> str:
102
+ """
103
+ Transcribes audio from Gradio's audio component.
104
+ The input is a tuple (sample_rate, numpy_array).
105
+ """
106
+ if audio_tuple is None:
107
+ return ""
108
+
109
+ sample_rate, audio_data = audio_tuple
110
+
111
+ # Convert to mono float32
112
+ if audio_data.ndim > 1:
113
+ audio_data = audio_data.mean(axis=1)
114
+ if audio_data.dtype != np.float32:
115
+ audio_data = audio_data.astype(np.float32) / np.iinfo(audio_data.dtype).max
116
+
117
+ # Resample if necessary
118
+ if sample_rate != self.STT_SAMPLE_RATE:
119
+ audio_data = librosa.resample(y=audio_data, orig_sr=sample_rate, target_sr=self.STT_SAMPLE_RATE)
120
+
121
+ if len(audio_data) < 1000: # Ignore very short audio clips
122
+ return "(Audio too short to transcribe)"
123
+
124
+ # Process and transcribe
125
+ inputs = self.stt_processor(audio_data, sampling_rate=self.STT_SAMPLE_RATE, return_tensors="pt")
126
+ inputs = {key: val.to(self.device) for key, val in inputs.items()}
127
+
128
+ with torch.no_grad():
129
+ generated_ids = self.stt_model.generate(**inputs, max_new_tokens=128)
130
+
131
+ transcription = self.stt_processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
132
+ return transcription.strip()
133
+
134
+ def generate_speech(self, text: str) -> str:
135
+ """
136
+ Generates audio from text and saves it to a temporary file.
137
+ Returns the path to the audio file.
138
+ """
139
+ if not text:
140
+ return None
141
+
142
+ # Clean text
143
+ text = text.strip()
144
+
145
+ try:
146
+ inputs = self.tts_tokenizer(text, return_tensors="np")
147
+ input_ids = inputs.input_ids
148
+ ort_inputs = {self.tts_session.get_inputs()[0].name: input_ids}
149
+ audio_waveform = self.tts_session.run(None, ort_inputs)[0].flatten()
150
+
151
+ # Save to a temporary WAV file
152
+ output_path = os.path.join(TEMP_DIR, f"{os.urandom(8).hex()}.wav")
153
+ write_wav(output_path, self.TTS_SAMPLE_RATE, audio_waveform)
154
+ return output_path
155
+ except Exception as e:
156
+ print(f"Error during audio generation: {e}")
157
+ return None
158
+
159
+ def get_llm_response(self, chat_history: list):
160
+ """
161
+ Gets a streaming response from the LLM.
162
+ Yields the updated full response at each step.
163
+ """
164
+ # Format messages for the pipeline
165
+ # The Gemma-2 instruction-tuned model uses a specific turn-based format
166
+ messages = [{'role': 'system', 'content': self.SYSTEM_PROMPT}]
167
+ for turn in chat_history:
168
+ messages.append({'role': 'user', 'content': turn[0]})
169
+ if turn[1] is not None:
170
+ messages.append({'role': 'assistant', 'content': turn[1]})
171
+
172
+ prompt = self.llm_pipeline.tokenizer.apply_chat_template(
173
+ messages,
174
+ tokenize=False,
175
+ add_generation_prompt=True
176
+ )
177
+
178
+ terminators = [
179
+ self.llm_pipeline.tokenizer.eos_token_id,
180
+ self.llm_pipeline.tokenizer.convert_tokens_to_ids("<|eot_id|>")
181
+ ]
182
+
183
+ streamer = self.llm_pipeline(
184
+ prompt,
185
+ max_new_tokens=512,
186
+ eos_token_id=terminators,
187
+ do_sample=True,
188
+ temperature=0.6,
189
+ top_p=0.9,
190
+ streamer=gr.TextIterator(),
191
+ )
192
+ return streamer
193
+
194
+ # --- Gradio Interface Logic ---
195
+
196
+ # Instantiate the assistant
197
+ assistant = WeeboAssistant()
198
+
199
+ def s2s_pipeline(audio_input, chat_history):
200
+ """The main function for the Speech-to-Speech tab."""
201
+ # 1. Transcribe user's speech
202
+ user_text = assistant.transcribe_audio(audio_input)
203
+ if not user_text or user_text.startswith("("):
204
+ chat_history.append((user_text or "(No valid speech detected)", None))
205
+ yield chat_history, None, "Please record your voice again."
206
+ return
207
+
208
+ chat_history.append((user_text, None))
209
+ yield chat_history, None, "..." # Show user text and a thinking indicator
210
+
211
+ # 2. Get LLM response as a stream
212
+ response_stream = assistant.get_llm_response(chat_history)
213
+
214
+ # Stream the response text to the UI
215
+ llm_response_text = ""
216
+ for text_chunk in response_stream:
217
+ llm_response_text = text_chunk
218
+ chat_history[-1] = (user_text, llm_response_text)
219
+ yield chat_history, None, llm_response_text
220
+
221
+ # 3. Synthesize the final LLM response to speech
222
+ final_audio_path = assistant.generate_speech(llm_response_text)
223
+
224
+ # 4. Final update to the UI
225
+ yield chat_history, final_audio_path, llm_response_text
226
+
227
+ def t2t_pipeline(text_input, chat_history):
228
+ """The main function for the Text-to-Text tab."""
229
+ chat_history.append((text_input, None))
230
+ yield chat_history, "..."
231
+
232
+ response_stream = assistant.get_llm_response(chat_history)
233
+
234
+ llm_response_text = ""
235
+ for text_chunk in response_stream:
236
+ llm_response_text = text_chunk
237
+ chat_history[-1] = (text_input, llm_response_text)
238
+ yield chat_history, llm_response_text
239
+
240
+ # --- Build Gradio UI ---
241
+ with gr.Blocks(theme=gr.themes.Soft(), title="Msaidizi wa Kiswahili") as demo:
242
+ gr.Markdown("# 🤖 Msaidizi wa Sauti wa Kiswahili (Swahili Voice Assistant)")
243
+ gr.Markdown("Ongea na msaidizi kwa Kiswahili. Toa sauti, andika maandishi, na upate majibu kwa sauti au maandishi.")
244
+
245
+ with gr.Tabs():
246
+ # Tab 1: Speech-to-Speech
247
+ with gr.TabItem("🎙️ Sauti-kwa-Sauti (Speech-to-Speech)"):
248
+ with gr.Row():
249
+ with gr.Column(scale=2):
250
+ s2s_audio_in = gr.Audio(sources=["microphone"], type="numpy", label="Ongea Hapa (Speak Here)")
251
+ s2s_submit_btn = gr.Button("Tuma (Submit)", variant="primary")
252
+ with gr.Column(scale=3):
253
+ s2s_chatbot = gr.Chatbot(label="Mazungumzo (Conversation)", bubble_full_width=False, height=400)
254
+ s2s_audio_out = gr.Audio(type="filepath", label="Jibu la Sauti (Audio Response)", autoplay=True)
255
+ s2s_text_out = gr.Textbox(label="Jibu la Maandishi (Text Response)", interactive=False)
256
+
257
+ # Tab 2: Text-to-Text
258
+ with gr.TabItem("⌨️ Maandishi-kwa-Maandishi (Text-to-Text)"):
259
+ t2t_chatbot = gr.Chatbot(label="Mazungumzo (Conversation)", bubble_full_width=False, height=500)
260
+ with gr.Row():
261
+ t2t_text_in = gr.Textbox(label="Andika Hapa (Write Here)", placeholder="Habari yako...", scale=4)
262
+ t2t_submit_btn = gr.Button("Tuma (Submit)", variant="primary", scale=1)
263
+
264
+ # Tab 3: Direct Tools
265
+ with gr.TabItem("🛠️ Zana (Tools)"):
266
+ with gr.Row():
267
+ # Speech to Text Tool
268
+ with gr.Column():
269
+ gr.Markdown("### Unukuzi wa Sauti (Speech Transcription)")
270
+ tool_s2t_audio_in = gr.Audio(sources=["microphone"], type="numpy", label="Sauti ya Kuingiza (Input Audio)")
271
+ tool_s2t_text_out = gr.Textbox(label="Maandishi Yaliyonukuliwa (Transcribed Text)", interactive=False)
272
+ tool_s2t_btn = gr.Button("Nukuu (Transcribe)")
273
+ # Text to Speech Tool
274
+ with gr.Column():
275
+ gr.Markdown("### Utengenezaji wa Sauti (Speech Synthesis)")
276
+ tool_t2s_text_in = gr.Textbox(label="Maandishi ya Kuingiza (Input Text)", placeholder="Andika Kiswahili hapa...")
277
+ tool_t2s_audio_out = gr.Audio(type="filepath", label="Sauti Iliyotengenezwa (Synthesized Audio)", autoplay=False)
278
+ tool_t2s_btn = gr.Button("Tengeneza Sauti (Synthesize)")
279
+
280
+ # --- Event Handlers ---
281
+
282
+ # Speech-to-Speech handler
283
+ s2s_submit_btn.click(
284
+ fn=s2s_pipeline,
285
+ inputs=[s2s_audio_in, s2s_chatbot],
286
+ outputs=[s2s_chatbot, s2s_audio_out, s2s_text_out],
287
+ queue=True
288
+ )
289
+
290
+ # Text-to-Text handler
291
+ t2t_submit_btn.click(
292
+ fn=t2t_pipeline,
293
+ inputs=[t2t_text_in, t2t_chatbot],
294
+ outputs=[t2t_chatbot, t2t_text_in.change(value="")], # Clear input box on submit
295
+ queue=True
296
+ ).then(
297
+ lambda x: x, t2t_chatbot, t2t_text_in
298
+ ) # The text response is streamed directly to the chatbot UI
299
+
300
+ # Tool handlers
301
+ tool_s2t_btn.click(
302
+ fn=assistant.transcribe_audio,
303
+ inputs=tool_s2t_audio_in,
304
+ outputs=tool_s2t_text_out
305
+ )
306
+ tool_t2s_btn.click(
307
+ fn=assistant.generate_speech,
308
+ inputs=tool_t2s_text_in,
309
+ outputs=tool_t2s_audio_out
310
+ )
311
+
312
+ # Launch the Gradio app
313
+ demo.queue().launch(debug=True)