oracle-demo / ui.py
sasan's picture
Upload folder using huggingface_hub
617df14 verified
import asyncio
import logging
import os
import re
import gradio as gr
import numpy as np
from cleantext import clean
from dotenv import load_dotenv
from fastrtc import (
AdditionalOutputs,
AlgoOptions,
ReplyOnPause,
SileroVadOptions,
Stream,
audio_to_bytes,
get_stt_model,
get_tts_model,
)
from llama_index.core.workflow import Context
from num2words import num2words
from openai import OpenAI
from scipy import signal
from transformers.models.auto.modeling_auto import AutoModelForSpeechSeq2Seq
from transformers.models.auto.processing_auto import AutoProcessor
from transformers.pipelines import pipeline
from transformers.utils.import_utils import is_flash_attn_2_available
from chatbot import agent, get_chat_history, update_chat_history
from transcription import resample_audio, warmup_model
from utils.device import get_device, get_torch_and_np_dtypes
# Configure logging
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s - %(levelname)s - %(message)s",
datefmt="%Y-%m-%d %H:%M:%S",
)
logger = logging.getLogger(__name__)
device = get_device(force_cpu=False)
torch_dtype, np_dtype = get_torch_and_np_dtypes(device, use_bfloat16=False)
logger.info(f"Using device: {device}, torch_dtype: {torch_dtype}, np_dtype: {np_dtype}")
attention = "flash_attention_2" if is_flash_attn_2_available() else "sdpa"
logger.info(f"Using attention: {attention}")
stt_model_name = "openai/whisper-large-v2"
try:
model = AutoModelForSpeechSeq2Seq.from_pretrained(
stt_model_name,
torch_dtype=torch_dtype,
low_cpu_mem_usage=True,
use_safetensors=True,
attn_implementation=attention,
)
model.to(device)
except Exception as e:
logger.error(f"Error loading ASR model: {e}")
logger.error(f"Are you providing a valid model ID? {stt_model_name}")
raise
processor = AutoProcessor.from_pretrained(stt_model_name)
# Create a custom prompt to guide the model
initial_prompt = "LuxDev, Sasan, Jafarnejad, LUXDEV"
prompt_ids = processor.get_prompt_ids(initial_prompt, return_tensors="pt").to(device)
# warmup_model(processor, model, device, np_dtype, torch_dtype, logger)
# Load environment variables from .env file
load_dotenv()
logger.info("Environment variables loaded")
sambanova_client = OpenAI(
api_key=os.getenv("OPENAI_API_KEY"),
)
logger.info("OpenAI client initialized")
tts_model = get_tts_model()
logger.info("STT and TTS models initialized")
# Create context - moved before the listen function
ctx = Context(agent)
async def listen(audio: tuple[int, np.ndarray]):
sample_rate, audio_array = audio
logger.info(f"Sample rate: {sample_rate}Hz, Shape: {audio_array.shape}")
# Resample audio to 16kHz if needed
audio_array, sample_rate = resample_audio(audio_array, sample_rate)
# Process audio input
input_features = processor(
audio_array, sampling_rate=sample_rate, return_tensors="pt"
).input_features
input_features = input_features.to(
device=device, dtype=torch_dtype
) # Convert to correct dtype
# Generate transcription
predicted_ids = model.generate(
input_features,
# task="transcribe",
# language="english",
max_length=448,
num_beams=5,
temperature=0.0,
no_repeat_ngram_size=3,
length_penalty=1.0,
repetition_penalty=1.0,
# Use the prompt tokens directly
prompt_ids=prompt_ids,
)
# Decode the transcription
transcription = processor.batch_decode(predicted_ids, skip_special_tokens=True)[
0
].strip()
logger.info(f"Transcript: {transcription}")
# Check if transcription is empty or too short
if not transcription or len(transcription.strip()) < 2:
logger.info("Empty or too short transcription, skipping processing")
return
logger.info("Sending request to OpenAI")
try:
full_response = await agent.run(
transcription,
ctx=ctx,
)
response = full_response.response.content
# Update chat history
update_chat_history(transcription, response)
logger.info(f"OpenAI response: {response}")
if response is None or not response.strip():
logger.warning("Received empty response from OpenAI")
return
# Preprocess the text for TTS
tts_text = preprocess_text_for_tts(response)
logger.info(f"Preprocessed text for TTS: {tts_text}")
# Check if preprocessed text is empty
if not tts_text or not tts_text.strip():
logger.warning("Preprocessed TTS text is empty")
return
logger.info("Starting TTS streaming")
try:
chunk_count = 0
async for audio_chunk in tts_model.stream_tts(tts_text):
chunk_count += 1
# Add a small delay to prevent overwhelming the connection
if chunk_count % 10 == 0:
await asyncio.sleep(0.01)
yield audio_chunk, AdditionalOutputs(transcription, response)
logger.info(f"TTS streaming completed with {chunk_count} chunks")
except Exception as e:
logger.error(f"TTS streaming error: {e}")
# Return empty audio chunk if TTS fails
yield (16000, np.array([], dtype=np.float32)), AdditionalOutputs(
transcription, f"Error in text-to-speech: {str(e)}"
)
except Exception as e:
logger.error(f"Error in agent processing: {e}")
yield (16000, np.array([], dtype=np.float32)), AdditionalOutputs(
transcription, f"Error processing request: {str(e)}"
)
def preprocess_text_for_tts(text):
"""
Preprocess text to make it more suitable for TTS using specialized libraries.
"""
# Remove markdown formatting with more robust patterns
# First, handle the most common cases with a more reliable approach
# Remove ** when they appear in pairs (bold text) - handle nested cases
while "**" in text:
# Find pairs of ** and remove them
start = text.find("**")
if start == -1:
break
end = text.find("**", start + 2)
if end == -1:
break
# Extract the content between ** and replace the whole pattern
content = text[start + 2 : end]
text = text[:start] + content + text[end + 2 :]
# Remove * when they appear in pairs (italic text) - but be careful not to remove single *
# We need to be more careful here to avoid removing legitimate asterisks
# Only remove * if it's clearly markdown formatting
while "*" in text:
start = text.find("*")
if start == -1:
break
# Look for the next * that's not part of **
end = text.find("*", start + 1)
if end == -1:
break
# Check if this is part of a ** pattern (already handled above)
if start > 0 and text[start - 1] == "*":
break
if end + 1 < len(text) and text[end + 1] == "*":
break
# Extract the content between * and replace the whole pattern
content = text[start + 1 : end]
text = text[:start] + content + text[end + 1 :]
# Remove __ when they appear in pairs (bold text)
while "__" in text:
start = text.find("__")
if start == -1:
break
end = text.find("__", start + 2)
if end == -1:
break
content = text[start + 2 : end]
text = text[:start] + content + text[end + 2 :]
# Remove _ when they appear in pairs (italic text) - but be careful
while "_" in text:
start = text.find("_")
if start == -1:
break
end = text.find("_", start + 1)
if end == -1:
break
# Check if this is part of a __ pattern (already handled above)
if start > 0 and text[start - 1] == "_":
break
if end + 1 < len(text) and text[end + 1] == "_":
break
content = text[start + 1 : end]
text = text[:start] + content + text[end + 1 :]
# Remove ` when they appear in pairs (inline code)
while "`" in text:
start = text.find("`")
if start == -1:
break
end = text.find("`", start + 1)
if end == -1:
break
content = text[start + 1 : end]
text = text[:start] + content + text[end + 1 :]
# Remove # at the beginning of lines (headers)
lines = text.split("\n")
cleaned_lines = []
for line in lines:
# Remove leading # characters
cleaned_line = line.lstrip("#").lstrip()
cleaned_lines.append(cleaned_line)
text = "\n".join(cleaned_lines)
# Remove markdown links [text](url) -> text
while "[" in text and "](" in text:
start_bracket = text.find("[")
if start_bracket == -1:
break
end_bracket = text.find("]", start_bracket)
if end_bracket == -1:
break
start_paren = text.find("(", end_bracket)
if start_paren == -1:
break
end_paren = text.find(")", start_paren)
if end_paren == -1:
break
# Extract the link text
link_text = text[start_bracket + 1 : end_bracket]
# Replace the entire [text](url) with just the text
text = text[:start_bracket] + link_text + text[end_paren + 1 :]
# Clean the text
text = clean(
text,
fix_unicode=True,
to_ascii=True,
lower=False,
no_line_breaks=False,
no_urls=True,
no_emails=True,
no_phone_numbers=True,
no_numbers=False,
no_digits=False,
no_currency_symbols=True,
no_punct=False,
replace_with_punct="",
replace_with_url="",
replace_with_email="",
replace_with_phone_number="",
replace_with_number="",
replace_with_digit="",
replace_with_currency_symbol="",
)
# Convert numbers to words
def replace_numbers(match):
try:
return num2words(float(match.group()), lang="en")
except:
return match.group()
text = re.sub(r"\b\d+(?:\.\d+)?\b", replace_numbers, text)
logger.debug(f"Preprocessed text: {text}")
return text.strip()
logger.info("Initializing Stream with ReplyOnPause")
# stream = Stream(ReplyOnPause(echo), modality="audio", mode="send-receive")
logger.info("Initializing FastRTC stream")
stream = Stream(
handler=ReplyOnPause(
listen,
algo_options=AlgoOptions(
# Duration in seconds of audio chunks (default 0.6)
audio_chunk_duration=0.8,
# If the chunk has more than started_talking_threshold seconds of speech, the user started talking (default 0.2)
started_talking_threshold=0.3,
# If, after the user started speaking, there is a chunk with less than speech_threshold seconds of speech, the user stopped speaking. (default 0.1)
speech_threshold=0.8,
),
model_options=SileroVadOptions(
# Threshold for what is considered speech (default 0.5)
threshold=0.6,
# Final speech chunks shorter min_speech_duration_ms are thrown out (default 250)
min_speech_duration_ms=500,
# # Max duration of speech chunks, longer will be split (default float('inf'))
max_speech_duration_s=30,
# Wait for ms at the end of each speech chunk before separating it (default 2000)
min_silence_duration_ms=1500,
# # Chunk size for VAD model. Can be 512, 1024, 1536 for 16k s.r. (default 1024)
window_size_samples=1024,
# # Final speech chunks are padded by speech_pad_ms each side (default 400)
speech_pad_ms=300,
),
),
# send-receive: bidirectional streaming (default)
# send: client to server only
# receive: server to client only
modality="audio",
mode="send-receive",
concurrency_limit=1, # Limit to one connection at a time
additional_outputs=[
gr.Textbox(label="Transcript"),
gr.Textbox(label="Chatbot Response"),
],
additional_outputs_handler=lambda current_transcript, current_response, new_transcript, new_response: (
(
(current_transcript + " " + new_transcript)
if current_transcript
else new_transcript
),
new_response, # Replace chatbot response with the latest one
),
# rtc_configuration=get_rtc_credentials(provider="hf") if APP_MODE == "deployed" else None
ui_args={
"title": "Oracle Voice Chatbot",
"subtitle": "Ask me anything",
},
)
# Create custom Blocks with CSS and title
custom_css = """
.footer {
display: none !important;
}
.gradio-container::after {
content: "Made in Luxembourg 🇱🇺";
display: block;
text-align: center;
padding: 10px;
color: #666;
font-size: 12px;
border-top: 1px solid #e0e0e0;
margin-top: 20px;
}
"""
# Set custom CSS and title on the existing UI
stream.ui.css = custom_css
stream.ui.title = "Oracle Voice Chatbot"
logger.info("Launching UI")
stream.ui.launch(
app_kwargs={
"title": "Oracle Voice Chatbot",
"docs_url": None,
"redoc_url": None,
}
)