JoeArmani
commited on
Commit
·
c7c1b4e
1
Parent(s):
64e7c31
chat refinements
Browse files- chatbot_config.py +8 -4
- chatbot_model.py +28 -118
- cross_encoder_reranker.py +2 -1
- run_chatbot_chat.py +47 -23
- run_chatbot_validation.py +7 -16
- tf_data_pipeline.py +28 -38
chatbot_config.py
CHANGED
@@ -4,19 +4,23 @@ from typing import Dict
|
|
4 |
|
5 |
@dataclass
|
6 |
class ChatbotConfig:
|
7 |
-
"""
|
8 |
-
|
9 |
-
|
|
|
|
|
10 |
learning_rate: float = 0.0005
|
11 |
min_text_length: int = 3
|
12 |
-
max_context_turns: int =
|
13 |
pretrained_model: str = 'sentence-transformers/all-MiniLM-L6-v2'
|
14 |
cross_encoder_model: str = 'cross-encoder/ms-marco-MiniLM-L-12-v2'
|
15 |
summarizer_model: str = 't5-small'
|
16 |
embedding_batch_size: int = 64
|
17 |
search_batch_size: int = 64
|
18 |
max_batch_size: int = 64
|
|
|
19 |
max_retries: int = 3
|
|
|
20 |
|
21 |
def to_dict(self) -> Dict:
|
22 |
"""Convert config to dictionary."""
|
|
|
4 |
|
5 |
@dataclass
|
6 |
class ChatbotConfig:
|
7 |
+
"""
|
8 |
+
All config params for the chatbot
|
9 |
+
"""
|
10 |
+
max_context_length: int = 512
|
11 |
+
embedding_dim: int = 384 # Sentence Transformer dim
|
12 |
learning_rate: float = 0.0005
|
13 |
min_text_length: int = 3
|
14 |
+
max_context_turns: int = 24
|
15 |
pretrained_model: str = 'sentence-transformers/all-MiniLM-L6-v2'
|
16 |
cross_encoder_model: str = 'cross-encoder/ms-marco-MiniLM-L-12-v2'
|
17 |
summarizer_model: str = 't5-small'
|
18 |
embedding_batch_size: int = 64
|
19 |
search_batch_size: int = 64
|
20 |
max_batch_size: int = 64
|
21 |
+
neg_samples: int = 10
|
22 |
max_retries: int = 3
|
23 |
+
nlist: int = 100
|
24 |
|
25 |
def to_dict(self) -> Dict:
|
26 |
"""Convert config to dictionary."""
|
chatbot_model.py
CHANGED
@@ -22,6 +22,9 @@ from tqdm.auto import tqdm
|
|
22 |
|
23 |
absl.logging.set_verbosity(absl.logging.WARNING)
|
24 |
logger = config_logger(__name__)
|
|
|
|
|
|
|
25 |
|
26 |
class RetrievalChatbot(DeviceAwareModel):
|
27 |
"""
|
@@ -59,7 +62,6 @@ class RetrievalChatbot(DeviceAwareModel):
|
|
59 |
tokenizer=self.tokenizer,
|
60 |
encoder=self.encoder,
|
61 |
response_pool=[],
|
62 |
-
max_length=self.config.max_context_token_limit,
|
63 |
query_embeddings_cache={},
|
64 |
)
|
65 |
|
@@ -96,7 +98,7 @@ class RetrievalChatbot(DeviceAwareModel):
|
|
96 |
return Summarizer(
|
97 |
tokenizer=self.tokenizer,
|
98 |
model_name=self.config.summarizer_model,
|
99 |
-
max_summary_length=self.config.
|
100 |
device=self.device,
|
101 |
max_summary_rounds=2
|
102 |
)
|
@@ -218,7 +220,6 @@ class RetrievalChatbot(DeviceAwareModel):
|
|
218 |
) -> List[Tuple[str, float]]:
|
219 |
"""
|
220 |
Retrieve top-k responses using FAISS and cross-encoder re-ranking.
|
221 |
-
|
222 |
Args:
|
223 |
query: The user's input text.
|
224 |
top_k: Number of responses to return.
|
@@ -226,7 +227,6 @@ class RetrievalChatbot(DeviceAwareModel):
|
|
226 |
summarizer: Optional summarizer for long queries.
|
227 |
summarize_threshold: Threshold to summarize long queries.
|
228 |
boost_factor: Factor to boost scores for keyword matches.
|
229 |
-
|
230 |
Returns:
|
231 |
List of (response_text, final_score).
|
232 |
"""
|
@@ -241,18 +241,27 @@ class RetrievalChatbot(DeviceAwareModel):
|
|
241 |
|
242 |
# Detect domain for query
|
243 |
detected_domain = self.detect_domain_from_query(query)
|
|
|
244 |
|
245 |
-
#
|
246 |
-
logger.info("Retrieving initial candidates from FAISS...")
|
247 |
faiss_candidates = self.data_pipeline.retrieve_responses(query, top_k=top_k * 10)
|
248 |
|
249 |
if not faiss_candidates:
|
250 |
logger.warning("No candidates retrieved from FAISS.")
|
251 |
return []
|
252 |
|
253 |
-
#
|
254 |
-
|
255 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
256 |
faiss_scores = [item[1] for item in faiss_candidates]
|
257 |
|
258 |
if reranker is None:
|
@@ -277,9 +286,10 @@ class RetrievalChatbot(DeviceAwareModel):
|
|
277 |
|
278 |
final_candidates.append((resp_text, length_adjusted_score))
|
279 |
|
280 |
-
#
|
281 |
final_candidates.sort(key=lambda x: x[1], reverse=True)
|
282 |
-
logger.info(f"Returning top-{top_k} re-ranked responses.")
|
|
|
283 |
return final_candidates[:top_k]
|
284 |
|
285 |
def extract_keywords(self, query: str) -> List[str]:
|
@@ -323,7 +333,7 @@ class RetrievalChatbot(DeviceAwareModel):
|
|
323 |
|
324 |
def detect_domain_from_query(self, query: str) -> str:
|
325 |
"""
|
326 |
-
Detect the domain of the query based on keywords. Used for
|
327 |
"""
|
328 |
domain_patterns = {
|
329 |
'restaurant': r'\b(restaurant|restaurants?|dining|food|foods?|dine|reservation|reservations?|table|tables?|menu|menus?|cuisine|cuisines?|eat|eats?|place\s?to\s?eat|places\s?to\s?eat|hungry|chef|chefs?|dish|dishes?|meal|meals?|fork|forks?|knife|knives?|spoon|spoons?|brunch|bistro|buffet|buffets?|catering|caterings?|gourmet|fast\s?food|fine\s?dining|takeaway|takeaways?|delivery|deliveries|restaurant\s?booking)\b',
|
@@ -348,85 +358,6 @@ class RetrievalChatbot(DeviceAwareModel):
|
|
348 |
pattern = r'^[\s]*[\d]+([\s.,\d]+)*[\s]*$'
|
349 |
return bool(re.match(pattern, text.strip()))
|
350 |
|
351 |
-
def faiss_search(
|
352 |
-
self,
|
353 |
-
query: str,
|
354 |
-
domain: str = 'other',
|
355 |
-
top_k: int = 10,
|
356 |
-
boost_factor: float = 1.15
|
357 |
-
) -> List[Tuple[str, float]]:
|
358 |
-
"""
|
359 |
-
Retrieve top-k responses from the FAISS index (IndexFlatIP) given a user query.
|
360 |
-
Args:
|
361 |
-
query (str): The user input text.
|
362 |
-
domain (str): The detected domain from possible domains: ['restaurant', 'movie', 'ride_share', 'coffee', 'pizza', 'auto', 'other']
|
363 |
-
top_k (int): Number of top results to return.
|
364 |
-
boost_factor (float, optional): Factor to boost scores for keyword matches.
|
365 |
-
Returns:
|
366 |
-
List[Tuple[str, float]]: List of (response_text, similarity) sorted by descending similarity.
|
367 |
-
"""
|
368 |
-
# Encode the query
|
369 |
-
q_emb = self.data_pipeline.encode_query(query)
|
370 |
-
q_emb_np = q_emb.reshape(1, -1).astype('float32')
|
371 |
-
|
372 |
-
# Search the index
|
373 |
-
distances, indices = self.data_pipeline.index.search(q_emb_np, top_k * 10)
|
374 |
-
|
375 |
-
# IndexFlatIP: 'distances' are inner products (cosine similarities for normalized vectors).
|
376 |
-
candidates = []
|
377 |
-
for rank, idx in enumerate(indices[0]):
|
378 |
-
if idx < 0:
|
379 |
-
continue
|
380 |
-
text_dict = self.data_pipeline.response_pool[idx]
|
381 |
-
text = text_dict.get('text', '').strip()
|
382 |
-
cand_domain = text_dict.get('domain', 'other')
|
383 |
-
score = distances[0][rank]
|
384 |
-
|
385 |
-
# Skip purely numeric or extremely short text (fewer than 3 words):
|
386 |
-
words = text.split()
|
387 |
-
if len(words) < 4:
|
388 |
-
continue
|
389 |
-
if self.is_numeric_response(text):
|
390 |
-
continue
|
391 |
-
|
392 |
-
candidates.append((text, cand_domain, score))
|
393 |
-
|
394 |
-
if not candidates:
|
395 |
-
logger.warning("No valid candidates found after initial numeric/length filtering.")
|
396 |
-
return []
|
397 |
-
|
398 |
-
# Sort candidates by score descending
|
399 |
-
candidates.sort(key=lambda x: x[2], reverse=True)
|
400 |
-
|
401 |
-
# Filter in-domain responses
|
402 |
-
in_domain = [c for c in candidates if c[1] == domain]
|
403 |
-
if not in_domain:
|
404 |
-
logger.info(f"No in-domain responses found for '{domain}'. Using all candidates.")
|
405 |
-
in_domain = candidates
|
406 |
-
|
407 |
-
# Boost responses containing query keywords
|
408 |
-
query_keywords = self.extract_keywords(query)
|
409 |
-
boosted = []
|
410 |
-
for (resp_text, resp_domain, score) in in_domain:
|
411 |
-
new_score = score
|
412 |
-
# If the domain is known AND the response text shares any query keywords, boost it
|
413 |
-
if query_keywords and any(kw in resp_text.lower() for kw in query_keywords):
|
414 |
-
new_score *= boost_factor
|
415 |
-
|
416 |
-
# Apply length penalty/bonus
|
417 |
-
new_score = self.length_adjust_score(resp_text, new_score)
|
418 |
-
|
419 |
-
boosted.append((resp_text, new_score))
|
420 |
-
|
421 |
-
# Sort boosted responses
|
422 |
-
boosted.sort(key=lambda x: x[1], reverse=True)
|
423 |
-
|
424 |
-
# Debug logging (see FAISS responses)
|
425 |
-
# for resp, score in boosted[:100]:
|
426 |
-
# logger.debug(f"Candidate: '{resp}' with score {score}")
|
427 |
-
|
428 |
-
return boosted[:top_k]
|
429 |
-
|
430 |
def introduction_message(self) -> None:
|
431 |
"""Print an introduction message to introduce the chatbot."""
|
432 |
print(
|
@@ -453,7 +384,7 @@ class RetrievalChatbot(DeviceAwareModel):
|
|
453 |
print("\nAssistant: Goodbye!")
|
454 |
break
|
455 |
|
456 |
-
response, candidates, metrics = self.chat(
|
457 |
query=user_input,
|
458 |
conversation_history=None,
|
459 |
quality_checker=quality_checker,
|
@@ -466,7 +397,7 @@ class RetrievalChatbot(DeviceAwareModel):
|
|
466 |
print("\n Alternative responses:")
|
467 |
for resp, score in candidates[1:4]:
|
468 |
print(f" Score: {score:.4f} - {resp}")
|
469 |
-
|
470 |
print("\n[Low Confidence]: Consider rephrasing your query for better assistance.")
|
471 |
|
472 |
def chat(
|
@@ -504,10 +435,10 @@ class RetrievalChatbot(DeviceAwareModel):
|
|
504 |
|
505 |
# if uncertain, ask for clarification
|
506 |
if not is_confident or top_response_score < 0.5:
|
507 |
-
return ("I need more information to provide a good answer. Could you please clarify?", responses, metrics)
|
508 |
|
509 |
# Return the top response
|
510 |
-
return responses[0][0], responses, metrics
|
511 |
|
512 |
return get_response(self, query)
|
513 |
|
@@ -535,27 +466,6 @@ class RetrievalChatbot(DeviceAwareModel):
|
|
535 |
conversation_parts.append(f"{USER_TOKEN} {query}")
|
536 |
return "\n".join(conversation_parts)
|
537 |
|
538 |
-
# def _build_conversation_context(
|
539 |
-
# self,
|
540 |
-
# query: str,
|
541 |
-
# conversation_history: Optional[List[Tuple[str, str]]]
|
542 |
-
# ) -> str:
|
543 |
-
# """
|
544 |
-
# Build conversation context string from conversation history.
|
545 |
-
# """
|
546 |
-
# if not conversation_history:
|
547 |
-
# return f"{self.tokenizer.additional_special_tokens[self.tokenizer.additional_special_tokens.index('<USER>')]} {query}"
|
548 |
-
|
549 |
-
# conversation_parts = []
|
550 |
-
# for user_txt, assistant_txt in conversation_history:
|
551 |
-
# conversation_parts.extend([
|
552 |
-
# f"{self.tokenizer.additional_special_tokens[self.tokenizer.additional_special_tokens.index('<USER>')]} {user_txt}",
|
553 |
-
# f"{self.tokenizer.additional_special_tokens[self.tokenizer.additional_special_tokens.index('<ASSISTANT>')]} {assistant_txt}"
|
554 |
-
# ])
|
555 |
-
|
556 |
-
# conversation_parts.append(f"{self.tokenizer.additional_special_tokens[self.tokenizer.additional_special_tokens.index('<USER>')]} {query}")
|
557 |
-
# return "\n".join(conversation_parts)
|
558 |
-
|
559 |
def train_model(
|
560 |
self,
|
561 |
tfrecord_file_path: str,
|
@@ -633,7 +543,7 @@ class RetrievalChatbot(DeviceAwareModel):
|
|
633 |
logger.info("Using fixed learning rate.")
|
634 |
|
635 |
# Dummy step to force initialization
|
636 |
-
dummy_input = tf.zeros((1, self.config.
|
637 |
with tf.GradientTape() as tape:
|
638 |
dummy_output = self.encoder(dummy_input)
|
639 |
dummy_loss = tf.cast(tf.reduce_mean(dummy_output), tf.float32)
|
@@ -747,7 +657,7 @@ class RetrievalChatbot(DeviceAwareModel):
|
|
747 |
logger.info(f"New validation pairs: {val_size}")
|
748 |
|
749 |
dataset = dataset.map(
|
750 |
-
lambda x: parse_tfrecord_fn(x, self.config.
|
751 |
num_parallel_calls=tf.data.AUTOTUNE
|
752 |
)
|
753 |
|
|
|
22 |
|
23 |
absl.logging.set_verbosity(absl.logging.WARNING)
|
24 |
logger = config_logger(__name__)
|
25 |
+
logger.setLevel("WARNING")
|
26 |
+
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
|
27 |
+
tqdm(disable=True)
|
28 |
|
29 |
class RetrievalChatbot(DeviceAwareModel):
|
30 |
"""
|
|
|
62 |
tokenizer=self.tokenizer,
|
63 |
encoder=self.encoder,
|
64 |
response_pool=[],
|
|
|
65 |
query_embeddings_cache={},
|
66 |
)
|
67 |
|
|
|
98 |
return Summarizer(
|
99 |
tokenizer=self.tokenizer,
|
100 |
model_name=self.config.summarizer_model,
|
101 |
+
max_summary_length=self.config.max_context_length // 4,
|
102 |
device=self.device,
|
103 |
max_summary_rounds=2
|
104 |
)
|
|
|
220 |
) -> List[Tuple[str, float]]:
|
221 |
"""
|
222 |
Retrieve top-k responses using FAISS and cross-encoder re-ranking.
|
|
|
223 |
Args:
|
224 |
query: The user's input text.
|
225 |
top_k: Number of responses to return.
|
|
|
227 |
summarizer: Optional summarizer for long queries.
|
228 |
summarize_threshold: Threshold to summarize long queries.
|
229 |
boost_factor: Factor to boost scores for keyword matches.
|
|
|
230 |
Returns:
|
231 |
List of (response_text, final_score).
|
232 |
"""
|
|
|
241 |
|
242 |
# Detect domain for query
|
243 |
detected_domain = self.detect_domain_from_query(query)
|
244 |
+
#logger.info(f"Detected domain: {detected_domain}")
|
245 |
|
246 |
+
# Retrieve candidates from FAISS
|
247 |
+
#logger.info("Retrieving initial candidates from FAISS...")
|
248 |
faiss_candidates = self.data_pipeline.retrieve_responses(query, top_k=top_k * 10)
|
249 |
|
250 |
if not faiss_candidates:
|
251 |
logger.warning("No candidates retrieved from FAISS.")
|
252 |
return []
|
253 |
|
254 |
+
# Filter out-of-domain responses
|
255 |
+
if detected_domain != 'other':
|
256 |
+
in_domain_candidates = [c for c in faiss_candidates if c[0]["domain"] == detected_domain]
|
257 |
+
if in_domain_candidates:
|
258 |
+
faiss_candidates = in_domain_candidates
|
259 |
+
else:
|
260 |
+
logger.info(f"No in-domain responses found for '{query}'. Using all candidates.")
|
261 |
+
|
262 |
+
# Re-rank candidates using Cross-Encoder
|
263 |
+
#logger.info("Re-ranking candidates using Cross-Encoder...")
|
264 |
+
texts = [item[0]["text"] for item in faiss_candidates] # Extract response texts
|
265 |
faiss_scores = [item[1] for item in faiss_candidates]
|
266 |
|
267 |
if reranker is None:
|
|
|
286 |
|
287 |
final_candidates.append((resp_text, length_adjusted_score))
|
288 |
|
289 |
+
# Sort and return top-k results
|
290 |
final_candidates.sort(key=lambda x: x[1], reverse=True)
|
291 |
+
#logger.info(f"Returning top-{top_k} re-ranked responses.")
|
292 |
+
|
293 |
return final_candidates[:top_k]
|
294 |
|
295 |
def extract_keywords(self, query: str) -> List[str]:
|
|
|
333 |
|
334 |
def detect_domain_from_query(self, query: str) -> str:
|
335 |
"""
|
336 |
+
Detect the domain of the query based on keywords. Used for filtering FAISS search.
|
337 |
"""
|
338 |
domain_patterns = {
|
339 |
'restaurant': r'\b(restaurant|restaurants?|dining|food|foods?|dine|reservation|reservations?|table|tables?|menu|menus?|cuisine|cuisines?|eat|eats?|place\s?to\s?eat|places\s?to\s?eat|hungry|chef|chefs?|dish|dishes?|meal|meals?|fork|forks?|knife|knives?|spoon|spoons?|brunch|bistro|buffet|buffets?|catering|caterings?|gourmet|fast\s?food|fine\s?dining|takeaway|takeaways?|delivery|deliveries|restaurant\s?booking)\b',
|
|
|
358 |
pattern = r'^[\s]*[\d]+([\s.,\d]+)*[\s]*$'
|
359 |
return bool(re.match(pattern, text.strip()))
|
360 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
361 |
def introduction_message(self) -> None:
|
362 |
"""Print an introduction message to introduce the chatbot."""
|
363 |
print(
|
|
|
384 |
print("\nAssistant: Goodbye!")
|
385 |
break
|
386 |
|
387 |
+
response, candidates, metrics, top_response_score = self.chat(
|
388 |
query=user_input,
|
389 |
conversation_history=None,
|
390 |
quality_checker=quality_checker,
|
|
|
397 |
print("\n Alternative responses:")
|
398 |
for resp, score in candidates[1:4]:
|
399 |
print(f" Score: {score:.4f} - {resp}")
|
400 |
+
elif top_response_score < 0.7:
|
401 |
print("\n[Low Confidence]: Consider rephrasing your query for better assistance.")
|
402 |
|
403 |
def chat(
|
|
|
435 |
|
436 |
# if uncertain, ask for clarification
|
437 |
if not is_confident or top_response_score < 0.5:
|
438 |
+
return ("I need more information to provide a good answer. Could you please clarify?", responses, metrics, top_response_score)
|
439 |
|
440 |
# Return the top response
|
441 |
+
return responses[0][0], responses, metrics, top_response_score
|
442 |
|
443 |
return get_response(self, query)
|
444 |
|
|
|
466 |
conversation_parts.append(f"{USER_TOKEN} {query}")
|
467 |
return "\n".join(conversation_parts)
|
468 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
469 |
def train_model(
|
470 |
self,
|
471 |
tfrecord_file_path: str,
|
|
|
543 |
logger.info("Using fixed learning rate.")
|
544 |
|
545 |
# Dummy step to force initialization
|
546 |
+
dummy_input = tf.zeros((1, self.config.max_context_length), dtype=tf.int32)
|
547 |
with tf.GradientTape() as tape:
|
548 |
dummy_output = self.encoder(dummy_input)
|
549 |
dummy_loss = tf.cast(tf.reduce_mean(dummy_output), tf.float32)
|
|
|
657 |
logger.info(f"New validation pairs: {val_size}")
|
658 |
|
659 |
dataset = dataset.map(
|
660 |
+
lambda x: parse_tfrecord_fn(x, self.config.max_context_length, self.data_pipeline.neg_samples),
|
661 |
num_parallel_calls=tf.data.AUTOTUNE
|
662 |
)
|
663 |
|
cross_encoder_reranker.py
CHANGED
@@ -42,7 +42,8 @@ class CrossEncoderReranker:
|
|
42 |
padding=True,
|
43 |
truncation=True,
|
44 |
max_length=max_length,
|
45 |
-
return_tensors="tf"
|
|
|
46 |
)
|
47 |
|
48 |
# Forward pass, logits shape [batch_size, 1]
|
|
|
42 |
padding=True,
|
43 |
truncation=True,
|
44 |
max_length=max_length,
|
45 |
+
return_tensors="tf",
|
46 |
+
verbose=False
|
47 |
)
|
48 |
|
49 |
# Forward pass, logits shape [batch_size, 1]
|
run_chatbot_chat.py
CHANGED
@@ -1,12 +1,19 @@
|
|
1 |
import os
|
2 |
import json
|
3 |
-
from
|
4 |
from chatbot_config import ChatbotConfig
|
|
|
|
|
|
|
5 |
from response_quality_checker import ResponseQualityChecker
|
6 |
from environment_setup import EnvironmentSetup
|
7 |
from logger_config import config_logger
|
8 |
|
9 |
logger = config_logger(__name__)
|
|
|
|
|
|
|
|
|
10 |
|
11 |
def run_chatbot_chat():
|
12 |
env = EnvironmentSetup()
|
@@ -37,38 +44,55 @@ def run_chatbot_chat():
|
|
37 |
config = ChatbotConfig()
|
38 |
logger.warning("No config.json found. Using default ChatbotConfig.")
|
39 |
|
40 |
-
#
|
41 |
try:
|
42 |
-
|
|
|
43 |
except Exception as e:
|
44 |
-
logger.error(f"Failed to load
|
45 |
-
return
|
46 |
-
|
47 |
-
# Confirm FAISS index & response pool exist
|
48 |
-
if not os.path.exists(FAISS_INDEX_PATH) or not os.path.exists(RESPONSE_POOL_PATH):
|
49 |
-
logger.error("FAISS index or response pool file is missing.")
|
50 |
return
|
51 |
-
|
52 |
# Load FAISS index and response pool
|
53 |
try:
|
54 |
-
|
55 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
56 |
with open(RESPONSE_POOL_PATH, "r", encoding="utf-8") as f:
|
57 |
-
|
58 |
-
|
|
|
|
|
59 |
# Validate dimension consistency
|
60 |
-
|
61 |
-
|
62 |
except Exception as e:
|
63 |
logger.error(f"Failed to load or validate FAISS index: {e}")
|
64 |
return
|
65 |
-
|
66 |
-
#
|
67 |
-
|
68 |
-
|
69 |
-
|
70 |
-
|
71 |
-
|
|
|
|
|
|
|
72 |
|
73 |
if __name__ == "__main__":
|
74 |
run_chatbot_chat()
|
|
|
1 |
import os
|
2 |
import json
|
3 |
+
from tqdm.auto import tqdm
|
4 |
from chatbot_config import ChatbotConfig
|
5 |
+
from chatbot_model import RetrievalChatbot
|
6 |
+
from sentence_transformers import SentenceTransformer
|
7 |
+
from tf_data_pipeline import TFDataPipeline
|
8 |
from response_quality_checker import ResponseQualityChecker
|
9 |
from environment_setup import EnvironmentSetup
|
10 |
from logger_config import config_logger
|
11 |
|
12 |
logger = config_logger(__name__)
|
13 |
+
logger.setLevel("WARNING")
|
14 |
+
|
15 |
+
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
|
16 |
+
tqdm(disable=True)
|
17 |
|
18 |
def run_chatbot_chat():
|
19 |
env = EnvironmentSetup()
|
|
|
44 |
config = ChatbotConfig()
|
45 |
logger.warning("No config.json found. Using default ChatbotConfig.")
|
46 |
|
47 |
+
# Init SentenceTransformer
|
48 |
try:
|
49 |
+
encoder = SentenceTransformer(config.pretrained_model)
|
50 |
+
logger.info(f"Loaded SentenceTransformer model: {config.pretrained_model}")
|
51 |
except Exception as e:
|
52 |
+
logger.error(f"Failed to load SentenceTransformer: {e}")
|
|
|
|
|
|
|
|
|
|
|
53 |
return
|
54 |
+
|
55 |
# Load FAISS index and response pool
|
56 |
try:
|
57 |
+
# Initialize TFDataPipeline
|
58 |
+
data_pipeline = TFDataPipeline(
|
59 |
+
config=config,
|
60 |
+
tokenizer=encoder.tokenizer,
|
61 |
+
encoder=encoder,
|
62 |
+
response_pool=[],
|
63 |
+
query_embeddings_cache={},
|
64 |
+
index_type='IndexFlatIP',
|
65 |
+
faiss_index_file_path=FAISS_INDEX_PATH
|
66 |
+
)
|
67 |
+
|
68 |
+
if not os.path.exists(FAISS_INDEX_PATH) or not os.path.exists(RESPONSE_POOL_PATH):
|
69 |
+
logger.error("FAISS index or response pool file is missing.")
|
70 |
+
return
|
71 |
+
|
72 |
+
data_pipeline.load_faiss_index(FAISS_INDEX_PATH)
|
73 |
+
logger.info(f"FAISS index loaded from {FAISS_INDEX_PATH}.")
|
74 |
+
|
75 |
with open(RESPONSE_POOL_PATH, "r", encoding="utf-8") as f:
|
76 |
+
data_pipeline.response_pool = json.load(f)
|
77 |
+
logger.info(f"Response pool loaded from {RESPONSE_POOL_PATH}.")
|
78 |
+
logger.info(f"Total responses in pool: {len(data_pipeline.response_pool)}")
|
79 |
+
|
80 |
# Validate dimension consistency
|
81 |
+
data_pipeline.validate_faiss_index()
|
82 |
+
logger.info("FAISS index and response pool validated successfully.")
|
83 |
except Exception as e:
|
84 |
logger.error(f"Failed to load or validate FAISS index: {e}")
|
85 |
return
|
86 |
+
|
87 |
+
# Run interactive chat
|
88 |
+
try:
|
89 |
+
chatbot = RetrievalChatbot.load_model(load_dir=MODEL_DIR, mode="inference")
|
90 |
+
quality_checker = ResponseQualityChecker(data_pipeline=data_pipeline)
|
91 |
+
|
92 |
+
logger.info("\nStarting interactive chat session...")
|
93 |
+
chatbot.run_interactive_chat(quality_checker=quality_checker, show_alternatives=False)
|
94 |
+
except Exception as e:
|
95 |
+
logger.error(f"Interactive chat session failed: {e}")
|
96 |
|
97 |
if __name__ == "__main__":
|
98 |
run_chatbot_chat()
|
run_chatbot_validation.py
CHANGED
@@ -44,9 +44,8 @@ def run_chatbot_validation():
|
|
44 |
|
45 |
# Init SentenceTransformer
|
46 |
try:
|
47 |
-
|
48 |
-
|
49 |
-
logger.info(f"Loaded SentenceTransformer model: {model_name}")
|
50 |
except Exception as e:
|
51 |
logger.error(f"Failed to load SentenceTransformer: {e}")
|
52 |
return
|
@@ -108,18 +107,10 @@ def run_chatbot_validation():
|
|
108 |
# Run interactive chat loop
|
109 |
try:
|
110 |
logger.info("\nStarting interactive chat session...")
|
111 |
-
|
112 |
-
|
113 |
-
|
114 |
-
|
115 |
-
|
116 |
-
|
117 |
-
responses = data_pipeline.retrieve_responses(user_input, top_k=3)
|
118 |
-
print("Top Responses:")
|
119 |
-
for i, (response, score) in enumerate(responses, start=1):
|
120 |
-
print(f"{i}. {response} (Score: {score:.4f})")
|
121 |
-
except KeyboardInterrupt:
|
122 |
-
logger.info("Interactive chat session interrupted by user.")
|
123 |
-
|
124 |
if __name__ == "__main__":
|
125 |
run_chatbot_validation()
|
|
|
44 |
|
45 |
# Init SentenceTransformer
|
46 |
try:
|
47 |
+
encoder = SentenceTransformer(config.pretrained_model)
|
48 |
+
logger.info(f"Loaded SentenceTransformer model: {config.pretrained_model}")
|
|
|
49 |
except Exception as e:
|
50 |
logger.error(f"Failed to load SentenceTransformer: {e}")
|
51 |
return
|
|
|
107 |
# Run interactive chat loop
|
108 |
try:
|
109 |
logger.info("\nStarting interactive chat session...")
|
110 |
+
chatbot.run_interactive_chat(quality_checker=quality_checker, show_alternatives=True)
|
111 |
+
except Exception as e:
|
112 |
+
logger.error(f"Interactive chat session failed: {e}")
|
113 |
+
|
114 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
115 |
if __name__ == "__main__":
|
116 |
run_chatbot_validation()
|
tf_data_pipeline.py
CHANGED
@@ -6,7 +6,7 @@ import h5py
|
|
6 |
import math
|
7 |
import random
|
8 |
import gc
|
9 |
-
from tqdm import tqdm
|
10 |
import json
|
11 |
from pathlib import Path
|
12 |
from typing import Union, Optional, Dict, List, Tuple, Generator
|
@@ -28,31 +28,25 @@ class TFDataPipeline:
|
|
28 |
encoder: SentenceTransformer,
|
29 |
response_pool: List[str],
|
30 |
query_embeddings_cache: dict,
|
31 |
-
model_name: str = 'sentence-transformers/all-MiniLM-L6-v2',
|
32 |
-
max_length: int = 512,
|
33 |
-
neg_samples: int = 10,
|
34 |
index_type: str = 'IndexFlatIP',
|
35 |
faiss_index_file_path: str = 'models/faiss_indices/faiss_index_production.index',
|
36 |
-
dimension: int = 384,
|
37 |
-
nlist: int = 100,
|
38 |
-
max_retries: int = 3
|
39 |
):
|
40 |
self.config = config
|
41 |
self.tokenizer = tokenizer
|
42 |
self.encoder = encoder
|
43 |
-
self.model = SentenceTransformer(
|
44 |
self.faiss_index_file_path = faiss_index_file_path
|
45 |
self.response_pool = response_pool
|
46 |
-
self.max_length = max_length
|
47 |
-
self.neg_samples = neg_samples
|
48 |
self.query_embeddings_cache = query_embeddings_cache # In-memory cache for embeddings
|
49 |
-
self.dimension = config.embedding_dim
|
50 |
self.index_type = index_type
|
51 |
-
self.
|
52 |
-
self.
|
53 |
-
self.
|
54 |
-
self.
|
55 |
-
self.
|
|
|
|
|
|
|
56 |
|
57 |
# Build text -> domain map for O(1) domain lookups (hard negative sampling)
|
58 |
self._text_domain_map = {}
|
@@ -159,7 +153,7 @@ class TFDataPipeline:
|
|
159 |
speaker = turn.get('speaker')
|
160 |
text = turn.get('text', '').strip()
|
161 |
if speaker == 'assistant' and text:
|
162 |
-
if len(text) <= self.
|
163 |
# Use tuple as set key to ensure uniqueness
|
164 |
key = (domain, text)
|
165 |
if key not in response_set:
|
@@ -388,7 +382,7 @@ class TFDataPipeline:
|
|
388 |
# f"Collision detected: text '{stripped_text}' found with domains "
|
389 |
# f"'{existing_domain}' and '{domain}'. Keeping the first."
|
390 |
# )
|
391 |
-
# By default, keep the first domain or overwrite.
|
392 |
continue
|
393 |
else:
|
394 |
# Insert into the dict
|
@@ -434,7 +428,7 @@ class TFDataPipeline:
|
|
434 |
prepared,
|
435 |
padding='max_length',
|
436 |
truncation=True,
|
437 |
-
max_length=self.
|
438 |
return_tensors='np'
|
439 |
)
|
440 |
input_ids = encodings['input_ids']
|
@@ -454,23 +448,19 @@ class TFDataPipeline:
|
|
454 |
def retrieve_responses(self, query: str, top_k: int = 10) -> List[Tuple[str, float]]:
|
455 |
"""
|
456 |
Retrieve top-k responses for a query using FAISS.
|
457 |
-
|
458 |
-
Args:
|
459 |
-
query: User's query text.
|
460 |
-
top_k: Number of responses to return.
|
461 |
-
|
462 |
-
Returns:
|
463 |
-
List of tuples (response text, similarity score).
|
464 |
"""
|
465 |
query_embedding = self.encode_query(query).reshape(1, -1).astype("float32")
|
466 |
distances, indices = self.index.search(query_embedding, top_k)
|
467 |
|
468 |
results = []
|
469 |
-
for idx, dist in
|
|
|
|
|
|
|
470 |
if idx < 0:
|
471 |
continue
|
472 |
response = self.response_pool[idx]
|
473 |
-
results.append((response
|
474 |
|
475 |
return results
|
476 |
|
@@ -496,7 +486,7 @@ class TFDataPipeline:
|
|
496 |
for dialogue in batch_dialogues:
|
497 |
pairs = self._extract_pairs_from_dialogue(dialogue)
|
498 |
for query, positive in pairs:
|
499 |
-
if len(query) <= self.
|
500 |
queries.append(query)
|
501 |
positives.append(positive)
|
502 |
|
@@ -524,14 +514,14 @@ class TFDataPipeline:
|
|
524 |
try:
|
525 |
encoded_queries = self.tokenizer.batch_encode_plus(
|
526 |
queries,
|
527 |
-
max_length=self.config.
|
528 |
truncation=True,
|
529 |
padding='max_length',
|
530 |
return_tensors='tf'
|
531 |
)
|
532 |
encoded_positives = self.tokenizer.batch_encode_plus(
|
533 |
positives,
|
534 |
-
max_length=self.config.
|
535 |
truncation=True,
|
536 |
padding='max_length',
|
537 |
return_tensors='tf'
|
@@ -547,7 +537,7 @@ class TFDataPipeline:
|
|
547 |
flattened_negatives = [neg for sublist in hard_negatives for neg in sublist]
|
548 |
encoded_negatives = self.tokenizer.batch_encode_plus(
|
549 |
flattened_negatives,
|
550 |
-
max_length=self.config.
|
551 |
truncation=True,
|
552 |
padding='max_length',
|
553 |
return_tensors='tf'
|
@@ -555,7 +545,7 @@ class TFDataPipeline:
|
|
555 |
|
556 |
# Reshape to [num_queries, num_negatives, max_length]
|
557 |
num_negatives = self.config.neg_samples
|
558 |
-
reshaped_negatives = encoded_negatives['input_ids'].numpy().reshape(-1, num_negatives, self.config.
|
559 |
except Exception as e:
|
560 |
logger.error(f"Error during negatives tokenization: {e}")
|
561 |
pbar.update(1)
|
@@ -600,7 +590,7 @@ class TFDataPipeline:
|
|
600 |
batch_queries,
|
601 |
padding=True,
|
602 |
truncation=True,
|
603 |
-
max_length=self.
|
604 |
return_tensors='tf'
|
605 |
)
|
606 |
batch_embeddings = self.encoder(encoded['input_ids'], training=False).numpy()
|
@@ -667,14 +657,14 @@ class TFDataPipeline:
|
|
667 |
# Use tf.py_function, limit parallelism
|
668 |
q_ids, p_ids, n_ids = tf.py_function(
|
669 |
func=self._tokenize_triple_py,
|
670 |
-
inp=[q, p, n, tf.constant(self.
|
671 |
Tout=[tf.int32, tf.int32, tf.int32]
|
672 |
)
|
673 |
|
674 |
# Set shape info for the output tensors
|
675 |
-
q_ids.set_shape([None, self.
|
676 |
-
p_ids.set_shape([None, self.
|
677 |
-
n_ids.set_shape([None, self.neg_samples, self.
|
678 |
|
679 |
return q_ids, p_ids, n_ids
|
680 |
|
|
|
6 |
import math
|
7 |
import random
|
8 |
import gc
|
9 |
+
from tqdm.auto import tqdm
|
10 |
import json
|
11 |
from pathlib import Path
|
12 |
from typing import Union, Optional, Dict, List, Tuple, Generator
|
|
|
28 |
encoder: SentenceTransformer,
|
29 |
response_pool: List[str],
|
30 |
query_embeddings_cache: dict,
|
|
|
|
|
|
|
31 |
index_type: str = 'IndexFlatIP',
|
32 |
faiss_index_file_path: str = 'models/faiss_indices/faiss_index_production.index',
|
|
|
|
|
|
|
33 |
):
|
34 |
self.config = config
|
35 |
self.tokenizer = tokenizer
|
36 |
self.encoder = encoder
|
37 |
+
self.model = SentenceTransformer(config.pretrained_model)
|
38 |
self.faiss_index_file_path = faiss_index_file_path
|
39 |
self.response_pool = response_pool
|
|
|
|
|
40 |
self.query_embeddings_cache = query_embeddings_cache # In-memory cache for embeddings
|
|
|
41 |
self.index_type = index_type
|
42 |
+
self.neg_samples = config.neg_samples
|
43 |
+
self.nlist = config.nlist
|
44 |
+
self.dimension = config.embedding_dim
|
45 |
+
self.max_context_length = config.max_context_length
|
46 |
+
self.embedding_batch_size = config.embedding_batch_size
|
47 |
+
self.search_batch_size = config.search_batch_size
|
48 |
+
self.max_batch_size = config.max_batch_size
|
49 |
+
self.max_retries = config.max_retries
|
50 |
|
51 |
# Build text -> domain map for O(1) domain lookups (hard negative sampling)
|
52 |
self._text_domain_map = {}
|
|
|
153 |
speaker = turn.get('speaker')
|
154 |
text = turn.get('text', '').strip()
|
155 |
if speaker == 'assistant' and text:
|
156 |
+
if len(text) <= self.max_context_length:
|
157 |
# Use tuple as set key to ensure uniqueness
|
158 |
key = (domain, text)
|
159 |
if key not in response_set:
|
|
|
382 |
# f"Collision detected: text '{stripped_text}' found with domains "
|
383 |
# f"'{existing_domain}' and '{domain}'. Keeping the first."
|
384 |
# )
|
385 |
+
# By default, keep the first domain or overwrite. Skip overwriting:
|
386 |
continue
|
387 |
else:
|
388 |
# Insert into the dict
|
|
|
428 |
prepared,
|
429 |
padding='max_length',
|
430 |
truncation=True,
|
431 |
+
max_length=self.max_context_length,
|
432 |
return_tensors='np'
|
433 |
)
|
434 |
input_ids = encodings['input_ids']
|
|
|
448 |
def retrieve_responses(self, query: str, top_k: int = 10) -> List[Tuple[str, float]]:
|
449 |
"""
|
450 |
Retrieve top-k responses for a query using FAISS.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
451 |
"""
|
452 |
query_embedding = self.encode_query(query).reshape(1, -1).astype("float32")
|
453 |
distances, indices = self.index.search(query_embedding, top_k)
|
454 |
|
455 |
results = []
|
456 |
+
for idx, dist in tqdm(
|
457 |
+
zip(indices[0], distances[0]),
|
458 |
+
disable=True # Silence tqdm
|
459 |
+
):
|
460 |
if idx < 0:
|
461 |
continue
|
462 |
response = self.response_pool[idx]
|
463 |
+
results.append((response, dist))
|
464 |
|
465 |
return results
|
466 |
|
|
|
486 |
for dialogue in batch_dialogues:
|
487 |
pairs = self._extract_pairs_from_dialogue(dialogue)
|
488 |
for query, positive in pairs:
|
489 |
+
if len(query) <= self.max_context_length and len(positive) <= self.max_context_length:
|
490 |
queries.append(query)
|
491 |
positives.append(positive)
|
492 |
|
|
|
514 |
try:
|
515 |
encoded_queries = self.tokenizer.batch_encode_plus(
|
516 |
queries,
|
517 |
+
max_length=self.config.max_context_length,
|
518 |
truncation=True,
|
519 |
padding='max_length',
|
520 |
return_tensors='tf'
|
521 |
)
|
522 |
encoded_positives = self.tokenizer.batch_encode_plus(
|
523 |
positives,
|
524 |
+
max_length=self.config.max_context_length,
|
525 |
truncation=True,
|
526 |
padding='max_length',
|
527 |
return_tensors='tf'
|
|
|
537 |
flattened_negatives = [neg for sublist in hard_negatives for neg in sublist]
|
538 |
encoded_negatives = self.tokenizer.batch_encode_plus(
|
539 |
flattened_negatives,
|
540 |
+
max_length=self.config.max_context_length,
|
541 |
truncation=True,
|
542 |
padding='max_length',
|
543 |
return_tensors='tf'
|
|
|
545 |
|
546 |
# Reshape to [num_queries, num_negatives, max_length]
|
547 |
num_negatives = self.config.neg_samples
|
548 |
+
reshaped_negatives = encoded_negatives['input_ids'].numpy().reshape(-1, num_negatives, self.config.max_context_length)
|
549 |
except Exception as e:
|
550 |
logger.error(f"Error during negatives tokenization: {e}")
|
551 |
pbar.update(1)
|
|
|
590 |
batch_queries,
|
591 |
padding=True,
|
592 |
truncation=True,
|
593 |
+
max_length=self.max_context_length,
|
594 |
return_tensors='tf'
|
595 |
)
|
596 |
batch_embeddings = self.encoder(encoded['input_ids'], training=False).numpy()
|
|
|
657 |
# Use tf.py_function, limit parallelism
|
658 |
q_ids, p_ids, n_ids = tf.py_function(
|
659 |
func=self._tokenize_triple_py,
|
660 |
+
inp=[q, p, n, tf.constant(self.max_context_length), tf.constant(self.neg_samples)],
|
661 |
Tout=[tf.int32, tf.int32, tf.int32]
|
662 |
)
|
663 |
|
664 |
# Set shape info for the output tensors
|
665 |
+
q_ids.set_shape([None, self.max_context_length]) # [batch_size, max_length]
|
666 |
+
p_ids.set_shape([None, self.max_context_length]) # [batch_size, max_length]
|
667 |
+
n_ids.set_shape([None, self.neg_samples, self.max_context_length]) # [batch_size, neg_samples, max_length]
|
668 |
|
669 |
return q_ids, p_ids, n_ids
|
670 |
|