JoeArmani
commited on
Commit
·
9b268d0
1
Parent(s):
c7c1b4e
finalize Gradio updates
Browse files- app.py +145 -0
- readme.md +11 -35
- requirements.txt +28 -26
app.py
ADDED
@@ -0,0 +1,145 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import json
|
3 |
+
import gradio as gr
|
4 |
+
from pathlib import Path
|
5 |
+
from typing import List, Tuple
|
6 |
+
from chatbot_config import ChatbotConfig
|
7 |
+
from chatbot_model import RetrievalChatbot
|
8 |
+
from tf_data_pipeline import TFDataPipeline
|
9 |
+
from response_quality_checker import ResponseQualityChecker
|
10 |
+
from environment_setup import EnvironmentSetup
|
11 |
+
from sentence_transformers import SentenceTransformer
|
12 |
+
from logger_config import config_logger
|
13 |
+
|
14 |
+
logger = config_logger(__name__)
|
15 |
+
|
16 |
+
def load_pipeline():
|
17 |
+
"""
|
18 |
+
Loads config, FAISS index, response pool, SentenceTransformer, TFDataPipeline, and sets up the chatbot.
|
19 |
+
"""
|
20 |
+
MODEL_DIR = "models"
|
21 |
+
FAISS_INDICES_DIR = os.path.join(MODEL_DIR, "faiss_indices")
|
22 |
+
FAISS_INDEX_PRODUCTION_PATH = os.path.join(FAISS_INDICES_DIR, "faiss_index_production.index")
|
23 |
+
RESPONSE_POOL_PATH = FAISS_INDEX_PRODUCTION_PATH.replace(".index", "_responses.json")
|
24 |
+
|
25 |
+
config_path = Path(MODEL_DIR) / "config.json"
|
26 |
+
if config_path.exists():
|
27 |
+
with open(config_path, "r", encoding="utf-8") as f:
|
28 |
+
config_dict = json.load(f)
|
29 |
+
config = ChatbotConfig.from_dict(config_dict)
|
30 |
+
else:
|
31 |
+
config = ChatbotConfig()
|
32 |
+
|
33 |
+
# Initialize environment
|
34 |
+
env = EnvironmentSetup()
|
35 |
+
env.initialize()
|
36 |
+
|
37 |
+
# Load models and data
|
38 |
+
encoder = SentenceTransformer(config.pretrained_model)
|
39 |
+
|
40 |
+
data_pipeline = TFDataPipeline(
|
41 |
+
config=config,
|
42 |
+
tokenizer=encoder.tokenizer,
|
43 |
+
encoder=encoder,
|
44 |
+
response_pool=[],
|
45 |
+
query_embeddings_cache={},
|
46 |
+
index_type='IndexFlatIP',
|
47 |
+
faiss_index_file_path=FAISS_INDEX_PRODUCTION_PATH
|
48 |
+
)
|
49 |
+
|
50 |
+
# Load FAISS index and response pool
|
51 |
+
if os.path.exists(FAISS_INDEX_PRODUCTION_PATH) and os.path.exists(RESPONSE_POOL_PATH):
|
52 |
+
data_pipeline.load_faiss_index(FAISS_INDEX_PRODUCTION_PATH)
|
53 |
+
with open(RESPONSE_POOL_PATH, "r", encoding="utf-8") as f:
|
54 |
+
data_pipeline.response_pool = json.load(f)
|
55 |
+
data_pipeline.validate_faiss_index()
|
56 |
+
else:
|
57 |
+
logger.warning("FAISS index or responses are missing. The chatbot may not work properly.")
|
58 |
+
|
59 |
+
chatbot = RetrievalChatbot.load_model(load_dir=MODEL_DIR, mode="inference")
|
60 |
+
quality_checker = ResponseQualityChecker(data_pipeline=data_pipeline)
|
61 |
+
|
62 |
+
return chatbot, quality_checker
|
63 |
+
|
64 |
+
# Load the chatbot and quality checker globally
|
65 |
+
chatbot, quality_checker = load_pipeline()
|
66 |
+
|
67 |
+
def respond(message: str, history: List[List[str]]) -> Tuple[str, List[List[str]]]:
|
68 |
+
"""Generate chatbot response using internal context handling."""
|
69 |
+
if not message.strip(): # Skip
|
70 |
+
return "", history
|
71 |
+
|
72 |
+
try:
|
73 |
+
response, _, metrics, confidence = chatbot.chat(
|
74 |
+
query=message,
|
75 |
+
conversation_history=None, # Handled internally
|
76 |
+
quality_checker=quality_checker,
|
77 |
+
top_k=10
|
78 |
+
)
|
79 |
+
|
80 |
+
history.append((message, response))
|
81 |
+
return "", history
|
82 |
+
except Exception as e:
|
83 |
+
logger.error(f"Error generating response: {e}")
|
84 |
+
error_message = "I apologize, but I encountered an error processing your request."
|
85 |
+
history.append((message, error_message))
|
86 |
+
return "", history
|
87 |
+
|
88 |
+
def main():
|
89 |
+
"""Initialize and launch Gradio interface."""
|
90 |
+
with gr.Blocks(
|
91 |
+
title="Chatbot Demo",
|
92 |
+
css="""
|
93 |
+
.message-wrap { max-height: 800px !important; }
|
94 |
+
.chatbot { min-height: 600px; }
|
95 |
+
"""
|
96 |
+
) as demo:
|
97 |
+
gr.Markdown(
|
98 |
+
"""
|
99 |
+
# Retrieval-Based Chatbot Demo using Sentence Transformers + FAISS
|
100 |
+
Knowledge areas: restaurants, movie tickets, rideshare, coffee, pizza, and auto repair.
|
101 |
+
"""
|
102 |
+
)
|
103 |
+
|
104 |
+
# Chat interface with custom height
|
105 |
+
chatbot = gr.Chatbot(
|
106 |
+
label="Conversation",
|
107 |
+
container=True,
|
108 |
+
height=600,
|
109 |
+
show_label=True,
|
110 |
+
elem_classes="chatbot"
|
111 |
+
)
|
112 |
+
|
113 |
+
# Input area with send button
|
114 |
+
with gr.Row():
|
115 |
+
msg = gr.Textbox(
|
116 |
+
show_label=False,
|
117 |
+
placeholder="Type your message here...",
|
118 |
+
container=False,
|
119 |
+
scale=8
|
120 |
+
)
|
121 |
+
send = gr.Button(
|
122 |
+
"Send",
|
123 |
+
variant="primary",
|
124 |
+
scale=1,
|
125 |
+
min_width=100
|
126 |
+
)
|
127 |
+
|
128 |
+
clear = gr.Button("Clear Conversation", variant="secondary")
|
129 |
+
|
130 |
+
# Event handlers
|
131 |
+
msg.submit(respond, [msg, chatbot], [msg, chatbot], queue=False)
|
132 |
+
send.click(respond, [msg, chatbot], [msg, chatbot], queue=False)
|
133 |
+
clear.click(lambda: ([], []), outputs=[chatbot, msg], queue=False)
|
134 |
+
|
135 |
+
# Responsive interface
|
136 |
+
msg.change(lambda: None, None, None, queue=False)
|
137 |
+
|
138 |
+
return demo
|
139 |
+
|
140 |
+
if __name__ == "__main__":
|
141 |
+
demo = main()
|
142 |
+
demo.launch(
|
143 |
+
server_name="0.0.0.0",
|
144 |
+
server_port=7860,
|
145 |
+
)
|
readme.md
CHANGED
@@ -1,42 +1,18 @@
|
|
1 |
-
# Retrieval
|
2 |
|
3 |
-
|
4 |
|
5 |
-
##
|
6 |
-
|
7 |
-
A Python tool to generate high-quality dialog variations.
|
8 |
-
|
9 |
-
This package automatically downloads the following models during installation:
|
10 |
-
|
11 |
-
- Universal Sentence Encoder v4 (TensorFlow Hub)
|
12 |
-
- ChatGPT Paraphraser T5-base
|
13 |
-
- Helsinki-NLP translation models (en-de, de-es, es-en)
|
14 |
-
- spaCy en_core_web_sm, eng_core_web_md
|
15 |
-
- nltk wordnet and averaged_perceptron_tagger_eng models
|
16 |
-
|
17 |
-
## Install package
|
18 |
|
19 |
-
|
20 |
|
21 |
-
##
|
22 |
|
23 |
-
|
24 |
-
|
25 |
-
Two approaches are used for text augmentation: paraphrasing and back-translation. The pipeline also includes quality metrics for evaluating the augmented text.
|
26 |
-
Special handling is implemented for very short text such as greetings and farewells, which are predefined and filtered for quality.
|
27 |
-
The pipeline is designed to process a dataset of dialogues and generate multiple high-quality augmented versions of each dialogue.
|
28 |
-
The pipeline ensures duplicate dialogues are not generated and that the output meets quality thresholds for semantic similarity, grammar, fluency, diversity, and content preservation.
|
29 |
|
30 |
-
##
|
31 |
|
32 |
-
|
33 |
-
|
34 |
-
|
35 |
-
Helsinki-NLP. (2024). Opus-MT [Computer software]. GitHub. <https://github.com/Helsinki-NLP/Opus-MT>
|
36 |
-
Hugging Face. (n.d.). Transformers. Hugging Face. <https://huggingface.co/docs/transformers/en/index>
|
37 |
-
Humarin. (2023). ChatGPT paraphraser on T5-base [Computer software]. Hugging Face. <https://huggingface.co/humarin/chatgpt_paraphraser_on_T5_base>
|
38 |
-
Keita, Z. (2022). Data augmentation in NLP using back-translation with MarianMT. Towards Data Science. <https://towardsdatascience.com/data-augmentation-in-nlp-using-back-translation-with-marianmt-a8939dfea50a>
|
39 |
-
Memgraph. (2023). Cosine similarity in Python with scikit-learn. Memgraph. <https://memgraph.com/blog/cosine-similarity-python-scikit-learn>
|
40 |
-
Morris, J. (n.d.). language-tool-python (Version 2.8.1) [Computer software]. PyPI. <https://pypi.org/project/language-tool-python/>
|
41 |
-
TensorFlow. (n.d.). Universal sentence encoder. TensorFlow Hub. <https://www.tensorflow.org/hub/tutorials/semantic_similarity_with_tf_hub_universal_encoder>
|
42 |
-
Waheed, A. (2023). How to calculate ROUGE score in Python. Python Code. <https://thepythoncode.com/article/calculate-rouge-score-in-python>
|
|
|
1 |
+
# CSC252 Retrieval Chatbot
|
2 |
|
3 |
+
This is a retrieval-based chatbot using Sentence Transformers and FAISS for efficient similarity search.
|
4 |
|
5 |
+
## Description
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
6 |
|
7 |
+
The chatbot uses a pre-trained Sentence Transformer model to encode queries and a FAISS index to retrieve relevant responses from a curated response pool (Taskmaster-1 dataset)
|
8 |
|
9 |
+
## Usage
|
10 |
|
11 |
+
Simply type your question in the chat interface and the bot will retrieve the most relevant response from its knowledge base.
|
12 |
+
Features
|
|
|
|
|
|
|
|
|
13 |
|
14 |
+
## Semantic search using Sentence Transformers
|
15 |
|
16 |
+
Efficient retrieval using FAISS indexing
|
17 |
+
Context-aware responses
|
18 |
+
Quality checking of responses
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
requirements.txt
CHANGED
@@ -1,27 +1,29 @@
|
|
1 |
-
faiss-cpu>=1.7.0
|
2 |
-
|
3 |
-
|
4 |
-
|
5 |
-
|
6 |
-
|
7 |
-
|
8 |
-
|
9 |
-
|
10 |
-
|
11 |
-
|
12 |
-
|
13 |
-
|
14 |
-
|
15 |
-
|
16 |
-
|
17 |
-
|
18 |
-
|
19 |
-
|
20 |
-
|
21 |
-
|
|
|
|
|
22 |
|
23 |
-
# Dev dependencies
|
24 |
-
black>=22.0.0
|
25 |
-
isort>=5.10.0
|
26 |
-
mypy>=1.0.0
|
27 |
-
pytest>=7.0.0
|
|
|
1 |
+
faiss-cpu>=1.7.0 # Facebook AI Similarity Search
|
2 |
+
gradio>=3.30.0 # Web app creation
|
3 |
+
h5py>=3.1.0 # For saving and loading models
|
4 |
+
ipython>=8.0.0 # Interactive Python
|
5 |
+
loguru>=0.7.0 # Enhanced logging (optional but recommended)
|
6 |
+
matplotlib>=3.5.0 # Validation plotting
|
7 |
+
nlpaug>=1.1.0 # Data augmentation for NLP
|
8 |
+
nltk>=3.6.0 # Natural language toolkit
|
9 |
+
numpy>=1.19.0 # Numerical computation
|
10 |
+
pandas>=1.5.0 # Data handling
|
11 |
+
pyyaml>=6.0.0 # Config management
|
12 |
+
scikit-learn>=1.0.0 # ML tools
|
13 |
+
sacremoses>=0.0.53 # Required for some HuggingFace pipelines
|
14 |
+
sentencepiece>=0.1.99 # Required for Transformers
|
15 |
+
sentence-transformers>=2.2.2 # Sentence embeddings
|
16 |
+
spacy>=3.0.0 # Text processing, tokenization
|
17 |
+
tensorflow>=2.13.0 # TensorFlow
|
18 |
+
tensorflow-hub>=0.12.0 # Pretrained model hub
|
19 |
+
tokenizers>=0.13.0 # HuggingFace tokenizers
|
20 |
+
torch>=2.0.0 # PyTorch
|
21 |
+
tqdm>=4.64.0 # Progress bars
|
22 |
+
transformers>=4.30.0 # Hugging Face Transformers
|
23 |
+
typing-extensions>=4.0.0
|
24 |
|
25 |
+
# Dev dependencies:
|
26 |
+
black>=22.0.0 # Code formatting
|
27 |
+
isort>=5.10.0 # Import sorting
|
28 |
+
mypy>=1.0.0 # Type checking
|
29 |
+
pytest>=7.0.0 # Testing
|