File size: 9,604 Bytes
a248e18 6fefd54 a248e18 0a88016 e6292a4 0a88016 aeacff6 9dc2bc4 aeacff6 0a88016 a248e18 f2cec3a 9dc2bc4 a248e18 1452cfe a248e18 aeacff6 a248e18 aeacff6 a248e18 aeacff6 a248e18 6fefd54 aeacff6 e6292a4 aeacff6 a248e18 aeacff6 a248e18 aeacff6 a248e18 aeacff6 a248e18 aeacff6 a248e18 aeacff6 a248e18 aeacff6 a248e18 aeacff6 9dc2bc4 e6292a4 6fefd54 a248e18 aeacff6 a248e18 6fefd54 a248e18 6fefd54 a248e18 aeacff6 a248e18 6fefd54 e6292a4 6fefd54 a248e18 e6292a4 a248e18 6fefd54 a248e18 aeacff6 a248e18 6fefd54 a248e18 e6292a4 a248e18 6fefd54 a248e18 aeacff6 6fefd54 aeacff6 a248e18 6fefd54 a248e18 6fefd54 a248e18 9dc2bc4 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 |
# -*- coding: utf-8 -*-
"""
This script implements a multi-modal Swahili assistant for Hugging Face Spaces.
It uses Gradio for the user interface and loads models from the HF Hub.
"""
import gradio as gr
import numpy as np
import onnxruntime
import torch
import librosa
from transformers import AutoProcessor, AutoModelForSpeechSeq2Seq, AutoTokenizer, pipeline, TextIteratorStreamer
from scipy.io.wavfile import write as write_wav
import os
import re
from huggingface_hub import login
import threading
# --- Login to Hugging Face using secret ---
# Make sure HF_TOKEN is set in your Hugging Face Space > Settings > Repository secrets
hf_token = os.environ.get("hugface")
if not hf_token:
raise ValueError("HF_TOKEN not found. Please set it in Hugging Face Space repository secrets.")
login(token=hf_token)
print("Successfully logged into Hugging Face Hub!")
# --- Configuration ---
STT_MODEL_ID = "EYEDOL/SALAMA_C3"
LLM_MODEL_ID = "meta-llama/Llama-3.2-1B-Instruct" # <-- FIX: Switched to Llama-3.2
TTS_TOKENIZER_ID = "facebook/mms-tts-swh"
TTS_ONNX_MODEL_PATH = "swahili_tts.onnx"
TEMP_DIR = "temp"
os.makedirs(TEMP_DIR, exist_ok=True)
class WeeboAssistant:
def __init__(self):
self.STT_SAMPLE_RATE = 16000
self.TTS_SAMPLE_RATE = 16000
self.SYSTEM_PROMPT = (
"Wewe ni msaidizi mwenye akili, jibu swali lililoulizwa kwa UFUPI na kwa usahihi. "
"Jibu kwa lugha ya Kiswahili pekee. Hakuna jibu refu."
)
self._init_models()
def _init_models(self):
print("Initializing models...")
self.device = "cuda" if torch.cuda.is_available() else "cpu"
self.torch_dtype = torch.bfloat16 if self.device == "cuda" else torch.float32
print(f"Using device: {self.device}")
# STT
print(f"Loading STT model: {STT_MODEL_ID}")
self.stt_processor = AutoProcessor.from_pretrained(STT_MODEL_ID)
self.stt_model = AutoModelForSpeechSeq2Seq.from_pretrained(
STT_MODEL_ID,
torch_dtype=self.torch_dtype,
low_cpu_mem_usage=True,
use_safetensors=True
).to(self.device)
print("STT model loaded successfully.")
# LLM
print(f"Loading LLM: {LLM_MODEL_ID}")
self.llm_tokenizer = AutoTokenizer.from_pretrained(LLM_MODEL_ID)
self.llm_pipeline = pipeline(
"text-generation",
model=LLM_MODEL_ID,
model_kwargs={"torch_dtype": self.torch_dtype},
tokenizer=self.llm_tokenizer,
device=self.device,
)
print("LLM pipeline loaded successfully.")
# TTS
print(f"Loading TTS model: {TTS_ONNX_MODEL_PATH}")
self.tts_session = onnxruntime.InferenceSession(
TTS_ONNX_MODEL_PATH,
providers=["CUDAExecutionProvider", "CPUExecutionProvider"]
)
self.tts_tokenizer = AutoTokenizer.from_pretrained(TTS_TOKENIZER_ID)
print("TTS model and tokenizer loaded successfully.")
print("-" * 30)
print("All models initialized successfully! โ
")
def transcribe_audio(self, audio_tuple):
if audio_tuple is None:
return ""
sample_rate, audio_data = audio_tuple
if audio_data.ndim > 1:
audio_data = audio_data.mean(axis=1)
if audio_data.dtype != np.float32:
audio_data = audio_data.astype(np.float32) / np.iinfo(audio_data.dtype).max
if sample_rate != self.STT_SAMPLE_RATE:
audio_data = librosa.resample(y=audio_data, orig_sr=sample_rate, target_sr=self.STT_SAMPLE_RATE)
if len(audio_data) < 1000:
return "(Audio too short to transcribe)"
inputs = self.stt_processor(audio_data, sampling_rate=self.STT_SAMPLE_RATE, return_tensors="pt")
inputs = {k: v.to(self.device) for k, v in inputs.items()}
with torch.no_grad():
generated_ids = self.stt_model.generate(**inputs, max_new_tokens=128)
transcription = self.stt_processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
return transcription.strip()
def generate_speech(self, text):
if not text:
return None
text = text.strip()
inputs = self.tts_tokenizer(text, return_tensors="np")
ort_inputs = {self.tts_session.get_inputs()[0].name: inputs.input_ids}
audio_waveform = self.tts_session.run(None, ort_inputs)[0].flatten()
output_path = os.path.join(TEMP_DIR, f"{os.urandom(8).hex()}.wav")
write_wav(output_path, self.TTS_SAMPLE_RATE, audio_waveform)
return output_path
def get_llm_response(self, chat_history):
# <-- FIX: Reverted to using a 'system' role, which is correct for Llama 3 -->
messages = [{'role': 'system', 'content': self.SYSTEM_PROMPT}]
for user_msg, assistant_msg in chat_history:
messages.append({"role": "user", "content": user_msg})
if assistant_msg:
messages.append({"role": "assistant", "content": assistant_msg})
prompt = self.llm_pipeline.tokenizer.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)
terminators = [
self.llm_pipeline.tokenizer.eos_token_id,
self.llm_pipeline.tokenizer.convert_tokens_to_ids("<|eot_id|>")
]
streamer = TextIteratorStreamer(
self.llm_pipeline.tokenizer, skip_prompt=True, skip_special_tokens=True
)
generation_kwargs = dict(
streamer=streamer,
max_new_tokens=512,
eos_token_id=terminators,
do_sample=True,
temperature=0.6,
top_p=0.9,
)
thread = threading.Thread(target=self.llm_pipeline, args=[prompt], kwargs=generation_kwargs)
thread.start()
return streamer
assistant = WeeboAssistant()
def s2s_pipeline(audio_input, chat_history):
user_text = assistant.transcribe_audio(audio_input)
if not user_text or user_text.startswith("("):
chat_history.append((user_text or "(No valid speech detected)", None))
yield chat_history, None, "Please record your voice again."
return
chat_history.append((user_text, ""))
yield chat_history, None, "..."
response_stream = assistant.get_llm_response(chat_history)
llm_response_text = ""
for text_chunk in response_stream:
llm_response_text += text_chunk
chat_history[-1] = (user_text, llm_response_text)
yield chat_history, None, llm_response_text
final_audio_path = assistant.generate_speech(llm_response_text)
yield chat_history, final_audio_path, llm_response_text
def t2t_pipeline(text_input, chat_history):
chat_history.append((text_input, ""))
yield chat_history
response_stream = assistant.get_llm_response(chat_history)
llm_response_text = ""
for text_chunk in response_stream:
llm_response_text += text_chunk
chat_history[-1] = (text_input, llm_response_text)
yield chat_history
def clear_textbox():
return gr.Textbox(value="")
with gr.Blocks(theme=gr.themes.Soft(), title="Msaidizi wa Kiswahili") as demo:
gr.Markdown("# ๐ค Msaidizi wa Sauti wa Kiswahili (Swahili Voice Assistant)")
gr.Markdown("Ongea na msaidizi kwa Kiswahili. Toa sauti, andika maandishi, na upate majibu kwa sauti au maandishi.")
with gr.Tabs():
with gr.TabItem("๐๏ธ Sauti-kwa-Sauti (Speech-to-Speech)"):
with gr.Row():
with gr.Column(scale=2):
s2s_audio_in = gr.Audio(sources=["microphone"], type="numpy", label="Ongea Hapa (Speak Here)")
s2s_submit_btn = gr.Button("Tuma (Submit)", variant="primary")
with gr.Column(scale=3):
s2s_chatbot = gr.Chatbot(label="Mazungumzo (Conversation)", bubble_full_width=False, height=400)
s2s_audio_out = gr.Audio(type="filepath", label="Jibu la Sauti (Audio Response)", autoplay=True)
s2s_text_out = gr.Textbox(label="Jibu la Maandishi (Text Response)", interactive=False)
with gr.TabItem("โจ๏ธ Maandishi-kwa-Maandishi (Text-to-Text)"):
t2t_chatbot = gr.Chatbot(label="Mazungumzo (Conversation)", bubble_full_width=False, height=500)
with gr.Row():
t2t_text_in = gr.Textbox(show_label=False, placeholder="Habari yako...", scale=4, container=False)
t2t_submit_btn = gr.Button("Tuma (Submit)", variant="primary", scale=1)
with gr.TabItem("๐ ๏ธ Zana (Tools)"):
with gr.Row():
with gr.Column():
gr.Markdown("### Unukuzi wa Sauti (Speech Transcription)")
tool_s2t_audio_in = gr.Audio(sources=["microphone", "upload"], type="numpy", label="Sauti ya Kuingiza (Input Audio)")
tool_s2t_text_out = gr.Textbox(label="Maandishi Yaliyonukuliwa (Transcribed Text)", interactive=False)
tool_s2t_btn = gr.Button("Nukuu (Transcribe)")
with gr.Column():
gr.Markdown("### Utengenezaji wa Sauti (Speech Synthesis)")
tool_t2s_text_in = gr.Textbox(label="Maandishi ya Kuingiza (Input Text)", placeholder="Andika Kiswahili hapa...")
tool_t2s_audio_out = gr.Audio(type="filepath", label="Sauti Iliyotengenezwa (Synthesized Audio)", autoplay=False)
tool_t2s_btn = gr.Button("Tengeneza Sauti |