Update app.py
Browse files
app.py
CHANGED
@@ -902,394 +902,11 @@ class ConversationEngine:
|
|
902 |
conversation_engine = ConversationEngine()
|
903 |
speech_recognizer = SpeechRecognizer()
|
904 |
|
905 |
-
class ConversationEngine:
|
906 |
-
def __init__(self):
|
907 |
-
self.conversation_history = []
|
908 |
-
self.system_prompt = "You are a helpful assistant that speaks Malayalam fluently. Always respond in Malayalam script with proper formatting."
|
909 |
-
self.saved_voice = None
|
910 |
-
self.saved_voice_text = ""
|
911 |
-
self.tts_cache = {} # Cache for TTS outputs
|
912 |
-
|
913 |
-
# TTS background processing queue
|
914 |
-
self.tts_queue = queue.Queue()
|
915 |
-
self.tts_thread = threading.Thread(target=self.tts_worker, daemon=True)
|
916 |
-
self.tts_thread.start()
|
917 |
-
|
918 |
-
# Initialize IndicF5 TTS model if available
|
919 |
-
self.tts_model = None
|
920 |
-
self.device = None
|
921 |
-
try:
|
922 |
-
self.initialize_tts_model()
|
923 |
-
|
924 |
-
# Test the model if it was loaded successfully
|
925 |
-
if self.tts_model is not None:
|
926 |
-
print("TTS model initialized successfully")
|
927 |
-
except Exception as e:
|
928 |
-
print(f"Error initializing TTS model: {e}")
|
929 |
-
traceback.print_exc()
|
930 |
-
|
931 |
-
def initialize_tts_model(self):
|
932 |
-
"""Initialize the IndicF5 TTS model with optimizations"""
|
933 |
-
try:
|
934 |
-
# Check for HF token in environment and use it if available
|
935 |
-
hf_token = os.getenv("HF_TOKEN")
|
936 |
-
if hf_token:
|
937 |
-
print("Logging into Hugging Face with the provided token.")
|
938 |
-
login(token=hf_token)
|
939 |
-
|
940 |
-
if torch.cuda.is_available():
|
941 |
-
self.device = torch.device("cuda")
|
942 |
-
print(f"Using GPU: {torch.cuda.get_device_name(0)}")
|
943 |
-
else:
|
944 |
-
self.device = torch.device("cpu")
|
945 |
-
print("Using CPU")
|
946 |
-
|
947 |
-
# Enable performance optimizations
|
948 |
-
torch.backends.cudnn.benchmark = True
|
949 |
-
|
950 |
-
# Load TTS model and move it to the appropriate device (GPU/CPU)
|
951 |
-
print("Loading TTS model from ai4bharat/IndicF5...")
|
952 |
-
repo_id = "ai4bharat/IndicF5"
|
953 |
-
self.tts_model = AutoModel.from_pretrained(repo_id, trust_remote_code=True)
|
954 |
-
self.tts_model = self.tts_model.to(self.device)
|
955 |
-
|
956 |
-
# Set model to evaluation mode for faster inference
|
957 |
-
self.tts_model.eval()
|
958 |
-
print("TTS model loaded successfully")
|
959 |
-
except Exception as e:
|
960 |
-
print(f"Failed to load TTS model: {e}")
|
961 |
-
self.tts_model = None
|
962 |
-
traceback.print_exc()
|
963 |
-
|
964 |
-
def tts_worker(self):
|
965 |
-
"""Background worker to process TTS requests"""
|
966 |
-
while True:
|
967 |
-
try:
|
968 |
-
# Get text and callback from queue
|
969 |
-
text, callback = self.tts_queue.get()
|
970 |
-
|
971 |
-
# Generate speech
|
972 |
-
audio_path = self._generate_tts(text)
|
973 |
-
|
974 |
-
# Execute callback with result
|
975 |
-
if callback:
|
976 |
-
callback(audio_path)
|
977 |
-
|
978 |
-
# Mark task as done
|
979 |
-
self.tts_queue.task_done()
|
980 |
-
except Exception as e:
|
981 |
-
print(f"Error in TTS worker: {e}")
|
982 |
-
traceback.print_exc()
|
983 |
-
|
984 |
-
def transcribe_audio(self, audio_data, language="ml-IN"):
|
985 |
-
"""Convert audio to text using speech recognition"""
|
986 |
-
if audio_data is None:
|
987 |
-
print("No audio data received")
|
988 |
-
return "No audio detected", ""
|
989 |
-
|
990 |
-
# Make sure we have audio data in the expected format
|
991 |
-
try:
|
992 |
-
if isinstance(audio_data, tuple) and len(audio_data) == 2:
|
993 |
-
# Expected format: (sample_rate, audio_samples)
|
994 |
-
sample_rate, audio_samples = audio_data
|
995 |
-
else:
|
996 |
-
print(f"Unexpected audio format: {type(audio_data)}")
|
997 |
-
return "Invalid audio format", ""
|
998 |
-
|
999 |
-
if len(audio_samples) == 0:
|
1000 |
-
print("Empty audio samples")
|
1001 |
-
return "No speech detected", ""
|
1002 |
-
|
1003 |
-
# Save the audio temporarily
|
1004 |
-
temp_file = tempfile.NamedTemporaryFile(delete=False, suffix=".wav")
|
1005 |
-
temp_file.close()
|
1006 |
-
|
1007 |
-
# Save the audio data to the temporary file
|
1008 |
-
sf.write(temp_file.name, audio_samples, sample_rate)
|
1009 |
-
|
1010 |
-
# Use speech recognition on the file
|
1011 |
-
recognizer = sr.Recognizer()
|
1012 |
-
with sr.AudioFile(temp_file.name) as source:
|
1013 |
-
audio = recognizer.record(source)
|
1014 |
-
|
1015 |
-
text = recognizer.recognize_google(audio, language=language)
|
1016 |
-
print(f"Recognized: {text}")
|
1017 |
-
return text, text
|
1018 |
-
|
1019 |
-
except sr.UnknownValueError:
|
1020 |
-
print("Speech recognition could not understand audio")
|
1021 |
-
return "Could not understand audio", ""
|
1022 |
-
except sr.RequestError as e:
|
1023 |
-
print(f"Could not request results from Google Speech Recognition service: {e}")
|
1024 |
-
return f"Speech recognition service error: {str(e)}", ""
|
1025 |
-
except Exception as e:
|
1026 |
-
print(f"Error processing audio: {e}")
|
1027 |
-
traceback.print_exc()
|
1028 |
-
return f"Error processing audio: {str(e)}", ""
|
1029 |
-
finally:
|
1030 |
-
# Clean up temporary file
|
1031 |
-
if 'temp_file' in locals() and os.path.exists(temp_file.name):
|
1032 |
-
try:
|
1033 |
-
os.unlink(temp_file.name)
|
1034 |
-
except Exception as e:
|
1035 |
-
print(f"Error deleting temporary file: {e}")
|
1036 |
-
|
1037 |
-
def save_reference_voice(self, audio_data, reference_text):
|
1038 |
-
"""Save the reference voice for future TTS generation"""
|
1039 |
-
if audio_data is None or not reference_text.strip():
|
1040 |
-
return "Error: Both reference audio and text are required"
|
1041 |
-
|
1042 |
-
self.saved_voice = audio_data
|
1043 |
-
self.saved_voice_text = reference_text.strip()
|
1044 |
-
|
1045 |
-
# Clear TTS cache when voice changes
|
1046 |
-
self.tts_cache.clear()
|
1047 |
-
|
1048 |
-
# Debug info
|
1049 |
-
sample_rate, audio_samples = audio_data
|
1050 |
-
print(f"Saved reference voice: {len(audio_samples)} samples at {sample_rate}Hz")
|
1051 |
-
print(f"Reference text: {reference_text}")
|
1052 |
-
|
1053 |
-
return f"Voice saved successfully! Reference text: {reference_text}"
|
1054 |
-
|
1055 |
-
def process_text_input(self, text):
|
1056 |
-
"""Process text input from user"""
|
1057 |
-
if text and text.strip():
|
1058 |
-
return text, text
|
1059 |
-
return "No input provided", ""
|
1060 |
-
|
1061 |
-
def generate_response(self, input_text):
|
1062 |
-
"""Generate AI response using GPT-3.5 Turbo"""
|
1063 |
-
if not input_text or not input_text.strip():
|
1064 |
-
return "ഇൻപുട്ട് ലഭിച്ചില്ല. വീണ്ടും ശ്രമിക്കുക.", None # "No input received. Please try again."
|
1065 |
-
|
1066 |
-
try:
|
1067 |
-
# Prepare conversation context from history
|
1068 |
-
messages = [{"role": "system", "content": self.system_prompt}]
|
1069 |
-
|
1070 |
-
# Add previous conversations for context
|
1071 |
-
for entry in self.conversation_history:
|
1072 |
-
role = "user" if entry["role"] == "user" else "assistant"
|
1073 |
-
messages.append({"role": role, "content": entry["content"]})
|
1074 |
-
|
1075 |
-
# Add current input
|
1076 |
-
messages.append({"role": "user", "content": input_text})
|
1077 |
-
|
1078 |
-
# Call OpenAI API
|
1079 |
-
response = openai.ChatCompletion.create(
|
1080 |
-
model="gpt-3.5-turbo",
|
1081 |
-
messages=messages,
|
1082 |
-
max_tokens=500,
|
1083 |
-
temperature=0.7
|
1084 |
-
)
|
1085 |
-
|
1086 |
-
response_text = response.choices[0].message.content.strip()
|
1087 |
-
return response_text, None
|
1088 |
-
|
1089 |
-
except Exception as e:
|
1090 |
-
error_msg = f"എറർ: GPT മോഡലിൽ നിന്ന് ഉത്തരം ലഭിക്കുന്നതിൽ പ്രശ്നമുണ്ടായി: {str(e)}"
|
1091 |
-
print(f"Error in GPT response: {e}")
|
1092 |
-
traceback.print_exc()
|
1093 |
-
return error_msg, None
|
1094 |
-
|
1095 |
-
def resample_audio(self, audio, orig_sr, target_sr):
|
1096 |
-
"""Resample audio to match target sample rate only if necessary"""
|
1097 |
-
if orig_sr != target_sr:
|
1098 |
-
print(f"Resampling audio from {orig_sr}Hz to {target_sr}Hz")
|
1099 |
-
return librosa.resample(audio, orig_sr=orig_sr, target_sr=target_sr)
|
1100 |
-
return audio
|
1101 |
-
|
1102 |
-
def _generate_tts(self, text):
|
1103 |
-
"""Internal method to generate TTS without threading"""
|
1104 |
-
if not text or not text.strip():
|
1105 |
-
print("No text provided for TTS generation")
|
1106 |
-
return None
|
1107 |
-
|
1108 |
-
# Check cache first
|
1109 |
-
if text in self.tts_cache:
|
1110 |
-
print("Using cached TTS output")
|
1111 |
-
return self.tts_cache[text]
|
1112 |
-
|
1113 |
-
try:
|
1114 |
-
# Check if we have a saved voice and the TTS model
|
1115 |
-
if self.saved_voice is not None and self.tts_model is not None:
|
1116 |
-
sample_rate, audio_data = self.saved_voice
|
1117 |
-
|
1118 |
-
# Create a temporary file for the reference audio
|
1119 |
-
ref_temp_file = tempfile.NamedTemporaryFile(delete=False, suffix=".wav")
|
1120 |
-
ref_temp_file.close()
|
1121 |
-
print(f"Saving reference audio to {ref_temp_file.name}")
|
1122 |
-
|
1123 |
-
# Save the reference audio data
|
1124 |
-
sf.write(ref_temp_file.name, audio_data, sample_rate)
|
1125 |
-
|
1126 |
-
# Create a temporary file for the output audio
|
1127 |
-
output_temp_file = tempfile.NamedTemporaryFile(delete=False, suffix=".wav")
|
1128 |
-
output_temp_file.close()
|
1129 |
-
|
1130 |
-
try:
|
1131 |
-
# Generate speech using IndicF5 - simplified approach from second file
|
1132 |
-
print(f"Generating speech with IndicF5. Text: {text[:30]}...")
|
1133 |
-
start_time = time.time()
|
1134 |
-
|
1135 |
-
# Use torch.no_grad() to save memory and computation
|
1136 |
-
with torch.no_grad():
|
1137 |
-
# Run the inference - directly use the model as in the second file
|
1138 |
-
synth_audio = self.tts_model(
|
1139 |
-
text,
|
1140 |
-
ref_audio_path=ref_temp_file.name,
|
1141 |
-
ref_text=self.saved_voice_text
|
1142 |
-
)
|
1143 |
-
|
1144 |
-
end_time = time.time()
|
1145 |
-
print(f"Speech generation completed in {(end_time - start_time)} seconds")
|
1146 |
-
|
1147 |
-
# Normalize output if needed
|
1148 |
-
if synth_audio.dtype == np.int16:
|
1149 |
-
synth_audio = synth_audio.astype(np.float32) / 32768.0
|
1150 |
-
|
1151 |
-
# Resample the generated audio to match the reference audio's sample rate
|
1152 |
-
synth_audio = self.resample_audio(synth_audio, orig_sr=24000, target_sr=sample_rate)
|
1153 |
-
|
1154 |
-
# Save the synthesized audio
|
1155 |
-
print(f"Saving synthesized audio to {output_temp_file.name}")
|
1156 |
-
sf.write(output_temp_file.name, synth_audio, sample_rate)
|
1157 |
-
|
1158 |
-
# Cache the result
|
1159 |
-
self.tts_cache[text] = output_temp_file.name
|
1160 |
-
|
1161 |
-
print(f"TTS generation successful, output file: {output_temp_file.name}")
|
1162 |
-
return output_temp_file.name
|
1163 |
-
except Exception as e:
|
1164 |
-
print(f"IndicF5 TTS failed with error: {e}")
|
1165 |
-
traceback.print_exc()
|
1166 |
-
# Fall back to Google TTS
|
1167 |
-
return self.fallback_tts(text, output_temp_file.name)
|
1168 |
-
finally:
|
1169 |
-
# Clean up reference audio file
|
1170 |
-
if os.path.exists(ref_temp_file.name):
|
1171 |
-
try:
|
1172 |
-
os.unlink(ref_temp_file.name)
|
1173 |
-
except Exception as e:
|
1174 |
-
print(f"Error deleting temporary file: {e}")
|
1175 |
-
else:
|
1176 |
-
if self.saved_voice is None:
|
1177 |
-
print("No saved voice available for TTS")
|
1178 |
-
if self.tts_model is None:
|
1179 |
-
print("TTS model not initialized")
|
1180 |
-
|
1181 |
-
# No saved voice or TTS model, use fallback
|
1182 |
-
temp_file = tempfile.NamedTemporaryFile(delete=False, suffix=".wav")
|
1183 |
-
temp_file.close()
|
1184 |
-
return self.fallback_tts(text, temp_file.name)
|
1185 |
-
|
1186 |
-
except Exception as e:
|
1187 |
-
print(f"Error in TTS processing: {e}")
|
1188 |
-
traceback.print_exc()
|
1189 |
-
return None
|
1190 |
-
|
1191 |
-
def speak_with_indicf5(self, text, callback=None):
|
1192 |
-
"""Queue text for TTS generation"""
|
1193 |
-
if not text or not text.strip():
|
1194 |
-
if callback:
|
1195 |
-
callback(None)
|
1196 |
-
return None
|
1197 |
-
|
1198 |
-
# Check cache first for immediate response
|
1199 |
-
if text in self.tts_cache:
|
1200 |
-
print("Using cached TTS output")
|
1201 |
-
if callback:
|
1202 |
-
callback(self.tts_cache[text])
|
1203 |
-
return self.tts_cache[text]
|
1204 |
-
|
1205 |
-
# If no callback provided, generate synchronously
|
1206 |
-
if callback is None:
|
1207 |
-
return self._generate_tts(text)
|
1208 |
-
|
1209 |
-
# Otherwise, queue for async processing
|
1210 |
-
self.tts_queue.put((text, callback))
|
1211 |
-
return None
|
1212 |
-
|
1213 |
-
def fallback_tts(self, text, output_path):
|
1214 |
-
"""Fallback to Google TTS if IndicF5 fails"""
|
1215 |
-
try:
|
1216 |
-
from gtts import gTTS
|
1217 |
-
|
1218 |
-
# Determine if text is Malayalam
|
1219 |
-
is_malayalam = any('\u0D00' <= c <= '\u0D7F' for c in text)
|
1220 |
-
lang = 'ml' if is_malayalam else 'en'
|
1221 |
-
|
1222 |
-
print(f"Using fallback Google TTS with language: {lang}")
|
1223 |
-
tts = gTTS(text=text, lang=lang, slow=False)
|
1224 |
-
tts.save(output_path)
|
1225 |
-
|
1226 |
-
# Cache the result
|
1227 |
-
self.tts_cache[text] = output_path
|
1228 |
-
print(f"Fallback TTS saved to: {output_path}")
|
1229 |
-
|
1230 |
-
return output_path
|
1231 |
-
except Exception as e:
|
1232 |
-
print(f"Fallback TTS also failed: {e}")
|
1233 |
-
traceback.print_exc()
|
1234 |
-
return None
|
1235 |
-
|
1236 |
-
def add_message(self, role, content):
|
1237 |
-
"""Add a message to the conversation history"""
|
1238 |
-
timestamp = datetime.now().strftime("%H:%M:%S")
|
1239 |
-
self.conversation_history.append({
|
1240 |
-
"role": role,
|
1241 |
-
"content": content,
|
1242 |
-
"timestamp": timestamp
|
1243 |
-
})
|
1244 |
-
|
1245 |
-
def clear_conversation(self):
|
1246 |
-
"""Clear the conversation history"""
|
1247 |
-
self.conversation_history = []
|
1248 |
-
|
1249 |
-
def cleanup(self):
|
1250 |
-
"""Clean up resources when shutting down"""
|
1251 |
-
print("Cleaning up resources...")
|
1252 |
-
|
1253 |
-
# Load example Malayalam voices
|
1254 |
-
def load_audio_from_url(url):
|
1255 |
-
"""Load audio from a URL"""
|
1256 |
-
try:
|
1257 |
-
response = requests.get(url)
|
1258 |
-
if response.status_code == 200:
|
1259 |
-
audio_data, sample_rate = sf.read(io.BytesIO(response.content))
|
1260 |
-
return sample_rate, audio_data
|
1261 |
-
except Exception as e:
|
1262 |
-
print(f"Error loading audio from URL: {e}")
|
1263 |
-
return None, None
|
1264 |
-
|
1265 |
-
# Malayalam voice examples
|
1266 |
-
EXAMPLE_VOICES = [
|
1267 |
-
{
|
1268 |
-
"name": "Aparna Voice",
|
1269 |
-
"url": "https://raw.githubusercontent.com/Aparna0112/voicerecording-_TTS/main/Aparna%20Voice.wav",
|
1270 |
-
"transcript": "ഞാൻ ഒരു ഫോണിന്റെ കവർ നോക്കുകയാണ്. എനിക്ക് സ്മാർട്ട് ഫോണിന് കവർ വേണം"
|
1271 |
-
},
|
1272 |
-
{
|
1273 |
-
"name": "KC Voice",
|
1274 |
-
"url": "https://raw.githubusercontent.com/Aparna0112/voicerecording-_TTS/main/KC%20Voice.wav",
|
1275 |
-
"transcript": "ഹലോ ഇത് അപരനെ അല്ലേ ഞാൻ ജഗദീപ് ആണ് വിളിക്കുന്നത് ഇപ്പോൾ ഫ്രീയാണോ സംസാരിക്കാമോ"
|
1276 |
-
}
|
1277 |
-
]
|
1278 |
-
|
1279 |
-
# Preload example voices
|
1280 |
-
for voice in EXAMPLE_VOICES:
|
1281 |
-
sample_rate, audio_data = load_audio_from_url(voice["url"])
|
1282 |
-
if sample_rate is not None:
|
1283 |
-
voice["audio"] = (sample_rate, audio_data)
|
1284 |
-
print(f"Loaded example voice: {voice['name']}")
|
1285 |
-
else:
|
1286 |
-
print(f"Failed to load voice: {voice['name']}")
|
1287 |
-
|
1288 |
def create_chatbot_interface():
|
1289 |
"""Create a single-page chatbot interface with voice input, output, and voice selection"""
|
1290 |
|
1291 |
-
#
|
1292 |
-
|
1293 |
|
1294 |
# CSS for styling the chat interface
|
1295 |
css = """
|
@@ -1297,7 +914,7 @@ def create_chatbot_interface():
|
|
1297 |
display: flex;
|
1298 |
flex-direction: column;
|
1299 |
height: 100%;
|
1300 |
-
max-width:
|
1301 |
margin: 0 auto;
|
1302 |
}
|
1303 |
.chat-window {
|
@@ -1307,7 +924,7 @@ def create_chatbot_interface():
|
|
1307 |
background: #f5f7f9;
|
1308 |
border-radius: 0.5rem;
|
1309 |
margin-bottom: 1rem;
|
1310 |
-
min-height:
|
1311 |
}
|
1312 |
.input-area {
|
1313 |
display: flex;
|
@@ -1316,11 +933,12 @@ def create_chatbot_interface():
|
|
1316 |
align-items: center;
|
1317 |
}
|
1318 |
.message {
|
1319 |
-
margin-bottom:
|
1320 |
-
padding: 0.
|
1321 |
border-radius: 0.5rem;
|
1322 |
position: relative;
|
1323 |
max-width: 80%;
|
|
|
1324 |
}
|
1325 |
.user-message {
|
1326 |
background: #e1f5fe;
|
@@ -1341,10 +959,11 @@ def create_chatbot_interface():
|
|
1341 |
text-align: center;
|
1342 |
color: #333;
|
1343 |
margin-bottom: 1rem;
|
|
|
1344 |
}
|
1345 |
.chat-controls {
|
1346 |
display: flex;
|
1347 |
-
justify-content:
|
1348 |
margin-bottom: 0.5rem;
|
1349 |
}
|
1350 |
.voice-selector {
|
@@ -1353,46 +972,64 @@ def create_chatbot_interface():
|
|
1353 |
border-radius: 0.5rem;
|
1354 |
margin-bottom: 1rem;
|
1355 |
}
|
1356 |
-
.
|
1357 |
-
|
1358 |
-
|
1359 |
-
|
1360 |
-
|
1361 |
-
|
1362 |
}
|
1363 |
-
.
|
1364 |
-
|
1365 |
-
|
1366 |
-
|
1367 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1368 |
}
|
1369 |
"""
|
1370 |
|
1371 |
with gr.Blocks(css=css, title="Malayalam Voice Chatbot") as interface:
|
1372 |
-
gr.Markdown("# 🤖 Malayalam Voice Chatbot
|
1373 |
|
1374 |
-
# Create a state variable for TTS progress
|
1375 |
tts_progress_state = gr.State(0)
|
1376 |
audio_output_state = gr.State(None)
|
1377 |
|
1378 |
with gr.Row(elem_classes=["chatbot-container"]):
|
1379 |
with gr.Column():
|
1380 |
-
# Voice selection section
|
1381 |
with gr.Accordion("🎤 Voice Selection", open=True):
|
1382 |
-
|
1383 |
-
|
1384 |
-
|
1385 |
-
|
1386 |
-
|
1387 |
-
|
1388 |
-
|
1389 |
-
|
1390 |
-
|
1391 |
-
|
1392 |
-
|
1393 |
-
|
1394 |
-
|
1395 |
-
|
|
|
|
|
|
|
1396 |
|
1397 |
# Play selected example voice
|
1398 |
example_audio = gr.Audio(
|
@@ -1401,33 +1038,40 @@ def create_chatbot_interface():
|
|
1401 |
interactive=False
|
1402 |
)
|
1403 |
|
1404 |
-
|
1405 |
-
|
1406 |
-
|
1407 |
-
|
1408 |
-
|
1409 |
-
|
1410 |
-
|
1411 |
-
|
1412 |
-
|
1413 |
-
|
1414 |
-
|
1415 |
-
|
1416 |
-
|
1417 |
-
|
|
|
|
|
|
|
|
|
1418 |
|
1419 |
# Button to save the selected/recorded voice
|
1420 |
-
|
1421 |
-
|
|
|
1422 |
|
1423 |
-
#
|
1424 |
with gr.Row(elem_classes=["chat-controls"]):
|
|
|
1425 |
language_selector = gr.Dropdown(
|
1426 |
choices=["ml-IN", "en-US", "hi-IN", "ta-IN", "te-IN", "kn-IN"],
|
1427 |
value="ml-IN",
|
1428 |
-
label="Speech Recognition Language"
|
|
|
1429 |
)
|
1430 |
-
clear_btn = gr.Button("🧹 Clear Chat",
|
1431 |
|
1432 |
# Chat display area
|
1433 |
chatbot = gr.Chatbot(
|
@@ -1438,54 +1082,61 @@ def create_chatbot_interface():
|
|
1438 |
elem_classes=["chat-window"]
|
1439 |
)
|
1440 |
|
1441 |
-
#
|
1442 |
-
|
1443 |
-
|
1444 |
-
|
1445 |
-
|
1446 |
-
|
1447 |
-
|
1448 |
-
|
1449 |
-
|
1450 |
|
1451 |
# Audio output for the bot's response
|
1452 |
audio_output = gr.Audio(
|
1453 |
label="Bot's Voice Response",
|
1454 |
type="filepath",
|
1455 |
autoplay=True,
|
1456 |
-
visible=True
|
|
|
1457 |
)
|
1458 |
|
1459 |
-
#
|
1460 |
status_msg = gr.Textbox(
|
1461 |
label="Status",
|
1462 |
value="Ready",
|
1463 |
-
interactive=False
|
|
|
1464 |
)
|
1465 |
|
1466 |
# Input area with separate components
|
1467 |
with gr.Row(elem_classes=["input-area"]):
|
1468 |
-
|
1469 |
-
|
1470 |
-
|
1471 |
-
|
1472 |
-
|
1473 |
-
|
1474 |
-
|
1475 |
-
|
1476 |
-
|
1477 |
-
|
1478 |
-
|
1479 |
-
|
1480 |
-
|
1481 |
-
|
|
|
|
|
|
|
|
|
|
|
1482 |
def update_voice_example(voice_name):
|
1483 |
for voice in EXAMPLE_VOICES:
|
1484 |
if voice["name"] == voice_name and "audio" in voice:
|
1485 |
return voice["transcript"], voice["audio"]
|
1486 |
return "", None
|
1487 |
|
1488 |
-
# Function to save voice for TTS
|
1489 |
def save_voice_for_tts(example_name, example_audio, custom_audio, example_transcript, custom_transcript):
|
1490 |
try:
|
1491 |
# Check if we're using an example voice or custom recorded voice
|
@@ -1506,7 +1157,7 @@ def create_chatbot_interface():
|
|
1506 |
return "Error: No voice selected or recorded"
|
1507 |
|
1508 |
# Save the voice in the engine
|
1509 |
-
result =
|
1510 |
|
1511 |
return f"Voice saved successfully! Using {source}"
|
1512 |
except Exception as e:
|
@@ -1514,16 +1165,16 @@ def create_chatbot_interface():
|
|
1514 |
traceback.print_exc()
|
1515 |
return f"Error saving voice: {str(e)}"
|
1516 |
|
1517 |
-
# Function to update TTS progress
|
1518 |
def update_tts_progress(progress):
|
1519 |
return progress
|
1520 |
|
1521 |
-
# Audio generated callback
|
1522 |
def on_tts_generated(audio_path):
|
1523 |
print(f"TTS generation callback received path: {audio_path}")
|
1524 |
return audio_path, 100, "Response ready" # audio path, 100% progress, status message
|
1525 |
|
1526 |
-
# Function to process user input and generate response
|
1527 |
def process_input(audio, text_input, history, language, progress):
|
1528 |
try:
|
1529 |
# Update status
|
@@ -1535,7 +1186,7 @@ def create_chatbot_interface():
|
|
1535 |
# Check which input mode we're using
|
1536 |
if audio is not None:
|
1537 |
# Audio input
|
1538 |
-
transcribed_text, input_text =
|
1539 |
if not input_text:
|
1540 |
status = "Could not understand audio. Please try again."
|
1541 |
return history, None, status, text_input, progress
|
@@ -1549,7 +1200,7 @@ def create_chatbot_interface():
|
|
1549 |
return history, None, status, text_input, progress
|
1550 |
|
1551 |
# Add user message to conversation history
|
1552 |
-
|
1553 |
|
1554 |
# Update the Gradio chatbot display immediately with user message
|
1555 |
updated_history = history + [[transcribed_text, None]]
|
@@ -1559,10 +1210,10 @@ def create_chatbot_interface():
|
|
1559 |
progress = 30
|
1560 |
|
1561 |
# Generate response
|
1562 |
-
response_text, _ =
|
1563 |
|
1564 |
# Add assistant response to conversation history
|
1565 |
-
|
1566 |
|
1567 |
# Update the Gradio chatbot with the assistant's response
|
1568 |
updated_history = history + [[transcribed_text, response_text]]
|
@@ -1572,7 +1223,7 @@ def create_chatbot_interface():
|
|
1572 |
progress = 60
|
1573 |
|
1574 |
# Generate speech for response synchronously (for better debugging)
|
1575 |
-
audio_path =
|
1576 |
|
1577 |
if audio_path:
|
1578 |
status = f"Response ready: {audio_path}"
|
@@ -1591,9 +1242,9 @@ def create_chatbot_interface():
|
|
1591 |
traceback.print_exc()
|
1592 |
return history, None, error_message, text_input, progress
|
1593 |
|
1594 |
-
# Function to clear chat history
|
1595 |
def clear_chat():
|
1596 |
-
|
1597 |
return [], None, "Chat history cleared", "", 0
|
1598 |
|
1599 |
# Connect event handlers
|
@@ -1635,7 +1286,7 @@ def create_chatbot_interface():
|
|
1635 |
|
1636 |
# Setup cleanup on exit
|
1637 |
def exit_handler():
|
1638 |
-
|
1639 |
|
1640 |
import atexit
|
1641 |
atexit.register(exit_handler)
|
@@ -1643,10 +1294,4 @@ def create_chatbot_interface():
|
|
1643 |
# Enable queueing for better responsiveness
|
1644 |
interface.queue()
|
1645 |
|
1646 |
-
return interface
|
1647 |
-
|
1648 |
-
# Start the interface
|
1649 |
-
if __name__ == "__main__":
|
1650 |
-
print("Starting Malayalam Voice Chatbot with IndicF5 Voice Selection...")
|
1651 |
-
interface = create_chatbot_interface()
|
1652 |
-
interface.launch(debug=True) # Enable debug mode to see errors in the console
|
|
|
902 |
conversation_engine = ConversationEngine()
|
903 |
speech_recognizer = SpeechRecognizer()
|
904 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
905 |
def create_chatbot_interface():
|
906 |
"""Create a single-page chatbot interface with voice input, output, and voice selection"""
|
907 |
|
908 |
+
# Use global conversation engine
|
909 |
+
global conversation_engine, speech_recognizer
|
910 |
|
911 |
# CSS for styling the chat interface
|
912 |
css = """
|
|
|
914 |
display: flex;
|
915 |
flex-direction: column;
|
916 |
height: 100%;
|
917 |
+
max-width: 1000px;
|
918 |
margin: 0 auto;
|
919 |
}
|
920 |
.chat-window {
|
|
|
924 |
background: #f5f7f9;
|
925 |
border-radius: 0.5rem;
|
926 |
margin-bottom: 1rem;
|
927 |
+
min-height: 450px;
|
928 |
}
|
929 |
.input-area {
|
930 |
display: flex;
|
|
|
933 |
align-items: center;
|
934 |
}
|
935 |
.message {
|
936 |
+
margin-bottom: 0.8rem;
|
937 |
+
padding: 0.7rem;
|
938 |
border-radius: 0.5rem;
|
939 |
position: relative;
|
940 |
max-width: 80%;
|
941 |
+
font-size: 0.95rem;
|
942 |
}
|
943 |
.user-message {
|
944 |
background: #e1f5fe;
|
|
|
959 |
text-align: center;
|
960 |
color: #333;
|
961 |
margin-bottom: 1rem;
|
962 |
+
font-size: 1.8rem;
|
963 |
}
|
964 |
.chat-controls {
|
965 |
display: flex;
|
966 |
+
justify-content: flex-end;
|
967 |
margin-bottom: 0.5rem;
|
968 |
}
|
969 |
.voice-selector {
|
|
|
972 |
border-radius: 0.5rem;
|
973 |
margin-bottom: 1rem;
|
974 |
}
|
975 |
+
button.primary {
|
976 |
+
background-color: #4f46e5;
|
977 |
+
color: white;
|
978 |
+
padding: 0.6rem 1.2rem;
|
979 |
+
border-radius: 0.375rem;
|
980 |
+
font-weight: 500;
|
981 |
}
|
982 |
+
button.secondary {
|
983 |
+
background-color: #e5e7eb;
|
984 |
+
color: #374151;
|
985 |
+
padding: 0.6rem 1.2rem;
|
986 |
+
border-radius: 0.375rem;
|
987 |
+
font-weight: 500;
|
988 |
+
}
|
989 |
+
.audio-player {
|
990 |
+
margin-top: 0.5rem;
|
991 |
+
margin-bottom: 1rem;
|
992 |
+
}
|
993 |
+
/* Customizing Gradio's default elements */
|
994 |
+
.gradio-container {
|
995 |
+
max-width: 1000px !important;
|
996 |
+
}
|
997 |
+
.message-bubble {
|
998 |
+
font-size: 0.95rem !important;
|
999 |
+
}
|
1000 |
+
.message-wrap {
|
1001 |
+
margin-bottom: 8px !important;
|
1002 |
}
|
1003 |
"""
|
1004 |
|
1005 |
with gr.Blocks(css=css, title="Malayalam Voice Chatbot") as interface:
|
1006 |
+
gr.Markdown("# 🤖 Malayalam Voice Chatbot", elem_classes=["chatbot-header"])
|
1007 |
|
1008 |
+
# Create a state variable for TTS progress (hidden but needed for functionality)
|
1009 |
tts_progress_state = gr.State(0)
|
1010 |
audio_output_state = gr.State(None)
|
1011 |
|
1012 |
with gr.Row(elem_classes=["chatbot-container"]):
|
1013 |
with gr.Column():
|
1014 |
+
# Voice selection section
|
1015 |
with gr.Accordion("🎤 Voice Selection", open=True):
|
1016 |
+
with gr.Row():
|
1017 |
+
# Select from example voices
|
1018 |
+
with gr.Column(scale=1):
|
1019 |
+
voice_selector = gr.Dropdown(
|
1020 |
+
choices=[voice["name"] for voice in EXAMPLE_VOICES],
|
1021 |
+
value=EXAMPLE_VOICES[0]["name"] if EXAMPLE_VOICES else None,
|
1022 |
+
label="Select Example Voice"
|
1023 |
+
)
|
1024 |
+
|
1025 |
+
# Display selected voice info
|
1026 |
+
with gr.Column(scale=2):
|
1027 |
+
voice_info = gr.Textbox(
|
1028 |
+
value=EXAMPLE_VOICES[0]["transcript"] if EXAMPLE_VOICES else "",
|
1029 |
+
label="Voice Sample Transcript",
|
1030 |
+
lines=2,
|
1031 |
+
interactive=True
|
1032 |
+
)
|
1033 |
|
1034 |
# Play selected example voice
|
1035 |
example_audio = gr.Audio(
|
|
|
1038 |
interactive=False
|
1039 |
)
|
1040 |
|
1041 |
+
gr.Markdown("### 🎙️ Record Your Own Voice")
|
1042 |
+
|
1043 |
+
with gr.Row():
|
1044 |
+
# Or record your own voice
|
1045 |
+
with gr.Column(scale=1):
|
1046 |
+
custom_voice = gr.Audio(
|
1047 |
+
sources=["microphone", "upload"],
|
1048 |
+
type="numpy",
|
1049 |
+
label="Record/Upload Voice"
|
1050 |
+
)
|
1051 |
+
|
1052 |
+
# Transcript for custom voice
|
1053 |
+
with gr.Column(scale=2):
|
1054 |
+
custom_transcript = gr.Textbox(
|
1055 |
+
value="",
|
1056 |
+
label="Your Voice Transcript (what you said in Malayalam)",
|
1057 |
+
lines=2
|
1058 |
+
)
|
1059 |
|
1060 |
# Button to save the selected/recorded voice
|
1061 |
+
with gr.Row():
|
1062 |
+
save_voice_btn = gr.Button("💾 Save Voice for Chat", variant="primary")
|
1063 |
+
voice_status = gr.Textbox(label="Voice Status", value="No voice selected yet")
|
1064 |
|
1065 |
+
# Chat controls row (just the clear button)
|
1066 |
with gr.Row(elem_classes=["chat-controls"]):
|
1067 |
+
# Hidden language selector (kept for functionality)
|
1068 |
language_selector = gr.Dropdown(
|
1069 |
choices=["ml-IN", "en-US", "hi-IN", "ta-IN", "te-IN", "kn-IN"],
|
1070 |
value="ml-IN",
|
1071 |
+
label="Speech Recognition Language",
|
1072 |
+
visible=False
|
1073 |
)
|
1074 |
+
clear_btn = gr.Button("🧹 Clear Chat", variant="secondary")
|
1075 |
|
1076 |
# Chat display area
|
1077 |
chatbot = gr.Chatbot(
|
|
|
1082 |
elem_classes=["chat-window"]
|
1083 |
)
|
1084 |
|
1085 |
+
# Hidden progress bar (kept for functionality)
|
1086 |
+
tts_progress = gr.Slider(
|
1087 |
+
minimum=0,
|
1088 |
+
maximum=100,
|
1089 |
+
value=0,
|
1090 |
+
label="TTS Progress",
|
1091 |
+
interactive=False,
|
1092 |
+
visible=False
|
1093 |
+
)
|
1094 |
|
1095 |
# Audio output for the bot's response
|
1096 |
audio_output = gr.Audio(
|
1097 |
label="Bot's Voice Response",
|
1098 |
type="filepath",
|
1099 |
autoplay=True,
|
1100 |
+
visible=True,
|
1101 |
+
elem_classes=["audio-player"]
|
1102 |
)
|
1103 |
|
1104 |
+
# Hidden status message (kept for functionality)
|
1105 |
status_msg = gr.Textbox(
|
1106 |
label="Status",
|
1107 |
value="Ready",
|
1108 |
+
interactive=False,
|
1109 |
+
visible=False
|
1110 |
)
|
1111 |
|
1112 |
# Input area with separate components
|
1113 |
with gr.Row(elem_classes=["input-area"]):
|
1114 |
+
with gr.Column(scale=4):
|
1115 |
+
audio_msg = gr.Textbox(
|
1116 |
+
placeholder="Type a message in Malayalam...",
|
1117 |
+
lines=1,
|
1118 |
+
label=None,
|
1119 |
+
show_label=False
|
1120 |
+
)
|
1121 |
+
with gr.Column(scale=1):
|
1122 |
+
with gr.Row():
|
1123 |
+
audio_input = gr.Audio(
|
1124 |
+
sources=["microphone"],
|
1125 |
+
type="numpy",
|
1126 |
+
label=None,
|
1127 |
+
show_label=False,
|
1128 |
+
elem_classes=["audio-input"]
|
1129 |
+
)
|
1130 |
+
submit_btn = gr.Button("🚀 Send", variant="primary")
|
1131 |
+
|
1132 |
+
# Function to update voice example info (unchanged)
|
1133 |
def update_voice_example(voice_name):
|
1134 |
for voice in EXAMPLE_VOICES:
|
1135 |
if voice["name"] == voice_name and "audio" in voice:
|
1136 |
return voice["transcript"], voice["audio"]
|
1137 |
return "", None
|
1138 |
|
1139 |
+
# Function to save voice for TTS (unchanged)
|
1140 |
def save_voice_for_tts(example_name, example_audio, custom_audio, example_transcript, custom_transcript):
|
1141 |
try:
|
1142 |
# Check if we're using an example voice or custom recorded voice
|
|
|
1157 |
return "Error: No voice selected or recorded"
|
1158 |
|
1159 |
# Save the voice in the engine
|
1160 |
+
result = conversation_engine.save_reference_voice(voice_audio, transcript)
|
1161 |
|
1162 |
return f"Voice saved successfully! Using {source}"
|
1163 |
except Exception as e:
|
|
|
1165 |
traceback.print_exc()
|
1166 |
return f"Error saving voice: {str(e)}"
|
1167 |
|
1168 |
+
# Function to update TTS progress (unchanged)
|
1169 |
def update_tts_progress(progress):
|
1170 |
return progress
|
1171 |
|
1172 |
+
# Audio generated callback (unchanged)
|
1173 |
def on_tts_generated(audio_path):
|
1174 |
print(f"TTS generation callback received path: {audio_path}")
|
1175 |
return audio_path, 100, "Response ready" # audio path, 100% progress, status message
|
1176 |
|
1177 |
+
# Function to process user input and generate response (updated to use global engine)
|
1178 |
def process_input(audio, text_input, history, language, progress):
|
1179 |
try:
|
1180 |
# Update status
|
|
|
1186 |
# Check which input mode we're using
|
1187 |
if audio is not None:
|
1188 |
# Audio input
|
1189 |
+
transcribed_text, input_text = speech_recognizer.transcribe_audio(audio, language)
|
1190 |
if not input_text:
|
1191 |
status = "Could not understand audio. Please try again."
|
1192 |
return history, None, status, text_input, progress
|
|
|
1200 |
return history, None, status, text_input, progress
|
1201 |
|
1202 |
# Add user message to conversation history
|
1203 |
+
conversation_engine.add_message("user", input_text)
|
1204 |
|
1205 |
# Update the Gradio chatbot display immediately with user message
|
1206 |
updated_history = history + [[transcribed_text, None]]
|
|
|
1210 |
progress = 30
|
1211 |
|
1212 |
# Generate response
|
1213 |
+
response_text, _ = conversation_engine.generate_response(input_text)
|
1214 |
|
1215 |
# Add assistant response to conversation history
|
1216 |
+
conversation_engine.add_message("assistant", response_text)
|
1217 |
|
1218 |
# Update the Gradio chatbot with the assistant's response
|
1219 |
updated_history = history + [[transcribed_text, response_text]]
|
|
|
1223 |
progress = 60
|
1224 |
|
1225 |
# Generate speech for response synchronously (for better debugging)
|
1226 |
+
audio_path = conversation_engine._generate_tts(response_text)
|
1227 |
|
1228 |
if audio_path:
|
1229 |
status = f"Response ready: {audio_path}"
|
|
|
1242 |
traceback.print_exc()
|
1243 |
return history, None, error_message, text_input, progress
|
1244 |
|
1245 |
+
# Function to clear chat history (updated to use global engine)
|
1246 |
def clear_chat():
|
1247 |
+
conversation_engine.clear_conversation()
|
1248 |
return [], None, "Chat history cleared", "", 0
|
1249 |
|
1250 |
# Connect event handlers
|
|
|
1286 |
|
1287 |
# Setup cleanup on exit
|
1288 |
def exit_handler():
|
1289 |
+
conversation_engine.cleanup()
|
1290 |
|
1291 |
import atexit
|
1292 |
atexit.register(exit_handler)
|
|
|
1294 |
# Enable queueing for better responsiveness
|
1295 |
interface.queue()
|
1296 |
|
1297 |
+
return interface
|
|
|
|
|
|
|
|
|
|
|
|