chaty / app.py
hashhac
pad fix
5c42f52
raw
history blame
14.7 kB
from fastrtc import (
ReplyOnPause, AdditionalOutputs, Stream,
audio_to_bytes, aggregate_bytes_to_16bit
)
import gradio as gr
import time
import numpy as np
import torch
import os
import tempfile
from transformers import (
AutoModelForSpeechSeq2Seq,
AutoProcessor,
pipeline,
AutoTokenizer,
AutoModelForCausalLM
)
from gtts import gTTS
from scipy.io import wavfile
# Check if CUDA is available, otherwise use CPU
device = "cuda" if torch.cuda.is_available() else "cpu"
torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
# Step 1: Audio transcription with Whisper
def load_asr_model():
model_id = "openai/whisper-small"
model = AutoModelForSpeechSeq2Seq.from_pretrained(
model_id,
torch_dtype=torch_dtype,
low_cpu_mem_usage=True,
use_safetensors=True
)
model.to(device)
processor = AutoProcessor.from_pretrained(model_id)
return pipeline(
"automatic-speech-recognition",
model=model,
tokenizer=processor.tokenizer,
feature_extractor=processor.feature_extractor,
max_new_tokens=128,
chunk_length_s=30,
batch_size=16,
return_timestamps=False,
torch_dtype=torch_dtype,
device=device,
)
# Step 2: Text generation with a smaller LLM
def load_llm_model():
model_id = "facebook/opt-1.3b"
# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained(model_id)
# Print initial configuration
print(f"Initial pad token ID: {tokenizer.pad_token_id}, EOS token ID: {tokenizer.eos_token_id}")
# Load model
model = AutoModelForCausalLM.from_pretrained(
model_id,
torch_dtype=torch_dtype,
low_cpu_mem_usage=True
)
# THE KEY FIX: Set pad token consistently in both tokenizer and model config
if tokenizer.pad_token_id is None or tokenizer.pad_token_id == tokenizer.eos_token_id:
# Define a special token with ID that doesn't conflict
tokenizer.add_special_tokens({'pad_token': '[PAD]'})
model.resize_token_embeddings(len(tokenizer))
# Make sure model config has consistent pad token ID
model.config.pad_token_id = tokenizer.pad_token_id
# Important: Also set these token IDs in model config
if hasattr(model.config, 'decoder_start_token_id') and model.config.decoder_start_token_id is None:
model.config.decoder_start_token_id = tokenizer.pad_token_id
print(f"Modified token IDs - PAD: {tokenizer.pad_token_id}, EOS: {tokenizer.eos_token_id}")
print(f"Model config - PAD: {model.config.pad_token_id}, EOS: {model.config.eos_token_id}")
# Double-check that model config has pad token ID set
if not hasattr(model.config, 'pad_token_id') or model.config.pad_token_id is None:
model.config.pad_token_id = tokenizer.pad_token_id
# Move model to the right device
model.to(device)
return model, tokenizer
# Step 3: Text-to-Speech with gTTS (Google Text-to-Speech)
def gtts_text_to_speech(text):
"""Convert text to speech using gTTS and ensure proper WAV format."""
# Create absolute paths for temporary files
temp_dir = tempfile.gettempdir()
mp3_filename = os.path.join(temp_dir, f"tts_temp_{os.getpid()}_{time.time()}.mp3")
wav_filename = os.path.join(temp_dir, f"tts_temp_{os.getpid()}_{time.time()}.wav")
try:
# Make sure text is not empty
if not text or text.isspace():
text = "I don't have a response for that."
# Create gTTS object and save to MP3
tts = gTTS(text=text, lang='en', slow=False)
tts.save(mp3_filename)
print(f"MP3 file created: {mp3_filename}, size: {os.path.getsize(mp3_filename)}")
# Try multiple methods to convert MP3 to WAV
wav_created = False
# Method 1: Try ffmpeg (most reliable)
try:
import subprocess
cmd = ['ffmpeg', '-y', '-i', mp3_filename, '-acodec', 'pcm_s16le', '-ar', '24000', '-ac', '1', wav_filename]
print(f"Running ffmpeg command: {' '.join(cmd)}")
result = subprocess.run(
cmd,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
check=True
)
if os.path.exists(wav_filename) and os.path.getsize(wav_filename) > 100:
print(f"WAV file successfully created with ffmpeg: {wav_filename}, size: {os.path.getsize(wav_filename)}")
wav_created = True
else:
print(f"ffmpeg ran but WAV file is missing or too small: {wav_filename}")
except Exception as e:
print(f"ffmpeg conversion failed: {str(e)}")
# Method 2: Try pydub if ffmpeg failed
if not wav_created:
try:
from pydub import AudioSegment
print("Converting MP3 to WAV using pydub...")
sound = AudioSegment.from_mp3(mp3_filename)
sound = sound.set_frame_rate(24000).set_channels(1)
sound.export(wav_filename, format="wav")
if os.path.exists(wav_filename) and os.path.getsize(wav_filename) > 100:
print(f"WAV file successfully created with pydub: {wav_filename}, size: {os.path.getsize(wav_filename)}")
wav_created = True
else:
print(f"pydub ran but WAV file is missing or too small")
except Exception as e:
print(f"pydub conversion failed: {str(e)}")
# Method 3: Direct WAV creation with gTTS-like library (last resort)
if not wav_created:
try:
import numpy as np
from scipy.io import wavfile
print("Generating synthetic speech directly...")
# Generate a simple speech-like tone pattern
sample_rate = 24000
duration = len(text) * 0.075 # Approx timing
t = np.linspace(0, duration, int(sample_rate * duration), endpoint=False)
# Create a speech-like tone with some variation
frequencies = [220, 440, 330, 550]
audio = np.zeros_like(t)
for i, freq in enumerate(frequencies):
audio += 0.2 * np.sin(2 * np.pi * freq * t + i)
# Add some envelope
envelope = np.ones_like(t)
attack = int(0.01 * sample_rate)
release = int(0.1 * sample_rate)
envelope[:attack] = np.linspace(0, 1, attack)
envelope[-release:] = np.linspace(1, 0, release)
audio = audio * envelope
# Normalize and convert to int16
audio = audio / np.max(np.abs(audio))
audio = (audio * 32767).astype(np.int16)
# Save as WAV
wavfile.write(wav_filename, sample_rate, audio)
if os.path.exists(wav_filename) and os.path.getsize(wav_filename) > 100:
print(f"WAV file successfully created directly: {wav_filename}, size: {os.path.getsize(wav_filename)}")
wav_created = True
except Exception as e:
print(f"Direct WAV creation failed: {str(e)}")
# Read the WAV file if it was created
if wav_created:
try:
# Add a small delay to ensure the file is fully written
time.sleep(0.1)
# Read WAV file with scipy
print(f"Reading WAV file: {wav_filename}")
sample_rate, audio_data = wavfile.read(wav_filename)
# Convert to expected format
audio_data = audio_data.reshape(1, -1).astype(np.int16)
print(f"WAV file read successfully, shape: {audio_data.shape}, sample rate: {sample_rate}")
return (sample_rate, audio_data)
except Exception as e:
print(f"Error reading WAV file: {str(e)}")
# If all else fails, generate a simple tone
print("All methods failed. Falling back to synthetic audio tone")
sample_rate = 24000
duration_sec = max(1, len(text) * 0.1)
tone_length = int(sample_rate * duration_sec)
audio_data = np.sin(2 * np.pi * np.arange(tone_length) * 440 / sample_rate)
audio_data = (audio_data * 32767).astype(np.int16)
audio_data = audio_data.reshape(1, -1)
return (sample_rate, audio_data)
except Exception as e:
print(f"Unexpected error in text-to-speech: {str(e)}")
# Generate a simple tone as last resort
sample_rate = 24000
audio_data = np.sin(2 * np.pi * np.arange(sample_rate) * 440 / sample_rate)
audio_data = (audio_data * 32767).astype(np.int16)
audio_data = audio_data.reshape(1, -1)
return (sample_rate, audio_data)
finally:
# Clean up temporary files
for filename in [mp3_filename, wav_filename]:
try:
if os.path.exists(filename):
os.remove(filename)
except Exception as e:
print(f"Failed to remove temporary file {filename}: {str(e)}")
# Initialize models
print("Loading ASR model...")
asr_pipeline = load_asr_model()
print("Loading LLM model...")
llm_model, llm_tokenizer = load_llm_model()
# Chat history management
chat_history = []
def generate_response(prompt):
# If chat history is empty, add a system message
if not chat_history:
chat_history.append({"role": "system", "content": "You are a helpful, friendly AI assistant. Keep your responses concise and conversational."})
# Add user message to history
chat_history.append({"role": "user", "content": prompt})
# Prepare input for the model
full_prompt = ""
for message in chat_history:
if message["role"] == "system":
full_prompt += f"System: {message['content']}\n"
elif message["role"] == "user":
full_prompt += f"User: {message['content']}\n"
elif message["role"] == "assistant":
full_prompt += f"Assistant: {message['content']}\n"
full_prompt += "Assistant: "
# Instead of using the tokenizer to create inputs with padding,
# let's prepare the inputs differently:
input_ids = llm_tokenizer.encode(full_prompt, return_tensors='pt')
# Create attention mask manually (all 1's)
attention_mask = torch.ones_like(input_ids)
# Move to device
input_ids = input_ids.to(device)
attention_mask = attention_mask.to(device)
# Generate response with completely explicit parameters
with torch.no_grad():
output = llm_model.generate(
input_ids=input_ids,
attention_mask=attention_mask,
max_new_tokens=128,
do_sample=True,
temperature=0.7,
top_p=0.9,
pad_token_id=llm_tokenizer.pad_token_id,
eos_token_id=llm_tokenizer.eos_token_id,
use_cache=True,
no_repeat_ngram_size=3
)
response_text = llm_tokenizer.decode(output[0], skip_special_tokens=True)
response_text = response_text.split("Assistant: ")[-1].strip()
# Add assistant response to history
chat_history.append({"role": "assistant", "content": response_text})
# Keep history at a reasonable size
if len(chat_history) > 10:
# Keep system message and last 9 exchanges
chat_history.pop(1)
return response_text
def response(audio: tuple[int, np.ndarray]):
# Step 1: Convert audio to float32 before passing to ASR
sample_rate, audio_data = audio
# Convert int16 audio to float32
audio_float32 = audio_data.flatten().astype(np.float32) / 32768.0 # Normalize to [-1.0, 1.0]
# Speech-to-Text with correct data type
transcript = asr_pipeline({
"sampling_rate": sample_rate,
"raw": audio_float32
})
prompt = transcript["text"]
print(f"Transcribed: {prompt}")
# Step 2: Generate text response
response_text = generate_response(prompt)
print(f"Response: {response_text}")
# Step 3: Text-to-Speech using gTTS
sample_rate, audio_array = gtts_text_to_speech(response_text)
# Convert to expected format and yield chunks
chunk_size = int(sample_rate * 0.2) # 200ms chunks
for i in range(0, audio_array.shape[1], chunk_size):
chunk = audio_array[:, i:i+chunk_size]
if chunk.size > 0: # Ensure we don't yield empty chunks
yield (sample_rate, chunk)
stream = Stream(
modality="audio",
mode="send-receive",
handler=ReplyOnPause(response),
)
# For testing without WebRTC
def demo():
with gr.Blocks() as demo:
gr.Markdown("# Local Voice Chatbot")
audio_input = gr.Audio(sources=["microphone"], type="numpy")
audio_output = gr.Audio()
def process_audio(audio):
if audio is None:
return None
sample_rate, audio_array = audio
# Convert to float32 for ASR
audio_float32 = audio_array.flatten().astype(np.float32) / 32768.0
transcript = asr_pipeline({
"sampling_rate": sample_rate,
"raw": audio_float32
})
prompt = transcript["text"]
print(f"Transcribed: {prompt}")
response_text = generate_response(prompt)
print(f"Response: {response_text}")
sample_rate, audio_array = gtts_text_to_speech(response_text)
return (sample_rate, audio_array[0])
audio_input.change(process_audio, inputs=[audio_input], outputs=[audio_output])
demo.launch()
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser()
parser.add_argument("--demo", action="store_true", help="Run Gradio demo instead of WebRTC")
args = parser.parse_args()
# hugging face issues
demo()
# if args.demo:
# demo()
# else:
# # For running with FastRTC
# # You would need to add your FastRTC server code here
# pass