SALAMA / app.py
EYEDOL's picture
Create app.py
a248e18 verified
raw
history blame
12.7 kB
# -*- coding: utf-8 -*-
"""
This script implements a multi-modal Swahili assistant for Hugging Face Spaces.
It uses Gradio for the user interface and loads models from the HF Hub.
"""
import gradio as gr
import numpy as np
import onnxruntime
import torch
import librosa
from transformers import AutoProcessor, AutoModelForSpeechSeq2Seq, AutoTokenizer, pipeline
from scipy.io.wavfile import write as write_wav
import os
import re
# --- Configuration ---
# IMPORTANT: Replace these with your actual model IDs on the Hugging Face Hub.
# You must upload your fine-tuned ASR model to the Hub.
STT_MODEL_ID = "YOUR_USERNAME/YOUR_ASR_MODEL_ID" # e.g., "MickyMike/SALAMA_B3_ASR"
# You can use any powerful multilingual model that supports Swahili.
LLM_MODEL_ID = "google/gemma-2-9b-it"
# This is the tokenizer for your ONNX TTS model.
TTS_TOKENIZER_ID = "facebook/mms-tts-swh"
TTS_ONNX_MODEL_PATH = "swahili_tts.onnx" # Make sure this file is in your Space repo
# Ensure the temporary directory for audio files exists
TEMP_DIR = "temp"
os.makedirs(TEMP_DIR, exist_ok=True)
class WeeboAssistant:
def __init__(self):
# Audio settings
self.STT_SAMPLE_RATE = 16000
self.TTS_SAMPLE_RATE = 16000
# System prompt for the LLM
self.SYSTEM_PROMPT = "Wewe ni msaidizi mwenye akili, jibu swali lililoulizwa kwa UFUPI na kwa usahihi. Jibu kwa lugha ya Kiswahili pekee. Hakuna jibu refu."
self._init_models()
def _init_models(self):
"""Initializes all models required for the pipeline."""
print("Initializing models...")
self.device = "cuda" if torch.cuda.is_available() else "cpu"
self.torch_dtype = torch.bfloat16 if self.device == "cuda" else torch.float32
print(f"Using device: {self.device}")
# --- 1. Initialize Swahili Speech-to-Text (STT/ASR) ---
print(f"Loading STT model: {STT_MODEL_ID}")
try:
self.stt_processor = AutoProcessor.from_pretrained(STT_MODEL_ID)
self.stt_model = AutoModelForSpeechSeq2Seq.from_pretrained(
STT_MODEL_ID,
torch_dtype=self.torch_dtype,
low_cpu_mem_usage=True,
use_safetensors=True
)
self.stt_model.to(self.device)
print("STT model loaded successfully.")
except Exception as e:
print(f"FATAL: Could not load STT model. Please check the model ID and ensure you have access. Error: {e}")
# In a real app, you might want to handle this more gracefully
raise
# --- 2. Initialize Language Model (LLM) ---
print(f"Loading LLM: {LLM_MODEL_ID}")
try:
# We don't need a separate tokenizer for the pipeline
self.llm_pipeline = pipeline(
"text-generation",
model=LLM_MODEL_ID,
model_kwargs={"torch_dtype": self.torch_dtype},
device=self.device,
)
print("LLM pipeline loaded successfully.")
except Exception as e:
print(f"FATAL: Could not load LLM. Error: {e}")
raise
# --- 3. Initialize Swahili Text-to-Speech (TTS) ---
print(f"Loading TTS model: {TTS_ONNX_MODEL_PATH}")
try:
# The ONNX model should be in the same repository as app.py
self.tts_session = onnxruntime.InferenceSession(
TTS_ONNX_MODEL_PATH,
providers=["CUDAExecutionProvider", "CPUExecutionProvider"]
)
self.tts_tokenizer = AutoTokenizer.from_pretrained(TTS_TOKENIZER_ID)
print("TTS model and tokenizer loaded successfully.")
except Exception as e:
print(f"FATAL: Could not load TTS model. Make sure '{TTS_ONNX_MODEL_PATH}' is in the repository. Error: {e}")
raise
print("-" * 30)
print("All models initialized successfully! โœ…")
def transcribe_audio(self, audio_tuple: tuple) -> str:
"""
Transcribes audio from Gradio's audio component.
The input is a tuple (sample_rate, numpy_array).
"""
if audio_tuple is None:
return ""
sample_rate, audio_data = audio_tuple
# Convert to mono float32
if audio_data.ndim > 1:
audio_data = audio_data.mean(axis=1)
if audio_data.dtype != np.float32:
audio_data = audio_data.astype(np.float32) / np.iinfo(audio_data.dtype).max
# Resample if necessary
if sample_rate != self.STT_SAMPLE_RATE:
audio_data = librosa.resample(y=audio_data, orig_sr=sample_rate, target_sr=self.STT_SAMPLE_RATE)
if len(audio_data) < 1000: # Ignore very short audio clips
return "(Audio too short to transcribe)"
# Process and transcribe
inputs = self.stt_processor(audio_data, sampling_rate=self.STT_SAMPLE_RATE, return_tensors="pt")
inputs = {key: val.to(self.device) for key, val in inputs.items()}
with torch.no_grad():
generated_ids = self.stt_model.generate(**inputs, max_new_tokens=128)
transcription = self.stt_processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
return transcription.strip()
def generate_speech(self, text: str) -> str:
"""
Generates audio from text and saves it to a temporary file.
Returns the path to the audio file.
"""
if not text:
return None
# Clean text
text = text.strip()
try:
inputs = self.tts_tokenizer(text, return_tensors="np")
input_ids = inputs.input_ids
ort_inputs = {self.tts_session.get_inputs()[0].name: input_ids}
audio_waveform = self.tts_session.run(None, ort_inputs)[0].flatten()
# Save to a temporary WAV file
output_path = os.path.join(TEMP_DIR, f"{os.urandom(8).hex()}.wav")
write_wav(output_path, self.TTS_SAMPLE_RATE, audio_waveform)
return output_path
except Exception as e:
print(f"Error during audio generation: {e}")
return None
def get_llm_response(self, chat_history: list):
"""
Gets a streaming response from the LLM.
Yields the updated full response at each step.
"""
# Format messages for the pipeline
# The Gemma-2 instruction-tuned model uses a specific turn-based format
messages = [{'role': 'system', 'content': self.SYSTEM_PROMPT}]
for turn in chat_history:
messages.append({'role': 'user', 'content': turn[0]})
if turn[1] is not None:
messages.append({'role': 'assistant', 'content': turn[1]})
prompt = self.llm_pipeline.tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True
)
terminators = [
self.llm_pipeline.tokenizer.eos_token_id,
self.llm_pipeline.tokenizer.convert_tokens_to_ids("<|eot_id|>")
]
streamer = self.llm_pipeline(
prompt,
max_new_tokens=512,
eos_token_id=terminators,
do_sample=True,
temperature=0.6,
top_p=0.9,
streamer=gr.TextIterator(),
)
return streamer
# --- Gradio Interface Logic ---
# Instantiate the assistant
assistant = WeeboAssistant()
def s2s_pipeline(audio_input, chat_history):
"""The main function for the Speech-to-Speech tab."""
# 1. Transcribe user's speech
user_text = assistant.transcribe_audio(audio_input)
if not user_text or user_text.startswith("("):
chat_history.append((user_text or "(No valid speech detected)", None))
yield chat_history, None, "Please record your voice again."
return
chat_history.append((user_text, None))
yield chat_history, None, "..." # Show user text and a thinking indicator
# 2. Get LLM response as a stream
response_stream = assistant.get_llm_response(chat_history)
# Stream the response text to the UI
llm_response_text = ""
for text_chunk in response_stream:
llm_response_text = text_chunk
chat_history[-1] = (user_text, llm_response_text)
yield chat_history, None, llm_response_text
# 3. Synthesize the final LLM response to speech
final_audio_path = assistant.generate_speech(llm_response_text)
# 4. Final update to the UI
yield chat_history, final_audio_path, llm_response_text
def t2t_pipeline(text_input, chat_history):
"""The main function for the Text-to-Text tab."""
chat_history.append((text_input, None))
yield chat_history, "..."
response_stream = assistant.get_llm_response(chat_history)
llm_response_text = ""
for text_chunk in response_stream:
llm_response_text = text_chunk
chat_history[-1] = (text_input, llm_response_text)
yield chat_history, llm_response_text
# --- Build Gradio UI ---
with gr.Blocks(theme=gr.themes.Soft(), title="Msaidizi wa Kiswahili") as demo:
gr.Markdown("# ๐Ÿค– Msaidizi wa Sauti wa Kiswahili (Swahili Voice Assistant)")
gr.Markdown("Ongea na msaidizi kwa Kiswahili. Toa sauti, andika maandishi, na upate majibu kwa sauti au maandishi.")
with gr.Tabs():
# Tab 1: Speech-to-Speech
with gr.TabItem("๐ŸŽ™๏ธ Sauti-kwa-Sauti (Speech-to-Speech)"):
with gr.Row():
with gr.Column(scale=2):
s2s_audio_in = gr.Audio(sources=["microphone"], type="numpy", label="Ongea Hapa (Speak Here)")
s2s_submit_btn = gr.Button("Tuma (Submit)", variant="primary")
with gr.Column(scale=3):
s2s_chatbot = gr.Chatbot(label="Mazungumzo (Conversation)", bubble_full_width=False, height=400)
s2s_audio_out = gr.Audio(type="filepath", label="Jibu la Sauti (Audio Response)", autoplay=True)
s2s_text_out = gr.Textbox(label="Jibu la Maandishi (Text Response)", interactive=False)
# Tab 2: Text-to-Text
with gr.TabItem("โŒจ๏ธ Maandishi-kwa-Maandishi (Text-to-Text)"):
t2t_chatbot = gr.Chatbot(label="Mazungumzo (Conversation)", bubble_full_width=False, height=500)
with gr.Row():
t2t_text_in = gr.Textbox(label="Andika Hapa (Write Here)", placeholder="Habari yako...", scale=4)
t2t_submit_btn = gr.Button("Tuma (Submit)", variant="primary", scale=1)
# Tab 3: Direct Tools
with gr.TabItem("๐Ÿ› ๏ธ Zana (Tools)"):
with gr.Row():
# Speech to Text Tool
with gr.Column():
gr.Markdown("### Unukuzi wa Sauti (Speech Transcription)")
tool_s2t_audio_in = gr.Audio(sources=["microphone"], type="numpy", label="Sauti ya Kuingiza (Input Audio)")
tool_s2t_text_out = gr.Textbox(label="Maandishi Yaliyonukuliwa (Transcribed Text)", interactive=False)
tool_s2t_btn = gr.Button("Nukuu (Transcribe)")
# Text to Speech Tool
with gr.Column():
gr.Markdown("### Utengenezaji wa Sauti (Speech Synthesis)")
tool_t2s_text_in = gr.Textbox(label="Maandishi ya Kuingiza (Input Text)", placeholder="Andika Kiswahili hapa...")
tool_t2s_audio_out = gr.Audio(type="filepath", label="Sauti Iliyotengenezwa (Synthesized Audio)", autoplay=False)
tool_t2s_btn = gr.Button("Tengeneza Sauti (Synthesize)")
# --- Event Handlers ---
# Speech-to-Speech handler
s2s_submit_btn.click(
fn=s2s_pipeline,
inputs=[s2s_audio_in, s2s_chatbot],
outputs=[s2s_chatbot, s2s_audio_out, s2s_text_out],
queue=True
)
# Text-to-Text handler
t2t_submit_btn.click(
fn=t2t_pipeline,
inputs=[t2t_text_in, t2t_chatbot],
outputs=[t2t_chatbot, t2t_text_in.change(value="")], # Clear input box on submit
queue=True
).then(
lambda x: x, t2t_chatbot, t2t_text_in
) # The text response is streamed directly to the chatbot UI
# Tool handlers
tool_s2t_btn.click(
fn=assistant.transcribe_audio,
inputs=tool_s2t_audio_in,
outputs=tool_s2t_text_out
)
tool_t2s_btn.click(
fn=assistant.generate_speech,
inputs=tool_t2s_text_in,
outputs=tool_t2s_audio_out
)
# Launch the Gradio app
demo.queue().launch(debug=True)