JoeArmani
finalize Gradio updates
9b268d0
import os
import json
import gradio as gr
from pathlib import Path
from typing import List, Tuple
from chatbot_config import ChatbotConfig
from chatbot_model import RetrievalChatbot
from tf_data_pipeline import TFDataPipeline
from response_quality_checker import ResponseQualityChecker
from environment_setup import EnvironmentSetup
from sentence_transformers import SentenceTransformer
from logger_config import config_logger
logger = config_logger(__name__)
def load_pipeline():
"""
Loads config, FAISS index, response pool, SentenceTransformer, TFDataPipeline, and sets up the chatbot.
"""
MODEL_DIR = "models"
FAISS_INDICES_DIR = os.path.join(MODEL_DIR, "faiss_indices")
FAISS_INDEX_PRODUCTION_PATH = os.path.join(FAISS_INDICES_DIR, "faiss_index_production.index")
RESPONSE_POOL_PATH = FAISS_INDEX_PRODUCTION_PATH.replace(".index", "_responses.json")
config_path = Path(MODEL_DIR) / "config.json"
if config_path.exists():
with open(config_path, "r", encoding="utf-8") as f:
config_dict = json.load(f)
config = ChatbotConfig.from_dict(config_dict)
else:
config = ChatbotConfig()
# Initialize environment
env = EnvironmentSetup()
env.initialize()
# Load models and data
encoder = SentenceTransformer(config.pretrained_model)
data_pipeline = TFDataPipeline(
config=config,
tokenizer=encoder.tokenizer,
encoder=encoder,
response_pool=[],
query_embeddings_cache={},
index_type='IndexFlatIP',
faiss_index_file_path=FAISS_INDEX_PRODUCTION_PATH
)
# Load FAISS index and response pool
if os.path.exists(FAISS_INDEX_PRODUCTION_PATH) and os.path.exists(RESPONSE_POOL_PATH):
data_pipeline.load_faiss_index(FAISS_INDEX_PRODUCTION_PATH)
with open(RESPONSE_POOL_PATH, "r", encoding="utf-8") as f:
data_pipeline.response_pool = json.load(f)
data_pipeline.validate_faiss_index()
else:
logger.warning("FAISS index or responses are missing. The chatbot may not work properly.")
chatbot = RetrievalChatbot.load_model(load_dir=MODEL_DIR, mode="inference")
quality_checker = ResponseQualityChecker(data_pipeline=data_pipeline)
return chatbot, quality_checker
# Load the chatbot and quality checker globally
chatbot, quality_checker = load_pipeline()
def respond(message: str, history: List[List[str]]) -> Tuple[str, List[List[str]]]:
"""Generate chatbot response using internal context handling."""
if not message.strip(): # Skip
return "", history
try:
response, _, metrics, confidence = chatbot.chat(
query=message,
conversation_history=None, # Handled internally
quality_checker=quality_checker,
top_k=10
)
history.append((message, response))
return "", history
except Exception as e:
logger.error(f"Error generating response: {e}")
error_message = "I apologize, but I encountered an error processing your request."
history.append((message, error_message))
return "", history
def main():
"""Initialize and launch Gradio interface."""
with gr.Blocks(
title="Chatbot Demo",
css="""
.message-wrap { max-height: 800px !important; }
.chatbot { min-height: 600px; }
"""
) as demo:
gr.Markdown(
"""
# Retrieval-Based Chatbot Demo using Sentence Transformers + FAISS
Knowledge areas: restaurants, movie tickets, rideshare, coffee, pizza, and auto repair.
"""
)
# Chat interface with custom height
chatbot = gr.Chatbot(
label="Conversation",
container=True,
height=600,
show_label=True,
elem_classes="chatbot"
)
# Input area with send button
with gr.Row():
msg = gr.Textbox(
show_label=False,
placeholder="Type your message here...",
container=False,
scale=8
)
send = gr.Button(
"Send",
variant="primary",
scale=1,
min_width=100
)
clear = gr.Button("Clear Conversation", variant="secondary")
# Event handlers
msg.submit(respond, [msg, chatbot], [msg, chatbot], queue=False)
send.click(respond, [msg, chatbot], [msg, chatbot], queue=False)
clear.click(lambda: ([], []), outputs=[chatbot, msg], queue=False)
# Responsive interface
msg.change(lambda: None, None, None, queue=False)
return demo
if __name__ == "__main__":
demo = main()
demo.launch(
server_name="0.0.0.0",
server_port=7860,
)