File size: 9,755 Bytes
a248e18
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0a88016
 
aeacff6
 
 
 
 
0a88016
 
 
a248e18
f2cec3a
aeacff6
a248e18
1452cfe
a248e18
 
 
 
 
 
 
 
 
aeacff6
 
 
 
a248e18
 
 
 
 
 
 
 
aeacff6
a248e18
aeacff6
 
 
 
 
 
 
 
 
 
a248e18
aeacff6
 
 
 
 
 
 
a248e18
aeacff6
a248e18
aeacff6
 
 
 
 
 
 
a248e18
 
 
aeacff6
a248e18
 
 
 
 
 
 
 
 
aeacff6
a248e18
 
aeacff6
a248e18
 
 
 
 
aeacff6
a248e18
 
 
aeacff6
 
 
 
 
 
 
 
a248e18
 
aeacff6
 
a248e18
 
aeacff6
a248e18
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
aeacff6
a248e18
 
 
 
 
 
 
aeacff6
a248e18
 
 
 
 
 
 
 
 
aeacff6
a248e18
 
 
 
 
 
 
 
 
 
aeacff6
 
 
 
 
a248e18
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
aeacff6
a248e18
 
aeacff6
 
 
 
a248e18
 
 
 
 
 
 
 
 
 
 
 
aeacff6
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
# -*- coding: utf-8 -*-
"""
This script implements a multi-modal Swahili assistant for Hugging Face Spaces.
It uses Gradio for the user interface and loads models from the HF Hub.
"""

import gradio as gr
import numpy as np
import onnxruntime
import torch
import librosa
from transformers import AutoProcessor, AutoModelForSpeechSeq2Seq, AutoTokenizer, pipeline
from scipy.io.wavfile import write as write_wav
import os
import re
from huggingface_hub import login

# --- Login to Hugging Face using secret ---
# Make sure HF_TOKEN is set in your Hugging Face Space > Settings > Repository secrets
hf_token = os.environ.get("HF_TOKEN")
if not hf_token:
    raise ValueError("HF_TOKEN not found. Please set it in Hugging Face Space repository secrets.")
login(token=hf_token)
print("Successfully logged into Hugging Face Hub!")

# --- Configuration ---
STT_MODEL_ID = "EYEDOL/SALAMA_C3"
LLM_MODEL_ID = "google/gemma-3-1b-it"
TTS_TOKENIZER_ID = "facebook/mms-tts-swh"
TTS_ONNX_MODEL_PATH = "swahili_tts.onnx"

TEMP_DIR = "temp"
os.makedirs(TEMP_DIR, exist_ok=True)


class WeeboAssistant:
    def __init__(self):
        self.STT_SAMPLE_RATE = 16000
        self.TTS_SAMPLE_RATE = 16000
        self.SYSTEM_PROMPT = (
            "Wewe ni msaidizi mwenye akili, jibu swali lililoulizwa kwa UFUPI na kwa usahihi. "
            "Jibu kwa lugha ya Kiswahili pekee. Hakuna jibu refu."
        )
        self._init_models()

    def _init_models(self):
        print("Initializing models...")
        self.device = "cuda" if torch.cuda.is_available() else "cpu"
        self.torch_dtype = torch.bfloat16 if self.device == "cuda" else torch.float32
        print(f"Using device: {self.device}")

        # STT
        print(f"Loading STT model: {STT_MODEL_ID}")
        self.stt_processor = AutoProcessor.from_pretrained(STT_MODEL_ID)
        self.stt_model = AutoModelForSpeechSeq2Seq.from_pretrained(
            STT_MODEL_ID, 
            torch_dtype=self.torch_dtype, 
            low_cpu_mem_usage=True, 
            use_safetensors=True
        ).to(self.device)
        print("STT model loaded successfully.")

        # LLM
        print(f"Loading LLM: {LLM_MODEL_ID}")
        self.llm_pipeline = pipeline(
            "text-generation",
            model=LLM_MODEL_ID,
            model_kwargs={"torch_dtype": self.torch_dtype},
            device=self.device,
        )
        print("LLM pipeline loaded successfully.")

        # TTS
        print(f"Loading TTS model: {TTS_ONNX_MODEL_PATH}")
        self.tts_session = onnxruntime.InferenceSession(
            TTS_ONNX_MODEL_PATH,
            providers=["CUDAExecutionProvider", "CPUExecutionProvider"]
        )
        self.tts_tokenizer = AutoTokenizer.from_pretrained(TTS_TOKENIZER_ID)
        print("TTS model and tokenizer loaded successfully.")

        print("-" * 30)
        print("All models initialized successfully! โœ…")

    def transcribe_audio(self, audio_tuple):
        if audio_tuple is None:
            return ""
        sample_rate, audio_data = audio_tuple
        if audio_data.ndim > 1:
            audio_data = audio_data.mean(axis=1)
        if audio_data.dtype != np.float32:
            audio_data = audio_data.astype(np.float32) / np.iinfo(audio_data.dtype).max
        if sample_rate != self.STT_SAMPLE_RATE:
            audio_data = librosa.resample(y=audio_data, orig_sr=sample_rate, target_sr=self.STT_SAMPLE_RATE)
        if len(audio_data) < 1000:
            return "(Audio too short to transcribe)"
        inputs = self.stt_processor(audio_data, sampling_rate=self.STT_SAMPLE_RATE, return_tensors="pt")
        inputs = {k: v.to(self.device) for k, v in inputs.items()}
        with torch.no_grad():
            generated_ids = self.stt_model.generate(**inputs, max_new_tokens=128)
        transcription = self.stt_processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
        return transcription.strip()

    def generate_speech(self, text):
        if not text:
            return None
        text = text.strip()
        inputs = self.tts_tokenizer(text, return_tensors="np")
        ort_inputs = {self.tts_session.get_inputs()[0].name: inputs.input_ids}
        audio_waveform = self.tts_session.run(None, ort_inputs)[0].flatten()
        output_path = os.path.join(TEMP_DIR, f"{os.urandom(8).hex()}.wav")
        write_wav(output_path, self.TTS_SAMPLE_RATE, audio_waveform)
        return output_path

    def get_llm_response(self, chat_history):
        messages = [{'role': 'system', 'content': self.SYSTEM_PROMPT}]
        for turn in chat_history:
            messages.append({'role': 'user', 'content': turn[0]})
            if turn[1] is not None:
                messages.append({'role': 'assistant', 'content': turn[1]})
        prompt = self.llm_pipeline.tokenizer.apply_chat_template(
            messages, tokenize=False, add_generation_prompt=True
        )
        terminators = [
            self.llm_pipeline.tokenizer.eos_token_id,
            self.llm_pipeline.tokenizer.convert_tokens_to_ids("<|eot_id|>")
        ]
        streamer = self.llm_pipeline(
            prompt,
            max_new_tokens=512,
            eos_token_id=terminators,
            do_sample=True,
            temperature=0.6,
            top_p=0.9,
            streamer=gr.TextIterator(),
        )
        return streamer


assistant = WeeboAssistant()


def s2s_pipeline(audio_input, chat_history):
    user_text = assistant.transcribe_audio(audio_input)
    if not user_text or user_text.startswith("("):
        chat_history.append((user_text or "(No valid speech detected)", None))
        yield chat_history, None, "Please record your voice again."
        return
    chat_history.append((user_text, None))
    yield chat_history, None, "..."
    response_stream = assistant.get_llm_response(chat_history)
    llm_response_text = ""
    for text_chunk in response_stream:
        llm_response_text = text_chunk
        chat_history[-1] = (user_text, llm_response_text)
        yield chat_history, None, llm_response_text
    final_audio_path = assistant.generate_speech(llm_response_text)
    yield chat_history, final_audio_path, llm_response_text


def t2t_pipeline(text_input, chat_history):
    chat_history.append((text_input, None))
    yield chat_history, "..."
    response_stream = assistant.get_llm_response(chat_history)
    llm_response_text = ""
    for text_chunk in response_stream:
        llm_response_text = text_chunk
        chat_history[-1] = (text_input, llm_response_text)
        yield chat_history, llm_response_text


def clear_textbox():
    return ""


with gr.Blocks(theme=gr.themes.Soft(), title="Msaidizi wa Kiswahili") as demo:
    gr.Markdown("# ๐Ÿค– Msaidizi wa Sauti wa Kiswahili (Swahili Voice Assistant)")
    gr.Markdown("Ongea na msaidizi kwa Kiswahili. Toa sauti, andika maandishi, na upate majibu kwa sauti au maandishi.")
    
    with gr.Tabs():
        with gr.TabItem("๐ŸŽ™๏ธ Sauti-kwa-Sauti (Speech-to-Speech)"):
            with gr.Row():
                with gr.Column(scale=2):
                    s2s_audio_in = gr.Audio(sources=["microphone"], type="numpy", label="Ongea Hapa (Speak Here)")
                    s2s_submit_btn = gr.Button("Tuma (Submit)", variant="primary")
                with gr.Column(scale=3):
                    s2s_chatbot = gr.Chatbot(label="Mazungumzo (Conversation)", bubble_full_width=False, height=400)
                    s2s_audio_out = gr.Audio(type="filepath", label="Jibu la Sauti (Audio Response)", autoplay=True)
                    s2s_text_out = gr.Textbox(label="Jibu la Maandishi (Text Response)", interactive=False)
            
        with gr.TabItem("โŒจ๏ธ Maandishi-kwa-Maandishi (Text-to-Text)"):
            t2t_chatbot = gr.Chatbot(label="Mazungumzo (Conversation)", bubble_full_width=False, height=500)
            with gr.Row():
                t2t_text_in = gr.Textbox(label="Andika Hapa (Write Here)", placeholder="Habari yako...", scale=4)
                t2t_submit_btn = gr.Button("Tuma (Submit)", variant="primary", scale=1)

        with gr.TabItem("๐Ÿ› ๏ธ Zana (Tools)"):
            with gr.Row():
                with gr.Column():
                    gr.Markdown("### Unukuzi wa Sauti (Speech Transcription)")
                    tool_s2t_audio_in = gr.Audio(sources=["microphone"], type="numpy", label="Sauti ya Kuingiza (Input Audio)")
                    tool_s2t_text_out = gr.Textbox(label="Maandishi Yaliyonukuliwa (Transcribed Text)", interactive=False)
                    tool_s2t_btn = gr.Button("Nukuu (Transcribe)")
                with gr.Column():
                    gr.Markdown("### Utengenezaji wa Sauti (Speech Synthesis)")
                    tool_t2s_text_in = gr.Textbox(label="Maandishi ya Kuingiza (Input Text)", placeholder="Andika Kiswahili hapa...")
                    tool_t2s_audio_out = gr.Audio(type="filepath", label="Sauti Iliyotengenezwa (Synthesized Audio)", autoplay=False)
                    tool_t2s_btn = gr.Button("Tengeneza Sauti (Synthesize)")

    s2s_submit_btn.click(
        fn=s2s_pipeline,
        inputs=[s2s_audio_in, s2s_chatbot],
        outputs=[s2s_chatbot, s2s_audio_out, s2s_text_out],
        queue=True
    )

    t2t_submit_btn.click(
        fn=t2t_pipeline,
        inputs=[t2t_text_in, t2t_chatbot],
        outputs=[t2t_chatbot, t2t_text_in],
        queue=True
    ).then(
        fn=clear_textbox,
        inputs=None,
        outputs=t2t_text_in
    )

    tool_s2t_btn.click(
        fn=assistant.transcribe_audio,
        inputs=tool_s2t_audio_in,
        outputs=tool_s2t_text_out
    )
    tool_t2s_btn.click(
        fn=assistant.generate_speech,
        inputs=tool_t2s_text_in,
        outputs=tool_t2s_audio_out
    )

demo.queue().launch(debug=True)