File size: 4,854 Bytes
9b268d0 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 |
import os
import json
import gradio as gr
from pathlib import Path
from typing import List, Tuple
from chatbot_config import ChatbotConfig
from chatbot_model import RetrievalChatbot
from tf_data_pipeline import TFDataPipeline
from response_quality_checker import ResponseQualityChecker
from environment_setup import EnvironmentSetup
from sentence_transformers import SentenceTransformer
from logger_config import config_logger
logger = config_logger(__name__)
def load_pipeline():
"""
Loads config, FAISS index, response pool, SentenceTransformer, TFDataPipeline, and sets up the chatbot.
"""
MODEL_DIR = "models"
FAISS_INDICES_DIR = os.path.join(MODEL_DIR, "faiss_indices")
FAISS_INDEX_PRODUCTION_PATH = os.path.join(FAISS_INDICES_DIR, "faiss_index_production.index")
RESPONSE_POOL_PATH = FAISS_INDEX_PRODUCTION_PATH.replace(".index", "_responses.json")
config_path = Path(MODEL_DIR) / "config.json"
if config_path.exists():
with open(config_path, "r", encoding="utf-8") as f:
config_dict = json.load(f)
config = ChatbotConfig.from_dict(config_dict)
else:
config = ChatbotConfig()
# Initialize environment
env = EnvironmentSetup()
env.initialize()
# Load models and data
encoder = SentenceTransformer(config.pretrained_model)
data_pipeline = TFDataPipeline(
config=config,
tokenizer=encoder.tokenizer,
encoder=encoder,
response_pool=[],
query_embeddings_cache={},
index_type='IndexFlatIP',
faiss_index_file_path=FAISS_INDEX_PRODUCTION_PATH
)
# Load FAISS index and response pool
if os.path.exists(FAISS_INDEX_PRODUCTION_PATH) and os.path.exists(RESPONSE_POOL_PATH):
data_pipeline.load_faiss_index(FAISS_INDEX_PRODUCTION_PATH)
with open(RESPONSE_POOL_PATH, "r", encoding="utf-8") as f:
data_pipeline.response_pool = json.load(f)
data_pipeline.validate_faiss_index()
else:
logger.warning("FAISS index or responses are missing. The chatbot may not work properly.")
chatbot = RetrievalChatbot.load_model(load_dir=MODEL_DIR, mode="inference")
quality_checker = ResponseQualityChecker(data_pipeline=data_pipeline)
return chatbot, quality_checker
# Load the chatbot and quality checker globally
chatbot, quality_checker = load_pipeline()
def respond(message: str, history: List[List[str]]) -> Tuple[str, List[List[str]]]:
"""Generate chatbot response using internal context handling."""
if not message.strip(): # Skip
return "", history
try:
response, _, metrics, confidence = chatbot.chat(
query=message,
conversation_history=None, # Handled internally
quality_checker=quality_checker,
top_k=10
)
history.append((message, response))
return "", history
except Exception as e:
logger.error(f"Error generating response: {e}")
error_message = "I apologize, but I encountered an error processing your request."
history.append((message, error_message))
return "", history
def main():
"""Initialize and launch Gradio interface."""
with gr.Blocks(
title="Chatbot Demo",
css="""
.message-wrap { max-height: 800px !important; }
.chatbot { min-height: 600px; }
"""
) as demo:
gr.Markdown(
"""
# Retrieval-Based Chatbot Demo using Sentence Transformers + FAISS
Knowledge areas: restaurants, movie tickets, rideshare, coffee, pizza, and auto repair.
"""
)
# Chat interface with custom height
chatbot = gr.Chatbot(
label="Conversation",
container=True,
height=600,
show_label=True,
elem_classes="chatbot"
)
# Input area with send button
with gr.Row():
msg = gr.Textbox(
show_label=False,
placeholder="Type your message here...",
container=False,
scale=8
)
send = gr.Button(
"Send",
variant="primary",
scale=1,
min_width=100
)
clear = gr.Button("Clear Conversation", variant="secondary")
# Event handlers
msg.submit(respond, [msg, chatbot], [msg, chatbot], queue=False)
send.click(respond, [msg, chatbot], [msg, chatbot], queue=False)
clear.click(lambda: ([], []), outputs=[chatbot, msg], queue=False)
# Responsive interface
msg.change(lambda: None, None, None, queue=False)
return demo
if __name__ == "__main__":
demo = main()
demo.launch(
server_name="0.0.0.0",
server_port=7860,
) |