|
import os |
|
import io |
|
import time |
|
import torch |
|
import librosa |
|
import requests |
|
import tempfile |
|
import threading |
|
import queue |
|
import traceback |
|
import numpy as np |
|
import soundfile as sf |
|
import gradio as gr |
|
from datetime import datetime |
|
from transformers import AutoModel, AutoModelForCausalLM, AutoTokenizer, pipeline, logging as trf_logging |
|
from huggingface_hub import login, hf_hub_download, scan_cache_dir |
|
import speech_recognition as sr |
|
import openai |
|
|
|
import torch |
|
print("CUDA available:", torch.cuda.is_available()) |
|
print("CUDA device:", torch.cuda.current_device() if torch.cuda.is_available() else "None") |
|
|
|
|
|
|
|
os.environ["HF_HUB_DOWNLOAD_TIMEOUT"] = "300" |
|
|
|
|
|
trf_logging.set_verbosity_info() |
|
|
|
|
|
HF_TOKEN = os.getenv("HF_TOKEN") |
|
OPENAI_API_KEY = os.getenv("OPENAI_API_KEY") |
|
|
|
|
|
openai.api_key = OPENAI_API_KEY |
|
|
|
|
|
if HF_TOKEN: |
|
print("🔐 Logging into Hugging Face with token...") |
|
login(token=HF_TOKEN) |
|
else: |
|
print("⚠️ HF_TOKEN not found. Proceeding without login...") |
|
|
|
|
|
|
|
|
|
|
|
|
|
tts_model = None |
|
asr_model = None |
|
|
|
|
|
tts_repo_id = "ai4bharat/IndicF5" |
|
asr_repo_id = "facebook/wav2vec2-large-xlsr-53" |
|
|
|
|
|
class TTSModelWrapper: |
|
def __init__(self, model): |
|
self.model = model |
|
|
|
def generate(self, text, ref_audio_path, ref_text): |
|
try: |
|
if self.model is None: |
|
raise ValueError("Model not initialized") |
|
|
|
output = self.model( |
|
text, |
|
ref_audio_path=ref_audio_path, |
|
ref_text=ref_text |
|
) |
|
return output |
|
except Exception as e: |
|
print(f"Error in TTS generation: {e}") |
|
traceback.print_exc() |
|
return None |
|
|
|
def load_tts_model_with_retry(max_retries=3, retry_delay=5): |
|
global tts_model, tts_model_wrapper |
|
|
|
print("Checking if TTS model is in cache...") |
|
try: |
|
cache_info = scan_cache_dir() |
|
model_in_cache = any(tts_repo_id in repo.repo_id for repo in cache_info.repos) |
|
if model_in_cache: |
|
print(f"Model {tts_repo_id} found in cache, loading locally...") |
|
tts_model = AutoModel.from_pretrained( |
|
tts_repo_id, |
|
trust_remote_code=True, |
|
local_files_only=True, |
|
device_map="auto", |
|
torch_dtype=torch.float16 |
|
) |
|
tts_model_wrapper = TTSModelWrapper(tts_model) |
|
print("TTS model loaded from cache successfully!") |
|
return |
|
except Exception as e: |
|
print(f"Cache check failed: {e}") |
|
|
|
for attempt in range(max_retries): |
|
try: |
|
print(f"Loading {tts_repo_id} model (attempt {attempt+1}/{max_retries})...") |
|
tts_model = AutoModel.from_pretrained( |
|
tts_repo_id, |
|
trust_remote_code=True, |
|
revision="main", |
|
use_auth_token=HF_TOKEN, |
|
low_cpu_mem_usage=True, |
|
device_map="auto" |
|
|
|
) |
|
tts_model_wrapper = TTSModelWrapper(tts_model) |
|
print(f"TTS model loaded successfully! Type: {type(tts_model)}") |
|
return |
|
except Exception as e: |
|
print(f"⚠️ Attempt {attempt+1}/{max_retries} failed: {e}") |
|
if attempt < max_retries - 1: |
|
print(f"Waiting {retry_delay} seconds before retrying...") |
|
time.sleep(retry_delay) |
|
retry_delay *= 1.5 |
|
|
|
try: |
|
print("Trying with fallback options...") |
|
tts_model = AutoModel.from_pretrained( |
|
tts_repo_id, |
|
trust_remote_code=True, |
|
revision="main", |
|
local_files_only=False, |
|
use_auth_token=HF_TOKEN, |
|
force_download=False, |
|
resume_download=True, |
|
device_map="auto" |
|
) |
|
tts_model_wrapper = TTSModelWrapper(tts_model) |
|
print("TTS model loaded with fallback options!") |
|
except Exception as e2: |
|
print(f"❌ All attempts to load TTS model failed: {e2}") |
|
print("Will continue without TTS model loaded.") |
|
|
|
|
|
def split_into_chunks(text, max_length=15): |
|
sentence_markers = ['.', '?', '!', ';', ':', '।', '॥'] |
|
chunks = [] |
|
current = "" |
|
|
|
for char in text: |
|
current += char |
|
if char in sentence_markers and current.strip(): |
|
chunks.append(current.strip()) |
|
current = "" |
|
|
|
if current.strip(): |
|
chunks.append(current.strip()) |
|
|
|
final_chunks = [] |
|
for chunk in chunks: |
|
if len(chunk) <= max_length: |
|
final_chunks.append(chunk) |
|
else: |
|
comma_splits = chunk.split(',') |
|
current_part = "" |
|
for part in comma_splits: |
|
if len(current_part) + len(part) <= max_length: |
|
if current_part: |
|
current_part += "," |
|
current_part += part |
|
else: |
|
if current_part: |
|
final_chunks.append(current_part.strip()) |
|
current_part = part |
|
if current_part: |
|
final_chunks.append(current_part.strip()) |
|
|
|
print(f"Split text into {len(final_chunks)} chunks") |
|
return final_chunks |
|
|
|
|
|
def load_asr_model(): |
|
global asr_model |
|
try: |
|
print(f"Loading ASR model from {asr_repo_id}...") |
|
asr_model = pipeline("automatic-speech-recognition", model=asr_repo_id, device=device) |
|
print("ASR model loaded successfully!") |
|
except Exception as e: |
|
print(f"Error loading ASR model: {e}") |
|
print("Will use Google's speech recognition API instead.") |
|
asr_model = None |
|
|
|
class SpeechRecognizer: |
|
def __init__(self): |
|
self.recognizer = sr.Recognizer() |
|
self.using_huggingface = asr_model is not None |
|
|
|
def recognize_from_file(self, audio_path, language="ml-IN"): |
|
"""Recognize speech from audio file with fallback mechanisms""" |
|
print(f"Recognizing speech from {audio_path}") |
|
try: |
|
|
|
if self.using_huggingface: |
|
try: |
|
result = asr_model(audio_path) |
|
transcription = result["text"] |
|
print(f"HF ASR result: {transcription}") |
|
return transcription |
|
except Exception as e: |
|
print(f"HF ASR failed: {e}, falling back to Google") |
|
|
|
|
|
with sr.AudioFile(audio_path) as source: |
|
audio_data = self.recognizer.record(source) |
|
text = self.recognizer.recognize_google(audio_data, language=language) |
|
print(f"Google ASR result: {text}") |
|
return text |
|
except Exception as e: |
|
print(f"Speech recognition failed: {e}") |
|
return "" |
|
|
|
def recognize_from_microphone(self, language="ml-IN", timeout=5): |
|
"""Recognize speech from microphone""" |
|
print("Listening to microphone...") |
|
try: |
|
with sr.Microphone() as source: |
|
self.recognizer.adjust_for_ambient_noise(source) |
|
print("Speak now...") |
|
try: |
|
audio = self.recognizer.listen(source, timeout=timeout) |
|
print("Processing speech...") |
|
|
|
|
|
temp_file = tempfile.NamedTemporaryFile(delete=False, suffix='.wav') |
|
temp_file.close() |
|
|
|
with open(temp_file.name, "wb") as f: |
|
f.write(audio.get_wav_data()) |
|
|
|
|
|
if self.using_huggingface: |
|
try: |
|
result = asr_model(temp_file.name) |
|
text = result["text"] |
|
print(f"HF ASR result: {text}") |
|
os.unlink(temp_file.name) |
|
return text |
|
except Exception as e: |
|
print(f"HF ASR failed: {e}, falling back to Google") |
|
|
|
|
|
text = self.recognizer.recognize_google(audio, language=language) |
|
print(f"Google ASR result: {text}") |
|
os.unlink(temp_file.name) |
|
return text |
|
|
|
except sr.WaitTimeoutError: |
|
print("No speech detected within timeout period") |
|
return "" |
|
except Exception as e: |
|
print(f"Speech recognition error: {e}") |
|
return "" |
|
except Exception as e: |
|
print(f"Microphone access error: {e}") |
|
return "" |
|
|
|
class ConversationManager: |
|
def __init__(self): |
|
self.conversation_history = [] |
|
self.system_prompt = ( |
|
|
|
|
|
|
|
|
|
"You are a helpful and friendly assistant who speaks Malayalam fluently. " |
|
"Respond like you're talking to a close friend over the phone — casual, warm, and natural. " |
|
"Keep your responses short, to the point, and avoid sounding robotic or formal. " |
|
"Use Malayalam when the user uses Malayalam, and English when the user uses English. " |
|
"Use the kind of expressions and tone you'd use while chatting with someone from Kerala." |
|
) |
|
|
|
def add_message(self, role, content): |
|
self.conversation_history.append({"role": role, "content": content}) |
|
|
|
def get_formatted_history(self): |
|
"""Format conversation history for OpenAI API""" |
|
messages = [{"role": "system", "content": self.system_prompt}] |
|
|
|
for msg in self.conversation_history: |
|
if msg["role"] == "user": |
|
messages.append({"role": "user", "content": msg["content"]}) |
|
else: |
|
messages.append({"role": "assistant", "content": msg["content"]}) |
|
|
|
return messages |
|
|
|
def generate_response(self, user_input): |
|
"""Generate response using GPT-3.5 Turbo""" |
|
if not openai.api_key: |
|
return "I'm sorry, but the language model is not available right now." |
|
|
|
self.add_message("user", user_input) |
|
|
|
try: |
|
|
|
messages = self.get_formatted_history() |
|
print(f"Sending prompt to OpenAI: {len(messages)} messages") |
|
|
|
|
|
response = openai.ChatCompletion.create( |
|
model="gpt-3.5-turbo", |
|
messages=messages, |
|
max_tokens=300, |
|
temperature=0.7, |
|
top_p=0.9, |
|
) |
|
|
|
|
|
response_text = response.choices[0].message["content"].strip() |
|
print(f"GPT-3.5 response: {response_text}") |
|
|
|
|
|
self.add_message("assistant", response_text) |
|
|
|
return response_text |
|
|
|
except Exception as e: |
|
print(f"Error generating response: {e}") |
|
fallback_response = "I'm having trouble thinking right now. Can we try again?" |
|
self.add_message("assistant", fallback_response) |
|
return fallback_response |
|
|
|
def remove_noise(audio_data, threshold=0.01): |
|
"""Apply simple noise gate to remove low-level noise""" |
|
if audio_data is None: |
|
return np.zeros(1000) |
|
|
|
|
|
if isinstance(audio_data, torch.Tensor): |
|
audio_data = audio_data.detach().cpu().numpy() |
|
if isinstance(audio_data, list): |
|
audio_data = np.array(audio_data) |
|
|
|
|
|
noise_mask = np.abs(audio_data) < threshold |
|
clean_audio = audio_data.copy() |
|
clean_audio[noise_mask] = 0 |
|
|
|
return clean_audio |
|
|
|
def apply_smoothing(audio_data, window_size=5): |
|
"""Apply gentle smoothing to reduce artifacts""" |
|
if audio_data is None or len(audio_data) < window_size*2: |
|
return audio_data |
|
|
|
|
|
kernel = np.ones(window_size) / window_size |
|
smoothed = np.convolve(audio_data, kernel, mode='same') |
|
|
|
|
|
smoothed[:window_size] = audio_data[:window_size] |
|
smoothed[-window_size:] = audio_data[-window_size:] |
|
|
|
return smoothed |
|
|
|
def enhance_audio(audio_data): |
|
"""Process audio to improve quality and reduce noise""" |
|
if audio_data is None: |
|
return np.zeros(1000) |
|
|
|
|
|
if isinstance(audio_data, torch.Tensor): |
|
audio_data = audio_data.detach().cpu().numpy() |
|
if isinstance(audio_data, list): |
|
audio_data = np.array(audio_data) |
|
|
|
|
|
if len(audio_data.shape) > 1: |
|
audio_data = audio_data.flatten() |
|
if audio_data.dtype != np.float32: |
|
audio_data = audio_data.astype(np.float32) |
|
|
|
|
|
if audio_data.size < 100: |
|
return audio_data |
|
|
|
|
|
rms = np.sqrt(np.mean(audio_data**2)) |
|
print(f"Initial RMS: {rms}") |
|
|
|
|
|
if rms < 0.05: |
|
target_rms = 0.2 |
|
gain = target_rms / max(rms, 0.0001) |
|
print(f"Applying gain factor: {gain}") |
|
audio_data = audio_data * gain |
|
|
|
|
|
audio_data = audio_data - np.mean(audio_data) |
|
|
|
|
|
audio_data = remove_noise(audio_data, threshold=0.01) |
|
|
|
|
|
audio_data = apply_smoothing(audio_data, window_size=3) |
|
|
|
|
|
max_amp = np.max(np.abs(audio_data)) |
|
if max_amp > 0.95: |
|
audio_data = 0.95 * audio_data / max_amp |
|
|
|
|
|
audio_data = np.tanh(audio_data * 1.1) * 0.9 |
|
|
|
return audio_data |
|
|
|
def split_into_chunks(text, max_length=8): |
|
"""Split text into smaller chunks based on punctuation and length""" |
|
|
|
sentence_markers = ['.', '?', '!', ';', ':', '।', '॥'] |
|
chunks = [] |
|
current = "" |
|
|
|
|
|
for char in text: |
|
current += char |
|
if char in sentence_markers and current.strip(): |
|
chunks.append(current.strip()) |
|
current = "" |
|
|
|
if current.strip(): |
|
chunks.append(current.strip()) |
|
|
|
|
|
final_chunks = [] |
|
for chunk in chunks: |
|
if len(chunk) <= max_length: |
|
final_chunks.append(chunk) |
|
else: |
|
|
|
comma_splits = chunk.split(',') |
|
current_part = "" |
|
|
|
for part in comma_splits: |
|
if len(current_part) + len(part) <= max_length: |
|
if current_part: |
|
current_part += "," |
|
current_part += part |
|
else: |
|
if current_part: |
|
final_chunks.append(current_part.strip()) |
|
current_part = part |
|
|
|
if current_part: |
|
final_chunks.append(current_part.strip()) |
|
|
|
print(f"Split text into {len(final_chunks)} chunks") |
|
return final_chunks |
|
|
|
class StreamingTTS: |
|
def __init__(self): |
|
self.is_generating = False |
|
self.should_stop = False |
|
self.temp_dir = None |
|
self.ref_audio_path = None |
|
self.output_file = None |
|
self.all_chunks = [] |
|
self.sample_rate = 24000 |
|
self.current_text = "" |
|
|
|
|
|
try: |
|
self.temp_dir = tempfile.mkdtemp() |
|
print(f"Created temp directory: {self.temp_dir}") |
|
except Exception as e: |
|
print(f"Error creating temp directory: {e}") |
|
self.temp_dir = "." |
|
|
|
def prepare_ref_audio(self, ref_audio, ref_sr): |
|
"""Prepare reference audio with enhanced quality""" |
|
try: |
|
if self.ref_audio_path is None: |
|
self.ref_audio_path = os.path.join(self.temp_dir, "ref_audio.wav") |
|
|
|
|
|
ref_audio = enhance_audio(ref_audio) |
|
|
|
|
|
sf.write(self.ref_audio_path, ref_audio, ref_sr, format='WAV', subtype='FLOAT') |
|
print(f"Saved reference audio to: {self.ref_audio_path}") |
|
|
|
|
|
if os.path.exists(self.ref_audio_path): |
|
print(f"Reference audio saved successfully: {os.path.getsize(self.ref_audio_path)} bytes") |
|
else: |
|
print("⚠️ Failed to create reference audio file!") |
|
|
|
|
|
if self.output_file is None: |
|
self.output_file = os.path.join(self.temp_dir, "output.wav") |
|
print(f"Output will be saved to: {self.output_file}") |
|
except Exception as e: |
|
print(f"Error preparing reference audio: {e}") |
|
|
|
def cleanup(self): |
|
"""Clean up temporary files""" |
|
if self.temp_dir: |
|
try: |
|
if os.path.exists(self.ref_audio_path): |
|
os.remove(self.ref_audio_path) |
|
if os.path.exists(self.output_file): |
|
os.remove(self.output_file) |
|
os.rmdir(self.temp_dir) |
|
self.temp_dir = None |
|
print("Cleaned up temporary files") |
|
except Exception as e: |
|
print(f"Error cleaning up: {e}") |
|
|
|
def generate(self, text, ref_audio, ref_sr, ref_text): |
|
"""Start generation in a new thread with validation""" |
|
if self.is_generating: |
|
print("Already generating speech, please wait") |
|
return |
|
|
|
|
|
self.current_text = text |
|
print(f"Setting current text to: '{self.current_text}'") |
|
|
|
|
|
if tts_model_wrapper is None or tts_model is None: |
|
print("⚠️ Model is not loaded. Cannot generate speech.") |
|
return |
|
|
|
self.is_generating = True |
|
self.should_stop = False |
|
self.all_chunks = [] |
|
|
|
|
|
threading.Thread( |
|
target=self._process_streaming, |
|
args=(text, ref_audio, ref_sr, ref_text), |
|
daemon=True |
|
).start() |
|
|
|
def _process_streaming(self, text, ref_audio, ref_sr, ref_text): |
|
"""Process text in chunks with high-quality audio generation""" |
|
try: |
|
|
|
if text != self.current_text: |
|
print(f"⚠️ Text mismatch detected! Expected: '{self.current_text}', Got: '{text}'") |
|
|
|
text = self.current_text |
|
|
|
|
|
self.prepare_ref_audio(ref_audio, ref_sr) |
|
|
|
|
|
print(f"Processing text: '{text}'") |
|
|
|
|
|
chunks = split_into_chunks(text) |
|
print(f"Processing {len(chunks)} chunks") |
|
|
|
combined_audio = None |
|
total_start_time = time.time() |
|
|
|
|
|
for i, chunk in enumerate(chunks): |
|
if self.should_stop: |
|
print("Stopping generation as requested") |
|
break |
|
|
|
chunk_start = time.time() |
|
print(f"Processing chunk {i+1}/{len(chunks)}: '{chunk}'") |
|
|
|
|
|
try: |
|
|
|
chunk_timeout = 30 |
|
|
|
with torch.inference_mode(): |
|
|
|
chunk_audio = tts_model_wrapper.generate( |
|
text=chunk, |
|
ref_audio_path=self.ref_audio_path, |
|
ref_text=ref_text |
|
) |
|
|
|
if chunk_audio is None or (hasattr(chunk_audio, 'size') and chunk_audio.size == 0): |
|
print("⚠️ Empty audio returned for this chunk") |
|
chunk_audio = np.zeros(int(24000 * 0.5)) |
|
|
|
|
|
chunk_audio = enhance_audio(chunk_audio) |
|
|
|
chunk_time = time.time() - chunk_start |
|
print(f"✓ Chunk {i+1} processed in {chunk_time:.2f}s") |
|
|
|
|
|
silence = np.zeros(int(24000 * 0.1)) |
|
chunk_audio = np.concatenate([chunk_audio, silence]) |
|
|
|
|
|
self.all_chunks.append(chunk_audio) |
|
|
|
|
|
if combined_audio is None: |
|
combined_audio = chunk_audio |
|
else: |
|
combined_audio = np.concatenate([combined_audio, chunk_audio]) |
|
|
|
|
|
processed_audio = enhance_audio(combined_audio) |
|
|
|
|
|
sf.write(self.output_file, processed_audio, 24000, format='WAV', subtype='FLOAT') |
|
|
|
except Exception as e: |
|
print(f"Error processing chunk {i+1}: {str(e)[:100]}") |
|
continue |
|
|
|
total_time = time.time() - total_start_time |
|
print(f"Total generation time: {total_time:.2f}s") |
|
|
|
except Exception as e: |
|
print(f"Error in streaming TTS: {str(e)[:200]}") |
|
|
|
if len(self.all_chunks) > 0: |
|
try: |
|
combined = np.concatenate(self.all_chunks) |
|
sf.write(self.output_file, combined, 24000, format='WAV', subtype='FLOAT') |
|
print("Saved partial output") |
|
except Exception as e2: |
|
print(f"Failed to save partial output: {e2}") |
|
finally: |
|
self.is_generating = False |
|
print("Generation complete") |
|
|
|
def get_current_audio(self): |
|
"""Get current audio file path for Gradio""" |
|
if self.output_file and os.path.exists(self.output_file): |
|
file_size = os.path.getsize(self.output_file) |
|
if file_size > 0: |
|
return self.output_file |
|
return None |
|
|
|
class ConversationEngine: |
|
def __init__(self): |
|
self.conversation_history = [] |
|
self.system_prompt = "You are a helpful assistant that speaks Malayalam fluently. Always respond in Malayalam script with proper formatting." |
|
self.saved_voice = None |
|
self.saved_voice_text = "" |
|
self.tts_cache = {} |
|
|
|
|
|
self.tts_queue = queue.Queue() |
|
self.tts_thread = threading.Thread(target=self.tts_worker, daemon=True) |
|
self.tts_thread.start() |
|
|
|
|
|
self.streaming_tts = StreamingTTS() |
|
|
|
def tts_worker(self): |
|
"""Background worker to process TTS requests""" |
|
while True: |
|
try: |
|
|
|
text, callback = self.tts_queue.get() |
|
|
|
|
|
audio_path = self._generate_tts(text) |
|
|
|
|
|
if callback: |
|
callback(audio_path) |
|
|
|
|
|
self.tts_queue.task_done() |
|
except Exception as e: |
|
print(f"Error in TTS worker: {e}") |
|
traceback.print_exc() |
|
|
|
def transcribe_audio(self, audio_data, language="ml-IN"): |
|
"""Convert audio to text using speech recognition""" |
|
if audio_data is None: |
|
print("No audio data received") |
|
return "No audio detected", "" |
|
|
|
|
|
try: |
|
if isinstance(audio_data, tuple) and len(audio_data) == 2: |
|
|
|
sample_rate, audio_samples = audio_data |
|
else: |
|
print(f"Unexpected audio format: {type(audio_data)}") |
|
return "Invalid audio format", "" |
|
|
|
if len(audio_samples) == 0: |
|
print("Empty audio samples") |
|
return "No speech detected", "" |
|
|
|
|
|
temp_file = tempfile.NamedTemporaryFile(delete=False, suffix=".wav") |
|
temp_file.close() |
|
|
|
|
|
sf.write(temp_file.name, audio_samples, sample_rate) |
|
|
|
|
|
recognizer = sr.Recognizer() |
|
with sr.AudioFile(temp_file.name) as source: |
|
audio = recognizer.record(source) |
|
|
|
text = recognizer.recognize_google(audio, language=language) |
|
print(f"Recognized: {text}") |
|
return text, text |
|
|
|
except sr.UnknownValueError: |
|
print("Speech recognition could not understand audio") |
|
return "Could not understand audio", "" |
|
except sr.RequestError as e: |
|
print(f"Could not request results from Google Speech Recognition service: {e}") |
|
return f"Speech recognition service error: {str(e)}", "" |
|
except Exception as e: |
|
print(f"Error processing audio: {e}") |
|
traceback.print_exc() |
|
return f"Error processing audio: {str(e)}", "" |
|
finally: |
|
|
|
if 'temp_file' in locals() and os.path.exists(temp_file.name): |
|
try: |
|
os.unlink(temp_file.name) |
|
except Exception as e: |
|
print(f"Error deleting temporary file: {e}") |
|
|
|
def save_reference_voice(self, audio_data, reference_text): |
|
"""Save the reference voice for future TTS generation""" |
|
if audio_data is None or not reference_text.strip(): |
|
return "Error: Both reference audio and text are required" |
|
|
|
self.saved_voice = audio_data |
|
self.saved_voice_text = reference_text.strip() |
|
|
|
|
|
self.tts_cache.clear() |
|
|
|
|
|
sample_rate, audio_samples = audio_data |
|
print(f"Saved reference voice: {len(audio_samples)} samples at {sample_rate}Hz") |
|
print(f"Reference text: {reference_text}") |
|
|
|
return f"Voice saved successfully! Reference text: {reference_text}" |
|
|
|
def process_text_input(self, text): |
|
"""Process text input from user""" |
|
if text and text.strip(): |
|
return text, text |
|
return "No input provided", "" |
|
|
|
def generate_response(self, input_text): |
|
"""Generate AI response using GPT-3.5 Turbo""" |
|
if not input_text or not input_text.strip(): |
|
return "ഇൻപുട്ട് ലഭിച്ചില്ല. വീണ്ടും ശ്രമിക്കുക.", None |
|
|
|
try: |
|
|
|
messages = [{"role": "system", "content": self.system_prompt}] |
|
|
|
|
|
for entry in self.conversation_history: |
|
role = "user" if entry["role"] == "user" else "assistant" |
|
messages.append({"role": role, "content": entry["content"]}) |
|
|
|
|
|
messages.append({"role": "user", "content": input_text}) |
|
|
|
|
|
response = openai.ChatCompletion.create( |
|
model="gpt-3.5-turbo", |
|
messages=messages, |
|
max_tokens=500, |
|
temperature=0.7 |
|
) |
|
|
|
response_text = response.choices[0].message["content"].strip() |
|
return response_text, None |
|
|
|
except Exception as e: |
|
error_msg = f"എറർ: GPT മോഡലിൽ നിന്ന് ഉത്തരം ലഭിക്കുന്നതിൽ പ്രശ്നമുണ്ടായി: {str(e)}" |
|
print(f"Error in GPT response: {e}") |
|
traceback.print_exc() |
|
return error_msg, None |
|
|
|
def resample_audio(self, audio, orig_sr, target_sr): |
|
"""Resample audio to match target sample rate only if necessary""" |
|
if orig_sr != target_sr: |
|
print(f"Resampling audio from {orig_sr}Hz to {target_sr}Hz") |
|
return librosa.resample(audio, orig_sr=orig_sr, target_sr=target_sr) |
|
return audio |
|
|
|
def _generate_tts(self, text): |
|
"""Internal method to generate TTS without threading""" |
|
if not text or not text.strip(): |
|
print("No text provided for TTS generation") |
|
return None |
|
|
|
|
|
if text in self.tts_cache: |
|
print("Using cached TTS output") |
|
return self.tts_cache[text] |
|
|
|
try: |
|
|
|
if self.saved_voice is not None and tts_model is not None: |
|
sample_rate, audio_data = self.saved_voice |
|
|
|
|
|
ref_temp_file = tempfile.NamedTemporaryFile(delete=False, suffix=".wav") |
|
ref_temp_file.close() |
|
print(f"Saving reference audio to {ref_temp_file.name}") |
|
|
|
|
|
sf.write(ref_temp_file.name, audio_data, sample_rate) |
|
|
|
|
|
output_temp_file = tempfile.NamedTemporaryFile(delete=False, suffix=".wav") |
|
output_temp_file.close() |
|
|
|
try: |
|
|
|
print(f"Generating speech with IndicF5. Text: {text[:30]}...") |
|
start_time = time.time() |
|
|
|
|
|
with torch.no_grad(): |
|
|
|
synth_audio = tts_model_wrapper.generate( |
|
text, |
|
ref_audio_path=ref_temp_file.name, |
|
ref_text=self.saved_voice_text |
|
) |
|
|
|
end_time = time.time() |
|
print(f"Speech generation completed in {end_time - start_time:.2f} seconds") |
|
|
|
|
|
synth_audio = enhance_audio(synth_audio) |
|
|
|
|
|
sf.write(output_temp_file.name, synth_audio, 24000) |
|
|
|
|
|
self.tts_cache[text] = output_temp_file.name |
|
|
|
print(f"TTS output saved to {output_temp_file.name}") |
|
return output_temp_file.name |
|
|
|
except Exception as e: |
|
print(f"Error generating speech: {e}") |
|
traceback.print_exc() |
|
return None |
|
finally: |
|
|
|
|
|
try: |
|
os.unlink(ref_temp_file.name) |
|
except Exception as e: |
|
print(f"Error cleaning up reference file: {e}") |
|
else: |
|
print("No saved voice reference or TTS model not loaded") |
|
return None |
|
except Exception as e: |
|
print(f"Error in TTS processing: {e}") |
|
traceback.print_exc() |
|
return None |
|
|
|
def queue_tts_generation(self, text, callback=None): |
|
"""Queue TTS generation in background thread""" |
|
print(f"Queueing TTS generation for text: {text[:30]}...") |
|
self.tts_queue.put((text, callback)) |
|
|
|
def generate_streamed_speech(self, text): |
|
"""Generate speech in a streaming manner for low latency""" |
|
if not self.saved_voice: |
|
print("No reference voice saved") |
|
return None |
|
|
|
if not text or not text.strip(): |
|
print("No text provided for streaming TTS") |
|
return None |
|
|
|
sample_rate, audio_data = self.saved_voice |
|
|
|
|
|
self.streaming_tts.generate( |
|
text=text, |
|
ref_audio=audio_data, |
|
ref_sr=sample_rate, |
|
ref_text=self.saved_voice_text |
|
) |
|
|
|
|
|
return self.streaming_tts.output_file |
|
|
|
def update_history(self, user_input, ai_response): |
|
"""Update conversation history""" |
|
if user_input and user_input.strip(): |
|
self.conversation_history.append({"role": "user", "content": user_input}) |
|
|
|
if ai_response and ai_response.strip(): |
|
self.conversation_history.append({"role": "assistant", "content": ai_response}) |
|
|
|
|
|
if len(self.conversation_history) > 20: |
|
self.conversation_history = self.conversation_history[-20:] |
|
|
|
|
|
conversation_engine = ConversationEngine() |
|
speech_recognizer = SpeechRecognizer() |
|
|
|
class ConversationEngine: |
|
def __init__(self): |
|
self.conversation_history = [] |
|
self.system_prompt = "You are a helpful assistant that speaks Malayalam fluently. Always respond in Malayalam script with proper formatting." |
|
self.saved_voice = None |
|
self.saved_voice_text = "" |
|
self.tts_cache = {} |
|
|
|
|
|
self.tts_queue = queue.Queue() |
|
self.tts_thread = threading.Thread(target=self.tts_worker, daemon=True) |
|
self.tts_thread.start() |
|
|
|
|
|
self.tts_model = None |
|
self.device = None |
|
try: |
|
self.initialize_tts_model() |
|
|
|
|
|
if self.tts_model is not None: |
|
print("TTS model initialized successfully") |
|
except Exception as e: |
|
print(f"Error initializing TTS model: {e}") |
|
traceback.print_exc() |
|
|
|
def initialize_tts_model(self): |
|
"""Initialize the IndicF5 TTS model with optimizations""" |
|
try: |
|
|
|
hf_token = os.getenv("HF_TOKEN") |
|
if hf_token: |
|
print("Logging into Hugging Face with the provided token.") |
|
login(token=hf_token) |
|
|
|
if torch.cuda.is_available(): |
|
self.device = torch.device("cuda") |
|
print(f"Using GPU: {torch.cuda.get_device_name(0)}") |
|
else: |
|
self.device = torch.device("cpu") |
|
print("Using CPU") |
|
|
|
|
|
torch.backends.cudnn.benchmark = True |
|
|
|
|
|
print("Loading TTS model from ai4bharat/IndicF5...") |
|
repo_id = "ai4bharat/IndicF5" |
|
self.tts_model = AutoModel.from_pretrained(repo_id, trust_remote_code=True) |
|
self.tts_model = self.tts_model.to(self.device) |
|
|
|
|
|
self.tts_model.eval() |
|
print("TTS model loaded successfully") |
|
except Exception as e: |
|
print(f"Failed to load TTS model: {e}") |
|
self.tts_model = None |
|
traceback.print_exc() |
|
|
|
def tts_worker(self): |
|
"""Background worker to process TTS requests""" |
|
while True: |
|
try: |
|
|
|
text, callback = self.tts_queue.get() |
|
|
|
|
|
audio_path = self._generate_tts(text) |
|
|
|
|
|
if callback: |
|
callback(audio_path) |
|
|
|
|
|
self.tts_queue.task_done() |
|
except Exception as e: |
|
print(f"Error in TTS worker: {e}") |
|
traceback.print_exc() |
|
|
|
def transcribe_audio(self, audio_data, language="ml-IN"): |
|
"""Convert audio to text using speech recognition""" |
|
if audio_data is None: |
|
print("No audio data received") |
|
return "No audio detected", "" |
|
|
|
|
|
try: |
|
if isinstance(audio_data, tuple) and len(audio_data) == 2: |
|
|
|
sample_rate, audio_samples = audio_data |
|
else: |
|
print(f"Unexpected audio format: {type(audio_data)}") |
|
return "Invalid audio format", "" |
|
|
|
if len(audio_samples) == 0: |
|
print("Empty audio samples") |
|
return "No speech detected", "" |
|
|
|
|
|
temp_file = tempfile.NamedTemporaryFile(delete=False, suffix=".wav") |
|
temp_file.close() |
|
|
|
|
|
sf.write(temp_file.name, audio_samples, sample_rate) |
|
|
|
|
|
recognizer = sr.Recognizer() |
|
with sr.AudioFile(temp_file.name) as source: |
|
audio = recognizer.record(source) |
|
|
|
text = recognizer.recognize_google(audio, language=language) |
|
print(f"Recognized: {text}") |
|
return text, text |
|
|
|
except sr.UnknownValueError: |
|
print("Speech recognition could not understand audio") |
|
return "Could not understand audio", "" |
|
except sr.RequestError as e: |
|
print(f"Could not request results from Google Speech Recognition service: {e}") |
|
return f"Speech recognition service error: {str(e)}", "" |
|
except Exception as e: |
|
print(f"Error processing audio: {e}") |
|
traceback.print_exc() |
|
return f"Error processing audio: {str(e)}", "" |
|
finally: |
|
|
|
if 'temp_file' in locals() and os.path.exists(temp_file.name): |
|
try: |
|
os.unlink(temp_file.name) |
|
except Exception as e: |
|
print(f"Error deleting temporary file: {e}") |
|
|
|
def save_reference_voice(self, audio_data, reference_text): |
|
"""Save the reference voice for future TTS generation""" |
|
if audio_data is None or not reference_text.strip(): |
|
return "Error: Both reference audio and text are required" |
|
|
|
self.saved_voice = audio_data |
|
self.saved_voice_text = reference_text.strip() |
|
|
|
|
|
self.tts_cache.clear() |
|
|
|
|
|
sample_rate, audio_samples = audio_data |
|
print(f"Saved reference voice: {len(audio_samples)} samples at {sample_rate}Hz") |
|
print(f"Reference text: {reference_text}") |
|
|
|
return f"Voice saved successfully! Reference text: {reference_text}" |
|
|
|
def process_text_input(self, text): |
|
"""Process text input from user""" |
|
if text and text.strip(): |
|
return text, text |
|
return "No input provided", "" |
|
|
|
def generate_response(self, input_text): |
|
"""Generate AI response using GPT-3.5 Turbo""" |
|
if not input_text or not input_text.strip(): |
|
return "ഇൻപുട്ട് ലഭിച്ചില്ല. വീണ്ടും ശ്രമിക്കുക.", None |
|
|
|
try: |
|
|
|
messages = [{"role": "system", "content": self.system_prompt}] |
|
|
|
|
|
for entry in self.conversation_history: |
|
role = "user" if entry["role"] == "user" else "assistant" |
|
messages.append({"role": role, "content": entry["content"]}) |
|
|
|
|
|
messages.append({"role": "user", "content": input_text}) |
|
|
|
|
|
response = openai.ChatCompletion.create( |
|
model="gpt-3.5-turbo", |
|
messages=messages, |
|
max_tokens=500, |
|
temperature=0.7 |
|
) |
|
|
|
response_text = response.choices[0].message.content.strip() |
|
return response_text, None |
|
|
|
except Exception as e: |
|
error_msg = f"എറർ: GPT മോഡലിൽ നിന്ന് ഉത്തരം ലഭിക്കുന്നതിൽ പ്രശ്നമുണ്ടായി: {str(e)}" |
|
print(f"Error in GPT response: {e}") |
|
traceback.print_exc() |
|
return error_msg, None |
|
|
|
def resample_audio(self, audio, orig_sr, target_sr): |
|
"""Resample audio to match target sample rate only if necessary""" |
|
if orig_sr != target_sr: |
|
print(f"Resampling audio from {orig_sr}Hz to {target_sr}Hz") |
|
return librosa.resample(audio, orig_sr=orig_sr, target_sr=target_sr) |
|
return audio |
|
|
|
def _generate_tts(self, text): |
|
"""Internal method to generate TTS without threading""" |
|
if not text or not text.strip(): |
|
print("No text provided for TTS generation") |
|
return None |
|
|
|
|
|
if text in self.tts_cache: |
|
print("Using cached TTS output") |
|
return self.tts_cache[text] |
|
|
|
try: |
|
|
|
if self.saved_voice is not None and self.tts_model is not None: |
|
sample_rate, audio_data = self.saved_voice |
|
|
|
|
|
ref_temp_file = tempfile.NamedTemporaryFile(delete=False, suffix=".wav") |
|
ref_temp_file.close() |
|
print(f"Saving reference audio to {ref_temp_file.name}") |
|
|
|
|
|
sf.write(ref_temp_file.name, audio_data, sample_rate) |
|
|
|
|
|
output_temp_file = tempfile.NamedTemporaryFile(delete=False, suffix=".wav") |
|
output_temp_file.close() |
|
|
|
try: |
|
|
|
print(f"Generating speech with IndicF5. Text: {text[:30]}...") |
|
start_time = time.time() |
|
|
|
|
|
with torch.no_grad(): |
|
|
|
synth_audio = self.tts_model( |
|
text, |
|
ref_audio_path=ref_temp_file.name, |
|
ref_text=self.saved_voice_text |
|
) |
|
|
|
end_time = time.time() |
|
print(f"Speech generation completed in {(end_time - start_time)} seconds") |
|
|
|
|
|
if synth_audio.dtype == np.int16: |
|
synth_audio = synth_audio.astype(np.float32) / 32768.0 |
|
|
|
|
|
synth_audio = self.resample_audio(synth_audio, orig_sr=24000, target_sr=sample_rate) |
|
|
|
|
|
print(f"Saving synthesized audio to {output_temp_file.name}") |
|
sf.write(output_temp_file.name, synth_audio, sample_rate) |
|
|
|
|
|
self.tts_cache[text] = output_temp_file.name |
|
|
|
print(f"TTS generation successful, output file: {output_temp_file.name}") |
|
return output_temp_file.name |
|
except Exception as e: |
|
print(f"IndicF5 TTS failed with error: {e}") |
|
traceback.print_exc() |
|
|
|
return self.fallback_tts(text, output_temp_file.name) |
|
finally: |
|
|
|
if os.path.exists(ref_temp_file.name): |
|
try: |
|
os.unlink(ref_temp_file.name) |
|
except Exception as e: |
|
print(f"Error deleting temporary file: {e}") |
|
else: |
|
if self.saved_voice is None: |
|
print("No saved voice available for TTS") |
|
if self.tts_model is None: |
|
print("TTS model not initialized") |
|
|
|
|
|
temp_file = tempfile.NamedTemporaryFile(delete=False, suffix=".wav") |
|
temp_file.close() |
|
return self.fallback_tts(text, temp_file.name) |
|
|
|
except Exception as e: |
|
print(f"Error in TTS processing: {e}") |
|
traceback.print_exc() |
|
return None |
|
|
|
def speak_with_indicf5(self, text, callback=None): |
|
"""Queue text for TTS generation""" |
|
if not text or not text.strip(): |
|
if callback: |
|
callback(None) |
|
return None |
|
|
|
|
|
if text in self.tts_cache: |
|
print("Using cached TTS output") |
|
if callback: |
|
callback(self.tts_cache[text]) |
|
return self.tts_cache[text] |
|
|
|
|
|
if callback is None: |
|
return self._generate_tts(text) |
|
|
|
|
|
self.tts_queue.put((text, callback)) |
|
return None |
|
|
|
def fallback_tts(self, text, output_path): |
|
"""Fallback to Google TTS if IndicF5 fails""" |
|
try: |
|
from gtts import gTTS |
|
|
|
|
|
is_malayalam = any('\u0D00' <= c <= '\u0D7F' for c in text) |
|
lang = 'ml' if is_malayalam else 'en' |
|
|
|
print(f"Using fallback Google TTS with language: {lang}") |
|
tts = gTTS(text=text, lang=lang, slow=False) |
|
tts.save(output_path) |
|
|
|
|
|
self.tts_cache[text] = output_path |
|
print(f"Fallback TTS saved to: {output_path}") |
|
|
|
return output_path |
|
except Exception as e: |
|
print(f"Fallback TTS also failed: {e}") |
|
traceback.print_exc() |
|
return None |
|
|
|
def add_message(self, role, content): |
|
"""Add a message to the conversation history""" |
|
timestamp = datetime.now().strftime("%H:%M:%S") |
|
self.conversation_history.append({ |
|
"role": role, |
|
"content": content, |
|
"timestamp": timestamp |
|
}) |
|
|
|
def clear_conversation(self): |
|
"""Clear the conversation history""" |
|
self.conversation_history = [] |
|
|
|
def cleanup(self): |
|
"""Clean up resources when shutting down""" |
|
print("Cleaning up resources...") |
|
|
|
|
|
def load_audio_from_url(url): |
|
"""Load audio from a URL""" |
|
try: |
|
response = requests.get(url) |
|
if response.status_code == 200: |
|
audio_data, sample_rate = sf.read(io.BytesIO(response.content)) |
|
return sample_rate, audio_data |
|
except Exception as e: |
|
print(f"Error loading audio from URL: {e}") |
|
return None, None |
|
|
|
|
|
EXAMPLE_VOICES = [ |
|
{ |
|
"name": "Aparna Voice", |
|
"url": "https://raw.githubusercontent.com/Aparna0112/voicerecording-_TTS/main/Aparna%20Voice.wav", |
|
"transcript": "ഞാൻ ഒരു ഫോണിന്റെ കവർ നോക്കുകയാണ്. എനിക്ക് സ്മാർട്ട് ഫോണിന് കവർ വേണം" |
|
}, |
|
{ |
|
"name": "KC Voice", |
|
"url": "https://raw.githubusercontent.com/Aparna0112/voicerecording-_TTS/main/KC%20Voice.wav", |
|
"transcript": "ഹലോ ഇത് അപരനെ അല്ലേ ഞാൻ ജഗദീപ് ആണ് വിളിക്കുന്നത് ഇപ്പോൾ ഫ്രീയാണോ സംസാരിക്കാമോ" |
|
} |
|
] |
|
|
|
|
|
for voice in EXAMPLE_VOICES: |
|
sample_rate, audio_data = load_audio_from_url(voice["url"]) |
|
if sample_rate is not None: |
|
voice["audio"] = (sample_rate, audio_data) |
|
print(f"Loaded example voice: {voice['name']}") |
|
else: |
|
print(f"Failed to load voice: {voice['name']}") |
|
|
|
def create_chatbot_interface(): |
|
"""Create a single-page chatbot interface with voice input, output, and voice selection""" |
|
|
|
|
|
engine = ConversationEngine() |
|
|
|
|
|
css = """ |
|
.chatbot-container { |
|
display: flex; |
|
flex-direction: column; |
|
height: 100%; |
|
max-width: 800px; |
|
margin: 0 auto; |
|
} |
|
.chat-window { |
|
flex-grow: 1; |
|
overflow-y: auto; |
|
padding: 1rem; |
|
background: #f5f7f9; |
|
border-radius: 0.5rem; |
|
margin-bottom: 1rem; |
|
min-height: 400px; |
|
} |
|
.input-area { |
|
display: flex; |
|
gap: 0.5rem; |
|
padding: 0.5rem; |
|
align-items: center; |
|
} |
|
.message { |
|
margin-bottom: 1rem; |
|
padding: 0.8rem; |
|
border-radius: 0.5rem; |
|
position: relative; |
|
max-width: 80%; |
|
} |
|
.user-message { |
|
background: #e1f5fe; |
|
align-self: flex-end; |
|
margin-left: auto; |
|
} |
|
.bot-message { |
|
background: #f0f0f0; |
|
align-self: flex-start; |
|
} |
|
.timestamp { |
|
font-size: 0.7rem; |
|
color: #888; |
|
margin-top: 0.2rem; |
|
text-align: right; |
|
} |
|
.chatbot-header { |
|
text-align: center; |
|
color: #333; |
|
margin-bottom: 1rem; |
|
} |
|
.chat-controls { |
|
display: flex; |
|
justify-content: space-between; |
|
margin-bottom: 0.5rem; |
|
} |
|
.voice-selector { |
|
background: #f8f9fa; |
|
padding: 1rem; |
|
border-radius: 0.5rem; |
|
margin-bottom: 1rem; |
|
} |
|
.progress-bar { |
|
height: 4px; |
|
background-color: #e0e0e0; |
|
position: relative; |
|
margin: 10px 0; |
|
border-radius: 2px; |
|
} |
|
.progress-bar-fill { |
|
height: 100%; |
|
background-color: #4CAF50; |
|
border-radius: 2px; |
|
transition: width 0.3s ease-in-out; |
|
} |
|
""" |
|
|
|
with gr.Blocks(css=css, title="Malayalam Voice Chatbot") as interface: |
|
gr.Markdown("# 🤖 Malayalam Voice Chatbot with Voice Selection", elem_classes=["chatbot-header"]) |
|
|
|
|
|
tts_progress_state = gr.State(0) |
|
audio_output_state = gr.State(None) |
|
|
|
with gr.Row(elem_classes=["chatbot-container"]): |
|
with gr.Column(): |
|
|
|
with gr.Accordion("🎤 Voice Selection", open=True): |
|
|
|
voice_selector = gr.Dropdown( |
|
choices=[voice["name"] for voice in EXAMPLE_VOICES], |
|
value=EXAMPLE_VOICES[0]["name"] if EXAMPLE_VOICES else None, |
|
label="Select Voice Example" |
|
) |
|
|
|
|
|
voice_info = gr.Textbox( |
|
value=EXAMPLE_VOICES[0]["transcript"] if EXAMPLE_VOICES else "", |
|
label="Voice Sample Transcript", |
|
lines=2, |
|
interactive=True |
|
) |
|
|
|
|
|
example_audio = gr.Audio( |
|
value=None, |
|
label="Example Voice", |
|
interactive=False |
|
) |
|
|
|
|
|
gr.Markdown("### OR Record Your Own Voice") |
|
|
|
custom_voice = gr.Audio( |
|
sources=["microphone", "upload"], |
|
type="numpy", |
|
label="Record/Upload Your Voice" |
|
) |
|
|
|
custom_transcript = gr.Textbox( |
|
value="", |
|
label="Your Voice Transcript (what you said in Malayalam)", |
|
lines=2 |
|
) |
|
|
|
|
|
save_voice_btn = gr.Button("💾 Save Voice for Chat", variant="primary") |
|
voice_status = gr.Textbox(label="Voice Status", value="No voice saved yet") |
|
|
|
|
|
with gr.Row(elem_classes=["chat-controls"]): |
|
language_selector = gr.Dropdown( |
|
choices=["ml-IN", "en-US", "hi-IN", "ta-IN", "te-IN", "kn-IN"], |
|
value="ml-IN", |
|
label="Speech Recognition Language" |
|
) |
|
clear_btn = gr.Button("🧹 Clear Chat", scale=0) |
|
|
|
|
|
chatbot = gr.Chatbot( |
|
[], |
|
elem_id="chatbox", |
|
bubble_full_width=False, |
|
height=450, |
|
elem_classes=["chat-window"] |
|
) |
|
|
|
|
|
with gr.Row(): |
|
tts_progress = gr.Slider( |
|
minimum=0, |
|
maximum=100, |
|
value=0, |
|
label="TTS Progress", |
|
interactive=False |
|
) |
|
|
|
|
|
audio_output = gr.Audio( |
|
label="Bot's Voice Response", |
|
type="filepath", |
|
autoplay=True, |
|
visible=True |
|
) |
|
|
|
|
|
status_msg = gr.Textbox( |
|
label="Status", |
|
value="Ready", |
|
interactive=False |
|
) |
|
|
|
|
|
with gr.Row(elem_classes=["input-area"]): |
|
audio_msg = gr.Textbox( |
|
label="Message", |
|
placeholder="Type a message or record audio", |
|
lines=1 |
|
) |
|
audio_input = gr.Audio( |
|
sources=["microphone"], |
|
type="numpy", |
|
label="Record", |
|
elem_classes=["audio-input"] |
|
) |
|
submit_btn = gr.Button("🚀 Send", variant="primary") |
|
|
|
|
|
def update_voice_example(voice_name): |
|
for voice in EXAMPLE_VOICES: |
|
if voice["name"] == voice_name and "audio" in voice: |
|
return voice["transcript"], voice["audio"] |
|
return "", None |
|
|
|
|
|
def save_voice_for_tts(example_name, example_audio, custom_audio, example_transcript, custom_transcript): |
|
try: |
|
|
|
if custom_audio is not None: |
|
|
|
if not custom_transcript.strip(): |
|
return "Error: Please provide a transcript for your recorded voice" |
|
|
|
voice_audio = custom_audio |
|
transcript = custom_transcript |
|
source = "custom recording" |
|
elif example_audio is not None: |
|
|
|
voice_audio = example_audio |
|
transcript = example_transcript |
|
source = f"example: {example_name}" |
|
else: |
|
return "Error: No voice selected or recorded" |
|
|
|
|
|
result = engine.save_reference_voice(voice_audio, transcript) |
|
|
|
return f"Voice saved successfully! Using {source}" |
|
except Exception as e: |
|
print(f"Error saving voice: {e}") |
|
traceback.print_exc() |
|
return f"Error saving voice: {str(e)}" |
|
|
|
|
|
def update_tts_progress(progress): |
|
return progress |
|
|
|
|
|
def on_tts_generated(audio_path): |
|
print(f"TTS generation callback received path: {audio_path}") |
|
return audio_path, 100, "Response ready" |
|
|
|
|
|
def process_input(audio, text_input, history, language, progress): |
|
try: |
|
|
|
status = "Processing input..." |
|
|
|
|
|
progress = 0 |
|
|
|
|
|
if audio is not None: |
|
|
|
transcribed_text, input_text = engine.transcribe_audio(audio, language) |
|
if not input_text: |
|
status = "Could not understand audio. Please try again." |
|
return history, None, status, text_input, progress |
|
elif text_input and text_input.strip(): |
|
|
|
input_text = text_input.strip() |
|
transcribed_text = input_text |
|
else: |
|
|
|
status = "No input detected. Please speak or type a message." |
|
return history, None, status, text_input, progress |
|
|
|
|
|
engine.add_message("user", input_text) |
|
|
|
|
|
updated_history = history + [[transcribed_text, None]] |
|
|
|
|
|
status = "Generating response..." |
|
progress = 30 |
|
|
|
|
|
response_text, _ = engine.generate_response(input_text) |
|
|
|
|
|
engine.add_message("assistant", response_text) |
|
|
|
|
|
updated_history = history + [[transcribed_text, response_text]] |
|
|
|
|
|
status = "Generating speech..." |
|
progress = 60 |
|
|
|
|
|
audio_path = engine._generate_tts(response_text) |
|
|
|
if audio_path: |
|
status = f"Response ready: {audio_path}" |
|
progress = 100 |
|
print(f"Audio generated successfully: {audio_path}") |
|
else: |
|
status = "Failed to generate speech" |
|
|
|
|
|
return updated_history, audio_path, status, "", progress |
|
|
|
except Exception as e: |
|
|
|
error_message = f"Error: {str(e)}" |
|
print(error_message) |
|
traceback.print_exc() |
|
return history, None, error_message, text_input, progress |
|
|
|
|
|
def clear_chat(): |
|
engine.clear_conversation() |
|
return [], None, "Chat history cleared", "", 0 |
|
|
|
|
|
|
|
|
|
voice_selector.change( |
|
update_voice_example, |
|
inputs=[voice_selector], |
|
outputs=[voice_info, example_audio] |
|
) |
|
|
|
|
|
save_voice_btn.click( |
|
save_voice_for_tts, |
|
inputs=[voice_selector, example_audio, custom_voice, voice_info, custom_transcript], |
|
outputs=[voice_status] |
|
) |
|
|
|
|
|
submit_btn.click( |
|
process_input, |
|
inputs=[audio_input, audio_msg, chatbot, language_selector, tts_progress_state], |
|
outputs=[chatbot, audio_output, status_msg, audio_msg, tts_progress] |
|
) |
|
|
|
|
|
audio_msg.submit( |
|
process_input, |
|
inputs=[audio_input, audio_msg, chatbot, language_selector, tts_progress_state], |
|
outputs=[chatbot, audio_output, status_msg, audio_msg, tts_progress] |
|
) |
|
|
|
|
|
clear_btn.click( |
|
clear_chat, |
|
inputs=[], |
|
outputs=[chatbot, audio_output, status_msg, audio_msg, tts_progress] |
|
) |
|
|
|
|
|
def exit_handler(): |
|
engine.cleanup() |
|
|
|
import atexit |
|
atexit.register(exit_handler) |
|
|
|
|
|
interface.queue() |
|
|
|
return interface |
|
|
|
|
|
if __name__ == "__main__": |
|
print("Starting Malayalam Voice Chatbot with IndicF5 Voice Selection...") |
|
interface = create_chatbot_interface() |
|
interface.launch(debug=True) |