Spaces:
Sleeping
Sleeping
import gradio as gr | |
import numpy as np | |
import torch | |
import os | |
import tempfile | |
from transformers import ( | |
AutoModelForSpeechSeq2Seq, | |
AutoProcessor, | |
pipeline, | |
AutoTokenizer, | |
AutoModelForCausalLM | |
) | |
# Check if CUDA is available, otherwise use CPU | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32 | |
# Initialize pyttsx3 for local TTS | |
def load_local_tts(): | |
import pyttsx3 | |
engine = pyttsx3.init() | |
engine.setProperty('rate', 150) # Speed of speech | |
engine.setProperty('volume', 0.9) # Volume | |
voices = engine.getProperty('voices') | |
if len(voices) > 1: | |
engine.setProperty('voice', voices[1].id) # Set female voice | |
return engine | |
# Initialize the TTS engine | |
print("Loading local TTS engine...") | |
tts_engine = load_local_tts() | |
def text_to_speech_local(text): | |
"""Convert text to speech using pyttsx3 local TTS engine""" | |
import tempfile | |
import soundfile as sf | |
# Create a temporary file to store the audio | |
with tempfile.NamedTemporaryFile(suffix='.wav', delete=False) as temp_file: | |
temp_filename = temp_file.name | |
# Generate speech to the temporary file | |
tts_engine.save_to_file(text, temp_filename) | |
tts_engine.runAndWait() | |
# Read the audio file | |
audio_data, sample_rate = sf.read(temp_filename) | |
# Convert to the expected format | |
if len(audio_data.shape) == 1: | |
audio_data = audio_data.reshape(1, -1) | |
else: | |
audio_data = audio_data[:, 0].reshape(1, -1) | |
# Ensure it's int16 | |
audio_data = (audio_data * 32767).astype(np.int16) | |
# Clean up the temporary file | |
os.unlink(temp_filename) | |
return (sample_rate, audio_data) | |
# Load ASR model (Whisper) | |
def load_asr_model(): | |
model_id = "openai/whisper-small" | |
model = AutoModelForSpeechSeq2Seq.from_pretrained( | |
model_id, | |
torch_dtype=torch_dtype, | |
low_cpu_mem_usage=True, | |
use_safetensors=True | |
) | |
model.to(device) | |
processor = AutoProcessor.from_pretrained(model_id) | |
return pipeline( | |
"automatic-speech-recognition", | |
model=model, | |
tokenizer=processor.tokenizer, | |
feature_extractor=processor.feature_extractor, | |
max_new_tokens=128, | |
chunk_length_s=30, | |
batch_size=16, | |
return_timestamps=False, | |
torch_dtype=torch_dtype, | |
device=device, | |
) | |
# Load LLM model | |
def load_llm_model(): | |
model_id = "facebook/opt-1.3b" | |
tokenizer = AutoTokenizer.from_pretrained(model_id) | |
if tokenizer.pad_token is None: | |
tokenizer.add_special_tokens({'pad_token': '[PAD]'}) | |
model = AutoModelForCausalLM.from_pretrained( | |
model_id, | |
torch_dtype=torch_dtype, | |
low_cpu_mem_usage=True | |
) | |
model.resize_token_embeddings(len(tokenizer)) | |
model.config.pad_token_id = tokenizer.pad_token_id | |
if hasattr(model.config, "word_embed_proj_dim"): | |
model.config._remove_wrong_keys = False | |
model.to(device) | |
return model, tokenizer | |
# Initialize models | |
print("Loading ASR model...") | |
asr_pipeline = load_asr_model() | |
print("Loading LLM model...") | |
llm_model, llm_tokenizer = load_llm_model() | |
# Chat history management | |
chat_history = [] | |
def generate_response(prompt): | |
# If chat history is empty, add a system message | |
if not chat_history: | |
chat_history.append({"role": "system", "content": "You are a helpful, friendly AI assistant. Keep your responses concise and conversational."}) | |
# Add user message to history | |
chat_history.append({"role": "user", "content": prompt}) | |
# Build prompt from chat history | |
full_prompt = "" | |
for message in chat_history: | |
if message["role"] == "system": | |
full_prompt += f"System: {message['content']}\n" | |
elif message["role"] == "user": | |
full_prompt += f"User: {message['content']}\n" | |
elif message["role"] == "assistant": | |
full_prompt += f"Assistant: {message['content']}\n" | |
full_prompt += "Assistant: " | |
# Encode input | |
encoded_input = llm_tokenizer.encode_plus( | |
full_prompt, | |
return_tensors="pt", | |
padding=False, | |
add_special_tokens=True, | |
return_attention_mask=True | |
) | |
input_ids = encoded_input["input_ids"].to(device) | |
attention_mask = torch.ones_like(input_ids).to(device) | |
# Generate response | |
with torch.no_grad(): | |
try: | |
output = llm_model.generate( | |
input_ids=input_ids, | |
attention_mask=attention_mask, | |
max_new_tokens=128, | |
do_sample=True, | |
temperature=0.7, | |
top_p=0.9, | |
pad_token_id=llm_tokenizer.pad_token_id, | |
eos_token_id=llm_tokenizer.eos_token_id, | |
use_cache=True | |
) | |
except Exception as e: | |
output = llm_model.generate( | |
input_ids=input_ids, | |
max_new_tokens=128, | |
do_sample=True, | |
temperature=0.7 | |
) | |
# Decode output | |
response_text = llm_tokenizer.decode(output[0], skip_special_tokens=True) | |
response_text = response_text.split("Assistant: ")[-1].strip() | |
# Add assistant response to history | |
chat_history.append({"role": "assistant", "content": response_text}) | |
# Keep history manageable | |
if len(chat_history) > 10: | |
chat_history.pop(1) | |
return response_text | |
def demo(): | |
with gr.Blocks() as demo: | |
gr.Markdown("# Voice Chatbot") | |
gr.Markdown("Simply speak into the microphone and get an audio response.") | |
audio_input = gr.Audio(sources=["microphone"], type="numpy", label="Speak") | |
audio_output = gr.Audio(label="Response", autoplay=True) | |
transcript_display = gr.Textbox(label="Conversation") | |
def process_audio(audio): | |
if audio is None: | |
return None, "No audio detected." | |
# Track conversation for display | |
conversation_text = "" | |
# Process audio | |
sample_rate, audio_array = audio | |
# Convert to float32 for ASR | |
audio_float32 = audio_array.flatten().astype(np.float32) / 32768.0 | |
# Speech-to-text | |
transcript = asr_pipeline({ | |
"sampling_rate": sample_rate, | |
"raw": audio_float32 | |
}) | |
prompt = transcript["text"] | |
conversation_text += f"You: {prompt}\n" | |
print(f"Transcribed: {prompt}") | |
# Generate response | |
response_text = generate_response(prompt) | |
conversation_text += f"Assistant: {response_text}\n" | |
print(f"Response: {response_text}") | |
# Convert to speech | |
sample_rate, audio_array = text_to_speech_local(response_text) | |
# Concatenate chunks for Gradio | |
full_audio = np.concatenate([audio_array[:, i:i+int(sample_rate*0.2)] | |
for i in range(0, audio_array.shape[1], int(sample_rate*0.2)) | |
if audio_array[:, i:i+int(sample_rate*0.2)].size > 0], axis=1) | |
return (sample_rate, full_audio), conversation_text | |
audio_input.change(process_audio, | |
inputs=[audio_input], | |
outputs=[audio_output, transcript_display]) | |
clear_btn = gr.Button("Clear Conversation") | |
clear_btn.click(lambda: (None, ""), outputs=[audio_output, transcript_display]) | |
# Add function to clear chat history | |
def reset_chat(): | |
global chat_history | |
chat_history = [] | |
return None, "Conversation history cleared." | |
reset_btn = gr.Button("Reset Chat History") | |
reset_btn.click(reset_chat, outputs=[audio_output, transcript_display]) | |
demo.launch() | |
if __name__ == "__main__": | |
demo() |