File size: 12,675 Bytes
a248e18 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 |
# -*- coding: utf-8 -*-
"""
This script implements a multi-modal Swahili assistant for Hugging Face Spaces.
It uses Gradio for the user interface and loads models from the HF Hub.
"""
import gradio as gr
import numpy as np
import onnxruntime
import torch
import librosa
from transformers import AutoProcessor, AutoModelForSpeechSeq2Seq, AutoTokenizer, pipeline
from scipy.io.wavfile import write as write_wav
import os
import re
# --- Configuration ---
# IMPORTANT: Replace these with your actual model IDs on the Hugging Face Hub.
# You must upload your fine-tuned ASR model to the Hub.
STT_MODEL_ID = "YOUR_USERNAME/YOUR_ASR_MODEL_ID" # e.g., "MickyMike/SALAMA_B3_ASR"
# You can use any powerful multilingual model that supports Swahili.
LLM_MODEL_ID = "google/gemma-2-9b-it"
# This is the tokenizer for your ONNX TTS model.
TTS_TOKENIZER_ID = "facebook/mms-tts-swh"
TTS_ONNX_MODEL_PATH = "swahili_tts.onnx" # Make sure this file is in your Space repo
# Ensure the temporary directory for audio files exists
TEMP_DIR = "temp"
os.makedirs(TEMP_DIR, exist_ok=True)
class WeeboAssistant:
def __init__(self):
# Audio settings
self.STT_SAMPLE_RATE = 16000
self.TTS_SAMPLE_RATE = 16000
# System prompt for the LLM
self.SYSTEM_PROMPT = "Wewe ni msaidizi mwenye akili, jibu swali lililoulizwa kwa UFUPI na kwa usahihi. Jibu kwa lugha ya Kiswahili pekee. Hakuna jibu refu."
self._init_models()
def _init_models(self):
"""Initializes all models required for the pipeline."""
print("Initializing models...")
self.device = "cuda" if torch.cuda.is_available() else "cpu"
self.torch_dtype = torch.bfloat16 if self.device == "cuda" else torch.float32
print(f"Using device: {self.device}")
# --- 1. Initialize Swahili Speech-to-Text (STT/ASR) ---
print(f"Loading STT model: {STT_MODEL_ID}")
try:
self.stt_processor = AutoProcessor.from_pretrained(STT_MODEL_ID)
self.stt_model = AutoModelForSpeechSeq2Seq.from_pretrained(
STT_MODEL_ID,
torch_dtype=self.torch_dtype,
low_cpu_mem_usage=True,
use_safetensors=True
)
self.stt_model.to(self.device)
print("STT model loaded successfully.")
except Exception as e:
print(f"FATAL: Could not load STT model. Please check the model ID and ensure you have access. Error: {e}")
# In a real app, you might want to handle this more gracefully
raise
# --- 2. Initialize Language Model (LLM) ---
print(f"Loading LLM: {LLM_MODEL_ID}")
try:
# We don't need a separate tokenizer for the pipeline
self.llm_pipeline = pipeline(
"text-generation",
model=LLM_MODEL_ID,
model_kwargs={"torch_dtype": self.torch_dtype},
device=self.device,
)
print("LLM pipeline loaded successfully.")
except Exception as e:
print(f"FATAL: Could not load LLM. Error: {e}")
raise
# --- 3. Initialize Swahili Text-to-Speech (TTS) ---
print(f"Loading TTS model: {TTS_ONNX_MODEL_PATH}")
try:
# The ONNX model should be in the same repository as app.py
self.tts_session = onnxruntime.InferenceSession(
TTS_ONNX_MODEL_PATH,
providers=["CUDAExecutionProvider", "CPUExecutionProvider"]
)
self.tts_tokenizer = AutoTokenizer.from_pretrained(TTS_TOKENIZER_ID)
print("TTS model and tokenizer loaded successfully.")
except Exception as e:
print(f"FATAL: Could not load TTS model. Make sure '{TTS_ONNX_MODEL_PATH}' is in the repository. Error: {e}")
raise
print("-" * 30)
print("All models initialized successfully! โ
")
def transcribe_audio(self, audio_tuple: tuple) -> str:
"""
Transcribes audio from Gradio's audio component.
The input is a tuple (sample_rate, numpy_array).
"""
if audio_tuple is None:
return ""
sample_rate, audio_data = audio_tuple
# Convert to mono float32
if audio_data.ndim > 1:
audio_data = audio_data.mean(axis=1)
if audio_data.dtype != np.float32:
audio_data = audio_data.astype(np.float32) / np.iinfo(audio_data.dtype).max
# Resample if necessary
if sample_rate != self.STT_SAMPLE_RATE:
audio_data = librosa.resample(y=audio_data, orig_sr=sample_rate, target_sr=self.STT_SAMPLE_RATE)
if len(audio_data) < 1000: # Ignore very short audio clips
return "(Audio too short to transcribe)"
# Process and transcribe
inputs = self.stt_processor(audio_data, sampling_rate=self.STT_SAMPLE_RATE, return_tensors="pt")
inputs = {key: val.to(self.device) for key, val in inputs.items()}
with torch.no_grad():
generated_ids = self.stt_model.generate(**inputs, max_new_tokens=128)
transcription = self.stt_processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
return transcription.strip()
def generate_speech(self, text: str) -> str:
"""
Generates audio from text and saves it to a temporary file.
Returns the path to the audio file.
"""
if not text:
return None
# Clean text
text = text.strip()
try:
inputs = self.tts_tokenizer(text, return_tensors="np")
input_ids = inputs.input_ids
ort_inputs = {self.tts_session.get_inputs()[0].name: input_ids}
audio_waveform = self.tts_session.run(None, ort_inputs)[0].flatten()
# Save to a temporary WAV file
output_path = os.path.join(TEMP_DIR, f"{os.urandom(8).hex()}.wav")
write_wav(output_path, self.TTS_SAMPLE_RATE, audio_waveform)
return output_path
except Exception as e:
print(f"Error during audio generation: {e}")
return None
def get_llm_response(self, chat_history: list):
"""
Gets a streaming response from the LLM.
Yields the updated full response at each step.
"""
# Format messages for the pipeline
# The Gemma-2 instruction-tuned model uses a specific turn-based format
messages = [{'role': 'system', 'content': self.SYSTEM_PROMPT}]
for turn in chat_history:
messages.append({'role': 'user', 'content': turn[0]})
if turn[1] is not None:
messages.append({'role': 'assistant', 'content': turn[1]})
prompt = self.llm_pipeline.tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True
)
terminators = [
self.llm_pipeline.tokenizer.eos_token_id,
self.llm_pipeline.tokenizer.convert_tokens_to_ids("<|eot_id|>")
]
streamer = self.llm_pipeline(
prompt,
max_new_tokens=512,
eos_token_id=terminators,
do_sample=True,
temperature=0.6,
top_p=0.9,
streamer=gr.TextIterator(),
)
return streamer
# --- Gradio Interface Logic ---
# Instantiate the assistant
assistant = WeeboAssistant()
def s2s_pipeline(audio_input, chat_history):
"""The main function for the Speech-to-Speech tab."""
# 1. Transcribe user's speech
user_text = assistant.transcribe_audio(audio_input)
if not user_text or user_text.startswith("("):
chat_history.append((user_text or "(No valid speech detected)", None))
yield chat_history, None, "Please record your voice again."
return
chat_history.append((user_text, None))
yield chat_history, None, "..." # Show user text and a thinking indicator
# 2. Get LLM response as a stream
response_stream = assistant.get_llm_response(chat_history)
# Stream the response text to the UI
llm_response_text = ""
for text_chunk in response_stream:
llm_response_text = text_chunk
chat_history[-1] = (user_text, llm_response_text)
yield chat_history, None, llm_response_text
# 3. Synthesize the final LLM response to speech
final_audio_path = assistant.generate_speech(llm_response_text)
# 4. Final update to the UI
yield chat_history, final_audio_path, llm_response_text
def t2t_pipeline(text_input, chat_history):
"""The main function for the Text-to-Text tab."""
chat_history.append((text_input, None))
yield chat_history, "..."
response_stream = assistant.get_llm_response(chat_history)
llm_response_text = ""
for text_chunk in response_stream:
llm_response_text = text_chunk
chat_history[-1] = (text_input, llm_response_text)
yield chat_history, llm_response_text
# --- Build Gradio UI ---
with gr.Blocks(theme=gr.themes.Soft(), title="Msaidizi wa Kiswahili") as demo:
gr.Markdown("# ๐ค Msaidizi wa Sauti wa Kiswahili (Swahili Voice Assistant)")
gr.Markdown("Ongea na msaidizi kwa Kiswahili. Toa sauti, andika maandishi, na upate majibu kwa sauti au maandishi.")
with gr.Tabs():
# Tab 1: Speech-to-Speech
with gr.TabItem("๐๏ธ Sauti-kwa-Sauti (Speech-to-Speech)"):
with gr.Row():
with gr.Column(scale=2):
s2s_audio_in = gr.Audio(sources=["microphone"], type="numpy", label="Ongea Hapa (Speak Here)")
s2s_submit_btn = gr.Button("Tuma (Submit)", variant="primary")
with gr.Column(scale=3):
s2s_chatbot = gr.Chatbot(label="Mazungumzo (Conversation)", bubble_full_width=False, height=400)
s2s_audio_out = gr.Audio(type="filepath", label="Jibu la Sauti (Audio Response)", autoplay=True)
s2s_text_out = gr.Textbox(label="Jibu la Maandishi (Text Response)", interactive=False)
# Tab 2: Text-to-Text
with gr.TabItem("โจ๏ธ Maandishi-kwa-Maandishi (Text-to-Text)"):
t2t_chatbot = gr.Chatbot(label="Mazungumzo (Conversation)", bubble_full_width=False, height=500)
with gr.Row():
t2t_text_in = gr.Textbox(label="Andika Hapa (Write Here)", placeholder="Habari yako...", scale=4)
t2t_submit_btn = gr.Button("Tuma (Submit)", variant="primary", scale=1)
# Tab 3: Direct Tools
with gr.TabItem("๐ ๏ธ Zana (Tools)"):
with gr.Row():
# Speech to Text Tool
with gr.Column():
gr.Markdown("### Unukuzi wa Sauti (Speech Transcription)")
tool_s2t_audio_in = gr.Audio(sources=["microphone"], type="numpy", label="Sauti ya Kuingiza (Input Audio)")
tool_s2t_text_out = gr.Textbox(label="Maandishi Yaliyonukuliwa (Transcribed Text)", interactive=False)
tool_s2t_btn = gr.Button("Nukuu (Transcribe)")
# Text to Speech Tool
with gr.Column():
gr.Markdown("### Utengenezaji wa Sauti (Speech Synthesis)")
tool_t2s_text_in = gr.Textbox(label="Maandishi ya Kuingiza (Input Text)", placeholder="Andika Kiswahili hapa...")
tool_t2s_audio_out = gr.Audio(type="filepath", label="Sauti Iliyotengenezwa (Synthesized Audio)", autoplay=False)
tool_t2s_btn = gr.Button("Tengeneza Sauti (Synthesize)")
# --- Event Handlers ---
# Speech-to-Speech handler
s2s_submit_btn.click(
fn=s2s_pipeline,
inputs=[s2s_audio_in, s2s_chatbot],
outputs=[s2s_chatbot, s2s_audio_out, s2s_text_out],
queue=True
)
# Text-to-Text handler
t2t_submit_btn.click(
fn=t2t_pipeline,
inputs=[t2t_text_in, t2t_chatbot],
outputs=[t2t_chatbot, t2t_text_in.change(value="")], # Clear input box on submit
queue=True
).then(
lambda x: x, t2t_chatbot, t2t_text_in
) # The text response is streamed directly to the chatbot UI
# Tool handlers
tool_s2t_btn.click(
fn=assistant.transcribe_audio,
inputs=tool_s2t_audio_in,
outputs=tool_s2t_text_out
)
tool_t2s_btn.click(
fn=assistant.generate_speech,
inputs=tool_t2s_text_in,
outputs=tool_t2s_audio_out
)
# Launch the Gradio app
demo.queue().launch(debug=True) |