audio1test / app.py
hashhac
try time
4fb650d
raw
history blame
8.21 kB
import gradio as gr
import numpy as np
import torch
import os
import tempfile
from transformers import (
AutoModelForSpeechSeq2Seq,
AutoProcessor,
pipeline,
AutoTokenizer,
AutoModelForCausalLM
)
# Check if CUDA is available, otherwise use CPU
device = "cuda" if torch.cuda.is_available() else "cpu"
torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
# Initialize pyttsx3 for local TTS
def load_local_tts():
import pyttsx3
engine = pyttsx3.init()
engine.setProperty('rate', 150) # Speed of speech
engine.setProperty('volume', 0.9) # Volume
voices = engine.getProperty('voices')
if len(voices) > 1:
engine.setProperty('voice', voices[1].id) # Set female voice
return engine
# Initialize the TTS engine
print("Loading local TTS engine...")
tts_engine = load_local_tts()
def text_to_speech_local(text):
"""Convert text to speech using pyttsx3 local TTS engine"""
import tempfile
import soundfile as sf
# Create a temporary file to store the audio
with tempfile.NamedTemporaryFile(suffix='.wav', delete=False) as temp_file:
temp_filename = temp_file.name
# Generate speech to the temporary file
tts_engine.save_to_file(text, temp_filename)
tts_engine.runAndWait()
# Read the audio file
audio_data, sample_rate = sf.read(temp_filename)
# Convert to the expected format
if len(audio_data.shape) == 1:
audio_data = audio_data.reshape(1, -1)
else:
audio_data = audio_data[:, 0].reshape(1, -1)
# Ensure it's int16
audio_data = (audio_data * 32767).astype(np.int16)
# Clean up the temporary file
os.unlink(temp_filename)
return (sample_rate, audio_data)
# Load ASR model (Whisper)
def load_asr_model():
model_id = "openai/whisper-small"
model = AutoModelForSpeechSeq2Seq.from_pretrained(
model_id,
torch_dtype=torch_dtype,
low_cpu_mem_usage=True,
use_safetensors=True
)
model.to(device)
processor = AutoProcessor.from_pretrained(model_id)
return pipeline(
"automatic-speech-recognition",
model=model,
tokenizer=processor.tokenizer,
feature_extractor=processor.feature_extractor,
max_new_tokens=128,
chunk_length_s=30,
batch_size=16,
return_timestamps=False,
torch_dtype=torch_dtype,
device=device,
)
# Load LLM model
def load_llm_model():
model_id = "facebook/opt-1.3b"
tokenizer = AutoTokenizer.from_pretrained(model_id)
if tokenizer.pad_token is None:
tokenizer.add_special_tokens({'pad_token': '[PAD]'})
model = AutoModelForCausalLM.from_pretrained(
model_id,
torch_dtype=torch_dtype,
low_cpu_mem_usage=True
)
model.resize_token_embeddings(len(tokenizer))
model.config.pad_token_id = tokenizer.pad_token_id
if hasattr(model.config, "word_embed_proj_dim"):
model.config._remove_wrong_keys = False
model.to(device)
return model, tokenizer
# Initialize models
print("Loading ASR model...")
asr_pipeline = load_asr_model()
print("Loading LLM model...")
llm_model, llm_tokenizer = load_llm_model()
# Chat history management
chat_history = []
def generate_response(prompt):
# If chat history is empty, add a system message
if not chat_history:
chat_history.append({"role": "system", "content": "You are a helpful, friendly AI assistant. Keep your responses concise and conversational."})
# Add user message to history
chat_history.append({"role": "user", "content": prompt})
# Build prompt from chat history
full_prompt = ""
for message in chat_history:
if message["role"] == "system":
full_prompt += f"System: {message['content']}\n"
elif message["role"] == "user":
full_prompt += f"User: {message['content']}\n"
elif message["role"] == "assistant":
full_prompt += f"Assistant: {message['content']}\n"
full_prompt += "Assistant: "
# Encode input
encoded_input = llm_tokenizer.encode_plus(
full_prompt,
return_tensors="pt",
padding=False,
add_special_tokens=True,
return_attention_mask=True
)
input_ids = encoded_input["input_ids"].to(device)
attention_mask = torch.ones_like(input_ids).to(device)
# Generate response
with torch.no_grad():
try:
output = llm_model.generate(
input_ids=input_ids,
attention_mask=attention_mask,
max_new_tokens=128,
do_sample=True,
temperature=0.7,
top_p=0.9,
pad_token_id=llm_tokenizer.pad_token_id,
eos_token_id=llm_tokenizer.eos_token_id,
use_cache=True
)
except Exception as e:
output = llm_model.generate(
input_ids=input_ids,
max_new_tokens=128,
do_sample=True,
temperature=0.7
)
# Decode output
response_text = llm_tokenizer.decode(output[0], skip_special_tokens=True)
response_text = response_text.split("Assistant: ")[-1].strip()
# Add assistant response to history
chat_history.append({"role": "assistant", "content": response_text})
# Keep history manageable
if len(chat_history) > 10:
chat_history.pop(1)
return response_text
def demo():
with gr.Blocks() as demo:
gr.Markdown("# Voice Chatbot")
gr.Markdown("Simply speak into the microphone and get an audio response.")
audio_input = gr.Audio(sources=["microphone"], type="numpy", label="Speak")
audio_output = gr.Audio(label="Response", autoplay=True)
transcript_display = gr.Textbox(label="Conversation")
def process_audio(audio):
if audio is None:
return None, "No audio detected."
# Track conversation for display
conversation_text = ""
# Process audio
sample_rate, audio_array = audio
# Convert to float32 for ASR
audio_float32 = audio_array.flatten().astype(np.float32) / 32768.0
# Speech-to-text
transcript = asr_pipeline({
"sampling_rate": sample_rate,
"raw": audio_float32
})
prompt = transcript["text"]
conversation_text += f"You: {prompt}\n"
print(f"Transcribed: {prompt}")
# Generate response
response_text = generate_response(prompt)
conversation_text += f"Assistant: {response_text}\n"
print(f"Response: {response_text}")
# Convert to speech
sample_rate, audio_array = text_to_speech_local(response_text)
# Concatenate chunks for Gradio
full_audio = np.concatenate([audio_array[:, i:i+int(sample_rate*0.2)]
for i in range(0, audio_array.shape[1], int(sample_rate*0.2))
if audio_array[:, i:i+int(sample_rate*0.2)].size > 0], axis=1)
return (sample_rate, full_audio), conversation_text
audio_input.change(process_audio,
inputs=[audio_input],
outputs=[audio_output, transcript_display])
clear_btn = gr.Button("Clear Conversation")
clear_btn.click(lambda: (None, ""), outputs=[audio_output, transcript_display])
# Add function to clear chat history
def reset_chat():
global chat_history
chat_history = []
return None, "Conversation history cleared."
reset_btn = gr.Button("Reset Chat History")
reset_btn.click(reset_chat, outputs=[audio_output, transcript_display])
demo.launch()
if __name__ == "__main__":
demo()