SALAMA / app.py
EYEDOL's picture
Update app.py
9dc2bc4 verified
raw
history blame
9.6 kB
# -*- coding: utf-8 -*-
"""
This script implements a multi-modal Swahili assistant for Hugging Face Spaces.
It uses Gradio for the user interface and loads models from the HF Hub.
"""
import gradio as gr
import numpy as np
import onnxruntime
import torch
import librosa
from transformers import AutoProcessor, AutoModelForSpeechSeq2Seq, AutoTokenizer, pipeline, TextIteratorStreamer
from scipy.io.wavfile import write as write_wav
import os
import re
from huggingface_hub import login
import threading
# --- Login to Hugging Face using secret ---
# Make sure HF_TOKEN is set in your Hugging Face Space > Settings > Repository secrets
hf_token = os.environ.get("hugface")
if not hf_token:
raise ValueError("HF_TOKEN not found. Please set it in Hugging Face Space repository secrets.")
login(token=hf_token)
print("Successfully logged into Hugging Face Hub!")
# --- Configuration ---
STT_MODEL_ID = "EYEDOL/SALAMA_C3"
LLM_MODEL_ID = "meta-llama/Llama-3.2-1B-Instruct" # <-- FIX: Switched to Llama-3.2
TTS_TOKENIZER_ID = "facebook/mms-tts-swh"
TTS_ONNX_MODEL_PATH = "swahili_tts.onnx"
TEMP_DIR = "temp"
os.makedirs(TEMP_DIR, exist_ok=True)
class WeeboAssistant:
def __init__(self):
self.STT_SAMPLE_RATE = 16000
self.TTS_SAMPLE_RATE = 16000
self.SYSTEM_PROMPT = (
"Wewe ni msaidizi mwenye akili, jibu swali lililoulizwa kwa UFUPI na kwa usahihi. "
"Jibu kwa lugha ya Kiswahili pekee. Hakuna jibu refu."
)
self._init_models()
def _init_models(self):
print("Initializing models...")
self.device = "cuda" if torch.cuda.is_available() else "cpu"
self.torch_dtype = torch.bfloat16 if self.device == "cuda" else torch.float32
print(f"Using device: {self.device}")
# STT
print(f"Loading STT model: {STT_MODEL_ID}")
self.stt_processor = AutoProcessor.from_pretrained(STT_MODEL_ID)
self.stt_model = AutoModelForSpeechSeq2Seq.from_pretrained(
STT_MODEL_ID,
torch_dtype=self.torch_dtype,
low_cpu_mem_usage=True,
use_safetensors=True
).to(self.device)
print("STT model loaded successfully.")
# LLM
print(f"Loading LLM: {LLM_MODEL_ID}")
self.llm_tokenizer = AutoTokenizer.from_pretrained(LLM_MODEL_ID)
self.llm_pipeline = pipeline(
"text-generation",
model=LLM_MODEL_ID,
model_kwargs={"torch_dtype": self.torch_dtype},
tokenizer=self.llm_tokenizer,
device=self.device,
)
print("LLM pipeline loaded successfully.")
# TTS
print(f"Loading TTS model: {TTS_ONNX_MODEL_PATH}")
self.tts_session = onnxruntime.InferenceSession(
TTS_ONNX_MODEL_PATH,
providers=["CUDAExecutionProvider", "CPUExecutionProvider"]
)
self.tts_tokenizer = AutoTokenizer.from_pretrained(TTS_TOKENIZER_ID)
print("TTS model and tokenizer loaded successfully.")
print("-" * 30)
print("All models initialized successfully! โœ…")
def transcribe_audio(self, audio_tuple):
if audio_tuple is None:
return ""
sample_rate, audio_data = audio_tuple
if audio_data.ndim > 1:
audio_data = audio_data.mean(axis=1)
if audio_data.dtype != np.float32:
audio_data = audio_data.astype(np.float32) / np.iinfo(audio_data.dtype).max
if sample_rate != self.STT_SAMPLE_RATE:
audio_data = librosa.resample(y=audio_data, orig_sr=sample_rate, target_sr=self.STT_SAMPLE_RATE)
if len(audio_data) < 1000:
return "(Audio too short to transcribe)"
inputs = self.stt_processor(audio_data, sampling_rate=self.STT_SAMPLE_RATE, return_tensors="pt")
inputs = {k: v.to(self.device) for k, v in inputs.items()}
with torch.no_grad():
generated_ids = self.stt_model.generate(**inputs, max_new_tokens=128)
transcription = self.stt_processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
return transcription.strip()
def generate_speech(self, text):
if not text:
return None
text = text.strip()
inputs = self.tts_tokenizer(text, return_tensors="np")
ort_inputs = {self.tts_session.get_inputs()[0].name: inputs.input_ids}
audio_waveform = self.tts_session.run(None, ort_inputs)[0].flatten()
output_path = os.path.join(TEMP_DIR, f"{os.urandom(8).hex()}.wav")
write_wav(output_path, self.TTS_SAMPLE_RATE, audio_waveform)
return output_path
def get_llm_response(self, chat_history):
# <-- FIX: Reverted to using a 'system' role, which is correct for Llama 3 -->
messages = [{'role': 'system', 'content': self.SYSTEM_PROMPT}]
for user_msg, assistant_msg in chat_history:
messages.append({"role": "user", "content": user_msg})
if assistant_msg:
messages.append({"role": "assistant", "content": assistant_msg})
prompt = self.llm_pipeline.tokenizer.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)
terminators = [
self.llm_pipeline.tokenizer.eos_token_id,
self.llm_pipeline.tokenizer.convert_tokens_to_ids("<|eot_id|>")
]
streamer = TextIteratorStreamer(
self.llm_pipeline.tokenizer, skip_prompt=True, skip_special_tokens=True
)
generation_kwargs = dict(
streamer=streamer,
max_new_tokens=512,
eos_token_id=terminators,
do_sample=True,
temperature=0.6,
top_p=0.9,
)
thread = threading.Thread(target=self.llm_pipeline, args=[prompt], kwargs=generation_kwargs)
thread.start()
return streamer
assistant = WeeboAssistant()
def s2s_pipeline(audio_input, chat_history):
user_text = assistant.transcribe_audio(audio_input)
if not user_text or user_text.startswith("("):
chat_history.append((user_text or "(No valid speech detected)", None))
yield chat_history, None, "Please record your voice again."
return
chat_history.append((user_text, ""))
yield chat_history, None, "..."
response_stream = assistant.get_llm_response(chat_history)
llm_response_text = ""
for text_chunk in response_stream:
llm_response_text += text_chunk
chat_history[-1] = (user_text, llm_response_text)
yield chat_history, None, llm_response_text
final_audio_path = assistant.generate_speech(llm_response_text)
yield chat_history, final_audio_path, llm_response_text
def t2t_pipeline(text_input, chat_history):
chat_history.append((text_input, ""))
yield chat_history
response_stream = assistant.get_llm_response(chat_history)
llm_response_text = ""
for text_chunk in response_stream:
llm_response_text += text_chunk
chat_history[-1] = (text_input, llm_response_text)
yield chat_history
def clear_textbox():
return gr.Textbox(value="")
with gr.Blocks(theme=gr.themes.Soft(), title="Msaidizi wa Kiswahili") as demo:
gr.Markdown("# ๐Ÿค– Msaidizi wa Sauti wa Kiswahili (Swahili Voice Assistant)")
gr.Markdown("Ongea na msaidizi kwa Kiswahili. Toa sauti, andika maandishi, na upate majibu kwa sauti au maandishi.")
with gr.Tabs():
with gr.TabItem("๐ŸŽ™๏ธ Sauti-kwa-Sauti (Speech-to-Speech)"):
with gr.Row():
with gr.Column(scale=2):
s2s_audio_in = gr.Audio(sources=["microphone"], type="numpy", label="Ongea Hapa (Speak Here)")
s2s_submit_btn = gr.Button("Tuma (Submit)", variant="primary")
with gr.Column(scale=3):
s2s_chatbot = gr.Chatbot(label="Mazungumzo (Conversation)", bubble_full_width=False, height=400)
s2s_audio_out = gr.Audio(type="filepath", label="Jibu la Sauti (Audio Response)", autoplay=True)
s2s_text_out = gr.Textbox(label="Jibu la Maandishi (Text Response)", interactive=False)
with gr.TabItem("โŒจ๏ธ Maandishi-kwa-Maandishi (Text-to-Text)"):
t2t_chatbot = gr.Chatbot(label="Mazungumzo (Conversation)", bubble_full_width=False, height=500)
with gr.Row():
t2t_text_in = gr.Textbox(show_label=False, placeholder="Habari yako...", scale=4, container=False)
t2t_submit_btn = gr.Button("Tuma (Submit)", variant="primary", scale=1)
with gr.TabItem("๐Ÿ› ๏ธ Zana (Tools)"):
with gr.Row():
with gr.Column():
gr.Markdown("### Unukuzi wa Sauti (Speech Transcription)")
tool_s2t_audio_in = gr.Audio(sources=["microphone", "upload"], type="numpy", label="Sauti ya Kuingiza (Input Audio)")
tool_s2t_text_out = gr.Textbox(label="Maandishi Yaliyonukuliwa (Transcribed Text)", interactive=False)
tool_s2t_btn = gr.Button("Nukuu (Transcribe)")
with gr.Column():
gr.Markdown("### Utengenezaji wa Sauti (Speech Synthesis)")
tool_t2s_text_in = gr.Textbox(label="Maandishi ya Kuingiza (Input Text)", placeholder="Andika Kiswahili hapa...")
tool_t2s_audio_out = gr.Audio(type="filepath", label="Sauti Iliyotengenezwa (Synthesized Audio)", autoplay=False)
tool_t2s_btn = gr.Button("Tengeneza Sauti