SALAMA / app.py
EYEDOL's picture
Update app.py
6fefd54 verified
raw
history blame
11.1 kB
# -*- coding: utf-8 -*-
"""
This script implements a multi-modal Swahili assistant for Hugging Face Spaces.
It uses Gradio for the user interface and loads models from the HF Hub.
"""
import gradio as gr
import numpy as np
import onnxruntime
import torch
import librosa
from transformers import AutoProcessor, AutoModelForSpeechSeq2Seq, AutoTokenizer, pipeline, TextIteratorStreamer
from scipy.io.wavfile import write as write_wav
import os
import re
from huggingface_hub import login
import threading # <-- FIX: Added threading import
# --- Login to Hugging Face using secret ---
# Make sure HF_TOKEN is set in your Hugging Face Space > Settings > Repository secrets
hf_token = os.environ.get("hugface") #
if not hf_token:
raise ValueError("HF_TOKEN not found. Please set it in Hugging Face Space repository secrets.")
login(token=hf_token)
print("Successfully logged into Hugging Face Hub!")
# --- Configuration ---
STT_MODEL_ID = "EYEDOL/SALAMA_C3"
LLM_MODEL_ID = "google/gemma-1.1-2b-it"
TTS_TOKENIZER_ID = "facebook/mms-tts-swh"
TTS_ONNX_MODEL_PATH = "swahili_tts.onnx"
TEMP_DIR = "temp"
os.makedirs(TEMP_DIR, exist_ok=True)
class WeeboAssistant:
def __init__(self):
self.STT_SAMPLE_RATE = 16000
self.TTS_SAMPLE_RATE = 16000
self.SYSTEM_PROMPT = (
"Wewe ni msaidizi mwenye akili, jibu swali lililoulizwa kwa UFUPI na kwa usahihi. "
"Jibu kwa lugha ya Kiswahili pekee. Hakuna jibu refu."
)
self._init_models()
def _init_models(self):
print("Initializing models...")
self.device = "cuda" if torch.cuda.is_available() else "cpu"
self.torch_dtype = torch.bfloat16 if self.device == "cuda" else torch.float32
print(f"Using device: {self.device}")
# STT
print(f"Loading STT model: {STT_MODEL_ID}")
self.stt_processor = AutoProcessor.from_pretrained(STT_MODEL_ID)
self.stt_model = AutoModelForSpeechSeq2Seq.from_pretrained(
STT_MODEL_ID,
torch_dtype=self.torch_dtype,
low_cpu_mem_usage=True,
use_safetensors=True
).to(self.device)
print("STT model loaded successfully.")
# LLM
print(f"Loading LLM: {LLM_MODEL_ID}")
# <-- FIX: Initialize tokenizer separately to use it with the streamer
self.llm_tokenizer = AutoTokenizer.from_pretrained(LLM_MODEL_ID)
self.llm_pipeline = pipeline(
"text-generation",
model=LLM_MODEL_ID,
model_kwargs={"torch_dtype": self.torch_dtype},
tokenizer=self.llm_tokenizer, # Pass the tokenizer here
device=self.device,
)
print("LLM pipeline loaded successfully.")
# TTS
print(f"Loading TTS model: {TTS_ONNX_MODEL_PATH}")
self.tts_session = onnxruntime.InferenceSession(
TTS_ONNX_MODEL_PATH,
providers=["CUDAExecutionProvider", "CPUExecutionProvider"]
)
self.tts_tokenizer = AutoTokenizer.from_pretrained(TTS_TOKENIZER_ID)
print("TTS model and tokenizer loaded successfully.")
print("-" * 30)
print("All models initialized successfully! โœ…")
def transcribe_audio(self, audio_tuple):
if audio_tuple is None:
return ""
sample_rate, audio_data = audio_tuple
if audio_data.ndim > 1:
audio_data = audio_data.mean(axis=1)
if audio_data.dtype != np.float32:
audio_data = audio_data.astype(np.float32) / np.iinfo(audio_data.dtype).max
if sample_rate != self.STT_SAMPLE_RATE:
audio_data = librosa.resample(y=audio_data, orig_sr=sample_rate, target_sr=self.STT_SAMPLE_RATE)
if len(audio_data) < 1000:
return "(Audio too short to transcribe)"
inputs = self.stt_processor(audio_data, sampling_rate=self.STT_SAMPLE_RATE, return_tensors="pt")
inputs = {k: v.to(self.device) for k, v in inputs.items()}
with torch.no_grad():
generated_ids = self.stt_model.generate(**inputs, max_new_tokens=128)
transcription = self.stt_processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
return transcription.strip()
def generate_speech(self, text):
if not text:
return None
text = text.strip()
inputs = self.tts_tokenizer(text, return_tensors="np")
ort_inputs = {self.tts_session.get_inputs()[0].name: inputs.input_ids}
audio_waveform = self.tts_session.run(None, ort_inputs)[0].flatten()
output_path = os.path.join(TEMP_DIR, f"{os.urandom(8).hex()}.wav")
write_wav(output_path, self.TTS_SAMPLE_RATE, audio_waveform)
return output_path
def get_llm_response(self, chat_history):
messages = [{'role': 'system', 'content': self.SYSTEM_PROMPT}]
for turn in chat_history:
messages.append({'role': 'user', 'content': turn[0]})
if turn[1] is not None:
messages.append({'role': 'assistant', 'content': turn[1]})
prompt = self.llm_pipeline.tokenizer.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)
terminators = [
self.llm_pipeline.tokenizer.eos_token_id,
self.llm_pipeline.tokenizer.convert_tokens_to_ids("<|eot_id|>")
]
# <-- START OF FIX: Use TextIteratorStreamer instead of gr.TextIterator -->
streamer = TextIteratorStreamer(
self.llm_pipeline.tokenizer, skip_prompt=True, skip_special_tokens=True
)
generation_kwargs = dict(
streamer=streamer,
max_new_tokens=512,
eos_token_id=terminators,
do_sample=True,
temperature=0.6,
top_p=0.9,
)
# Run the pipeline in a separate thread to enable streaming
thread = threading.Thread(target=self.llm_pipeline, args=[prompt], kwargs=generation_kwargs)
thread.start()
return streamer
# <-- END OF FIX -->
assistant = WeeboAssistant()
def s2s_pipeline(audio_input, chat_history):
user_text = assistant.transcribe_audio(audio_input)
if not user_text or user_text.startswith("("):
chat_history.append((user_text or "(No valid speech detected)", None))
yield chat_history, None, "Please record your voice again."
return
chat_history.append((user_text, ""))
yield chat_history, None, "..." # Show thinking indicator
response_stream = assistant.get_llm_response(chat_history)
llm_response_text = ""
for text_chunk in response_stream:
llm_response_text += text_chunk # <-- FIX: Append chunk to full response
chat_history[-1] = (user_text, llm_response_text)
yield chat_history, None, llm_response_text
final_audio_path = assistant.generate_speech(llm_response_text)
yield chat_history, final_audio_path, llm_response_text
def t2t_pipeline(text_input, chat_history):
chat_history.append((text_input, ""))
yield chat_history
response_stream = assistant.get_llm_response(chat_history)
llm_response_text = ""
for text_chunk in response_stream:
llm_response_text += text_chunk # <-- FIX: Append chunk to full response
chat_history[-1] = (text_input, llm_response_text)
yield chat_history
def clear_textbox():
return gr.Textbox(value="")
with gr.Blocks(theme=gr.themes.Soft(), title="Msaidizi wa Kiswahili") as demo:
gr.Markdown("# ๐Ÿค– Msaidizi wa Sauti wa Kiswahili (Swahili Voice Assistant)")
gr.Markdown("Ongea na msaidizi kwa Kiswahili. Toa sauti, andika maandishi, na upate majibu kwa sauti au maandishi.")
with gr.Tabs():
with gr.TabItem("๐ŸŽ™๏ธ Sauti-kwa-Sauti (Speech-to-Speech)"):
with gr.Row():
with gr.Column(scale=2):
s2s_audio_in = gr.Audio(sources=["microphone"], type="numpy", label="Ongea Hapa (Speak Here)")
s2s_submit_btn = gr.Button("Tuma (Submit)", variant="primary")
with gr.Column(scale=3):
s2s_chatbot = gr.Chatbot(label="Mazungumzo (Conversation)", bubble_full_width=False, height=400)
s2s_audio_out = gr.Audio(type="filepath", label="Jibu la Sauti (Audio Response)", autoplay=True)
s2s_text_out = gr.Textbox(label="Jibu la Maandishi (Text Response)", interactive=False)
with gr.TabItem("โŒจ๏ธ Maandishi-kwa-Maandishi (Text-to-Text)"):
t2t_chatbot = gr.Chatbot(label="Mazungumzo (Conversation)", bubble_full_width=False, height=500)
with gr.Row():
t2t_text_in = gr.Textbox(show_label=False, placeholder="Habari yako...", scale=4, container=False)
t2t_submit_btn = gr.Button("Tuma (Submit)", variant="primary", scale=1)
with gr.TabItem("๐Ÿ› ๏ธ Zana (Tools)"):
with gr.Row():
with gr.Column():
gr.Markdown("### Unukuzi wa Sauti (Speech Transcription)")
tool_s2t_audio_in = gr.Audio(sources=["microphone", "upload"], type="numpy", label="Sauti ya Kuingiza (Input Audio)")
tool_s2t_text_out = gr.Textbox(label="Maandishi Yaliyonukuliwa (Transcribed Text)", interactive=False)
tool_s2t_btn = gr.Button("Nukuu (Transcribe)")
with gr.Column():
gr.Markdown("### Utengenezaji wa Sauti (Speech Synthesis)")
tool_t2s_text_in = gr.Textbox(label="Maandishi ya Kuingiza (Input Text)", placeholder="Andika Kiswahili hapa...")
tool_t2s_audio_out = gr.Audio(type="filepath", label="Sauti Iliyotengenezwa (Synthesized Audio)", autoplay=False)
tool_t2s_btn = gr.Button("Tengeneza Sauti (Synthesize)")
s2s_submit_btn.click(
fn=s2s_pipeline,
inputs=[s2s_audio_in, s2s_chatbot],
outputs=[s2s_chatbot, s2s_audio_out, s2s_text_out],
queue=True
).then(
fn=lambda: gr.Audio(value=None), # Clear audio input after submit
inputs=None,
outputs=s2s_audio_in
)
t2t_submit_btn.click(
fn=t2t_pipeline,
inputs=[t2t_text_in, t2t_chatbot],
outputs=[t2t_chatbot], # <-- FIX: Only output to the chatbot
queue=True
).then(
fn=clear_textbox,
inputs=None,
outputs=t2t_text_in
)
# Also allow Enter key to submit text
t2t_text_in.submit(
fn=t2t_pipeline,
inputs=[t2t_text_in, t2t_chatbot],
outputs=[t2t_chatbot],
queue=True
).then(
fn=clear_textbox,
inputs=None,
outputs=t2t_text_in
)
tool_s2t_btn.click(
fn=assistant.transcribe_audio,
inputs=tool_s2t_audio_in,
outputs=tool_s2t_text_out,
queue=True
)
tool_t2s_btn.click(
fn=assistant.generate_speech,
inputs=tool_t2s_text_in,
outputs=tool_t2s_audio_out,
queue=True
)
demo.queue().launch(debug=True)