chaty / app.py
hashhac
no more eos tockens for padding
fe65571
raw
history blame
8.53 kB
from fastrtc import (
ReplyOnPause, AdditionalOutputs, Stream,
audio_to_bytes, aggregate_bytes_to_16bit
)
import gradio as gr
import numpy as np
import torch
import os
import tempfile
from transformers import (
AutoModelForSpeechSeq2Seq,
AutoProcessor,
pipeline,
AutoTokenizer,
AutoModelForCausalLM
)
from gtts import gTTS
from scipy.io import wavfile
# Check if CUDA is available, otherwise use CPU
device = "cuda" if torch.cuda.is_available() else "cpu"
torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
# Step 1: Audio transcription with Whisper
def load_asr_model():
model_id = "openai/whisper-small"
model = AutoModelForSpeechSeq2Seq.from_pretrained(
model_id,
torch_dtype=torch_dtype,
low_cpu_mem_usage=True,
use_safetensors=True
)
model.to(device)
processor = AutoProcessor.from_pretrained(model_id)
return pipeline(
"automatic-speech-recognition",
model=model,
tokenizer=processor.tokenizer,
feature_extractor=processor.feature_extractor,
max_new_tokens=128,
chunk_length_s=30,
batch_size=16,
return_timestamps=False,
torch_dtype=torch_dtype,
device=device,
)
# Step 2: Text generation with a smaller LLM
def load_llm_model():
model_id = "facebook/opt-1.3b"
tokenizer = AutoTokenizer.from_pretrained(model_id)
# Ensure pad token is set to something different than EOS token
if tokenizer.pad_token is None:
# Use a different special token as padding token
tokenizer.add_special_tokens({'pad_token': '[PAD]'})
# Resize the token embeddings since we added a new token
model = AutoModelForCausalLM.from_pretrained(
model_id,
torch_dtype=torch_dtype,
low_cpu_mem_usage=True
)
model.resize_token_embeddings(len(tokenizer))
else:
model = AutoModelForCausalLM.from_pretrained(
model_id,
torch_dtype=torch_dtype,
low_cpu_mem_usage=True
)
model.to(device)
return model, tokenizer
# Step 3: Text-to-Speech with gTTS (Google Text-to-Speech)
def gtts_text_to_speech(text):
# Create a temporary file
with tempfile.NamedTemporaryFile(suffix='.wav', delete=False) as f:
tmp_filename = f.name
# Use gTTS to convert text to speech
tts = gTTS(text=text, lang='en', slow=False)
# Save as MP3 first (gTTS only outputs MP3)
mp3_filename = tmp_filename.replace('.wav', '.mp3')
tts.save(mp3_filename)
# Convert MP3 to WAV using FFmpeg if available, otherwise use a fallback
try:
import subprocess
subprocess.run(['ffmpeg', '-i', mp3_filename, '-acodec', 'pcm_s16le', '-ar', '24000', '-ac', '1', tmp_filename],
stdout=subprocess.PIPE, stderr=subprocess.PIPE)
except (ImportError, FileNotFoundError):
# Fallback if FFmpeg is not available
from pydub import AudioSegment
sound = AudioSegment.from_mp3(mp3_filename)
sound = sound.set_frame_rate(24000).set_channels(1)
sound.export(tmp_filename, format="wav")
# Read the WAV file
sample_rate, audio_data = wavfile.read(tmp_filename)
# Clean up temporary files
os.remove(mp3_filename)
os.remove(tmp_filename)
# Convert to expected format
audio_data = audio_data.reshape(1, -1).astype(np.int16)
return (sample_rate, audio_data)
# Initialize models
print("Loading ASR model...")
asr_pipeline = load_asr_model()
print("Loading LLM model...")
llm_model, llm_tokenizer = load_llm_model()
# Chat history management
chat_history = []
def generate_response(prompt):
# If chat history is empty, add a system message
if not chat_history:
chat_history.append({"role": "system", "content": "You are a helpful, friendly AI assistant. Keep your responses concise and conversational."})
# Add user message to history
chat_history.append({"role": "user", "content": prompt})
# Prepare input for the model
full_prompt = ""
for message in chat_history:
if message["role"] == "system":
full_prompt += f"System: {message['content']}\n"
elif message["role"] == "user":
full_prompt += f"User: {message['content']}\n"
elif message["role"] == "assistant":
full_prompt += f"Assistant: {message['content']}\n"
full_prompt += "Assistant: "
# Generate response with proper attention mask
# Let the tokenizer create the attention mask automatically
tokenized_inputs = llm_tokenizer(
full_prompt,
return_tensors="pt",
padding=True,
return_attention_mask=True # This generates the proper attention mask
)
# Move to device
input_ids = tokenized_inputs["input_ids"].to(device)
attention_mask = tokenized_inputs["attention_mask"].to(device)
# Generate response
with torch.no_grad():
output = llm_model.generate(
input_ids=input_ids,
attention_mask=attention_mask, # Use the tokenizer's attention mask
max_new_tokens=128,
do_sample=True,
temperature=0.7,
top_p=0.9
)
response_text = llm_tokenizer.decode(output[0], skip_special_tokens=True)
response_text = response_text.split("Assistant: ")[-1].strip()
# Add assistant response to history
chat_history.append({"role": "assistant", "content": response_text})
# Keep history at a reasonable size
if len(chat_history) > 10:
# Keep system message and last 9 exchanges
chat_history.pop(1)
return response_text
def response(audio: tuple[int, np.ndarray]):
# Step 1: Convert audio to float32 before passing to ASR
sample_rate, audio_data = audio
# Convert int16 audio to float32
audio_float32 = audio_data.flatten().astype(np.float32) / 32768.0 # Normalize to [-1.0, 1.0]
# Speech-to-Text with correct data type
transcript = asr_pipeline({
"sampling_rate": sample_rate,
"raw": audio_float32
})
prompt = transcript["text"]
print(f"Transcribed: {prompt}")
# Step 2: Generate text response
response_text = generate_response(prompt)
print(f"Response: {response_text}")
# Step 3: Text-to-Speech using gTTS
sample_rate, audio_array = gtts_text_to_speech(response_text)
# Convert to expected format and yield chunks
chunk_size = int(sample_rate * 0.2) # 200ms chunks
for i in range(0, audio_array.shape[1], chunk_size):
chunk = audio_array[:, i:i+chunk_size]
if chunk.size > 0: # Ensure we don't yield empty chunks
yield (sample_rate, chunk)
stream = Stream(
modality="audio",
mode="send-receive",
handler=ReplyOnPause(response),
)
# For testing without WebRTC
def demo():
with gr.Blocks() as demo:
gr.Markdown("# Local Voice Chatbot")
audio_input = gr.Audio(sources=["microphone"], type="numpy")
audio_output = gr.Audio()
def process_audio(audio):
if audio is None:
return None
sample_rate, audio_array = audio
# Convert to float32 for ASR
audio_float32 = audio_array.flatten().astype(np.float32) / 32768.0
transcript = asr_pipeline({
"sampling_rate": sample_rate,
"raw": audio_float32
})
prompt = transcript["text"]
print(f"Transcribed: {prompt}")
response_text = generate_response(prompt)
print(f"Response: {response_text}")
sample_rate, audio_array = gtts_text_to_speech(response_text)
return (sample_rate, audio_array[0])
audio_input.change(process_audio, inputs=[audio_input], outputs=[audio_output])
demo.launch()
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser()
parser.add_argument("--demo", action="store_true", help="Run Gradio demo instead of WebRTC")
args = parser.parse_args()
# hugging face issues
demo()
# if args.demo:
# demo()
# else:
# # For running with FastRTC
# # You would need to add your FastRTC server code here
# pass