File size: 12,675 Bytes
a248e18
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
# -*- coding: utf-8 -*-
"""
This script implements a multi-modal Swahili assistant for Hugging Face Spaces.
It uses Gradio for the user interface and loads models from the HF Hub.
"""

import gradio as gr
import numpy as np
import onnxruntime
import torch
import librosa
from transformers import AutoProcessor, AutoModelForSpeechSeq2Seq, AutoTokenizer, pipeline
from scipy.io.wavfile import write as write_wav
import os
import re

# --- Configuration ---
# IMPORTANT: Replace these with your actual model IDs on the Hugging Face Hub.
# You must upload your fine-tuned ASR model to the Hub.
STT_MODEL_ID = "YOUR_USERNAME/YOUR_ASR_MODEL_ID"  # e.g., "MickyMike/SALAMA_B3_ASR"

# You can use any powerful multilingual model that supports Swahili.
LLM_MODEL_ID = "google/gemma-2-9b-it" 

# This is the tokenizer for your ONNX TTS model.
TTS_TOKENIZER_ID = "facebook/mms-tts-swh"
TTS_ONNX_MODEL_PATH = "swahili_tts.onnx" # Make sure this file is in your Space repo

# Ensure the temporary directory for audio files exists
TEMP_DIR = "temp"
os.makedirs(TEMP_DIR, exist_ok=True)


class WeeboAssistant:
    def __init__(self):
        # Audio settings
        self.STT_SAMPLE_RATE = 16000
        self.TTS_SAMPLE_RATE = 16000
        
        # System prompt for the LLM
        self.SYSTEM_PROMPT = "Wewe ni msaidizi mwenye akili, jibu swali lililoulizwa kwa UFUPI na kwa usahihi. Jibu kwa lugha ya Kiswahili pekee. Hakuna jibu refu."

        self._init_models()

    def _init_models(self):
        """Initializes all models required for the pipeline."""
        print("Initializing models...")
        self.device = "cuda" if torch.cuda.is_available() else "cpu"
        self.torch_dtype = torch.bfloat16 if self.device == "cuda" else torch.float32
        print(f"Using device: {self.device}")

        # --- 1. Initialize Swahili Speech-to-Text (STT/ASR) ---
        print(f"Loading STT model: {STT_MODEL_ID}")
        try:
            self.stt_processor = AutoProcessor.from_pretrained(STT_MODEL_ID)
            self.stt_model = AutoModelForSpeechSeq2Seq.from_pretrained(
                STT_MODEL_ID, 
                torch_dtype=self.torch_dtype, 
                low_cpu_mem_usage=True, 
                use_safetensors=True
            )
            self.stt_model.to(self.device)
            print("STT model loaded successfully.")
        except Exception as e:
            print(f"FATAL: Could not load STT model. Please check the model ID and ensure you have access. Error: {e}")
            # In a real app, you might want to handle this more gracefully
            raise

        # --- 2. Initialize Language Model (LLM) ---
        print(f"Loading LLM: {LLM_MODEL_ID}")
        try:
            # We don't need a separate tokenizer for the pipeline
            self.llm_pipeline = pipeline(
                "text-generation",
                model=LLM_MODEL_ID,
                model_kwargs={"torch_dtype": self.torch_dtype},
                device=self.device,
            )
            print("LLM pipeline loaded successfully.")
        except Exception as e:
            print(f"FATAL: Could not load LLM. Error: {e}")
            raise

        # --- 3. Initialize Swahili Text-to-Speech (TTS) ---
        print(f"Loading TTS model: {TTS_ONNX_MODEL_PATH}")
        try:
            # The ONNX model should be in the same repository as app.py
            self.tts_session = onnxruntime.InferenceSession(
                TTS_ONNX_MODEL_PATH,
                providers=["CUDAExecutionProvider", "CPUExecutionProvider"]
            )
            self.tts_tokenizer = AutoTokenizer.from_pretrained(TTS_TOKENIZER_ID)
            print("TTS model and tokenizer loaded successfully.")
        except Exception as e:
            print(f"FATAL: Could not load TTS model. Make sure '{TTS_ONNX_MODEL_PATH}' is in the repository. Error: {e}")
            raise
            
        print("-" * 30)
        print("All models initialized successfully! โœ…")

    def transcribe_audio(self, audio_tuple: tuple) -> str:
        """
        Transcribes audio from Gradio's audio component.
        The input is a tuple (sample_rate, numpy_array).
        """
        if audio_tuple is None:
            return ""
            
        sample_rate, audio_data = audio_tuple
        
        # Convert to mono float32
        if audio_data.ndim > 1:
            audio_data = audio_data.mean(axis=1)
        if audio_data.dtype != np.float32:
            audio_data = audio_data.astype(np.float32) / np.iinfo(audio_data.dtype).max

        # Resample if necessary
        if sample_rate != self.STT_SAMPLE_RATE:
            audio_data = librosa.resample(y=audio_data, orig_sr=sample_rate, target_sr=self.STT_SAMPLE_RATE)
        
        if len(audio_data) < 1000: # Ignore very short audio clips
            return "(Audio too short to transcribe)"

        # Process and transcribe
        inputs = self.stt_processor(audio_data, sampling_rate=self.STT_SAMPLE_RATE, return_tensors="pt")
        inputs = {key: val.to(self.device) for key, val in inputs.items()}
        
        with torch.no_grad():
            generated_ids = self.stt_model.generate(**inputs, max_new_tokens=128)
            
        transcription = self.stt_processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
        return transcription.strip()

    def generate_speech(self, text: str) -> str:
        """
        Generates audio from text and saves it to a temporary file.
        Returns the path to the audio file.
        """
        if not text:
            return None

        # Clean text
        text = text.strip()
        
        try:
            inputs = self.tts_tokenizer(text, return_tensors="np")
            input_ids = inputs.input_ids
            ort_inputs = {self.tts_session.get_inputs()[0].name: input_ids}
            audio_waveform = self.tts_session.run(None, ort_inputs)[0].flatten()
            
            # Save to a temporary WAV file
            output_path = os.path.join(TEMP_DIR, f"{os.urandom(8).hex()}.wav")
            write_wav(output_path, self.TTS_SAMPLE_RATE, audio_waveform)
            return output_path
        except Exception as e:
            print(f"Error during audio generation: {e}")
            return None

    def get_llm_response(self, chat_history: list):
        """
        Gets a streaming response from the LLM.
        Yields the updated full response at each step.
        """
        # Format messages for the pipeline
        # The Gemma-2 instruction-tuned model uses a specific turn-based format
        messages = [{'role': 'system', 'content': self.SYSTEM_PROMPT}]
        for turn in chat_history:
             messages.append({'role': 'user', 'content': turn[0]})
             if turn[1] is not None:
                messages.append({'role': 'assistant', 'content': turn[1]})

        prompt = self.llm_pipeline.tokenizer.apply_chat_template(
            messages, 
            tokenize=False, 
            add_generation_prompt=True
        )

        terminators = [
            self.llm_pipeline.tokenizer.eos_token_id,
            self.llm_pipeline.tokenizer.convert_tokens_to_ids("<|eot_id|>")
        ]

        streamer = self.llm_pipeline(
            prompt,
            max_new_tokens=512,
            eos_token_id=terminators,
            do_sample=True,
            temperature=0.6,
            top_p=0.9,
            streamer=gr.TextIterator(),
        )
        return streamer

# --- Gradio Interface Logic ---

# Instantiate the assistant
assistant = WeeboAssistant()

def s2s_pipeline(audio_input, chat_history):
    """The main function for the Speech-to-Speech tab."""
    # 1. Transcribe user's speech
    user_text = assistant.transcribe_audio(audio_input)
    if not user_text or user_text.startswith("("):
        chat_history.append((user_text or "(No valid speech detected)", None))
        yield chat_history, None, "Please record your voice again."
        return

    chat_history.append((user_text, None))
    yield chat_history, None, "..." # Show user text and a thinking indicator

    # 2. Get LLM response as a stream
    response_stream = assistant.get_llm_response(chat_history)
    
    # Stream the response text to the UI
    llm_response_text = ""
    for text_chunk in response_stream:
        llm_response_text = text_chunk
        chat_history[-1] = (user_text, llm_response_text)
        yield chat_history, None, llm_response_text
        
    # 3. Synthesize the final LLM response to speech
    final_audio_path = assistant.generate_speech(llm_response_text)

    # 4. Final update to the UI
    yield chat_history, final_audio_path, llm_response_text

def t2t_pipeline(text_input, chat_history):
    """The main function for the Text-to-Text tab."""
    chat_history.append((text_input, None))
    yield chat_history, "..."

    response_stream = assistant.get_llm_response(chat_history)
    
    llm_response_text = ""
    for text_chunk in response_stream:
        llm_response_text = text_chunk
        chat_history[-1] = (text_input, llm_response_text)
        yield chat_history, llm_response_text

# --- Build Gradio UI ---
with gr.Blocks(theme=gr.themes.Soft(), title="Msaidizi wa Kiswahili") as demo:
    gr.Markdown("# ๐Ÿค– Msaidizi wa Sauti wa Kiswahili (Swahili Voice Assistant)")
    gr.Markdown("Ongea na msaidizi kwa Kiswahili. Toa sauti, andika maandishi, na upate majibu kwa sauti au maandishi.")
    
    with gr.Tabs():
        # Tab 1: Speech-to-Speech
        with gr.TabItem("๐ŸŽ™๏ธ Sauti-kwa-Sauti (Speech-to-Speech)"):
            with gr.Row():
                with gr.Column(scale=2):
                    s2s_audio_in = gr.Audio(sources=["microphone"], type="numpy", label="Ongea Hapa (Speak Here)")
                    s2s_submit_btn = gr.Button("Tuma (Submit)", variant="primary")
                with gr.Column(scale=3):
                    s2s_chatbot = gr.Chatbot(label="Mazungumzo (Conversation)", bubble_full_width=False, height=400)
                    s2s_audio_out = gr.Audio(type="filepath", label="Jibu la Sauti (Audio Response)", autoplay=True)
                    s2s_text_out = gr.Textbox(label="Jibu la Maandishi (Text Response)", interactive=False)
            
        # Tab 2: Text-to-Text
        with gr.TabItem("โŒจ๏ธ Maandishi-kwa-Maandishi (Text-to-Text)"):
            t2t_chatbot = gr.Chatbot(label="Mazungumzo (Conversation)", bubble_full_width=False, height=500)
            with gr.Row():
                t2t_text_in = gr.Textbox(label="Andika Hapa (Write Here)", placeholder="Habari yako...", scale=4)
                t2t_submit_btn = gr.Button("Tuma (Submit)", variant="primary", scale=1)

        # Tab 3: Direct Tools
        with gr.TabItem("๐Ÿ› ๏ธ Zana (Tools)"):
            with gr.Row():
                # Speech to Text Tool
                with gr.Column():
                    gr.Markdown("### Unukuzi wa Sauti (Speech Transcription)")
                    tool_s2t_audio_in = gr.Audio(sources=["microphone"], type="numpy", label="Sauti ya Kuingiza (Input Audio)")
                    tool_s2t_text_out = gr.Textbox(label="Maandishi Yaliyonukuliwa (Transcribed Text)", interactive=False)
                    tool_s2t_btn = gr.Button("Nukuu (Transcribe)")
                # Text to Speech Tool
                with gr.Column():
                    gr.Markdown("### Utengenezaji wa Sauti (Speech Synthesis)")
                    tool_t2s_text_in = gr.Textbox(label="Maandishi ya Kuingiza (Input Text)", placeholder="Andika Kiswahili hapa...")
                    tool_t2s_audio_out = gr.Audio(type="filepath", label="Sauti Iliyotengenezwa (Synthesized Audio)", autoplay=False)
                    tool_t2s_btn = gr.Button("Tengeneza Sauti (Synthesize)")

    # --- Event Handlers ---
    
    # Speech-to-Speech handler
    s2s_submit_btn.click(
        fn=s2s_pipeline,
        inputs=[s2s_audio_in, s2s_chatbot],
        outputs=[s2s_chatbot, s2s_audio_out, s2s_text_out],
        queue=True
    )

    # Text-to-Text handler
    t2t_submit_btn.click(
        fn=t2t_pipeline,
        inputs=[t2t_text_in, t2t_chatbot],
        outputs=[t2t_chatbot, t2t_text_in.change(value="")], # Clear input box on submit
        queue=True
    ).then(
        lambda x: x, t2t_chatbot, t2t_text_in
    ) # The text response is streamed directly to the chatbot UI

    # Tool handlers
    tool_s2t_btn.click(
        fn=assistant.transcribe_audio,
        inputs=tool_s2t_audio_in,
        outputs=tool_s2t_text_out
    )
    tool_t2s_btn.click(
        fn=assistant.generate_speech,
        inputs=tool_t2s_text_in,
        outputs=tool_t2s_audio_out
    )

# Launch the Gradio app
demo.queue().launch(debug=True)