# -*- coding: utf-8 -*- """ This script implements a multi-modal Swahili assistant for Hugging Face Spaces. It uses Gradio for the user interface and loads models from the HF Hub. """ import gradio as gr import numpy as np import onnxruntime import torch import librosa from transformers import AutoProcessor, AutoModelForSpeechSeq2Seq, AutoTokenizer, pipeline from scipy.io.wavfile import write as write_wav import os import re from huggingface_hub import login # --- Login to Hugging Face using secret --- # Make sure HF_TOKEN is set in your Hugging Face Space > Settings > Repository secrets hf_token = os.environ.get("HF_TOKEN") if not hf_token: raise ValueError("HF_TOKEN not found. Please set it in Hugging Face Space repository secrets.") login(token=hf_token) print("Successfully logged into Hugging Face Hub!") # --- Configuration --- STT_MODEL_ID = "EYEDOL/SALAMA_C3" LLM_MODEL_ID = "google/gemma-3-1b-it" TTS_TOKENIZER_ID = "facebook/mms-tts-swh" TTS_ONNX_MODEL_PATH = "swahili_tts.onnx" TEMP_DIR = "temp" os.makedirs(TEMP_DIR, exist_ok=True) class WeeboAssistant: def __init__(self): self.STT_SAMPLE_RATE = 16000 self.TTS_SAMPLE_RATE = 16000 self.SYSTEM_PROMPT = ( "Wewe ni msaidizi mwenye akili, jibu swali lililoulizwa kwa UFUPI na kwa usahihi. " "Jibu kwa lugha ya Kiswahili pekee. Hakuna jibu refu." ) self._init_models() def _init_models(self): print("Initializing models...") self.device = "cuda" if torch.cuda.is_available() else "cpu" self.torch_dtype = torch.bfloat16 if self.device == "cuda" else torch.float32 print(f"Using device: {self.device}") # STT print(f"Loading STT model: {STT_MODEL_ID}") self.stt_processor = AutoProcessor.from_pretrained(STT_MODEL_ID) self.stt_model = AutoModelForSpeechSeq2Seq.from_pretrained( STT_MODEL_ID, torch_dtype=self.torch_dtype, low_cpu_mem_usage=True, use_safetensors=True ).to(self.device) print("STT model loaded successfully.") # LLM print(f"Loading LLM: {LLM_MODEL_ID}") self.llm_pipeline = pipeline( "text-generation", model=LLM_MODEL_ID, model_kwargs={"torch_dtype": self.torch_dtype}, device=self.device, ) print("LLM pipeline loaded successfully.") # TTS print(f"Loading TTS model: {TTS_ONNX_MODEL_PATH}") self.tts_session = onnxruntime.InferenceSession( TTS_ONNX_MODEL_PATH, providers=["CUDAExecutionProvider", "CPUExecutionProvider"] ) self.tts_tokenizer = AutoTokenizer.from_pretrained(TTS_TOKENIZER_ID) print("TTS model and tokenizer loaded successfully.") print("-" * 30) print("All models initialized successfully! ✅") def transcribe_audio(self, audio_tuple): if audio_tuple is None: return "" sample_rate, audio_data = audio_tuple if audio_data.ndim > 1: audio_data = audio_data.mean(axis=1) if audio_data.dtype != np.float32: audio_data = audio_data.astype(np.float32) / np.iinfo(audio_data.dtype).max if sample_rate != self.STT_SAMPLE_RATE: audio_data = librosa.resample(y=audio_data, orig_sr=sample_rate, target_sr=self.STT_SAMPLE_RATE) if len(audio_data) < 1000: return "(Audio too short to transcribe)" inputs = self.stt_processor(audio_data, sampling_rate=self.STT_SAMPLE_RATE, return_tensors="pt") inputs = {k: v.to(self.device) for k, v in inputs.items()} with torch.no_grad(): generated_ids = self.stt_model.generate(**inputs, max_new_tokens=128) transcription = self.stt_processor.batch_decode(generated_ids, skip_special_tokens=True)[0] return transcription.strip() def generate_speech(self, text): if not text: return None text = text.strip() inputs = self.tts_tokenizer(text, return_tensors="np") ort_inputs = {self.tts_session.get_inputs()[0].name: inputs.input_ids} audio_waveform = self.tts_session.run(None, ort_inputs)[0].flatten() output_path = os.path.join(TEMP_DIR, f"{os.urandom(8).hex()}.wav") write_wav(output_path, self.TTS_SAMPLE_RATE, audio_waveform) return output_path def get_llm_response(self, chat_history): messages = [{'role': 'system', 'content': self.SYSTEM_PROMPT}] for turn in chat_history: messages.append({'role': 'user', 'content': turn[0]}) if turn[1] is not None: messages.append({'role': 'assistant', 'content': turn[1]}) prompt = self.llm_pipeline.tokenizer.apply_chat_template( messages, tokenize=False, add_generation_prompt=True ) terminators = [ self.llm_pipeline.tokenizer.eos_token_id, self.llm_pipeline.tokenizer.convert_tokens_to_ids("<|eot_id|>") ] streamer = self.llm_pipeline( prompt, max_new_tokens=512, eos_token_id=terminators, do_sample=True, temperature=0.6, top_p=0.9, streamer=gr.TextIterator(), ) return streamer assistant = WeeboAssistant() def s2s_pipeline(audio_input, chat_history): user_text = assistant.transcribe_audio(audio_input) if not user_text or user_text.startswith("("): chat_history.append((user_text or "(No valid speech detected)", None)) yield chat_history, None, "Please record your voice again." return chat_history.append((user_text, None)) yield chat_history, None, "..." response_stream = assistant.get_llm_response(chat_history) llm_response_text = "" for text_chunk in response_stream: llm_response_text = text_chunk chat_history[-1] = (user_text, llm_response_text) yield chat_history, None, llm_response_text final_audio_path = assistant.generate_speech(llm_response_text) yield chat_history, final_audio_path, llm_response_text def t2t_pipeline(text_input, chat_history): chat_history.append((text_input, None)) yield chat_history, "..." response_stream = assistant.get_llm_response(chat_history) llm_response_text = "" for text_chunk in response_stream: llm_response_text = text_chunk chat_history[-1] = (text_input, llm_response_text) yield chat_history, llm_response_text def clear_textbox(): return "" with gr.Blocks(theme=gr.themes.Soft(), title="Msaidizi wa Kiswahili") as demo: gr.Markdown("# 🤖 Msaidizi wa Sauti wa Kiswahili (Swahili Voice Assistant)") gr.Markdown("Ongea na msaidizi kwa Kiswahili. Toa sauti, andika maandishi, na upate majibu kwa sauti au maandishi.") with gr.Tabs(): with gr.TabItem("🎙️ Sauti-kwa-Sauti (Speech-to-Speech)"): with gr.Row(): with gr.Column(scale=2): s2s_audio_in = gr.Audio(sources=["microphone"], type="numpy", label="Ongea Hapa (Speak Here)") s2s_submit_btn = gr.Button("Tuma (Submit)", variant="primary") with gr.Column(scale=3): s2s_chatbot = gr.Chatbot(label="Mazungumzo (Conversation)", bubble_full_width=False, height=400) s2s_audio_out = gr.Audio(type="filepath", label="Jibu la Sauti (Audio Response)", autoplay=True) s2s_text_out = gr.Textbox(label="Jibu la Maandishi (Text Response)", interactive=False) with gr.TabItem("⌨️ Maandishi-kwa-Maandishi (Text-to-Text)"): t2t_chatbot = gr.Chatbot(label="Mazungumzo (Conversation)", bubble_full_width=False, height=500) with gr.Row(): t2t_text_in = gr.Textbox(label="Andika Hapa (Write Here)", placeholder="Habari yako...", scale=4) t2t_submit_btn = gr.Button("Tuma (Submit)", variant="primary", scale=1) with gr.TabItem("🛠️ Zana (Tools)"): with gr.Row(): with gr.Column(): gr.Markdown("### Unukuzi wa Sauti (Speech Transcription)") tool_s2t_audio_in = gr.Audio(sources=["microphone"], type="numpy", label="Sauti ya Kuingiza (Input Audio)") tool_s2t_text_out = gr.Textbox(label="Maandishi Yaliyonukuliwa (Transcribed Text)", interactive=False) tool_s2t_btn = gr.Button("Nukuu (Transcribe)") with gr.Column(): gr.Markdown("### Utengenezaji wa Sauti (Speech Synthesis)") tool_t2s_text_in = gr.Textbox(label="Maandishi ya Kuingiza (Input Text)", placeholder="Andika Kiswahili hapa...") tool_t2s_audio_out = gr.Audio(type="filepath", label="Sauti Iliyotengenezwa (Synthesized Audio)", autoplay=False) tool_t2s_btn = gr.Button("Tengeneza Sauti (Synthesize)") s2s_submit_btn.click( fn=s2s_pipeline, inputs=[s2s_audio_in, s2s_chatbot], outputs=[s2s_chatbot, s2s_audio_out, s2s_text_out], queue=True ) t2t_submit_btn.click( fn=t2t_pipeline, inputs=[t2t_text_in, t2t_chatbot], outputs=[t2t_chatbot, t2t_text_in], queue=True ).then( fn=clear_textbox, inputs=None, outputs=t2t_text_in ) tool_s2t_btn.click( fn=assistant.transcribe_audio, inputs=tool_s2t_audio_in, outputs=tool_s2t_text_out ) tool_t2s_btn.click( fn=assistant.generate_speech, inputs=tool_t2s_text_in, outputs=tool_t2s_audio_out ) demo.queue().launch(debug=True)