|
|
|
""" |
|
Salama Assistant โ fixed full app.py with PEFT adapter loading (base + adapter) |
|
|
|
Drop this file into your Hugging Face Space (replace your existing app.py). |
|
|
|
Requirements: |
|
- transformers |
|
- peft |
|
- onnxruntime |
|
- librosa |
|
- huggingface_hub |
|
- gradio |
|
|
|
Note: install `peft` (e.g. add to requirements.txt: "peft>=0.4.0") or pip install in your environment. |
|
""" |
|
|
|
import os |
|
import json |
|
import tempfile |
|
import threading |
|
import numpy as np |
|
import gradio as gr |
|
import librosa |
|
import torch |
|
from scipy.io.wavfile import write as write_wav |
|
from huggingface_hub import login |
|
import onnxruntime |
|
|
|
from transformers import ( |
|
AutoProcessor, |
|
AutoModelForSpeechSeq2Seq, |
|
AutoTokenizer, |
|
AutoConfig, |
|
AutoModelForCausalLM, |
|
pipeline, |
|
TextIteratorStreamer, |
|
) |
|
|
|
|
|
from peft import PeftModel, PeftConfig |
|
|
|
|
|
STT_MODEL_ID = "EYEDOL/SALAMA_C3" |
|
ADAPTER_REPO_ID = "EYEDOL/Llama-3.2-3b_ON_ALPACA5" |
|
BASE_MODEL_ID = "unsloth/Llama-3.2-3B-Instruct" |
|
TTS_TOKENIZER_ID = "facebook/mms-tts-swh" |
|
TTS_ONNX_MODEL_PATH = "swahili_tts.onnx" |
|
|
|
TEMP_DIR = "temp" |
|
os.makedirs(TEMP_DIR, exist_ok=True) |
|
|
|
|
|
HF_TOKEN = os.environ.get("HF_TOKEN") or os.environ.get("hugface") |
|
if not HF_TOKEN: |
|
print("Warning: HF_TOKEN not found in env. Public models may still load, but private repos require a token.") |
|
else: |
|
try: |
|
login(token=HF_TOKEN) |
|
print("Successfully logged into Hugging Face Hub!") |
|
except Exception as e: |
|
print("Warning: huggingface_hub.login() failed:", e) |
|
|
|
|
|
class WeeboAssistant: |
|
def __init__(self): |
|
self.STT_SAMPLE_RATE = 16000 |
|
self.TTS_SAMPLE_RATE = 16000 |
|
self.SYSTEM_PROMPT = ( |
|
"Wewe ni msaidizi mwenye akili, jibu swali lililoulizwa KWA UFUPI na kwa usahihi kwa sauti ya mazungumzo. " |
|
"Jibu kwa lugha ya Kiswahili pekee. Hakuna jibu refu." |
|
) |
|
self._init_models() |
|
|
|
def _init_models(self): |
|
print("Initializing models...") |
|
self.device = "cuda" if torch.cuda.is_available() else "cpu" |
|
self.torch_dtype = torch.bfloat16 if self.device == "cuda" else torch.float32 |
|
print(f"Using device: {self.device}") |
|
|
|
|
|
print(f"Loading STT model: {STT_MODEL_ID}") |
|
self.stt_processor = AutoProcessor.from_pretrained(STT_MODEL_ID) |
|
self.stt_model = AutoModelForSpeechSeq2Seq.from_pretrained( |
|
STT_MODEL_ID, |
|
torch_dtype=self.torch_dtype, |
|
low_cpu_mem_usage=True, |
|
use_safetensors=True, |
|
) |
|
if self.device == "cuda": |
|
try: |
|
self.stt_model = self.stt_model.to("cuda") |
|
except Exception: |
|
pass |
|
print("STT model loaded successfully.") |
|
|
|
|
|
print(f"Loading base LLM: {BASE_MODEL_ID} and applying adapter: {ADAPTER_REPO_ID}") |
|
|
|
|
|
try: |
|
self.llm_tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL_ID, use_fast=True) |
|
except Exception as e: |
|
print("Warning: could not load base tokenizer, falling back to adapter tokenizer. Error:", e) |
|
self.llm_tokenizer = AutoTokenizer.from_pretrained(ADAPTER_REPO_ID, use_fast=True) |
|
|
|
|
|
device_map = "auto" if torch.cuda.is_available() else None |
|
try: |
|
self.llm_model = AutoModelForCausalLM.from_pretrained( |
|
BASE_MODEL_ID, |
|
torch_dtype=self.torch_dtype, |
|
low_cpu_mem_usage=True, |
|
device_map=device_map, |
|
trust_remote_code=True, |
|
) |
|
except Exception as e: |
|
|
|
raise RuntimeError( |
|
"Failed to load base model. Ensure the base model ID is correct and the HF_TOKEN has access if private. Error: " |
|
+ str(e) |
|
) |
|
|
|
|
|
try: |
|
|
|
peft_config = PeftConfig.from_pretrained(ADAPTER_REPO_ID) |
|
self.llm_model = PeftModel.from_pretrained( |
|
self.llm_model, |
|
ADAPTER_REPO_ID, |
|
device_map=device_map, |
|
torch_dtype=self.torch_dtype, |
|
low_cpu_mem_usage=True, |
|
) |
|
except Exception as e: |
|
raise RuntimeError( |
|
"Failed to load/apply PEFT adapter from adapter repo. Make sure adapter files (adapter_config.json and adapter_model.safetensors) are present and HF_TOKEN has access if private. Error: " |
|
+ str(e) |
|
) |
|
|
|
|
|
try: |
|
device_index = 0 if torch.cuda.is_available() else -1 |
|
self.llm_pipeline = pipeline( |
|
"text-generation", |
|
model=self.llm_model, |
|
tokenizer=self.llm_tokenizer, |
|
device=device_index, |
|
model_kwargs={"torch_dtype": self.torch_dtype}, |
|
) |
|
except Exception as e: |
|
print("Warning: could not create text-generation pipeline. Streaming generate will still work. Error:", e) |
|
self.llm_pipeline = None |
|
|
|
print("LLM base + adapter loaded successfully.") |
|
|
|
|
|
print(f"Loading TTS model: {TTS_ONNX_MODEL_PATH}") |
|
providers = ["CPUExecutionProvider"] |
|
if torch.cuda.is_available(): |
|
providers = ["CUDAExecutionProvider", "CPUExecutionProvider"] |
|
self.tts_session = onnxruntime.InferenceSession(TTS_ONNX_MODEL_PATH, providers=providers) |
|
self.tts_tokenizer = AutoTokenizer.from_pretrained(TTS_TOKENIZER_ID) |
|
print("TTS model and tokenizer loaded successfully.") |
|
|
|
print("-" * 30) |
|
print("All models initialized successfully! โ
") |
|
|
|
|
|
def transcribe_audio(self, audio_tuple): |
|
if audio_tuple is None: |
|
return "" |
|
sample_rate, audio_data = audio_tuple |
|
if audio_data.ndim > 1: |
|
audio_data = audio_data.mean(axis=1) |
|
if audio_data.dtype != np.float32: |
|
if np.issubdtype(audio_data.dtype, np.integer): |
|
max_val = np.iinfo(audio_data.dtype).max |
|
audio_data = audio_data.astype(np.float32) / float(max_val) |
|
else: |
|
audio_data = audio_data.astype(np.float32) |
|
if sample_rate != self.STT_SAMPLE_RATE: |
|
audio_data = librosa.resample(y=audio_data, orig_sr=sample_rate, target_sr=self.STT_SAMPLE_RATE) |
|
if len(audio_data) < 1000: |
|
return "(Audio too short to transcribe)" |
|
|
|
inputs = self.stt_processor(audio_data, sampling_rate=self.STT_SAMPLE_RATE, return_tensors="pt") |
|
inputs = {k: v.to(next(self.stt_model.parameters()).device) for k, v in inputs.items()} |
|
with torch.no_grad(): |
|
generated_ids = self.stt_model.generate(**inputs, max_new_tokens=128) |
|
transcription = self.stt_processor.batch_decode(generated_ids, skip_special_tokens=True)[0] |
|
return transcription.strip() |
|
|
|
def generate_speech(self, text): |
|
if not text: |
|
return None |
|
text = text.strip() |
|
inputs = self.tts_tokenizer(text, return_tensors="np") |
|
input_name = self.tts_session.get_inputs()[0].name |
|
ort_inputs = {input_name: inputs["input_ids"]} |
|
audio_waveform = self.tts_session.run(None, ort_inputs)[0].flatten() |
|
|
|
if np.issubdtype(audio_waveform.dtype, np.floating): |
|
audio_clip = np.clip(audio_waveform, -1.0, 1.0) |
|
audio_int16 = (audio_clip * 32767).astype(np.int16) |
|
else: |
|
audio_int16 = audio_waveform.astype(np.int16) |
|
|
|
output_path = os.path.join(TEMP_DIR, f"{os.urandom(8).hex()}.wav") |
|
write_wav(output_path, self.TTS_SAMPLE_RATE, audio_int16) |
|
return output_path |
|
|
|
def get_llm_response(self, chat_history): |
|
prompt_lines = [self.SYSTEM_PROMPT.strip(), |
|
"" |
|
] |
|
|
|
|
|
|
|
for user_msg, assistant_msg in chat_history: |
|
if user_msg: |
|
prompt_lines.append("User: " + user_msg) |
|
if assistant_msg: |
|
prompt_lines.append("Assistant: " + assistant_msg) |
|
prompt_lines.append("Assistant: ") |
|
prompt = "".join(prompt_lines) |
|
|
|
inputs = self.llm_tokenizer(prompt, return_tensors="pt") |
|
try: |
|
model_device = next(self.llm_model.parameters()).device |
|
except StopIteration: |
|
model_device = torch.device("cpu") |
|
inputs = {k: v.to(model_device) for k, v in inputs.items()} |
|
|
|
streamer = TextIteratorStreamer(self.llm_tokenizer, skip_prompt=True, skip_special_tokens=True) |
|
|
|
generation_kwargs = dict( |
|
input_ids=inputs["input_ids"], |
|
attention_mask=inputs.get("attention_mask", None), |
|
max_new_tokens=512, |
|
do_sample=True, |
|
temperature=0.6, |
|
top_p=0.9, |
|
streamer=streamer, |
|
eos_token_id=getattr(self.llm_tokenizer, "eos_token_id", None), |
|
) |
|
|
|
gen_thread = threading.Thread(target=self.llm_model.generate, kwargs=generation_kwargs, daemon=True) |
|
gen_thread.start() |
|
|
|
return streamer |
|
|
|
|
|
|
|
assistant = WeeboAssistant() |
|
|
|
|
|
|
|
|
|
def s2s_pipeline(audio_input, chat_history): |
|
user_text = assistant.transcribe_audio(audio_input) |
|
if not user_text or user_text.startswith("("): |
|
chat_history.append((user_text or "(No valid speech detected)", None)) |
|
yield chat_history, None, "Please record your voice again." |
|
return |
|
|
|
chat_history.append((user_text, "")) |
|
yield chat_history, None, "..." |
|
|
|
response_stream = assistant.get_llm_response(chat_history) |
|
llm_response_text = "" |
|
for text_chunk in response_stream: |
|
llm_response_text += text_chunk |
|
chat_history[-1] = (user_text, llm_response_text) |
|
yield chat_history, None, llm_response_text |
|
|
|
final_audio_path = assistant.generate_speech(llm_response_text) |
|
yield chat_history, final_audio_path, llm_response_text |
|
|
|
|
|
def t2t_pipeline(text_input, chat_history): |
|
chat_history.append((text_input, "")) |
|
yield chat_history |
|
|
|
response_stream = assistant.get_llm_response(chat_history) |
|
llm_response_text = "" |
|
for text_chunk in response_stream: |
|
llm_response_text += text_chunk |
|
chat_history[-1] = (text_input, llm_response_text) |
|
yield chat_history |
|
|
|
|
|
def clear_textbox(): |
|
return gr.Textbox(value="") |
|
|
|
|
|
|
|
with gr.Blocks(theme=gr.themes.Soft(), title="Msaidizi wa Kiswahili") as demo: |
|
gr.Markdown("# ๐ค Msaidizi wa Sauti wa Kiswahili (Swahili Voice Assistant)") |
|
gr.Markdown("Ongea na msaidizi kwa Kiswahili. Toa sauti, andika maandishi, na upate majibu kwa sauti au maandishi.") |
|
|
|
with gr.Tabs(): |
|
with gr.TabItem("๐๏ธ Sauti-kwa-Sauti (Speech-to-Speech)"): |
|
with gr.Row(): |
|
with gr.Column(scale=2): |
|
s2s_audio_in = gr.Audio(sources=["microphone"], type="numpy", label="Ongea Hapa (Speak Here)") |
|
s2s_submit_btn = gr.Button("Tuma (Submit)", variant="primary") |
|
with gr.Column(scale=3): |
|
s2s_chatbot = gr.Chatbot(label="Mazungumzo (Conversation)", bubble_full_width=False, height=400) |
|
s2s_audio_out = gr.Audio(type="filepath", label="Jibu la Sauti (Audio Response)", autoplay=True) |
|
s2s_text_out = gr.Textbox(label="Jibu la Maandishi (Text Response)", interactive=False) |
|
|
|
with gr.TabItem("โจ๏ธ Maandishi-kwa-Maandishi (Text-to-Text)"): |
|
t2t_chatbot = gr.Chatbot(label="Mazungumzo (Conversation)", bubble_full_width=False, height=500) |
|
with gr.Row(): |
|
t2t_text_in = gr.Textbox(show_label=False, placeholder="Habari yako...", scale=4, container=False) |
|
t2t_submit_btn = gr.Button("Tuma (Submit)", variant="primary", scale=1) |
|
|
|
with gr.TabItem("๐ ๏ธ Zana (Tools)"): |
|
with gr.Row(): |
|
with gr.Column(): |
|
gr.Markdown("### Unukuzi wa Sauti (Speech Transcription)") |
|
tool_s2t_audio_in = gr.Audio(sources=["microphone", "upload"], type="numpy", label="Sauti ya Kuingiza (Input Audio)") |
|
tool_s2t_text_out = gr.Textbox(label="Maandishi Yaliyonukuliwa (Transcribed Text)", interactive=False) |
|
tool_s2t_btn = gr.Button("Nukuu (Transcribe)") |
|
with gr.Column(): |
|
gr.Markdown("### Utengenezaji wa Sauti (Speech Synthesis)") |
|
tool_t2s_text_in = gr.Textbox(label="Maandishi ya Kuingiza (Input Text)", placeholder="Andika Kiswahili hapa...") |
|
tool_t2s_audio_out = gr.Audio(type="filepath", label="Sauti Iliyotengenezwa (Synthesized Audio)", autoplay=False) |
|
tool_t2s_btn = gr.Button("Tengeneza Sauti (Synthesize)") |
|
|
|
s2s_submit_btn.click( |
|
fn=s2s_pipeline, |
|
inputs=[s2s_audio_in, s2s_chatbot], |
|
outputs=[s2s_chatbot, s2s_audio_out, s2s_text_out], |
|
queue=True, |
|
).then( |
|
fn=lambda: gr.Audio(value=None), |
|
inputs=None, |
|
outputs=s2s_audio_in, |
|
) |
|
|
|
t2t_submit_btn.click( |
|
fn=t2t_pipeline, |
|
inputs=[t2t_text_in, t2t_chatbot], |
|
outputs=[t2t_chatbot], |
|
queue=True, |
|
).then( |
|
fn=clear_textbox, |
|
inputs=None, |
|
outputs=t2t_text_in, |
|
) |
|
|
|
t2t_text_in.submit( |
|
fn=t2t_pipeline, |
|
inputs=[t2t_text_in, t2t_chatbot], |
|
outputs=[t2t_chatbot], |
|
queue=True, |
|
).then( |
|
fn=clear_textbox, |
|
inputs=None, |
|
outputs=t2t_text_in, |
|
) |
|
|
|
tool_s2t_btn.click( |
|
fn=assistant.transcribe_audio, |
|
inputs=tool_s2t_audio_in, |
|
outputs=tool_s2t_text_out, |
|
queue=True, |
|
) |
|
|
|
tool_t2s_btn.click( |
|
fn=assistant.generate_speech, |
|
inputs=tool_t2s_text_in, |
|
outputs=tool_t2s_audio_out, |
|
queue=True, |
|
) |
|
|
|
demo.queue().launch(debug=True) |
|
|