Spaces:
Sleeping
Sleeping
import asyncio | |
import logging | |
import os | |
import re | |
import gradio as gr | |
import numpy as np | |
from cleantext import clean | |
from dotenv import load_dotenv | |
from fastrtc import ( | |
AdditionalOutputs, | |
AlgoOptions, | |
ReplyOnPause, | |
SileroVadOptions, | |
Stream, | |
audio_to_bytes, | |
get_stt_model, | |
get_tts_model, | |
) | |
from llama_index.core.workflow import Context | |
from num2words import num2words | |
from openai import OpenAI | |
from scipy import signal | |
from transformers.models.auto.modeling_auto import AutoModelForSpeechSeq2Seq | |
from transformers.models.auto.processing_auto import AutoProcessor | |
from transformers.pipelines import pipeline | |
from transformers.utils.import_utils import is_flash_attn_2_available | |
from chatbot import agent, get_chat_history, update_chat_history | |
from transcription import resample_audio, warmup_model | |
from utils.device import get_device, get_torch_and_np_dtypes | |
# Configure logging | |
logging.basicConfig( | |
level=logging.INFO, | |
format="%(asctime)s - %(levelname)s - %(message)s", | |
datefmt="%Y-%m-%d %H:%M:%S", | |
) | |
logger = logging.getLogger(__name__) | |
device = get_device(force_cpu=False) | |
torch_dtype, np_dtype = get_torch_and_np_dtypes(device, use_bfloat16=False) | |
logger.info(f"Using device: {device}, torch_dtype: {torch_dtype}, np_dtype: {np_dtype}") | |
attention = "flash_attention_2" if is_flash_attn_2_available() else "sdpa" | |
logger.info(f"Using attention: {attention}") | |
stt_model_name = "openai/whisper-large-v2" | |
try: | |
model = AutoModelForSpeechSeq2Seq.from_pretrained( | |
stt_model_name, | |
torch_dtype=torch_dtype, | |
low_cpu_mem_usage=True, | |
use_safetensors=True, | |
attn_implementation=attention, | |
) | |
model.to(device) | |
except Exception as e: | |
logger.error(f"Error loading ASR model: {e}") | |
logger.error(f"Are you providing a valid model ID? {stt_model_name}") | |
raise | |
processor = AutoProcessor.from_pretrained(stt_model_name) | |
# Create a custom prompt to guide the model | |
initial_prompt = "LuxDev, Sasan, Jafarnejad, LUXDEV" | |
prompt_ids = processor.get_prompt_ids(initial_prompt, return_tensors="pt").to(device) | |
# warmup_model(processor, model, device, np_dtype, torch_dtype, logger) | |
# Load environment variables from .env file | |
load_dotenv() | |
logger.info("Environment variables loaded") | |
sambanova_client = OpenAI( | |
api_key=os.getenv("OPENAI_API_KEY"), | |
) | |
logger.info("OpenAI client initialized") | |
tts_model = get_tts_model() | |
logger.info("STT and TTS models initialized") | |
# Create context - moved before the listen function | |
ctx = Context(agent) | |
async def listen(audio: tuple[int, np.ndarray]): | |
sample_rate, audio_array = audio | |
logger.info(f"Sample rate: {sample_rate}Hz, Shape: {audio_array.shape}") | |
# Resample audio to 16kHz if needed | |
audio_array, sample_rate = resample_audio(audio_array, sample_rate) | |
# Process audio input | |
input_features = processor( | |
audio_array, sampling_rate=sample_rate, return_tensors="pt" | |
).input_features | |
input_features = input_features.to( | |
device=device, dtype=torch_dtype | |
) # Convert to correct dtype | |
# Generate transcription | |
predicted_ids = model.generate( | |
input_features, | |
# task="transcribe", | |
# language="english", | |
max_length=448, | |
num_beams=5, | |
temperature=0.0, | |
no_repeat_ngram_size=3, | |
length_penalty=1.0, | |
repetition_penalty=1.0, | |
# Use the prompt tokens directly | |
prompt_ids=prompt_ids, | |
) | |
# Decode the transcription | |
transcription = processor.batch_decode(predicted_ids, skip_special_tokens=True)[ | |
0 | |
].strip() | |
logger.info(f"Transcript: {transcription}") | |
# Check if transcription is empty or too short | |
if not transcription or len(transcription.strip()) < 2: | |
logger.info("Empty or too short transcription, skipping processing") | |
return | |
logger.info("Sending request to OpenAI") | |
try: | |
full_response = await agent.run( | |
transcription, | |
ctx=ctx, | |
) | |
response = full_response.response.content | |
# Update chat history | |
update_chat_history(transcription, response) | |
logger.info(f"OpenAI response: {response}") | |
if response is None or not response.strip(): | |
logger.warning("Received empty response from OpenAI") | |
return | |
# Preprocess the text for TTS | |
tts_text = preprocess_text_for_tts(response) | |
logger.info(f"Preprocessed text for TTS: {tts_text}") | |
# Check if preprocessed text is empty | |
if not tts_text or not tts_text.strip(): | |
logger.warning("Preprocessed TTS text is empty") | |
return | |
logger.info("Starting TTS streaming") | |
try: | |
chunk_count = 0 | |
async for audio_chunk in tts_model.stream_tts(tts_text): | |
chunk_count += 1 | |
# Add a small delay to prevent overwhelming the connection | |
if chunk_count % 10 == 0: | |
await asyncio.sleep(0.01) | |
yield audio_chunk, AdditionalOutputs(transcription, response) | |
logger.info(f"TTS streaming completed with {chunk_count} chunks") | |
except Exception as e: | |
logger.error(f"TTS streaming error: {e}") | |
# Return empty audio chunk if TTS fails | |
yield (16000, np.array([], dtype=np.float32)), AdditionalOutputs( | |
transcription, f"Error in text-to-speech: {str(e)}" | |
) | |
except Exception as e: | |
logger.error(f"Error in agent processing: {e}") | |
yield (16000, np.array([], dtype=np.float32)), AdditionalOutputs( | |
transcription, f"Error processing request: {str(e)}" | |
) | |
def preprocess_text_for_tts(text): | |
""" | |
Preprocess text to make it more suitable for TTS using specialized libraries. | |
""" | |
# Remove markdown formatting with more robust patterns | |
# First, handle the most common cases with a more reliable approach | |
# Remove ** when they appear in pairs (bold text) - handle nested cases | |
while "**" in text: | |
# Find pairs of ** and remove them | |
start = text.find("**") | |
if start == -1: | |
break | |
end = text.find("**", start + 2) | |
if end == -1: | |
break | |
# Extract the content between ** and replace the whole pattern | |
content = text[start + 2 : end] | |
text = text[:start] + content + text[end + 2 :] | |
# Remove * when they appear in pairs (italic text) - but be careful not to remove single * | |
# We need to be more careful here to avoid removing legitimate asterisks | |
# Only remove * if it's clearly markdown formatting | |
while "*" in text: | |
start = text.find("*") | |
if start == -1: | |
break | |
# Look for the next * that's not part of ** | |
end = text.find("*", start + 1) | |
if end == -1: | |
break | |
# Check if this is part of a ** pattern (already handled above) | |
if start > 0 and text[start - 1] == "*": | |
break | |
if end + 1 < len(text) and text[end + 1] == "*": | |
break | |
# Extract the content between * and replace the whole pattern | |
content = text[start + 1 : end] | |
text = text[:start] + content + text[end + 1 :] | |
# Remove __ when they appear in pairs (bold text) | |
while "__" in text: | |
start = text.find("__") | |
if start == -1: | |
break | |
end = text.find("__", start + 2) | |
if end == -1: | |
break | |
content = text[start + 2 : end] | |
text = text[:start] + content + text[end + 2 :] | |
# Remove _ when they appear in pairs (italic text) - but be careful | |
while "_" in text: | |
start = text.find("_") | |
if start == -1: | |
break | |
end = text.find("_", start + 1) | |
if end == -1: | |
break | |
# Check if this is part of a __ pattern (already handled above) | |
if start > 0 and text[start - 1] == "_": | |
break | |
if end + 1 < len(text) and text[end + 1] == "_": | |
break | |
content = text[start + 1 : end] | |
text = text[:start] + content + text[end + 1 :] | |
# Remove ` when they appear in pairs (inline code) | |
while "`" in text: | |
start = text.find("`") | |
if start == -1: | |
break | |
end = text.find("`", start + 1) | |
if end == -1: | |
break | |
content = text[start + 1 : end] | |
text = text[:start] + content + text[end + 1 :] | |
# Remove # at the beginning of lines (headers) | |
lines = text.split("\n") | |
cleaned_lines = [] | |
for line in lines: | |
# Remove leading # characters | |
cleaned_line = line.lstrip("#").lstrip() | |
cleaned_lines.append(cleaned_line) | |
text = "\n".join(cleaned_lines) | |
# Remove markdown links [text](url) -> text | |
while "[" in text and "](" in text: | |
start_bracket = text.find("[") | |
if start_bracket == -1: | |
break | |
end_bracket = text.find("]", start_bracket) | |
if end_bracket == -1: | |
break | |
start_paren = text.find("(", end_bracket) | |
if start_paren == -1: | |
break | |
end_paren = text.find(")", start_paren) | |
if end_paren == -1: | |
break | |
# Extract the link text | |
link_text = text[start_bracket + 1 : end_bracket] | |
# Replace the entire [text](url) with just the text | |
text = text[:start_bracket] + link_text + text[end_paren + 1 :] | |
# Clean the text | |
text = clean( | |
text, | |
fix_unicode=True, | |
to_ascii=True, | |
lower=False, | |
no_line_breaks=False, | |
no_urls=True, | |
no_emails=True, | |
no_phone_numbers=True, | |
no_numbers=False, | |
no_digits=False, | |
no_currency_symbols=True, | |
no_punct=False, | |
replace_with_punct="", | |
replace_with_url="", | |
replace_with_email="", | |
replace_with_phone_number="", | |
replace_with_number="", | |
replace_with_digit="", | |
replace_with_currency_symbol="", | |
) | |
# Convert numbers to words | |
def replace_numbers(match): | |
try: | |
return num2words(float(match.group()), lang="en") | |
except: | |
return match.group() | |
text = re.sub(r"\b\d+(?:\.\d+)?\b", replace_numbers, text) | |
logger.debug(f"Preprocessed text: {text}") | |
return text.strip() | |
logger.info("Initializing Stream with ReplyOnPause") | |
# stream = Stream(ReplyOnPause(echo), modality="audio", mode="send-receive") | |
logger.info("Initializing FastRTC stream") | |
stream = Stream( | |
handler=ReplyOnPause( | |
listen, | |
algo_options=AlgoOptions( | |
# Duration in seconds of audio chunks (default 0.6) | |
audio_chunk_duration=0.8, | |
# If the chunk has more than started_talking_threshold seconds of speech, the user started talking (default 0.2) | |
started_talking_threshold=0.3, | |
# If, after the user started speaking, there is a chunk with less than speech_threshold seconds of speech, the user stopped speaking. (default 0.1) | |
speech_threshold=0.8, | |
), | |
model_options=SileroVadOptions( | |
# Threshold for what is considered speech (default 0.5) | |
threshold=0.6, | |
# Final speech chunks shorter min_speech_duration_ms are thrown out (default 250) | |
min_speech_duration_ms=500, | |
# # Max duration of speech chunks, longer will be split (default float('inf')) | |
max_speech_duration_s=30, | |
# Wait for ms at the end of each speech chunk before separating it (default 2000) | |
min_silence_duration_ms=1500, | |
# # Chunk size for VAD model. Can be 512, 1024, 1536 for 16k s.r. (default 1024) | |
window_size_samples=1024, | |
# # Final speech chunks are padded by speech_pad_ms each side (default 400) | |
speech_pad_ms=300, | |
), | |
), | |
# send-receive: bidirectional streaming (default) | |
# send: client to server only | |
# receive: server to client only | |
modality="audio", | |
mode="send-receive", | |
concurrency_limit=1, # Limit to one connection at a time | |
additional_outputs=[ | |
gr.Textbox(label="Transcript"), | |
gr.Textbox(label="Chatbot Response"), | |
], | |
additional_outputs_handler=lambda current_transcript, current_response, new_transcript, new_response: ( | |
( | |
(current_transcript + " " + new_transcript) | |
if current_transcript | |
else new_transcript | |
), | |
new_response, # Replace chatbot response with the latest one | |
), | |
# rtc_configuration=get_rtc_credentials(provider="hf") if APP_MODE == "deployed" else None | |
ui_args={ | |
"title": "Oracle Voice Chatbot", | |
"subtitle": "Ask me anything", | |
}, | |
) | |
# Create custom Blocks with CSS and title | |
custom_css = """ | |
.footer { | |
display: none !important; | |
} | |
.gradio-container::after { | |
content: "Made in Luxembourg 🇱🇺"; | |
display: block; | |
text-align: center; | |
padding: 10px; | |
color: #666; | |
font-size: 12px; | |
border-top: 1px solid #e0e0e0; | |
margin-top: 20px; | |
} | |
""" | |
# Set custom CSS and title on the existing UI | |
stream.ui.css = custom_css | |
stream.ui.title = "Oracle Voice Chatbot" | |
logger.info("Launching UI") | |
stream.ui.launch( | |
app_kwargs={ | |
"title": "Oracle Voice Chatbot", | |
"docs_url": None, | |
"redoc_url": None, | |
} | |
) | |