import asyncio import logging import os import re import gradio as gr import numpy as np from cleantext import clean from dotenv import load_dotenv from fastrtc import ( AdditionalOutputs, AlgoOptions, ReplyOnPause, SileroVadOptions, Stream, audio_to_bytes, get_stt_model, get_tts_model, ) from llama_index.core.workflow import Context from num2words import num2words from openai import OpenAI from scipy import signal from transformers.models.auto.modeling_auto import AutoModelForSpeechSeq2Seq from transformers.models.auto.processing_auto import AutoProcessor from transformers.pipelines import pipeline from transformers.utils.import_utils import is_flash_attn_2_available from chatbot import agent, get_chat_history, update_chat_history from transcription import resample_audio, warmup_model from utils.device import get_device, get_torch_and_np_dtypes # Configure logging logging.basicConfig( level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s", datefmt="%Y-%m-%d %H:%M:%S", ) logger = logging.getLogger(__name__) device = get_device(force_cpu=False) torch_dtype, np_dtype = get_torch_and_np_dtypes(device, use_bfloat16=False) logger.info(f"Using device: {device}, torch_dtype: {torch_dtype}, np_dtype: {np_dtype}") attention = "flash_attention_2" if is_flash_attn_2_available() else "sdpa" logger.info(f"Using attention: {attention}") stt_model_name = "openai/whisper-large-v2" try: model = AutoModelForSpeechSeq2Seq.from_pretrained( stt_model_name, torch_dtype=torch_dtype, low_cpu_mem_usage=True, use_safetensors=True, attn_implementation=attention, ) model.to(device) except Exception as e: logger.error(f"Error loading ASR model: {e}") logger.error(f"Are you providing a valid model ID? {stt_model_name}") raise processor = AutoProcessor.from_pretrained(stt_model_name) # Create a custom prompt to guide the model initial_prompt = "LuxDev, Sasan, Jafarnejad, LUXDEV" prompt_ids = processor.get_prompt_ids(initial_prompt, return_tensors="pt").to(device) # warmup_model(processor, model, device, np_dtype, torch_dtype, logger) # Load environment variables from .env file load_dotenv() logger.info("Environment variables loaded") sambanova_client = OpenAI( api_key=os.getenv("OPENAI_API_KEY"), ) logger.info("OpenAI client initialized") tts_model = get_tts_model() logger.info("STT and TTS models initialized") # Create context - moved before the listen function ctx = Context(agent) async def listen(audio: tuple[int, np.ndarray]): sample_rate, audio_array = audio logger.info(f"Sample rate: {sample_rate}Hz, Shape: {audio_array.shape}") # Resample audio to 16kHz if needed audio_array, sample_rate = resample_audio(audio_array, sample_rate) # Process audio input input_features = processor( audio_array, sampling_rate=sample_rate, return_tensors="pt" ).input_features input_features = input_features.to( device=device, dtype=torch_dtype ) # Convert to correct dtype # Generate transcription predicted_ids = model.generate( input_features, # task="transcribe", # language="english", max_length=448, num_beams=5, temperature=0.0, no_repeat_ngram_size=3, length_penalty=1.0, repetition_penalty=1.0, # Use the prompt tokens directly prompt_ids=prompt_ids, ) # Decode the transcription transcription = processor.batch_decode(predicted_ids, skip_special_tokens=True)[ 0 ].strip() logger.info(f"Transcript: {transcription}") # Check if transcription is empty or too short if not transcription or len(transcription.strip()) < 2: logger.info("Empty or too short transcription, skipping processing") return logger.info("Sending request to OpenAI") try: full_response = await agent.run( transcription, ctx=ctx, ) response = full_response.response.content # Update chat history update_chat_history(transcription, response) logger.info(f"OpenAI response: {response}") if response is None or not response.strip(): logger.warning("Received empty response from OpenAI") return # Preprocess the text for TTS tts_text = preprocess_text_for_tts(response) logger.info(f"Preprocessed text for TTS: {tts_text}") # Check if preprocessed text is empty if not tts_text or not tts_text.strip(): logger.warning("Preprocessed TTS text is empty") return logger.info("Starting TTS streaming") try: chunk_count = 0 async for audio_chunk in tts_model.stream_tts(tts_text): chunk_count += 1 # Add a small delay to prevent overwhelming the connection if chunk_count % 10 == 0: await asyncio.sleep(0.01) yield audio_chunk, AdditionalOutputs(transcription, response) logger.info(f"TTS streaming completed with {chunk_count} chunks") except Exception as e: logger.error(f"TTS streaming error: {e}") # Return empty audio chunk if TTS fails yield (16000, np.array([], dtype=np.float32)), AdditionalOutputs( transcription, f"Error in text-to-speech: {str(e)}" ) except Exception as e: logger.error(f"Error in agent processing: {e}") yield (16000, np.array([], dtype=np.float32)), AdditionalOutputs( transcription, f"Error processing request: {str(e)}" ) def preprocess_text_for_tts(text): """ Preprocess text to make it more suitable for TTS using specialized libraries. """ # Remove markdown formatting with more robust patterns # First, handle the most common cases with a more reliable approach # Remove ** when they appear in pairs (bold text) - handle nested cases while "**" in text: # Find pairs of ** and remove them start = text.find("**") if start == -1: break end = text.find("**", start + 2) if end == -1: break # Extract the content between ** and replace the whole pattern content = text[start + 2 : end] text = text[:start] + content + text[end + 2 :] # Remove * when they appear in pairs (italic text) - but be careful not to remove single * # We need to be more careful here to avoid removing legitimate asterisks # Only remove * if it's clearly markdown formatting while "*" in text: start = text.find("*") if start == -1: break # Look for the next * that's not part of ** end = text.find("*", start + 1) if end == -1: break # Check if this is part of a ** pattern (already handled above) if start > 0 and text[start - 1] == "*": break if end + 1 < len(text) and text[end + 1] == "*": break # Extract the content between * and replace the whole pattern content = text[start + 1 : end] text = text[:start] + content + text[end + 1 :] # Remove __ when they appear in pairs (bold text) while "__" in text: start = text.find("__") if start == -1: break end = text.find("__", start + 2) if end == -1: break content = text[start + 2 : end] text = text[:start] + content + text[end + 2 :] # Remove _ when they appear in pairs (italic text) - but be careful while "_" in text: start = text.find("_") if start == -1: break end = text.find("_", start + 1) if end == -1: break # Check if this is part of a __ pattern (already handled above) if start > 0 and text[start - 1] == "_": break if end + 1 < len(text) and text[end + 1] == "_": break content = text[start + 1 : end] text = text[:start] + content + text[end + 1 :] # Remove ` when they appear in pairs (inline code) while "`" in text: start = text.find("`") if start == -1: break end = text.find("`", start + 1) if end == -1: break content = text[start + 1 : end] text = text[:start] + content + text[end + 1 :] # Remove # at the beginning of lines (headers) lines = text.split("\n") cleaned_lines = [] for line in lines: # Remove leading # characters cleaned_line = line.lstrip("#").lstrip() cleaned_lines.append(cleaned_line) text = "\n".join(cleaned_lines) # Remove markdown links [text](url) -> text while "[" in text and "](" in text: start_bracket = text.find("[") if start_bracket == -1: break end_bracket = text.find("]", start_bracket) if end_bracket == -1: break start_paren = text.find("(", end_bracket) if start_paren == -1: break end_paren = text.find(")", start_paren) if end_paren == -1: break # Extract the link text link_text = text[start_bracket + 1 : end_bracket] # Replace the entire [text](url) with just the text text = text[:start_bracket] + link_text + text[end_paren + 1 :] # Clean the text text = clean( text, fix_unicode=True, to_ascii=True, lower=False, no_line_breaks=False, no_urls=True, no_emails=True, no_phone_numbers=True, no_numbers=False, no_digits=False, no_currency_symbols=True, no_punct=False, replace_with_punct="", replace_with_url="", replace_with_email="", replace_with_phone_number="", replace_with_number="", replace_with_digit="", replace_with_currency_symbol="", ) # Convert numbers to words def replace_numbers(match): try: return num2words(float(match.group()), lang="en") except: return match.group() text = re.sub(r"\b\d+(?:\.\d+)?\b", replace_numbers, text) logger.debug(f"Preprocessed text: {text}") return text.strip() logger.info("Initializing Stream with ReplyOnPause") # stream = Stream(ReplyOnPause(echo), modality="audio", mode="send-receive") logger.info("Initializing FastRTC stream") stream = Stream( handler=ReplyOnPause( listen, algo_options=AlgoOptions( # Duration in seconds of audio chunks (default 0.6) audio_chunk_duration=0.8, # If the chunk has more than started_talking_threshold seconds of speech, the user started talking (default 0.2) started_talking_threshold=0.3, # If, after the user started speaking, there is a chunk with less than speech_threshold seconds of speech, the user stopped speaking. (default 0.1) speech_threshold=0.8, ), model_options=SileroVadOptions( # Threshold for what is considered speech (default 0.5) threshold=0.6, # Final speech chunks shorter min_speech_duration_ms are thrown out (default 250) min_speech_duration_ms=500, # # Max duration of speech chunks, longer will be split (default float('inf')) max_speech_duration_s=30, # Wait for ms at the end of each speech chunk before separating it (default 2000) min_silence_duration_ms=1500, # # Chunk size for VAD model. Can be 512, 1024, 1536 for 16k s.r. (default 1024) window_size_samples=1024, # # Final speech chunks are padded by speech_pad_ms each side (default 400) speech_pad_ms=300, ), ), # send-receive: bidirectional streaming (default) # send: client to server only # receive: server to client only modality="audio", mode="send-receive", concurrency_limit=1, # Limit to one connection at a time additional_outputs=[ gr.Textbox(label="Transcript"), gr.Textbox(label="Chatbot Response"), ], additional_outputs_handler=lambda current_transcript, current_response, new_transcript, new_response: ( ( (current_transcript + " " + new_transcript) if current_transcript else new_transcript ), new_response, # Replace chatbot response with the latest one ), # rtc_configuration=get_rtc_credentials(provider="hf") if APP_MODE == "deployed" else None ui_args={ "title": "Oracle Voice Chatbot", "subtitle": "Ask me anything", }, ) # Create custom Blocks with CSS and title custom_css = """ .footer { display: none !important; } .gradio-container::after { content: "Made in Luxembourg 🇱🇺"; display: block; text-align: center; padding: 10px; color: #666; font-size: 12px; border-top: 1px solid #e0e0e0; margin-top: 20px; } """ # Set custom CSS and title on the existing UI stream.ui.css = custom_css stream.ui.title = "Oracle Voice Chatbot" logger.info("Launching UI") stream.ui.launch( app_kwargs={ "title": "Oracle Voice Chatbot", "docs_url": None, "redoc_url": None, } )