File size: 4,854 Bytes
9b268d0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
import os
import json
import gradio as gr
from pathlib import Path
from typing import List, Tuple
from chatbot_config import ChatbotConfig
from chatbot_model import RetrievalChatbot
from tf_data_pipeline import TFDataPipeline
from response_quality_checker import ResponseQualityChecker
from environment_setup import EnvironmentSetup
from sentence_transformers import SentenceTransformer
from logger_config import config_logger

logger = config_logger(__name__)

def load_pipeline():
    """
    Loads config, FAISS index, response pool, SentenceTransformer, TFDataPipeline, and sets up the chatbot.
    """
    MODEL_DIR = "models"
    FAISS_INDICES_DIR = os.path.join(MODEL_DIR, "faiss_indices")
    FAISS_INDEX_PRODUCTION_PATH = os.path.join(FAISS_INDICES_DIR, "faiss_index_production.index")
    RESPONSE_POOL_PATH = FAISS_INDEX_PRODUCTION_PATH.replace(".index", "_responses.json")

    config_path = Path(MODEL_DIR) / "config.json"
    if config_path.exists():
        with open(config_path, "r", encoding="utf-8") as f:
            config_dict = json.load(f)
        config = ChatbotConfig.from_dict(config_dict)
    else:
        config = ChatbotConfig()

    # Initialize environment
    env = EnvironmentSetup()
    env.initialize()

    # Load models and data
    encoder = SentenceTransformer(config.pretrained_model)
    
    data_pipeline = TFDataPipeline(
        config=config,
        tokenizer=encoder.tokenizer,
        encoder=encoder,
        response_pool=[],
        query_embeddings_cache={},
        index_type='IndexFlatIP',
        faiss_index_file_path=FAISS_INDEX_PRODUCTION_PATH
    )

    # Load FAISS index and response pool
    if os.path.exists(FAISS_INDEX_PRODUCTION_PATH) and os.path.exists(RESPONSE_POOL_PATH):
        data_pipeline.load_faiss_index(FAISS_INDEX_PRODUCTION_PATH)
        with open(RESPONSE_POOL_PATH, "r", encoding="utf-8") as f:
            data_pipeline.response_pool = json.load(f)
        data_pipeline.validate_faiss_index()
    else:
        logger.warning("FAISS index or responses are missing. The chatbot may not work properly.")

    chatbot = RetrievalChatbot.load_model(load_dir=MODEL_DIR, mode="inference")
    quality_checker = ResponseQualityChecker(data_pipeline=data_pipeline)

    return chatbot, quality_checker

# Load the chatbot and quality checker globally
chatbot, quality_checker = load_pipeline()

def respond(message: str, history: List[List[str]]) -> Tuple[str, List[List[str]]]:
    """Generate chatbot response using internal context handling."""
    if not message.strip():  # Skip
        return "", history
    
    try:
        response, _, metrics, confidence = chatbot.chat(
            query=message,
            conversation_history=None,  # Handled internally
            quality_checker=quality_checker,
            top_k=10
        )
        
        history.append((message, response))
        return "", history
    except Exception as e:
        logger.error(f"Error generating response: {e}")
        error_message = "I apologize, but I encountered an error processing your request."
        history.append((message, error_message))
        return "", history

def main():
    """Initialize and launch Gradio interface."""
    with gr.Blocks(
        title="Chatbot Demo",
        css="""
            .message-wrap { max-height: 800px !important; }
            .chatbot { min-height: 600px; }
        """
    ) as demo:
        gr.Markdown(
            """
            # Retrieval-Based Chatbot Demo using Sentence Transformers + FAISS
            Knowledge areas: restaurants, movie tickets, rideshare, coffee, pizza, and auto repair.
            """
        )

        # Chat interface with custom height
        chatbot = gr.Chatbot(
            label="Conversation",
            container=True,
            height=600,
            show_label=True,
            elem_classes="chatbot"
        )

        # Input area with send button
        with gr.Row():
            msg = gr.Textbox(
                show_label=False,
                placeholder="Type your message here...",
                container=False,
                scale=8
            )
            send = gr.Button(
                "Send",
                variant="primary",
                scale=1,
                min_width=100
            )

        clear = gr.Button("Clear Conversation", variant="secondary")

        # Event handlers
        msg.submit(respond, [msg, chatbot], [msg, chatbot], queue=False)
        send.click(respond, [msg, chatbot], [msg, chatbot], queue=False)
        clear.click(lambda: ([], []), outputs=[chatbot, msg], queue=False)

        # Responsive interface
        msg.change(lambda: None, None, None, queue=False)

    return demo

if __name__ == "__main__":
    demo = main()
    demo.launch(
        server_name="0.0.0.0",
        server_port=7860,
    )