import os import json import gradio as gr from pathlib import Path from typing import List, Tuple from chatbot_config import ChatbotConfig from chatbot_model import RetrievalChatbot from tf_data_pipeline import TFDataPipeline from response_quality_checker import ResponseQualityChecker from environment_setup import EnvironmentSetup from sentence_transformers import SentenceTransformer from logger_config import config_logger logger = config_logger(__name__) def load_pipeline(): """ Loads config, FAISS index, response pool, SentenceTransformer, TFDataPipeline, and sets up the chatbot. """ MODEL_DIR = "models" FAISS_INDICES_DIR = os.path.join(MODEL_DIR, "faiss_indices") FAISS_INDEX_PRODUCTION_PATH = os.path.join(FAISS_INDICES_DIR, "faiss_index_production.index") RESPONSE_POOL_PATH = FAISS_INDEX_PRODUCTION_PATH.replace(".index", "_responses.json") config_path = Path(MODEL_DIR) / "config.json" if config_path.exists(): with open(config_path, "r", encoding="utf-8") as f: config_dict = json.load(f) config = ChatbotConfig.from_dict(config_dict) else: config = ChatbotConfig() # Initialize environment env = EnvironmentSetup() env.initialize() # Load models and data encoder = SentenceTransformer(config.pretrained_model) data_pipeline = TFDataPipeline( config=config, tokenizer=encoder.tokenizer, encoder=encoder, response_pool=[], query_embeddings_cache={}, index_type='IndexFlatIP', faiss_index_file_path=FAISS_INDEX_PRODUCTION_PATH ) # Load FAISS index and response pool if os.path.exists(FAISS_INDEX_PRODUCTION_PATH) and os.path.exists(RESPONSE_POOL_PATH): data_pipeline.load_faiss_index(FAISS_INDEX_PRODUCTION_PATH) with open(RESPONSE_POOL_PATH, "r", encoding="utf-8") as f: data_pipeline.response_pool = json.load(f) data_pipeline.validate_faiss_index() else: logger.warning("FAISS index or responses are missing. The chatbot may not work properly.") chatbot = RetrievalChatbot.load_model(load_dir=MODEL_DIR, mode="inference") quality_checker = ResponseQualityChecker(data_pipeline=data_pipeline) return chatbot, quality_checker # Load the chatbot and quality checker globally chatbot, quality_checker = load_pipeline() def respond(message: str, history: List[List[str]]) -> Tuple[str, List[List[str]]]: """Generate chatbot response using internal context handling.""" if not message.strip(): # Skip return "", history try: response, _, metrics, confidence = chatbot.chat( query=message, conversation_history=None, # Handled internally quality_checker=quality_checker, top_k=10 ) history.append((message, response)) return "", history except Exception as e: logger.error(f"Error generating response: {e}") error_message = "I apologize, but I encountered an error processing your request." history.append((message, error_message)) return "", history def main(): """Initialize and launch Gradio interface.""" with gr.Blocks( title="Chatbot Demo", css=""" .message-wrap { max-height: 800px !important; } .chatbot { min-height: 600px; } """ ) as demo: gr.Markdown( """ # Retrieval-Based Chatbot Demo using Sentence Transformers + FAISS Knowledge areas: restaurants, movie tickets, rideshare, coffee, pizza, and auto repair. """ ) # Chat interface with custom height chatbot = gr.Chatbot( label="Conversation", container=True, height=600, show_label=True, elem_classes="chatbot" ) # Input area with send button with gr.Row(): msg = gr.Textbox( show_label=False, placeholder="Type your message here...", container=False, scale=8 ) send = gr.Button( "Send", variant="primary", scale=1, min_width=100 ) clear = gr.Button("Clear Conversation", variant="secondary") # Event handlers msg.submit(respond, [msg, chatbot], [msg, chatbot], queue=False) send.click(respond, [msg, chatbot], [msg, chatbot], queue=False) clear.click(lambda: ([], []), outputs=[chatbot, msg], queue=False) # Responsive interface msg.change(lambda: None, None, None, queue=False) return demo if __name__ == "__main__": demo = main() demo.launch( server_name="0.0.0.0", server_port=7860, )