from fastrtc import ( ReplyOnPause, AdditionalOutputs, Stream, audio_to_bytes, aggregate_bytes_to_16bit ) import gradio as gr import time import numpy as np import torch import os import tempfile from transformers import ( AutoModelForSpeechSeq2Seq, AutoProcessor, pipeline, AutoTokenizer, AutoModelForCausalLM ) from gtts import gTTS from scipy.io import wavfile # Check if CUDA is available, otherwise use CPU device = "cuda" if torch.cuda.is_available() else "cpu" torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32 # Step 1: Audio transcription with Whisper def load_asr_model(): model_id = "openai/whisper-small" model = AutoModelForSpeechSeq2Seq.from_pretrained( model_id, torch_dtype=torch_dtype, low_cpu_mem_usage=True, use_safetensors=True ) model.to(device) processor = AutoProcessor.from_pretrained(model_id) return pipeline( "automatic-speech-recognition", model=model, tokenizer=processor.tokenizer, feature_extractor=processor.feature_extractor, max_new_tokens=128, chunk_length_s=30, batch_size=16, return_timestamps=False, torch_dtype=torch_dtype, device=device, ) # Step 2: Text generation with a smaller LLM def load_llm_model(): model_id = "facebook/opt-1.3b" # Load tokenizer with special attention to the padding token tokenizer = AutoTokenizer.from_pretrained(model_id) # Print initial configuration print(f"Initial pad token ID: {tokenizer.pad_token_id}, EOS token ID: {tokenizer.eos_token_id}") # For OPT models specifically - configure tokenizer before loading model if tokenizer.pad_token is None: # Use a completely different token as pad token - must be done before model loading tokenizer.add_special_tokens({'pad_token': '[PAD]'}) # Ensure pad token is really different from EOS token assert tokenizer.pad_token_id != tokenizer.eos_token_id, "Pad token still same as EOS token!" print(f"Added special PAD token with ID {tokenizer.pad_token_id} (different from EOS: {tokenizer.eos_token_id})") # Load model with the knowledge that tokenizer may have been modified model = AutoModelForCausalLM.from_pretrained( model_id, torch_dtype=torch_dtype, low_cpu_mem_usage=True ) # Resize embeddings to match tokenizer model.resize_token_embeddings(len(tokenizer)) # CRITICAL: Make sure model config knows about the pad token model.config.pad_token_id = tokenizer.pad_token_id # OPT models need this explicit configuration if hasattr(model.config, "word_embed_proj_dim"): model.config._remove_wrong_keys = False # Move model to device model.to(device) print(f"Final token setup - Pad token: '{tokenizer.pad_token}' (ID: {tokenizer.pad_token_id})") print(f"Model config pad_token_id: {model.config.pad_token_id}") return model, tokenizer # Step 3: Text-to-Speech with gTTS (Google Text-to-Speech) def gtts_text_to_speech(text): """Convert text to speech using gTTS and ensure proper WAV format.""" # Import numpy and wavfile at the function level to ensure they're available in all code paths import numpy as np from scipy.io import wavfile # Create absolute paths for temporary files temp_dir = tempfile.gettempdir() mp3_filename = os.path.join(temp_dir, f"tts_temp_{os.getpid()}_{time.time()}.mp3") wav_filename = os.path.join(temp_dir, f"tts_temp_{os.getpid()}_{time.time()}.wav") try: # Make sure text is not empty if not text or text.isspace(): text = "I don't have a response for that." # Create gTTS object and save to MP3 tts = gTTS(text=text, lang='en', slow=False) tts.save(mp3_filename) print(f"MP3 file created: {mp3_filename}, size: {os.path.getsize(mp3_filename)}") # Try multiple methods to convert MP3 to WAV wav_created = False # Method 1: Try ffmpeg (most reliable) try: import subprocess cmd = ['ffmpeg', '-y', '-i', mp3_filename, '-acodec', 'pcm_s16le', '-ar', '24000', '-ac', '1', wav_filename] print(f"Running ffmpeg command: {' '.join(cmd)}") result = subprocess.run( cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, check=True ) if os.path.exists(wav_filename) and os.path.getsize(wav_filename) > 100: print(f"WAV file successfully created with ffmpeg: {wav_filename}, size: {os.path.getsize(wav_filename)}") wav_created = True else: print(f"ffmpeg ran but WAV file is missing or too small: {wav_filename}") except Exception as e: print(f"ffmpeg conversion failed: {str(e)}") # Method 2: Try pydub if ffmpeg failed if not wav_created: try: from pydub import AudioSegment print("Converting MP3 to WAV using pydub...") sound = AudioSegment.from_mp3(mp3_filename) sound = sound.set_frame_rate(24000).set_channels(1) sound.export(wav_filename, format="wav") if os.path.exists(wav_filename) and os.path.getsize(wav_filename) > 100: print(f"WAV file successfully created with pydub: {wav_filename}, size: {os.path.getsize(wav_filename)}") wav_created = True else: print(f"pydub ran but WAV file is missing or too small") except Exception as e: print(f"pydub conversion failed: {str(e)}") # Method 3: Direct WAV creation if not wav_created: try: print("Generating synthetic speech directly...") # Generate a simple speech-like tone pattern sample_rate = 24000 duration = len(text) * 0.075 # Approx timing t = np.linspace(0, duration, int(sample_rate * duration), endpoint=False) # Create a speech-like tone with some variation frequencies = [220, 440, 330, 550] audio = np.zeros_like(t) for i, freq in enumerate(frequencies): audio += 0.2 * np.sin(2 * np.pi * freq * t + i) # Add some envelope envelope = np.ones_like(t) attack = int(0.01 * sample_rate) release = int(0.1 * sample_rate) envelope[:attack] = np.linspace(0, 1, attack) envelope[-release:] = np.linspace(1, 0, release) audio = audio * envelope # Normalize and convert to int16 audio = audio / np.max(np.abs(audio)) audio = (audio * 32767).astype(np.int16) # Save as WAV wavfile.write(wav_filename, sample_rate, audio) if os.path.exists(wav_filename) and os.path.getsize(wav_filename) > 100: print(f"WAV file successfully created directly: {wav_filename}, size: {os.path.getsize(wav_filename)}") wav_created = True except Exception as e: print(f"Direct WAV creation failed: {str(e)}") # Read the WAV file if it was created if wav_created: try: # Add a small delay to ensure the file is fully written time.sleep(0.1) # Read WAV file with scipy print(f"Reading WAV file: {wav_filename}") sample_rate, audio_data = wavfile.read(wav_filename) # Convert to expected format audio_data = audio_data.reshape(1, -1).astype(np.int16) print(f"WAV file read successfully, shape: {audio_data.shape}, sample rate: {sample_rate}") return (sample_rate, audio_data) except Exception as e: print(f"Error reading WAV file: {str(e)}") # If all else fails, generate a simple tone print("All methods failed. Falling back to synthetic audio tone") sample_rate = 24000 duration_sec = max(1, len(text) * 0.1) tone_length = int(sample_rate * duration_sec) audio_data = np.sin(2 * np.pi * np.arange(tone_length) * 440 / sample_rate) audio_data = (audio_data * 32767).astype(np.int16) audio_data = audio_data.reshape(1, -1) return (sample_rate, audio_data) except Exception as e: print(f"Unexpected error in text-to-speech: {str(e)}") # Generate a simple tone as last resort sample_rate = 24000 audio_data = np.sin(2 * np.pi * np.arange(sample_rate) * 440 / sample_rate) audio_data = (audio_data * 32767).astype(np.int16) audio_data = audio_data.reshape(1, -1) return (sample_rate, audio_data) finally: # Clean up temporary files for filename in [mp3_filename, wav_filename]: try: if os.path.exists(filename): os.remove(filename) except Exception as e: print(f"Failed to remove temporary file {filename}: {str(e)}") # Initialize models print("Loading ASR model...") asr_pipeline = load_asr_model() print("Loading LLM model...") llm_model, llm_tokenizer = load_llm_model() # Chat history management chat_history = [] def generate_response(prompt): # If chat history is empty, add a system message if not chat_history: chat_history.append({"role": "system", "content": "You are a helpful, friendly AI assistant. Keep your responses concise and conversational."}) # Add user message to history chat_history.append({"role": "user", "content": prompt}) # Build full prompt from chat history full_prompt = "" for message in chat_history: if message["role"] == "system": full_prompt += f"System: {message['content']}\n" elif message["role"] == "user": full_prompt += f"User: {message['content']}\n" elif message["role"] == "assistant": full_prompt += f"Assistant: {message['content']}\n" full_prompt += "Assistant: " # Use encode_plus which offers more control encoded_input = llm_tokenizer.encode_plus( full_prompt, return_tensors="pt", padding=False, # Don't pad here - we'll handle it manually add_special_tokens=True, return_attention_mask=True ) # Extract and move tensors to device input_ids = encoded_input["input_ids"].to(device) # Create attention mask explicitly - all 1s for a non-padded sequence attention_mask = torch.ones_like(input_ids).to(device) # Print for debugging print(f"Input shape: {input_ids.shape}, Attention mask shape: {attention_mask.shape}") # Generate with very explicit parameters for OPT models with torch.no_grad(): try: output = llm_model.generate( input_ids=input_ids, attention_mask=attention_mask, # Explicitly pass attention mask max_new_tokens=128, do_sample=True, temperature=0.7, top_p=0.9, pad_token_id=llm_tokenizer.pad_token_id, # Explicitly set pad token ID eos_token_id=llm_tokenizer.eos_token_id, # Explicitly set EOS token ID use_cache=True, no_repeat_ngram_size=3, # Add these parameters specifically for OPT forced_bos_token_id=None, forced_eos_token_id=None, num_beams=1 # Simple greedy decoding with temperature ) except Exception as e: print(f"Error during generation: {e}") # Fallback with simpler parameters output = llm_model.generate( input_ids=input_ids, max_new_tokens=128, do_sample=True, temperature=0.7 ) # Decode only the generated part (not the input) response_text = llm_tokenizer.decode(output[0], skip_special_tokens=True) response_text = response_text.split("Assistant: ")[-1].strip() # Add assistant response to history chat_history.append({"role": "assistant", "content": response_text}) # Keep history manageable if len(chat_history) > 10: # Keep system message and last 9 exchanges chat_history.pop(1) return response_text def response(audio: tuple[int, np.ndarray]): # Step 1: Convert audio to float32 before passing to ASR sample_rate, audio_data = audio # Convert int16 audio to float32 audio_float32 = audio_data.flatten().astype(np.float32) / 32768.0 # Normalize to [-1.0, 1.0] # Speech-to-Text with correct data type transcript = asr_pipeline({ "sampling_rate": sample_rate, "raw": audio_float32 }) prompt = transcript["text"] print(f"Transcribed: {prompt}") # Step 2: Generate text response response_text = generate_response(prompt) print(f"Response: {response_text}") # Step 3: Text-to-Speech using gTTS sample_rate, audio_array = gtts_text_to_speech(response_text) # Convert to expected format and yield chunks chunk_size = int(sample_rate * 0.2) # 200ms chunks for i in range(0, audio_array.shape[1], chunk_size): chunk = audio_array[:, i:i+chunk_size] if chunk.size > 0: # Ensure we don't yield empty chunks yield (sample_rate, chunk) stream = Stream( modality="audio", mode="send-receive", handler=ReplyOnPause(response), ) # For testing without WebRTC def demo(): with gr.Blocks() as demo: gr.Markdown("# Local Voice Chatbot") audio_input = gr.Audio(sources=["microphone"], type="numpy") audio_output = gr.Audio() def process_audio(audio): if audio is None: return None sample_rate, audio_array = audio # Convert to float32 for ASR audio_float32 = audio_array.flatten().astype(np.float32) / 32768.0 transcript = asr_pipeline({ "sampling_rate": sample_rate, "raw": audio_float32 }) prompt = transcript["text"] print(f"Transcribed: {prompt}") response_text = generate_response(prompt) print(f"Response: {response_text}") sample_rate, audio_array = gtts_text_to_speech(response_text) return (sample_rate, audio_array[0]) audio_input.change(process_audio, inputs=[audio_input], outputs=[audio_output]) demo.launch() if __name__ == "__main__": import argparse parser = argparse.ArgumentParser() parser.add_argument("--demo", action="store_true", help="Run Gradio demo instead of WebRTC") args = parser.parse_args() # hugging face issues demo() # if args.demo: # demo() # else: # # For running with FastRTC # # You would need to add your FastRTC server code here # pass