JoeArmani
commited on
Commit
·
7a0020b
1
Parent(s):
d53c64b
updates - new iteration with type token
Browse files- .gitignore +7 -1
- build_faiss_index.py +161 -0
- chatbot_model.py +286 -202
- chatbot_validator.py +141 -68
- conversation_summarizer.py +1 -1
- cross_encoder_reranker.py +39 -13
- new_iteration/pipeline_config.py +9 -0
- new_iteration/run_taskmaster_processor.py +39 -0
- new_iteration/taskmaster_processor.py +177 -0
- prepare_data.py +137 -73
- response_quality_checker.py +266 -68
- tf_data_pipeline.py +293 -159
- validate_model.py +71 -42
.gitignore
CHANGED
@@ -180,4 +180,10 @@ cache/*
|
|
180 |
!cache/.gitkeep
|
181 |
training_data/*
|
182 |
!training_data/.gitkeep
|
183 |
-
augmented_dialogues.json
|
|
|
|
|
|
|
|
|
|
|
|
|
|
180 |
!cache/.gitkeep
|
181 |
training_data/*
|
182 |
!training_data/.gitkeep
|
183 |
+
augmented_dialogues.json
|
184 |
+
|
185 |
+
checkpoints_old_REMOVE/*
|
186 |
+
new_iteration/cache/*
|
187 |
+
new_iteration/data_prep_iterative_models/*
|
188 |
+
new_iteration/training_data/*
|
189 |
+
new_iteration/processed_outputs/*
|
build_faiss_index.py
ADDED
@@ -0,0 +1,161 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import json
|
3 |
+
from pathlib import Path
|
4 |
+
|
5 |
+
import faiss
|
6 |
+
import numpy as np
|
7 |
+
import tensorflow as tf
|
8 |
+
from transformers import AutoTokenizer, TFAutoModel
|
9 |
+
from tqdm.auto import tqdm
|
10 |
+
|
11 |
+
from chatbot_model import ChatbotConfig, EncoderModel
|
12 |
+
from tf_data_pipeline import TFDataPipeline
|
13 |
+
from logger_config import config_logger
|
14 |
+
|
15 |
+
logger = config_logger(__name__)
|
16 |
+
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
17 |
+
|
18 |
+
def sanity_check(encoder: EncoderModel, tokenizer: AutoTokenizer, config: ChatbotConfig):
|
19 |
+
"""
|
20 |
+
Perform a quick sanity check to ensure the model is loaded correctly.
|
21 |
+
"""
|
22 |
+
sample_response = "This is a test response."
|
23 |
+
encoded_sample = tokenizer(
|
24 |
+
[sample_response],
|
25 |
+
padding=True,
|
26 |
+
truncation=True,
|
27 |
+
max_length=config.max_context_token_limit,
|
28 |
+
return_tensors='tf'
|
29 |
+
)
|
30 |
+
|
31 |
+
# Get embedding
|
32 |
+
sample_embedding = encoder(encoded_sample['input_ids'], training=False).numpy()
|
33 |
+
|
34 |
+
# Check shape
|
35 |
+
if sample_embedding.shape[1] != config.embedding_dim:
|
36 |
+
logger.error(
|
37 |
+
f"Embedding dimension mismatch: Expected {config.embedding_dim}, "
|
38 |
+
f"got {sample_embedding.shape[1]}"
|
39 |
+
)
|
40 |
+
raise ValueError("Embedding dimension mismatch.")
|
41 |
+
else:
|
42 |
+
logger.info("Embedding dimension matches the configuration.")
|
43 |
+
|
44 |
+
# Check normalization
|
45 |
+
embedding_norm = np.linalg.norm(sample_embedding, axis=1)
|
46 |
+
if not np.allclose(embedding_norm, 1.0, atol=1e-5):
|
47 |
+
logger.error("Embeddings are not properly normalized.")
|
48 |
+
raise ValueError("Embeddings are not normalized.")
|
49 |
+
else:
|
50 |
+
logger.info("Embeddings are properly normalized.")
|
51 |
+
|
52 |
+
logger.info("Sanity check passed: Model loaded correctly and outputs are as expected.")
|
53 |
+
|
54 |
+
def build_faiss_index():
|
55 |
+
"""
|
56 |
+
Rebuild the FAISS index by:
|
57 |
+
1) Loading your config.json
|
58 |
+
2) Initializing encoder + loading submodule & custom weights
|
59 |
+
3) Loading tokenizer from disk
|
60 |
+
4) Creating a TFDataPipeline
|
61 |
+
5) Setting the pipeline's response_pool from a JSON file
|
62 |
+
6) Using pipeline.compute_and_index_response_embeddings()
|
63 |
+
7) Saving the FAISS index
|
64 |
+
"""
|
65 |
+
# Directories
|
66 |
+
MODELS_DIR = Path("models")
|
67 |
+
FAISS_DIR = MODELS_DIR / "faiss_indices"
|
68 |
+
FAISS_INDEX_PATH = FAISS_DIR / "faiss_index_production.index"
|
69 |
+
RESPONSES_PATH = FAISS_DIR / "faiss_index_production_responses.json"
|
70 |
+
TOKENIZER_DIR = MODELS_DIR / "tokenizer"
|
71 |
+
SHARED_ENCODER_DIR = MODELS_DIR / "shared_encoder"
|
72 |
+
CUSTOM_WEIGHTS_PATH = MODELS_DIR / "encoder_custom_weights.weights.h5"
|
73 |
+
|
74 |
+
# 1) Load ChatbotConfig
|
75 |
+
config_path = MODELS_DIR / "config.json"
|
76 |
+
if config_path.exists():
|
77 |
+
with open(config_path, "r", encoding="utf-8") as f:
|
78 |
+
config_dict = json.load(f)
|
79 |
+
config = ChatbotConfig.from_dict(config_dict)
|
80 |
+
logger.info(f"Loaded ChatbotConfig from {config_path}")
|
81 |
+
else:
|
82 |
+
config = ChatbotConfig()
|
83 |
+
logger.warning(f"No config.json found at {config_path}. Using default ChatbotConfig.")
|
84 |
+
|
85 |
+
# 2) Initialize the EncoderModel
|
86 |
+
encoder = EncoderModel(config=config)
|
87 |
+
logger.info("EncoderModel instantiated (empty).")
|
88 |
+
|
89 |
+
# Overwrite the submodule from 'shared_encoder' directory
|
90 |
+
if SHARED_ENCODER_DIR.exists():
|
91 |
+
logger.info(f"Loading DistilBERT submodule from {SHARED_ENCODER_DIR}...")
|
92 |
+
encoder.pretrained = TFAutoModel.from_pretrained(str(SHARED_ENCODER_DIR))
|
93 |
+
logger.info("Loaded HF submodule into encoder.pretrained.")
|
94 |
+
else:
|
95 |
+
logger.warning(f"No shared_encoder directory at {SHARED_ENCODER_DIR}. Using default pretrained model.")
|
96 |
+
|
97 |
+
# Build model once, then load custom weights (projection, etc.)
|
98 |
+
dummy_input = tf.zeros((1, config.max_context_token_limit), dtype=tf.int32)
|
99 |
+
_ = encoder(dummy_input, training=False) # builds the layers
|
100 |
+
|
101 |
+
if CUSTOM_WEIGHTS_PATH.exists():
|
102 |
+
logger.info(f"Loading custom top-level weights from {CUSTOM_WEIGHTS_PATH}")
|
103 |
+
encoder.load_weights(str(CUSTOM_WEIGHTS_PATH))
|
104 |
+
logger.info("Custom top-level weights loaded successfully.")
|
105 |
+
else:
|
106 |
+
logger.warning(f"Custom weights file not found at {CUSTOM_WEIGHTS_PATH}.")
|
107 |
+
|
108 |
+
# 3) Load tokenizer
|
109 |
+
if TOKENIZER_DIR.exists():
|
110 |
+
logger.info(f"Loading tokenizer from {TOKENIZER_DIR}")
|
111 |
+
tokenizer = AutoTokenizer.from_pretrained(str(TOKENIZER_DIR))
|
112 |
+
else:
|
113 |
+
logger.warning(f"No tokenizer dir at {TOKENIZER_DIR}, falling back to default HF tokenizer.")
|
114 |
+
tokenizer = AutoTokenizer.from_pretrained(config.pretrained_model)
|
115 |
+
#tokenizer.add_special_tokens({'additional_special_tokens': ['<EMPTY_NEGATIVE>']})
|
116 |
+
|
117 |
+
# 4) Quick sanity check
|
118 |
+
sanity_check(encoder, tokenizer, config)
|
119 |
+
|
120 |
+
# 5) Prepare a TFDataPipeline
|
121 |
+
pipeline = TFDataPipeline(
|
122 |
+
config=config,
|
123 |
+
tokenizer=tokenizer,
|
124 |
+
encoder=encoder,
|
125 |
+
index_file_path=str(FAISS_INDEX_PATH),
|
126 |
+
response_pool=[],
|
127 |
+
max_length=config.max_context_token_limit,
|
128 |
+
query_embeddings_cache={},
|
129 |
+
neg_samples=config.neg_samples,
|
130 |
+
index_type='IndexFlatIP',
|
131 |
+
nlist=100,
|
132 |
+
max_retries=config.max_retries
|
133 |
+
)
|
134 |
+
|
135 |
+
# 6) Load the existing response pool
|
136 |
+
if not RESPONSES_PATH.exists():
|
137 |
+
logger.error(f"Response pool JSON file not found at {RESPONSES_PATH}")
|
138 |
+
raise FileNotFoundError(f"No response pool JSON at {RESPONSES_PATH}")
|
139 |
+
|
140 |
+
with open(RESPONSES_PATH, "r", encoding="utf-8") as f:
|
141 |
+
response_pool = json.load(f)
|
142 |
+
logger.info(f"Loaded {len(response_pool)} responses from {RESPONSES_PATH}")
|
143 |
+
|
144 |
+
pipeline.response_pool = response_pool # assign to pipeline
|
145 |
+
|
146 |
+
# 7) Build (or rebuild) the FAISS index from pipeline method
|
147 |
+
# This does all the compute-embeddings + index.add in one place
|
148 |
+
logger.info("Starting to compute and index response embeddings via TFDataPipeline...")
|
149 |
+
pipeline.compute_and_index_response_embeddings()
|
150 |
+
|
151 |
+
# 8) Save the rebuilt FAISS index
|
152 |
+
pipeline.save_faiss_index(str(FAISS_INDEX_PATH))
|
153 |
+
|
154 |
+
# Verify
|
155 |
+
loaded_index = faiss.read_index(str(FAISS_INDEX_PATH))
|
156 |
+
logger.info(f"Verified the rebuilt FAISS index has {loaded_index.ntotal} vectors.")
|
157 |
+
|
158 |
+
return loaded_index, pipeline.response_pool
|
159 |
+
|
160 |
+
if __name__ == "__main__":
|
161 |
+
build_faiss_index()
|
chatbot_model.py
CHANGED
@@ -10,6 +10,8 @@ from pathlib import Path
|
|
10 |
import datetime
|
11 |
import faiss
|
12 |
import gc
|
|
|
|
|
13 |
from tf_data_pipeline import TFDataPipeline
|
14 |
from response_quality_checker import ResponseQualityChecker
|
15 |
from cross_encoder_reranker import CrossEncoderReranker
|
@@ -31,7 +33,7 @@ class ChatbotConfig:
|
|
31 |
num_attention_heads: int = 8
|
32 |
dropout_rate: float = 0.2
|
33 |
l2_reg_weight: float = 0.001
|
34 |
-
learning_rate: float = 0.
|
35 |
min_text_length: int = 3
|
36 |
max_context_turns: int = 5
|
37 |
warmup_steps: int = 200
|
@@ -41,7 +43,7 @@ class ChatbotConfig:
|
|
41 |
embedding_batch_size: int = 64
|
42 |
search_batch_size: int = 64
|
43 |
max_batch_size: int = 64
|
44 |
-
neg_samples: int =
|
45 |
max_retries: int = 3
|
46 |
|
47 |
def to_dict(self) -> Dict:
|
@@ -54,7 +56,7 @@ class ChatbotConfig:
|
|
54 |
"""Create config from dictionary."""
|
55 |
return cls(**{k: v for k, v in config_dict.items()
|
56 |
if k in cls.__dataclass_fields__})
|
57 |
-
|
58 |
class EncoderModel(tf.keras.Model):
|
59 |
"""Dual encoder model with pretrained embeddings."""
|
60 |
def __init__(
|
@@ -154,7 +156,7 @@ class RetrievalChatbot(DeviceAwareModel):
|
|
154 |
config=self.config,
|
155 |
tokenizer=self.tokenizer,
|
156 |
encoder=self.encoder,
|
157 |
-
index_file_path='
|
158 |
response_pool=[],
|
159 |
max_length=self.config.max_context_token_limit,
|
160 |
query_embeddings_cache={},
|
@@ -260,32 +262,49 @@ class RetrievalChatbot(DeviceAwareModel):
|
|
260 |
def load_model(cls, load_dir: Union[str, Path], mode: str = 'training') -> 'RetrievalChatbot':
|
261 |
"""
|
262 |
Load saved models and configuration.
|
263 |
-
|
264 |
-
Args:
|
265 |
-
load_dir (Union[str, Path]): Directory containing saved model files
|
266 |
-
mode (str): Either 'training' or 'inference'. In inference mode,
|
267 |
-
also loads FAISS index and response pool.
|
268 |
"""
|
269 |
load_dir = Path(load_dir)
|
270 |
|
271 |
-
# Load config
|
272 |
with open(load_dir / "config.json", "r") as f:
|
273 |
config = ChatbotConfig.from_dict(json.load(f))
|
274 |
|
275 |
-
# Initialize chatbot
|
276 |
chatbot = cls(config, mode=mode)
|
277 |
|
278 |
-
# Load
|
279 |
chatbot.encoder.pretrained = TFAutoModel.from_pretrained(
|
280 |
load_dir / "shared_encoder",
|
281 |
config=config
|
282 |
)
|
283 |
|
284 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
285 |
chatbot.tokenizer = AutoTokenizer.from_pretrained(load_dir / "tokenizer")
|
286 |
logger.info(f"Models and tokenizer loaded from {load_dir}")
|
287 |
|
288 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
289 |
if mode == 'inference':
|
290 |
cls._prepare_model_for_inference(chatbot, load_dir)
|
291 |
|
@@ -296,7 +315,7 @@ class RetrievalChatbot(DeviceAwareModel):
|
|
296 |
"""Internal method to load inference components."""
|
297 |
try:
|
298 |
# Load FAISS index
|
299 |
-
faiss_path = load_dir / '
|
300 |
if faiss_path.exists():
|
301 |
chatbot.index = faiss.read_index(str(faiss_path))
|
302 |
logger.info("FAISS index loaded successfully")
|
@@ -304,7 +323,7 @@ class RetrievalChatbot(DeviceAwareModel):
|
|
304 |
raise FileNotFoundError(f"FAISS index not found at {faiss_path}")
|
305 |
|
306 |
# Load response pool
|
307 |
-
response_pool_path = load_dir / '
|
308 |
if response_pool_path.exists():
|
309 |
with open(response_pool_path, 'r') as f:
|
310 |
chatbot.response_pool = json.load(f)
|
@@ -332,9 +351,12 @@ class RetrievalChatbot(DeviceAwareModel):
|
|
332 |
with open(save_dir / "config.json", "w") as f:
|
333 |
json.dump(self.config.to_dict(), f, indent=2)
|
334 |
|
335 |
-
# Save
|
336 |
self.encoder.pretrained.save_pretrained(save_dir / "shared_encoder")
|
337 |
|
|
|
|
|
|
|
338 |
# Save tokenizer
|
339 |
self.tokenizer.save_pretrained(save_dir / "tokenizer")
|
340 |
|
@@ -343,139 +365,270 @@ class RetrievalChatbot(DeviceAwareModel):
|
|
343 |
def retrieve_responses_cross_encoder(
|
344 |
self,
|
345 |
query: str,
|
346 |
-
top_k: int,
|
347 |
reranker: Optional[CrossEncoderReranker] = None,
|
348 |
summarizer: Optional[Summarizer] = None,
|
349 |
-
summarize_threshold: int = 512
|
350 |
) -> List[Tuple[str, float]]:
|
351 |
"""
|
352 |
-
Retrieve top-k
|
353 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
354 |
"""
|
355 |
-
|
356 |
-
reranker = self.reranker
|
357 |
-
if summarizer is None:
|
358 |
-
summarizer = self.summarizer
|
359 |
-
|
360 |
-
# Optional summarization
|
361 |
if summarizer and len(query.split()) > summarize_threshold:
|
362 |
-
logger.info(f"Query is long
|
363 |
query = summarizer.summarize_text(query)
|
364 |
-
logger.info(f"Summarized
|
365 |
-
|
366 |
-
# 2) Dense retrieval
|
367 |
-
dense_topk = self.retrieve_responses_faiss(query, top_k=top_k) # [(resp, dense_score), ...]
|
368 |
|
369 |
-
|
370 |
-
|
371 |
|
372 |
-
#
|
373 |
-
|
374 |
-
|
375 |
|
376 |
-
|
377 |
-
|
378 |
-
#
|
379 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
380 |
|
381 |
-
|
382 |
-
|
383 |
-
|
384 |
-
|
385 |
-
# top_k: int,
|
386 |
-
# reranker: Optional[CrossEncoderReranker] = None,
|
387 |
-
# summarizer: Optional[Summarizer] = None,
|
388 |
-
# summarize_threshold: int = 512 # Summarize over 512 tokens
|
389 |
-
# ) -> List[Tuple[str, float]]:
|
390 |
-
# """
|
391 |
-
# Retrieve top-k from FAISS, then re-rank them with a cross-encoder.
|
392 |
-
# Optionally summarize the user query if it's too long.
|
393 |
-
# """
|
394 |
-
# if reranker is None:
|
395 |
-
# reranker = self.reranker
|
396 |
-
# if summarizer is None:
|
397 |
-
# summarizer = self.summarizer
|
398 |
-
|
399 |
-
# # Optional summarization
|
400 |
-
# if summarizer and len(query.split()) > summarize_threshold:
|
401 |
-
# logger.info(f"Query is long. Summarizing before cross-encoder. Original length: {len(query.split())}")
|
402 |
-
# query = summarizer.summarize_text(query)
|
403 |
-
# logger.info(f"Summarized query: {query}")
|
404 |
|
405 |
-
|
406 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
407 |
|
408 |
-
|
409 |
-
|
|
|
|
|
|
|
410 |
|
411 |
-
|
412 |
-
|
413 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
414 |
|
415 |
-
|
416 |
-
|
417 |
-
# # Sort descending by cross-encoder score
|
418 |
-
# combined.sort(key=lambda x: x[1], reverse=True)
|
419 |
|
420 |
-
|
421 |
-
|
422 |
-
|
423 |
-
|
424 |
-
|
425 |
-
|
426 |
-
|
427 |
-
|
428 |
-
|
429 |
-
|
430 |
-
|
431 |
-
|
432 |
-
|
433 |
-
|
434 |
-
|
435 |
-
|
436 |
-
distances, indices = self.data_pipeline.index.search(q_emb_np, top_k)
|
437 |
-
|
438 |
-
# Map indices to responses and distances to similarities
|
439 |
-
top_responses = []
|
440 |
-
for i, idx in enumerate(indices[0]):
|
441 |
-
if idx < len(self.data_pipeline.response_pool):
|
442 |
-
top_responses.append((self.data_pipeline.response_pool[idx], float(distances[0][i])))
|
443 |
else:
|
444 |
-
|
445 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
446 |
return top_responses
|
447 |
-
# def retrieve_responses_faiss(
|
448 |
-
#
|
449 |
-
#
|
450 |
-
#
|
451 |
-
#
|
452 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
453 |
# # Encode the query
|
454 |
-
# q_emb = self.encode_query(query)
|
455 |
-
# q_emb_np = q_emb.
|
456 |
-
|
457 |
-
# #
|
458 |
-
#
|
459 |
-
|
460 |
-
# #
|
461 |
-
#
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
462 |
|
463 |
-
# #
|
464 |
-
#
|
465 |
-
#
|
466 |
-
#
|
467 |
-
#
|
|
|
|
|
|
|
|
|
|
|
468 |
# else:
|
469 |
-
#
|
470 |
-
|
471 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
472 |
|
473 |
def chat(
|
474 |
self,
|
475 |
query: str,
|
476 |
conversation_history: Optional[List[Tuple[str, str]]] = None,
|
477 |
quality_checker: Optional['ResponseQualityChecker'] = None,
|
478 |
-
top_k: int =
|
479 |
) -> Tuple[str, List[Tuple[str, float]], Dict[str, Any]]:
|
480 |
"""
|
481 |
Example chat method that always uses cross-encoder re-ranking
|
@@ -516,52 +669,6 @@ class RetrievalChatbot(DeviceAwareModel):
|
|
516 |
return results[0][0], results, {}
|
517 |
|
518 |
return get_response(self, query)
|
519 |
-
# def chat(
|
520 |
-
# self,
|
521 |
-
# query: str,
|
522 |
-
# conversation_history: Optional[List[Tuple[str, str]]] = None,
|
523 |
-
# quality_checker: Optional['ResponseQualityChecker'] = None,
|
524 |
-
# top_k: int = 5,
|
525 |
-
# ) -> Tuple[str, List[Tuple[str, float]], Dict[str, Any]]:
|
526 |
-
# """
|
527 |
-
# Example chat method that always uses cross-encoder re-ranking
|
528 |
-
# if self.reranker is available.
|
529 |
-
# """
|
530 |
-
# @self.run_on_device
|
531 |
-
# def get_response(self_arg, query_arg): # Add parameters that match decorator's expectations
|
532 |
-
# # 1) Build conversation context string
|
533 |
-
# conversation_str = self_arg._build_conversation_context(query_arg, conversation_history)
|
534 |
-
|
535 |
-
# # 2) Retrieve + cross-encoder re-rank
|
536 |
-
# results = self_arg.retrieve_responses_cross_encoder(
|
537 |
-
# query=conversation_str,
|
538 |
-
# top_k=top_k,
|
539 |
-
# reranker=self_arg.reranker,
|
540 |
-
# summarizer=self_arg.summarizer,
|
541 |
-
# summarize_threshold=512
|
542 |
-
# )
|
543 |
-
|
544 |
-
# # 3) Handle empty or confidence
|
545 |
-
# if not results:
|
546 |
-
# return (
|
547 |
-
# "I'm sorry, but I couldn't find a relevant response.",
|
548 |
-
# [],
|
549 |
-
# {}
|
550 |
-
# )
|
551 |
-
|
552 |
-
# if quality_checker:
|
553 |
-
# metrics = quality_checker.check_response_quality(query_arg, results)
|
554 |
-
# if not metrics.get('is_confident', False):
|
555 |
-
# return (
|
556 |
-
# "I need more information to provide a good answer. Could you please clarify?",
|
557 |
-
# results,
|
558 |
-
# metrics
|
559 |
-
# )
|
560 |
-
# return results[0][0], results, metrics
|
561 |
-
|
562 |
-
# return results[0][0], results, {}
|
563 |
-
|
564 |
-
# return get_response(self, query)
|
565 |
|
566 |
def _build_conversation_context(
|
567 |
self,
|
@@ -581,24 +688,6 @@ class RetrievalChatbot(DeviceAwareModel):
|
|
581 |
|
582 |
conversation_parts.append(f"{self.tokenizer.additional_special_tokens[self.tokenizer.additional_special_tokens.index('<USER>')]} {query}")
|
583 |
return "\n".join(conversation_parts)
|
584 |
-
# def _build_conversation_context(
|
585 |
-
# self,
|
586 |
-
# query: str,
|
587 |
-
# conversation_history: Optional[List[Tuple[str, str]]]
|
588 |
-
# ) -> str:
|
589 |
-
# """Build conversation context with better memory management."""
|
590 |
-
# if not conversation_history:
|
591 |
-
# return f"{self.special_tokens['user']} {query}"
|
592 |
-
|
593 |
-
# conversation_parts = []
|
594 |
-
# for user_txt, assistant_txt in conversation_history:
|
595 |
-
# conversation_parts.extend([
|
596 |
-
# f"{self.special_tokens['user']} {user_txt}",
|
597 |
-
# f"{self.special_tokens['assistant']} {assistant_txt}"
|
598 |
-
# ])
|
599 |
-
|
600 |
-
# conversation_parts.append(f"{self.special_tokens['user']} {query}")
|
601 |
-
# return "\n".join(conversation_parts)
|
602 |
|
603 |
def train_model(
|
604 |
self,
|
@@ -707,23 +796,14 @@ class RetrievalChatbot(DeviceAwareModel):
|
|
707 |
self.history = {'train_loss': [], 'val_loss': [], 'learning_rate': []}
|
708 |
|
709 |
if latest_checkpoint and not test_mode:
|
710 |
-
# Debug info before restore
|
711 |
-
logger.info("\nEncoder Variables:")
|
712 |
-
for var in self.encoder.variables:
|
713 |
-
logger.info(f"{var.name}: {var.dtype} - Shape: {var.shape}")
|
714 |
-
|
715 |
-
logger.info("\nOptimizer Variables:")
|
716 |
-
for var in self.optimizer.variables:
|
717 |
-
logger.info(f"{var.name}: {var.dtype} - Shape: {var.shape}")
|
718 |
-
|
719 |
# Add checkpoint inspection
|
720 |
-
logger.info("\nTrying to load checkpoint from: "
|
721 |
reader = tf.train.load_checkpoint(latest_checkpoint)
|
722 |
-
shape_from_key = reader.get_variable_to_shape_map()
|
723 |
-
dtype_from_key = reader.get_variable_to_dtype_map()
|
724 |
-
logger.info("\nCheckpoint Variables:")
|
725 |
-
for key in shape_from_key:
|
726 |
-
|
727 |
|
728 |
status = checkpoint.restore(latest_checkpoint)
|
729 |
status.assert_consumed()
|
@@ -754,6 +834,10 @@ class RetrievalChatbot(DeviceAwareModel):
|
|
754 |
logger.info(f"Loaded previous training history from {history_path}")
|
755 |
except Exception as e:
|
756 |
logger.warning(f"Could not load history, starting fresh: {e}")
|
|
|
|
|
|
|
|
|
757 |
else:
|
758 |
logger.info("Starting training from scratch")
|
759 |
checkpoint.epoch.assign(tf.cast(0, tf.int32))
|
|
|
10 |
import datetime
|
11 |
import faiss
|
12 |
import gc
|
13 |
+
|
14 |
+
import re
|
15 |
from tf_data_pipeline import TFDataPipeline
|
16 |
from response_quality_checker import ResponseQualityChecker
|
17 |
from cross_encoder_reranker import CrossEncoderReranker
|
|
|
33 |
num_attention_heads: int = 8
|
34 |
dropout_rate: float = 0.2
|
35 |
l2_reg_weight: float = 0.001
|
36 |
+
learning_rate: float = 0.001
|
37 |
min_text_length: int = 3
|
38 |
max_context_turns: int = 5
|
39 |
warmup_steps: int = 200
|
|
|
43 |
embedding_batch_size: int = 64
|
44 |
search_batch_size: int = 64
|
45 |
max_batch_size: int = 64
|
46 |
+
neg_samples: int = 10
|
47 |
max_retries: int = 3
|
48 |
|
49 |
def to_dict(self) -> Dict:
|
|
|
56 |
"""Create config from dictionary."""
|
57 |
return cls(**{k: v for k, v in config_dict.items()
|
58 |
if k in cls.__dataclass_fields__})
|
59 |
+
|
60 |
class EncoderModel(tf.keras.Model):
|
61 |
"""Dual encoder model with pretrained embeddings."""
|
62 |
def __init__(
|
|
|
156 |
config=self.config,
|
157 |
tokenizer=self.tokenizer,
|
158 |
encoder=self.encoder,
|
159 |
+
index_file_path='new_iteration/data_prep_iterative_models/faiss_indices/faiss_index_production.index',
|
160 |
response_pool=[],
|
161 |
max_length=self.config.max_context_token_limit,
|
162 |
query_embeddings_cache={},
|
|
|
262 |
def load_model(cls, load_dir: Union[str, Path], mode: str = 'training') -> 'RetrievalChatbot':
|
263 |
"""
|
264 |
Load saved models and configuration.
|
|
|
|
|
|
|
|
|
|
|
265 |
"""
|
266 |
load_dir = Path(load_dir)
|
267 |
|
268 |
+
# 1) Load config
|
269 |
with open(load_dir / "config.json", "r") as f:
|
270 |
config = ChatbotConfig.from_dict(json.load(f))
|
271 |
|
272 |
+
# 2) Initialize chatbot
|
273 |
chatbot = cls(config, mode=mode)
|
274 |
|
275 |
+
# 3) Load DistilBERT from huggingface folder
|
276 |
chatbot.encoder.pretrained = TFAutoModel.from_pretrained(
|
277 |
load_dir / "shared_encoder",
|
278 |
config=config
|
279 |
)
|
280 |
|
281 |
+
dummy_input = tf.zeros((1, config.max_context_token_limit), dtype=tf.int32)
|
282 |
+
_ = chatbot.encoder(dummy_input, training=False)
|
283 |
+
|
284 |
+
# # Then load your custom weights
|
285 |
+
# custom_weights_path = load_dir / "encoder_custom_weights.weights.h5"
|
286 |
+
# if custom_weights_path.exists():
|
287 |
+
# logger.info(f"Loading custom top-level weights from {custom_weights_path}")
|
288 |
+
# chatbot.encoder.load_weights(str(custom_weights_path))
|
289 |
+
# logger.info("Custom top-level weights loaded successfully.")
|
290 |
+
# else:
|
291 |
+
# logger.warning(f"Custom weights file not found at {custom_weights_path}.")
|
292 |
+
|
293 |
+
# 4) Load tokenizer
|
294 |
chatbot.tokenizer = AutoTokenizer.from_pretrained(load_dir / "tokenizer")
|
295 |
logger.info(f"Models and tokenizer loaded from {load_dir}")
|
296 |
|
297 |
+
|
298 |
+
|
299 |
+
# 5) Load the custom top layers' weights
|
300 |
+
custom_weights_path = load_dir / "encoder_custom_weights.weights.h5"
|
301 |
+
if custom_weights_path.exists():
|
302 |
+
chatbot.encoder.load_weights(str(custom_weights_path))
|
303 |
+
logger.info("Loaded custom encoder weights for projection/dropout/etc.")
|
304 |
+
else:
|
305 |
+
logger.warning(f"No custom encoder weights found at {custom_weights_path}. The top-level projection layer won't have learned parameters.")
|
306 |
+
|
307 |
+
# 6) If in inference mode, load FAISS, etc.
|
308 |
if mode == 'inference':
|
309 |
cls._prepare_model_for_inference(chatbot, load_dir)
|
310 |
|
|
|
315 |
"""Internal method to load inference components."""
|
316 |
try:
|
317 |
# Load FAISS index
|
318 |
+
faiss_path = load_dir / 'faiss_indices/faiss_index_production.index'
|
319 |
if faiss_path.exists():
|
320 |
chatbot.index = faiss.read_index(str(faiss_path))
|
321 |
logger.info("FAISS index loaded successfully")
|
|
|
323 |
raise FileNotFoundError(f"FAISS index not found at {faiss_path}")
|
324 |
|
325 |
# Load response pool
|
326 |
+
response_pool_path = load_dir / 'faiss_indices/faiss_index_production_responses.json'
|
327 |
if response_pool_path.exists():
|
328 |
with open(response_pool_path, 'r') as f:
|
329 |
chatbot.response_pool = json.load(f)
|
|
|
351 |
with open(save_dir / "config.json", "w") as f:
|
352 |
json.dump(self.config.to_dict(), f, indent=2)
|
353 |
|
354 |
+
# Save the HF DistilBERT submodule:
|
355 |
self.encoder.pretrained.save_pretrained(save_dir / "shared_encoder")
|
356 |
|
357 |
+
# ALSO save custom top-level layers' weights
|
358 |
+
self.encoder.save_weights(save_dir / "encoder_custom_weights.weights.h5")
|
359 |
+
|
360 |
# Save tokenizer
|
361 |
self.tokenizer.save_pretrained(save_dir / "tokenizer")
|
362 |
|
|
|
365 |
def retrieve_responses_cross_encoder(
|
366 |
self,
|
367 |
query: str,
|
368 |
+
top_k: int = 10,
|
369 |
reranker: Optional[CrossEncoderReranker] = None,
|
370 |
summarizer: Optional[Summarizer] = None,
|
371 |
+
summarize_threshold: int = 512
|
372 |
) -> List[Tuple[str, float]]:
|
373 |
"""
|
374 |
+
Retrieve top-k responses with optional domain-based boosting
|
375 |
+
and cross-encoder re-ranking.
|
376 |
+
|
377 |
+
Args:
|
378 |
+
query: The user's input text.
|
379 |
+
top_k: Number of final results to return.
|
380 |
+
reranker: CrossEncoderReranker for refined scoring, if available.
|
381 |
+
summarizer: Summarizer for long queries, if desired.
|
382 |
+
summarize_threshold: Summarize if query wordcount > threshold.
|
383 |
+
|
384 |
+
Returns:
|
385 |
+
List of (response_text, final_score).
|
386 |
"""
|
387 |
+
# 1) Optional query summarization
|
|
|
|
|
|
|
|
|
|
|
388 |
if summarizer and len(query.split()) > summarize_threshold:
|
389 |
+
logger.info(f"Query is long ({len(query.split())} words). Summarizing.")
|
390 |
query = summarizer.summarize_text(query)
|
391 |
+
logger.info(f"Summarized Query: {query}")
|
|
|
|
|
|
|
392 |
|
393 |
+
detected_domain = self.detect_domain_from_query(query)
|
394 |
+
logger.debug(f"Detected domain '{detected_domain}' for query: {query}")
|
395 |
|
396 |
+
# 2) Retrieve more initial candidates from FAISS
|
397 |
+
initial_k = min(top_k * 10, len(self.data_pipeline.response_pool))
|
398 |
+
dense_candidates = self.retrieve_responses_faiss(query, domain=detected_domain, top_k=initial_k)
|
399 |
|
400 |
+
boosted_candidates = dense_candidates
|
401 |
+
|
402 |
+
# 4) If we have a cross-encoder, re-rank these boosted candidates
|
403 |
+
if not reranker:
|
404 |
+
logger.warning("No CrossEncoderReranker provided; creating a default one.")
|
405 |
+
reranker = CrossEncoderReranker(model_name="cross-encoder/ms-marco-MiniLM-L-12-v2")
|
406 |
+
|
407 |
+
texts = [item[0] for item in boosted_candidates]
|
408 |
+
ce_scores = reranker.rerank(query, texts, max_length=256)
|
409 |
+
|
410 |
+
# Combine cross-encoder score with the base FAISS score (simple multiplicative approach)
|
411 |
+
final_candidates = []
|
412 |
+
for (resp_text, faiss_score), ce_score in zip(boosted_candidates, ce_scores):
|
413 |
+
# TODO: dial this in.
|
414 |
+
alpha = 0.8
|
415 |
+
combined_score = alpha * ce_score + (1 - alpha) * faiss_score
|
416 |
+
length_adjusted_score = self.length_adjust_score(resp_text, combined_score)
|
417 |
+
#combined_score = ce_score * faiss_score
|
418 |
+
final_candidates.append((resp_text, combined_score))
|
419 |
+
|
420 |
+
# Sort descending by combined score
|
421 |
+
final_candidates.sort(key=lambda x: x[1], reverse=True)
|
422 |
+
|
423 |
+
# Return top_k
|
424 |
+
return final_candidates[:top_k]
|
425 |
+
|
426 |
+
DOMAIN_KEYWORDS = {
|
427 |
+
'restaurant': ['restaurant', 'dining', 'food', 'dine', 'reservation', 'table', 'menu', 'cuisine', 'eat', 'place to eat', 'hungry', 'chef', 'dish', 'meal', 'brunch', 'bistro', 'buffet', 'catering', 'gourmet', 'fast food', 'fine dining', 'takeaway', 'delivery', 'restaurant booking'],
|
428 |
+
'movie': ['movie', 'cinema', 'film', 'ticket', 'showtime', 'showing', 'theater', 'flick', 'screening', 'film ticket', 'film show', 'blockbuster', 'premiere', 'trailer', 'director', 'actor', 'actress', 'plot', 'genre', 'screen', 'sequel', 'animation', 'documentary'],
|
429 |
+
'ride_share': ['ride', 'taxi', 'uber', 'lyft', 'car service', 'pickup', 'dropoff', 'driver', 'cab', 'hailing', 'rideshare', 'ride hailing', 'carpool', 'chauffeur', 'transit', 'transportation', 'hail ride'],
|
430 |
+
'coffee': ['coffee', 'café', 'cafe', 'starbucks', 'espresso', 'latte', 'mocha', 'americano', 'barista', 'brew', 'cappuccino', 'macchiato', 'iced coffee', 'cold brew', 'espresso machine', 'coffee shop', 'tea', 'chai', 'java', 'bean', 'roast', 'decaf'],
|
431 |
+
'pizza': ['pizza', 'delivery', 'order food', 'pepperoni', 'topping', 'pizzeria', 'slice', 'pie', 'margherita', 'deep dish', 'thin crust', 'cheese', 'oven', 'tossed', 'sauce', 'garlic bread', 'calzone'],
|
432 |
+
'auto': ['car', 'vehicle', 'repair', 'maintenance', 'mechanic', 'oil change', 'garage', 'auto shop', 'tire', 'check engine', 'battery', 'transmission', 'brake', 'engine diagnostics', 'carwash', 'detail', 'alignment', 'exhaust', 'spark plug', 'dashboard'],
|
433 |
+
}
|
434 |
+
|
435 |
+
def extract_keywords(self, query: str) -> List[str]:
|
436 |
+
"""
|
437 |
+
Extract keywords from the query based on DOMAIN_KEYWORDS.
|
438 |
+
"""
|
439 |
+
query_lower = query.lower()
|
440 |
+
keywords = set()
|
441 |
+
for domain, kws in self.DOMAIN_KEYWORDS.items():
|
442 |
+
for kw in kws:
|
443 |
+
if kw in query_lower:
|
444 |
+
keywords.add(kw)
|
445 |
+
return list(keywords)
|
446 |
+
|
447 |
+
def length_adjust_score(resp_text: str, base_score: float) -> float:
|
448 |
+
# Apply a short penalty
|
449 |
+
words = len(resp_text.split())
|
450 |
+
if words < 3:
|
451 |
+
# big penalty or skip entirely
|
452 |
+
return base_score * 0.1 # or base_score - 0.01
|
453 |
+
|
454 |
+
# Add a mild bonus for lines that exceed 12 words:
|
455 |
+
if words > 12:
|
456 |
+
# e.g. +0.002 * (words - 12)
|
457 |
+
bonus = 0.002 * (words - 12)
|
458 |
+
base_score += bonus
|
459 |
+
|
460 |
+
return base_score
|
461 |
+
|
462 |
+
def detect_domain_from_query(self, query: str) -> str:
|
463 |
+
"""
|
464 |
+
Detect the domain of the query based on keywords.
|
465 |
+
"""
|
466 |
+
domain_patterns = {
|
467 |
+
'restaurant': r'\b(restaurant|dining|food|dine|reservation|table|menu|cuisine|eat|place\s?to\s?eat|hungry|chef|dish|meal|fork|knife|spoon|brunch|bistro|buffet|catering|gourmet|fast\s?food|fine\s?dining|takeaway|delivery|restaurant\s?booking)\b',
|
468 |
+
'movie': r'\b(movie|cinema|film|ticket|showtime|showing|theater|flick|screening|film\s?ticket|film\s?show|blockbuster|premiere|trailer|director|actor|actress|plot|genre|screen|sequel|animation|documentary)\b',
|
469 |
+
'ride_share': r'\b(ride|taxi|uber|lyft|car\s?service|pickup|dropoff|driver|cab|hailing|rideshare|ride\s?hailing|carpool|chauffeur|transit|transportation|hail\s?ride)\b',
|
470 |
+
'coffee': r'\b(coffee|café|cafe|starbucks|espresso|latte|mocha|americano|barista|brew|cappuccino|macchiato|iced\s?coffee|cold\s?brew|espresso\s?machine|coffee\s?shop|tea|chai|java|bean|roast|decaf)\b',
|
471 |
+
'pizza': r'\b(pizza|delivery|order\s?food|pepperoni|topping|pizzeria|slice|pie|margherita|deep\s?dish|thin\s?crust|cheese|oven|tossed|sauce|garlic\s?bread|calzone)\b',
|
472 |
+
'auto': r'\b(car|vehicle|repair|maintenance|mechanic|oil\s?change|garage|auto\s?shop|tire|check\s?engine|battery|transmission|brake|engine\s?diagnostics|carwash|detail|alignment|exhaust|spark\s?plug|dashboard)\b',
|
473 |
+
}
|
474 |
|
475 |
+
# Check for matches
|
476 |
+
for domain, pattern in domain_patterns.items():
|
477 |
+
if re.search(pattern, query.lower()):
|
478 |
+
return domain
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
479 |
|
480 |
+
return 'other'
|
481 |
+
|
482 |
+
def is_numeric_response(text: str) -> bool:
|
483 |
+
"""
|
484 |
+
Return True if `text` is purely digits (and/or spaces).
|
485 |
+
e.g.: "4 3 13" -> True, " 42 " -> True, "hello 42" -> False
|
486 |
+
"""
|
487 |
+
pattern = r'^\s*[0-9]+(\s+[0-9]+)*\s*$'
|
488 |
+
return bool(re.match(pattern, text))
|
489 |
+
|
490 |
+
def retrieve_responses_faiss(
|
491 |
+
self,
|
492 |
+
query: str,
|
493 |
+
domain: str = 'other',
|
494 |
+
top_k: int = 5,
|
495 |
+
boost_factor: float = 1.3
|
496 |
+
) -> List[Tuple[str, float]]:
|
497 |
+
"""
|
498 |
+
Retrieve top-k responses from the FAISS index (IndexFlatIP) given a user query.
|
499 |
|
500 |
+
Args:
|
501 |
+
query (str): The user input text.
|
502 |
+
domain (str, optional): The detected domain. Defaults to 'other'.
|
503 |
+
top_k (int, optional): Number of top results to return. Defaults to 5.
|
504 |
+
boost_factor (float, optional): Factor to boost scores for keyword matches. Defaults to 1.3.
|
505 |
|
506 |
+
Returns:
|
507 |
+
List[Tuple[str, float]]: List of (response_text, similarity) sorted by descending similarity.
|
508 |
+
"""
|
509 |
+
# Encode the query
|
510 |
+
q_emb = self.data_pipeline.encode_query(query)
|
511 |
+
q_emb_np = q_emb.reshape(1, -1).astype('float32')
|
512 |
+
|
513 |
+
# Search the index
|
514 |
+
distances, indices = self.data_pipeline.index.search(q_emb_np, top_k * 10) # Adjust multiplier as needed
|
515 |
+
|
516 |
+
# IndexFlatIP: 'distances' are inner products (cosine similarities for normalized vectors)
|
517 |
+
candidates = []
|
518 |
+
for rank, idx in enumerate(indices[0]):
|
519 |
+
if idx == -1:
|
520 |
+
continue # FAISS may return -1 for invalid indices
|
521 |
+
response = self.data_pipeline.response_pool[idx]
|
522 |
+
text = response.get('text', '')
|
523 |
+
cand_domain = response.get('domain', 'other')
|
524 |
+
score = distances[0][rank]
|
525 |
+
|
526 |
+
# Filter out numeric responses and very short texts
|
527 |
+
if not self.is_numeric_response(text) and len(text.split()) >= self.config.min_text_length:
|
528 |
+
candidates.append((text, cand_domain, score))
|
529 |
+
|
530 |
+
if not candidates:
|
531 |
+
logger.warning("No valid candidates found after initial filtering.")
|
532 |
+
return []
|
533 |
|
534 |
+
# Sort candidates by score descending
|
535 |
+
candidates.sort(key=lambda x: x[2], reverse=True)
|
|
|
|
|
536 |
|
537 |
+
# Filter in-domain responses
|
538 |
+
if domain != 'other':
|
539 |
+
in_domain_responses = [c for c in candidates if c[1] == domain]
|
540 |
+
if not in_domain_responses:
|
541 |
+
logger.info(f"No in-domain responses found for domain '{domain}'. Falling back to all candidates.")
|
542 |
+
in_domain_responses = candidates
|
543 |
+
else:
|
544 |
+
in_domain_responses = candidates
|
545 |
+
|
546 |
+
# Boost responses containing query keywords
|
547 |
+
query_keywords = self.extract_keywords(query)
|
548 |
+
boosted_responses = []
|
549 |
+
for resp_text, resp_domain, score in in_domain_responses:
|
550 |
+
if any(kw in resp_text.lower() for kw in query_keywords):
|
551 |
+
boosted_score = score * boost_factor
|
552 |
+
logger.debug(f"Boosting response: '{resp_text}' by factor {boost_factor}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
553 |
else:
|
554 |
+
boosted_score = score
|
555 |
+
boosted_responses.append((resp_text, boosted_score))
|
556 |
+
|
557 |
+
# Sort boosted responses
|
558 |
+
boosted_responses.sort(key=lambda x: x[1], reverse=True)
|
559 |
+
|
560 |
+
# Select top_k responses
|
561 |
+
top_responses = boosted_responses[:top_k]
|
562 |
+
logger.debug(f"Top {top_k} responses selected.")
|
563 |
+
|
564 |
return top_responses
|
565 |
+
# def retrieve_responses_faiss(
|
566 |
+
# self,
|
567 |
+
# query: str,
|
568 |
+
# domain: str = 'other',
|
569 |
+
# top_k: int = 5,
|
570 |
+
# boost_factor: float = 1.3
|
571 |
+
# ) -> List[Tuple[str, float]]:
|
572 |
+
# """
|
573 |
+
# Retrieve top-k responses from the FAISS index (IndexFlatIP) given a user query.
|
574 |
+
|
575 |
+
# Args:
|
576 |
+
# query: The user input text
|
577 |
+
# top_k: Number of top results to return
|
578 |
+
|
579 |
+
# Returns:
|
580 |
+
# List of (response_text, similarity) sorted by descending similarity
|
581 |
+
# """
|
582 |
# # Encode the query
|
583 |
+
# q_emb = self.data_pipeline.encode_query(query)
|
584 |
+
# q_emb_np = q_emb.reshape(1, -1).astype('float32')
|
585 |
+
|
586 |
+
# # Search the index
|
587 |
+
# distances, indices = self.data_pipeline.index.search(q_emb_np, top_k * 10) # distances: shape [1, k], indices: shape [1, k]
|
588 |
+
|
589 |
+
# # IndexFlatIP: 'distances' are cosine similarities in [-1, +1].
|
590 |
+
# candidates = []
|
591 |
+
# for rank, idx in enumerate(indices[0]):
|
592 |
+
# text = self.response_pool[idx]['text']
|
593 |
+
# cand_domain = self.response_pool[idx]['domain']
|
594 |
+
# score = distances[0][rank]
|
595 |
+
|
596 |
+
# # filter out responses with only numbers or too few words
|
597 |
+
# word_count = len(text.split())
|
598 |
+
# if not self.is_numeric_resonse(text) and word_count >= 2:
|
599 |
+
# candidates.append((text, cand_domain, score))
|
600 |
+
|
601 |
+
# # Filter to in-domain responses
|
602 |
+
# candidates.sort(key=lambda x: x[2], reverse=True)
|
603 |
+
# in_domain_responses = [(text, score) for (text, cand_domain, score) in candidates if cand_domain == domain]
|
604 |
|
605 |
+
# # Boost keyword matching responses
|
606 |
+
# query_keywords = self.extract_keywords(query)
|
607 |
+
# boosted_responses = []
|
608 |
+
# for (resp_text, domain, score) in in_domain_responses:
|
609 |
+
# # Check if any keyword is present in the response text
|
610 |
+
# for kw in query_keywords:
|
611 |
+
# if kw in resp_text.lower():
|
612 |
+
# boosted_score = score * boost_factor
|
613 |
+
# print(f"Boosting response: '{resp_text}' by factor {boost_factor}")
|
614 |
+
# break
|
615 |
# else:
|
616 |
+
# boosted_score = score
|
617 |
+
# boosted_responses.append((resp_text, domain, boosted_score))
|
618 |
+
|
619 |
+
# # Debug
|
620 |
+
# logger.debug("\nFAISS Search Results (top 15 for debug):")
|
621 |
+
# for i, (resp, score) in enumerate(boosted_responses[:15], start=1):
|
622 |
+
# logger.debug(f"{i}. Score: {score:.4f} -> {resp[:60]}")
|
623 |
+
|
624 |
+
# return boosted_responses[:top_k]
|
625 |
|
626 |
def chat(
|
627 |
self,
|
628 |
query: str,
|
629 |
conversation_history: Optional[List[Tuple[str, str]]] = None,
|
630 |
quality_checker: Optional['ResponseQualityChecker'] = None,
|
631 |
+
top_k: int = 10,
|
632 |
) -> Tuple[str, List[Tuple[str, float]], Dict[str, Any]]:
|
633 |
"""
|
634 |
Example chat method that always uses cross-encoder re-ranking
|
|
|
669 |
return results[0][0], results, {}
|
670 |
|
671 |
return get_response(self, query)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
672 |
|
673 |
def _build_conversation_context(
|
674 |
self,
|
|
|
688 |
|
689 |
conversation_parts.append(f"{self.tokenizer.additional_special_tokens[self.tokenizer.additional_special_tokens.index('<USER>')]} {query}")
|
690 |
return "\n".join(conversation_parts)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
691 |
|
692 |
def train_model(
|
693 |
self,
|
|
|
796 |
self.history = {'train_loss': [], 'val_loss': [], 'learning_rate': []}
|
797 |
|
798 |
if latest_checkpoint and not test_mode:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
799 |
# Add checkpoint inspection
|
800 |
+
logger.info(f"\nTrying to load checkpoint from: {latest_checkpoint}")
|
801 |
reader = tf.train.load_checkpoint(latest_checkpoint)
|
802 |
+
# shape_from_key = reader.get_variable_to_shape_map()
|
803 |
+
# dtype_from_key = reader.get_variable_to_dtype_map()
|
804 |
+
# logger.info("\nCheckpoint Variables:")
|
805 |
+
# for key in shape_from_key:
|
806 |
+
# logger.info(f"{key}: dtype={dtype_from_key[key]} - Shape: {shape_from_key[key]}")
|
807 |
|
808 |
status = checkpoint.restore(latest_checkpoint)
|
809 |
status.assert_consumed()
|
|
|
834 |
logger.info(f"Loaded previous training history from {history_path}")
|
835 |
except Exception as e:
|
836 |
logger.warning(f"Could not load history, starting fresh: {e}")
|
837 |
+
|
838 |
+
# Fix for custom weights not being saved in the full model.
|
839 |
+
self.save_models(Path(checkpoint_dir) / "pretrained_full_model")
|
840 |
+
logger.info(f"Manually saved custom weights after restore.")
|
841 |
else:
|
842 |
logger.info("Starting training from scratch")
|
843 |
checkpoint.epoch.assign(tf.cast(0, tf.int32))
|
chatbot_validator.py
CHANGED
@@ -1,30 +1,41 @@
|
|
1 |
from typing import Dict, List, Tuple, Any, Optional
|
2 |
import numpy as np
|
3 |
-
|
4 |
from logger_config import config_logger
|
|
|
|
|
5 |
logger = config_logger(__name__)
|
6 |
|
|
|
7 |
class ChatbotValidator:
|
8 |
-
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
9 |
|
10 |
def __init__(self, chatbot, quality_checker):
|
11 |
"""
|
12 |
Initialize the validator.
|
13 |
|
14 |
Args:
|
15 |
-
chatbot: RetrievalChatbot instance
|
16 |
quality_checker: ResponseQualityChecker instance
|
17 |
"""
|
18 |
self.chatbot = chatbot
|
19 |
self.quality_checker = quality_checker
|
20 |
|
21 |
-
#
|
|
|
22 |
self.domain_queries = {
|
23 |
'restaurant': [
|
24 |
"I'd like to make a reservation for dinner tonight.",
|
25 |
-
"Can you book a table for 4
|
26 |
-
"
|
27 |
-
"I
|
28 |
"What's the wait time for a table right now?"
|
29 |
],
|
30 |
'movie_tickets': [
|
@@ -38,8 +49,8 @@ class ChatbotValidator:
|
|
38 |
"I need a ride from the airport to downtown.",
|
39 |
"How much would it cost to get to the mall?",
|
40 |
"Can you book a car for tomorrow morning?",
|
41 |
-
"Is there a driver available now?",
|
42 |
-
"What's the estimated arrival time?"
|
43 |
],
|
44 |
'services': [
|
45 |
"I need to schedule an oil change for my car.",
|
@@ -61,7 +72,9 @@ class ChatbotValidator:
|
|
61 |
self,
|
62 |
num_examples: int = 5,
|
63 |
top_k: int = 10,
|
64 |
-
domains: Optional[List[str]] = None
|
|
|
|
|
65 |
) -> Dict[str, Any]:
|
66 |
"""
|
67 |
Run comprehensive validation across specified domains.
|
@@ -69,36 +82,55 @@ class ChatbotValidator:
|
|
69 |
Args:
|
70 |
num_examples: Number of test queries per domain
|
71 |
top_k: Number of responses to retrieve for each query
|
72 |
-
domains: Optional list of
|
|
|
|
|
73 |
|
74 |
Returns:
|
75 |
Dict containing detailed validation metrics and domain-specific performance
|
76 |
"""
|
77 |
logger.info("\n=== Running Enhanced Automatic Validation ===")
|
78 |
|
79 |
-
# Select domains to test
|
80 |
test_domains = domains if domains else list(self.domain_queries.keys())
|
|
|
|
|
81 |
metrics_history = []
|
82 |
domain_metrics = {}
|
|
|
|
|
|
|
|
|
|
|
83 |
|
84 |
# Run validation for each domain
|
85 |
for domain in test_domains:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
86 |
domain_metrics[domain] = []
|
87 |
-
queries = self.domain_queries[domain][:num_examples]
|
88 |
|
89 |
logger.info(f"\n=== Testing {domain.title()} Domain ===")
|
90 |
|
91 |
for i, query in enumerate(queries, 1):
|
92 |
-
logger.info(f"\nTest Case {i}:")
|
93 |
-
logger.info(f"Query: {query}")
|
94 |
|
95 |
-
#
|
96 |
-
responses = self.chatbot.retrieve_responses_cross_encoder(query, top_k=top_k)
|
97 |
|
98 |
-
#
|
99 |
quality_metrics = self.quality_checker.check_response_quality(query, responses)
|
100 |
|
101 |
-
#
|
102 |
quality_metrics['domain'] = domain
|
103 |
metrics_history.append(quality_metrics)
|
104 |
domain_metrics[domain].append(quality_metrics)
|
@@ -106,11 +138,12 @@ class ChatbotValidator:
|
|
106 |
# Detailed logging
|
107 |
self._log_validation_results(query, responses, quality_metrics, i)
|
108 |
|
109 |
-
#
|
110 |
aggregate_metrics = self._calculate_aggregate_metrics(metrics_history)
|
111 |
domain_analysis = self._analyze_domain_performance(domain_metrics)
|
112 |
confidence_analysis = self._analyze_confidence_distribution(metrics_history)
|
113 |
|
|
|
114 |
aggregate_metrics.update({
|
115 |
'domain_performance': domain_analysis,
|
116 |
'confidence_analysis': confidence_analysis
|
@@ -120,48 +153,74 @@ class ChatbotValidator:
|
|
120 |
return aggregate_metrics
|
121 |
|
122 |
def _calculate_aggregate_metrics(self, metrics_history: List[Dict]) -> Dict[str, float]:
|
123 |
-
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
124 |
metrics = {
|
125 |
'num_queries_tested': len(metrics_history),
|
126 |
-
'avg_top_response_score': np.mean(
|
127 |
-
'avg_diversity': np.mean([m.get('response_diversity', 0) for m in metrics_history]),
|
128 |
-
'avg_relevance': np.mean([m.get('query_response_relevance', 0) for m in metrics_history]),
|
129 |
-
'avg_length_score': np.mean([m.get('response_length_score', 0) for m in metrics_history]),
|
130 |
-
'avg_score_gap': np.mean([m.get('top_3_score_gap', 0) for m in metrics_history]),
|
131 |
-
'confidence_rate': np.mean([m.get('is_confident', False)
|
|
|
132 |
|
133 |
# Additional statistical metrics
|
134 |
-
'median_top_score': np.median(
|
135 |
-
'score_std': np.std(
|
136 |
-
'min_score': np.min(
|
137 |
-
'max_score': np.max(
|
138 |
}
|
139 |
return metrics
|
140 |
|
141 |
-
def _analyze_domain_performance(self, domain_metrics: Dict[str, List[Dict]]) -> Dict[str, Dict]:
|
142 |
-
"""
|
143 |
-
|
144 |
-
|
145 |
-
|
146 |
-
|
147 |
-
|
148 |
-
|
149 |
-
|
150 |
-
|
151 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
152 |
}
|
153 |
|
154 |
-
return
|
155 |
|
156 |
def _analyze_confidence_distribution(self, metrics_history: List[Dict]) -> Dict[str, float]:
|
157 |
-
"""
|
158 |
-
|
|
|
|
|
|
|
|
|
159 |
|
|
|
160 |
return {
|
161 |
-
'percentile_25': np.percentile(scores, 25),
|
162 |
-
'percentile_50': np.percentile(scores, 50),
|
163 |
-
'percentile_75': np.percentile(scores, 75),
|
164 |
-
'percentile_90': np.percentile(scores, 90)
|
165 |
}
|
166 |
|
167 |
def _log_validation_results(
|
@@ -171,37 +230,51 @@ class ChatbotValidator:
|
|
171 |
metrics: Dict[str, Any],
|
172 |
case_num: int
|
173 |
):
|
174 |
-
"""
|
175 |
-
|
176 |
-
|
177 |
-
|
178 |
-
|
179 |
-
|
180 |
-
|
181 |
-
|
182 |
-
|
|
|
|
|
183 |
|
184 |
-
logger.info("
|
185 |
-
for i, (
|
186 |
-
logger.info(f"{i}
|
187 |
-
if i == 1 and not
|
188 |
-
logger.info(" [Low Confidence]")
|
189 |
|
190 |
def _log_validation_summary(self, metrics: Dict[str, Any]):
|
191 |
-
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
192 |
logger.info("\n=== Validation Summary ===")
|
193 |
|
|
|
194 |
logger.info("\nOverall Metrics:")
|
195 |
for metric, value in metrics.items():
|
|
|
196 |
if isinstance(value, (int, float)):
|
197 |
logger.info(f"{metric}: {value:.4f}")
|
198 |
|
|
|
|
|
199 |
logger.info("\nDomain Performance:")
|
200 |
-
for domain,
|
201 |
logger.info(f"\n{domain.title()}:")
|
202 |
-
for metric, value in
|
203 |
logger.info(f" {metric}: {value:.4f}")
|
204 |
|
|
|
|
|
205 |
logger.info("\nConfidence Distribution:")
|
206 |
-
for
|
207 |
-
logger.info(f"{
|
|
|
1 |
from typing import Dict, List, Tuple, Any, Optional
|
2 |
import numpy as np
|
3 |
+
import random
|
4 |
from logger_config import config_logger
|
5 |
+
from cross_encoder_reranker import CrossEncoderReranker
|
6 |
+
|
7 |
logger = config_logger(__name__)
|
8 |
|
9 |
+
|
10 |
class ChatbotValidator:
|
11 |
+
"""
|
12 |
+
Handles automated validation and performance analysis for the chatbot.
|
13 |
+
|
14 |
+
This validator executes domain-specific test queries, obtains candidate
|
15 |
+
responses via the chatbot, then evaluates them with a quality checker.
|
16 |
+
It aggregates metrics across queries and domains, logs intermediate
|
17 |
+
results, and returns a comprehensive summary.
|
18 |
+
"""
|
19 |
|
20 |
def __init__(self, chatbot, quality_checker):
|
21 |
"""
|
22 |
Initialize the validator.
|
23 |
|
24 |
Args:
|
25 |
+
chatbot: RetrievalChatbot instance for inference
|
26 |
quality_checker: ResponseQualityChecker instance
|
27 |
"""
|
28 |
self.chatbot = chatbot
|
29 |
self.quality_checker = quality_checker
|
30 |
|
31 |
+
# Basic domain-specific test queries (easy examples)
|
32 |
+
# Taskmaster-1 and Schema-Guided style
|
33 |
self.domain_queries = {
|
34 |
'restaurant': [
|
35 |
"I'd like to make a reservation for dinner tonight.",
|
36 |
+
"Can you book a table for 4 at an Italian restaurant?",
|
37 |
+
"Is there any availability to dine tomorrow at 7pm?",
|
38 |
+
"I'd like to cancel my reservation for tonight.",
|
39 |
"What's the wait time for a table right now?"
|
40 |
],
|
41 |
'movie_tickets': [
|
|
|
49 |
"I need a ride from the airport to downtown.",
|
50 |
"How much would it cost to get to the mall?",
|
51 |
"Can you book a car for tomorrow morning?",
|
52 |
+
"Is there a driver available right now?",
|
53 |
+
"What's the estimated arrival time for the driver?"
|
54 |
],
|
55 |
'services': [
|
56 |
"I need to schedule an oil change for my car.",
|
|
|
72 |
self,
|
73 |
num_examples: int = 5,
|
74 |
top_k: int = 10,
|
75 |
+
domains: Optional[List[str]] = None,
|
76 |
+
randomize: bool = False,
|
77 |
+
seed: int = 42
|
78 |
) -> Dict[str, Any]:
|
79 |
"""
|
80 |
Run comprehensive validation across specified domains.
|
|
|
82 |
Args:
|
83 |
num_examples: Number of test queries per domain
|
84 |
top_k: Number of responses to retrieve for each query
|
85 |
+
domains: Optional list of domain keys to test. If None, test all.
|
86 |
+
randomize: If True, randomly select queries from the domain lists
|
87 |
+
seed: Random seed for consistent sampling if randomize=True
|
88 |
|
89 |
Returns:
|
90 |
Dict containing detailed validation metrics and domain-specific performance
|
91 |
"""
|
92 |
logger.info("\n=== Running Enhanced Automatic Validation ===")
|
93 |
|
94 |
+
# Select which domains to test
|
95 |
test_domains = domains if domains else list(self.domain_queries.keys())
|
96 |
+
|
97 |
+
# Initialize results
|
98 |
metrics_history = []
|
99 |
domain_metrics = {}
|
100 |
+
|
101 |
+
reranker = CrossEncoderReranker(model_name="cross-encoder/ms-marco-MiniLM-L-12-v2")
|
102 |
+
|
103 |
+
# Prepare random selection if needed
|
104 |
+
rng = random.Random(seed)
|
105 |
|
106 |
# Run validation for each domain
|
107 |
for domain in test_domains:
|
108 |
+
# Avoid errors if domain key missing
|
109 |
+
if domain not in self.domain_queries:
|
110 |
+
logger.warning(f"Domain '{domain}' not found in domain_queries. Skipping.")
|
111 |
+
continue
|
112 |
+
|
113 |
+
all_queries = self.domain_queries[domain]
|
114 |
+
if randomize:
|
115 |
+
queries = rng.sample(all_queries, min(num_examples, len(all_queries)))
|
116 |
+
else:
|
117 |
+
queries = all_queries[:num_examples]
|
118 |
+
|
119 |
+
# Store domain-level metrics
|
120 |
domain_metrics[domain] = []
|
|
|
121 |
|
122 |
logger.info(f"\n=== Testing {domain.title()} Domain ===")
|
123 |
|
124 |
for i, query in enumerate(queries, 1):
|
125 |
+
logger.info(f"\nTest Case {i}: {query}")
|
|
|
126 |
|
127 |
+
# Retrieve top_k responses (including cross-encoder re-ranking if available)
|
128 |
+
responses = self.chatbot.retrieve_responses_cross_encoder(query, top_k=top_k, reranker=reranker)
|
129 |
|
130 |
+
# Evaluate with quality checker
|
131 |
quality_metrics = self.quality_checker.check_response_quality(query, responses)
|
132 |
|
133 |
+
# Save domain info
|
134 |
quality_metrics['domain'] = domain
|
135 |
metrics_history.append(quality_metrics)
|
136 |
domain_metrics[domain].append(quality_metrics)
|
|
|
138 |
# Detailed logging
|
139 |
self._log_validation_results(query, responses, quality_metrics, i)
|
140 |
|
141 |
+
# Final aggregation
|
142 |
aggregate_metrics = self._calculate_aggregate_metrics(metrics_history)
|
143 |
domain_analysis = self._analyze_domain_performance(domain_metrics)
|
144 |
confidence_analysis = self._analyze_confidence_distribution(metrics_history)
|
145 |
|
146 |
+
# Combine into one dictionary
|
147 |
aggregate_metrics.update({
|
148 |
'domain_performance': domain_analysis,
|
149 |
'confidence_analysis': confidence_analysis
|
|
|
153 |
return aggregate_metrics
|
154 |
|
155 |
def _calculate_aggregate_metrics(self, metrics_history: List[Dict]) -> Dict[str, float]:
|
156 |
+
"""
|
157 |
+
Calculate comprehensive aggregate metrics over all tested queries.
|
158 |
+
"""
|
159 |
+
if not metrics_history:
|
160 |
+
logger.warning("No metrics to aggregate. Returning empty summary.")
|
161 |
+
return {}
|
162 |
+
|
163 |
+
top_scores = [m.get('top_score', 0.0) for m in metrics_history]
|
164 |
+
|
165 |
+
# The length-based metrics are robust to missing or zero-length data
|
166 |
metrics = {
|
167 |
'num_queries_tested': len(metrics_history),
|
168 |
+
'avg_top_response_score': np.mean(top_scores),
|
169 |
+
'avg_diversity': np.mean([m.get('response_diversity', 0.0) for m in metrics_history]),
|
170 |
+
'avg_relevance': np.mean([m.get('query_response_relevance', 0.0) for m in metrics_history]),
|
171 |
+
'avg_length_score': np.mean([m.get('response_length_score', 0.0) for m in metrics_history]),
|
172 |
+
'avg_score_gap': np.mean([m.get('top_3_score_gap', 0.0) for m in metrics_history]),
|
173 |
+
'confidence_rate': np.mean([1.0 if m.get('is_confident', False) else 0.0
|
174 |
+
for m in metrics_history]),
|
175 |
|
176 |
# Additional statistical metrics
|
177 |
+
'median_top_score': np.median(top_scores),
|
178 |
+
'score_std': np.std(top_scores),
|
179 |
+
'min_score': np.min(top_scores),
|
180 |
+
'max_score': np.max(top_scores)
|
181 |
}
|
182 |
return metrics
|
183 |
|
184 |
+
def _analyze_domain_performance(self, domain_metrics: Dict[str, List[Dict]]) -> Dict[str, Dict[str, float]]:
|
185 |
+
"""
|
186 |
+
Analyze performance by domain, returning a nested dict.
|
187 |
+
"""
|
188 |
+
analysis = {}
|
189 |
+
|
190 |
+
for domain, metrics_list in domain_metrics.items():
|
191 |
+
if not metrics_list:
|
192 |
+
analysis[domain] = {}
|
193 |
+
continue
|
194 |
+
|
195 |
+
top_scores = [m.get('top_score', 0.0) for m in metrics_list]
|
196 |
+
|
197 |
+
analysis[domain] = {
|
198 |
+
'confidence_rate': np.mean([1.0 if m.get('is_confident', False) else 0.0
|
199 |
+
for m in metrics_list]),
|
200 |
+
'avg_relevance': np.mean([m.get('query_response_relevance', 0.0)
|
201 |
+
for m in metrics_list]),
|
202 |
+
'avg_diversity': np.mean([m.get('response_diversity', 0.0)
|
203 |
+
for m in metrics_list]),
|
204 |
+
'avg_top_score': np.mean(top_scores),
|
205 |
+
'num_samples': len(metrics_list)
|
206 |
}
|
207 |
|
208 |
+
return analysis
|
209 |
|
210 |
def _analyze_confidence_distribution(self, metrics_history: List[Dict]) -> Dict[str, float]:
|
211 |
+
"""
|
212 |
+
Analyze the distribution of top scores to gauge system confidence levels.
|
213 |
+
"""
|
214 |
+
if not metrics_history:
|
215 |
+
return {'percentile_25': 0.0, 'percentile_50': 0.0,
|
216 |
+
'percentile_75': 0.0, 'percentile_90': 0.0}
|
217 |
|
218 |
+
scores = [m.get('top_score', 0.0) for m in metrics_history]
|
219 |
return {
|
220 |
+
'percentile_25': float(np.percentile(scores, 25)),
|
221 |
+
'percentile_50': float(np.percentile(scores, 50)),
|
222 |
+
'percentile_75': float(np.percentile(scores, 75)),
|
223 |
+
'percentile_90': float(np.percentile(scores, 90))
|
224 |
}
|
225 |
|
226 |
def _log_validation_results(
|
|
|
230 |
metrics: Dict[str, Any],
|
231 |
case_num: int
|
232 |
):
|
233 |
+
"""
|
234 |
+
Log detailed validation results for each test case.
|
235 |
+
"""
|
236 |
+
domain = metrics.get('domain', 'Unknown')
|
237 |
+
is_confident = metrics.get('is_confident', False)
|
238 |
+
|
239 |
+
logger.info(f"Domain: {domain} | Confidence: {'Yes' if is_confident else 'No'}")
|
240 |
+
logger.info("Quality Metrics:")
|
241 |
+
for k, v in metrics.items():
|
242 |
+
if isinstance(v, (int, float)):
|
243 |
+
logger.info(f" {k}: {v:.4f}")
|
244 |
|
245 |
+
logger.info("Top 3 Responses:")
|
246 |
+
for i, (resp_text, score) in enumerate(responses[:3], 1):
|
247 |
+
logger.info(f"{i}) Score: {score:.4f} | {resp_text}")
|
248 |
+
if i == 1 and not is_confident:
|
249 |
+
logger.info(" [Low Confidence on Top Response]")
|
250 |
|
251 |
def _log_validation_summary(self, metrics: Dict[str, Any]):
|
252 |
+
"""
|
253 |
+
Log a summary of all validation metrics and domain performance.
|
254 |
+
"""
|
255 |
+
if not metrics:
|
256 |
+
logger.info("No metrics to summarize.")
|
257 |
+
return
|
258 |
+
|
259 |
logger.info("\n=== Validation Summary ===")
|
260 |
|
261 |
+
# Overall
|
262 |
logger.info("\nOverall Metrics:")
|
263 |
for metric, value in metrics.items():
|
264 |
+
# Skip sub-dicts here
|
265 |
if isinstance(value, (int, float)):
|
266 |
logger.info(f"{metric}: {value:.4f}")
|
267 |
|
268 |
+
# Domain performance
|
269 |
+
domain_perf = metrics.get('domain_performance', {})
|
270 |
logger.info("\nDomain Performance:")
|
271 |
+
for domain, domain_stats in domain_perf.items():
|
272 |
logger.info(f"\n{domain.title()}:")
|
273 |
+
for metric, value in domain_stats.items():
|
274 |
logger.info(f" {metric}: {value:.4f}")
|
275 |
|
276 |
+
# Confidence distribution
|
277 |
+
conf_analysis = metrics.get('confidence_analysis', {})
|
278 |
logger.info("\nConfidence Distribution:")
|
279 |
+
for pct, val in conf_analysis.items():
|
280 |
+
logger.info(f" {pct}: {val:.4f}")
|
conversation_summarizer.py
CHANGED
@@ -9,7 +9,7 @@ logger = logging.getLogger(__name__)
|
|
9 |
@dataclass
|
10 |
class ChatConfig:
|
11 |
max_sequence_length: int = 512
|
12 |
-
default_top_k: int =
|
13 |
chunk_size: int = 512
|
14 |
chunk_overlap: int = 256
|
15 |
min_confidence_score: float = 0.7
|
|
|
9 |
@dataclass
|
10 |
class ChatConfig:
|
11 |
max_sequence_length: int = 512
|
12 |
+
default_top_k: int = 10
|
13 |
chunk_size: int = 512
|
14 |
chunk_overlap: int = 256
|
15 |
min_confidence_score: float = 0.7
|
cross_encoder_reranker.py
CHANGED
@@ -1,19 +1,28 @@
|
|
1 |
from transformers import AutoTokenizer, TFAutoModelForSequenceClassification
|
2 |
import tensorflow as tf
|
3 |
-
from typing import List
|
|
|
4 |
|
5 |
from logger_config import config_logger
|
6 |
logger = config_logger(__name__)
|
7 |
|
8 |
class CrossEncoderReranker:
|
9 |
"""
|
10 |
-
Cross-Encoder Re-Ranker
|
11 |
-
outputs a single relevance score
|
12 |
"""
|
13 |
def __init__(self, model_name: str = "cross-encoder/ms-marco-MiniLM-L-12-v2"):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
14 |
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
|
15 |
self.model = TFAutoModelForSequenceClassification.from_pretrained(model_name)
|
16 |
-
|
17 |
|
18 |
def rerank(
|
19 |
self,
|
@@ -22,13 +31,21 @@ class CrossEncoderReranker:
|
|
22 |
max_length: int = 256
|
23 |
) -> List[float]:
|
24 |
"""
|
25 |
-
|
26 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
27 |
"""
|
28 |
-
# Build (query, candidate) pairs
|
29 |
pair_texts = [(query, candidate) for candidate in candidates]
|
30 |
|
31 |
-
# Tokenize the entire batch
|
32 |
encodings = self.tokenizer(
|
33 |
pair_texts,
|
34 |
padding=True,
|
@@ -37,15 +54,24 @@ class CrossEncoderReranker:
|
|
37 |
return_tensors="tf"
|
38 |
)
|
39 |
|
40 |
-
# Forward pass -> logits shape [batch_size, 1]
|
41 |
outputs = self.model(
|
42 |
input_ids=encodings["input_ids"],
|
43 |
attention_mask=encodings["attention_mask"],
|
44 |
-
token_type_ids=encodings.get("token_type_ids")
|
45 |
)
|
46 |
|
47 |
-
logits = outputs.logits
|
48 |
-
#
|
49 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
50 |
|
51 |
return scores.tolist()
|
|
|
1 |
from transformers import AutoTokenizer, TFAutoModelForSequenceClassification
|
2 |
import tensorflow as tf
|
3 |
+
from typing import List
|
4 |
+
import numpy as np
|
5 |
|
6 |
from logger_config import config_logger
|
7 |
logger = config_logger(__name__)
|
8 |
|
9 |
class CrossEncoderReranker:
|
10 |
"""
|
11 |
+
Cross-Encoder Re-Ranker that takes (query, candidate) pairs,
|
12 |
+
outputs a single relevance score in [0,1].
|
13 |
"""
|
14 |
def __init__(self, model_name: str = "cross-encoder/ms-marco-MiniLM-L-12-v2"):
|
15 |
+
"""
|
16 |
+
Initialize the cross-encoder with a pretrained model.
|
17 |
+
|
18 |
+
Args:
|
19 |
+
model_name: Name of a HF cross-encoder model. Must be
|
20 |
+
compatible with TFAutoModelForSequenceClassification.
|
21 |
+
"""
|
22 |
+
logger.info(f"Initializing CrossEncoderReranker with {model_name}...")
|
23 |
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
|
24 |
self.model = TFAutoModelForSequenceClassification.from_pretrained(model_name)
|
25 |
+
logger.info("Cross encoder model loaded successfully.")
|
26 |
|
27 |
def rerank(
|
28 |
self,
|
|
|
31 |
max_length: int = 256
|
32 |
) -> List[float]:
|
33 |
"""
|
34 |
+
Compute relevance scores for each candidate w.r.t. the query.
|
35 |
+
|
36 |
+
Args:
|
37 |
+
query: User's query text.
|
38 |
+
candidates: List of candidate response texts.
|
39 |
+
max_length: Max token length for each (query, candidate) pair.
|
40 |
+
|
41 |
+
Returns:
|
42 |
+
A list of float scores in [0,1], one per candidate,
|
43 |
+
indicating model's predicted relevance.
|
44 |
"""
|
45 |
+
# 1) Build (query, candidate) pairs
|
46 |
pair_texts = [(query, candidate) for candidate in candidates]
|
47 |
|
48 |
+
# 2) Tokenize the entire batch
|
49 |
encodings = self.tokenizer(
|
50 |
pair_texts,
|
51 |
padding=True,
|
|
|
54 |
return_tensors="tf"
|
55 |
)
|
56 |
|
57 |
+
# 3) Forward pass -> logits shape [batch_size, 1]
|
58 |
outputs = self.model(
|
59 |
input_ids=encodings["input_ids"],
|
60 |
attention_mask=encodings["attention_mask"],
|
61 |
+
token_type_ids=encodings.get("token_type_ids") # Some models need token_type_ids
|
62 |
)
|
63 |
|
64 |
+
logits = outputs.logits # shape [batch_size, 1]
|
65 |
+
# 4) Convert logits -> [0,1] range via sigmoid
|
66 |
+
# If the cross-encoder is a single-logit regression to [0,1],
|
67 |
+
# this is a typical interpretation.
|
68 |
+
scores = tf.nn.sigmoid(logits) # shape [batch_size, 1]
|
69 |
+
|
70 |
+
# 5) Flatten to a 1D NumPy array of floats
|
71 |
+
scores = tf.reshape(scores, [-1])
|
72 |
+
scores = scores.numpy().astype(float)
|
73 |
+
|
74 |
+
# logger.debug(f"Cross-Encoder raw logits: {logits.numpy().flatten().tolist()}")
|
75 |
+
# logger.debug(f"Cross-Encoder sigmoid scores: {scores.tolist()}")
|
76 |
|
77 |
return scores.tolist()
|
new_iteration/pipeline_config.py
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from dataclasses import dataclass
|
2 |
+
|
3 |
+
@dataclass
|
4 |
+
class PipelineConfig:
|
5 |
+
"""Minimal pipeline config."""
|
6 |
+
max_length: int = 512 # max length if you want to skip long utterances
|
7 |
+
min_turns: int = 4 # minimum total turns (user + assistant)
|
8 |
+
min_user_words: int = 3 # min words in each user turn
|
9 |
+
debug: bool = True # enable debug prints
|
new_iteration/run_taskmaster_processor.py
ADDED
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
from datetime import datetime
|
3 |
+
from pathlib import Path
|
4 |
+
|
5 |
+
from pipeline_config import PipelineConfig
|
6 |
+
from taskmaster_processor import TaskmasterProcessor
|
7 |
+
|
8 |
+
def main():
|
9 |
+
# 1) Setup config
|
10 |
+
config = PipelineConfig(
|
11 |
+
max_length=512,
|
12 |
+
min_turns=3,
|
13 |
+
min_user_words=3,
|
14 |
+
debug=True
|
15 |
+
)
|
16 |
+
|
17 |
+
# 2) Instantiate processor
|
18 |
+
base_dir = "datasets/taskmaster"
|
19 |
+
processor = TaskmasterProcessor(config)
|
20 |
+
|
21 |
+
# 3) Load raw dialogues
|
22 |
+
dialogues = processor.load_taskmaster_dataset(base_dir=base_dir, max_examples=None)
|
23 |
+
|
24 |
+
# 4) Filter & convert to final structure
|
25 |
+
final_dialogues = processor.filter_and_convert(dialogues)
|
26 |
+
|
27 |
+
# 5) Save final data
|
28 |
+
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
29 |
+
output_dir = Path("processed_outputs")
|
30 |
+
output_dir.mkdir(parents=True, exist_ok=True)
|
31 |
+
out_file = output_dir / f"taskmaster_only_{timestamp}.json"
|
32 |
+
|
33 |
+
with open(out_file, 'w', encoding='utf-8') as f:
|
34 |
+
json.dump(final_dialogues, f, indent=2)
|
35 |
+
|
36 |
+
print(f"[Taskmaster Only] Kept {len(final_dialogues)} dialogues => {out_file}")
|
37 |
+
|
38 |
+
if __name__ == "__main__":
|
39 |
+
main()
|
new_iteration/taskmaster_processor.py
ADDED
@@ -0,0 +1,177 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
import re
|
3 |
+
from pathlib import Path
|
4 |
+
from typing import List, Dict, Any, Optional
|
5 |
+
from dataclasses import dataclass, field
|
6 |
+
|
7 |
+
from pipeline_config import PipelineConfig
|
8 |
+
|
9 |
+
@dataclass
|
10 |
+
class TaskmasterDialogue:
|
11 |
+
"""Structured representation of a Taskmaster-1 dialogue."""
|
12 |
+
conversation_id: str
|
13 |
+
instruction_id: Optional[str]
|
14 |
+
scenario: Optional[str]
|
15 |
+
domain: str
|
16 |
+
turns: List[Dict[str, Any]] = field(default_factory=list)
|
17 |
+
|
18 |
+
def validate(self) -> bool:
|
19 |
+
"""Check if this dialogue has an ID and a list of turns."""
|
20 |
+
return bool(self.conversation_id and isinstance(self.turns, list))
|
21 |
+
|
22 |
+
class TaskmasterProcessor:
|
23 |
+
"""
|
24 |
+
Loads Taskmaster-1 dialogues, extracts domain from scenario,
|
25 |
+
filters them, and outputs a final pipeline-friendly format.
|
26 |
+
"""
|
27 |
+
def __init__(self, config: PipelineConfig):
|
28 |
+
self.config = config
|
29 |
+
|
30 |
+
def load_taskmaster_dataset(self, base_dir: str, max_examples: Optional[int] = None) -> List[TaskmasterDialogue]:
|
31 |
+
"""
|
32 |
+
Load and parse Taskmaster JSON for self-dialogs & woz-dialogs (Taskmaster-1).
|
33 |
+
Combines scenario text + conversation utterances to detect domain more robustly.
|
34 |
+
"""
|
35 |
+
required_files = {
|
36 |
+
"self-dialogs": "self-dialogs.json",
|
37 |
+
"woz-dialogs": "woz-dialogs.json",
|
38 |
+
"ontology": "ontology.json", # we might not actively use this, but let's expect it
|
39 |
+
}
|
40 |
+
# Check for missing
|
41 |
+
missing = [k for k, v in required_files.items() if not Path(base_dir, v).exists()]
|
42 |
+
if missing:
|
43 |
+
raise FileNotFoundError(f"Missing Taskmaster files: {missing}")
|
44 |
+
|
45 |
+
# Load ontology (optional usage)
|
46 |
+
ontology_path = Path(base_dir, required_files["ontology"])
|
47 |
+
with open(ontology_path, 'r', encoding='utf-8') as f:
|
48 |
+
ontology = json.load(f)
|
49 |
+
if self.config.debug:
|
50 |
+
print(f"[TaskmasterProcessor] Loaded ontology with {len(ontology.keys())} top-level keys (unused).")
|
51 |
+
|
52 |
+
dialogues: List[TaskmasterDialogue] = []
|
53 |
+
|
54 |
+
# We'll read the 2 main files
|
55 |
+
file_keys = ["self-dialogs", "woz-dialogs"]
|
56 |
+
for file_key in file_keys:
|
57 |
+
file_path = Path(base_dir, required_files[file_key])
|
58 |
+
with open(file_path, 'r', encoding='utf-8') as f:
|
59 |
+
raw_data = json.load(f)
|
60 |
+
|
61 |
+
for d in raw_data:
|
62 |
+
conversation_id = d.get("conversation_id", "")
|
63 |
+
instruction_id = d.get("instruction_id", None)
|
64 |
+
scenario_text = d.get("scenario", "") # old scenario approach
|
65 |
+
|
66 |
+
# Collect utterances -> turns
|
67 |
+
utterances = d.get("utterances", [])
|
68 |
+
turns = self._process_utterances(utterances)
|
69 |
+
|
70 |
+
# Instead of only using scenario_text, we combine scenario + turn texts.
|
71 |
+
# We'll pass everything to _extract_domain
|
72 |
+
domain = self._extract_domain(
|
73 |
+
scenario_text,
|
74 |
+
turns # pass the entire turn list so we can pick up domain keywords
|
75 |
+
)
|
76 |
+
|
77 |
+
# Create a structured object
|
78 |
+
new_dlg = TaskmasterDialogue(
|
79 |
+
conversation_id=conversation_id,
|
80 |
+
instruction_id=instruction_id,
|
81 |
+
scenario=scenario_text,
|
82 |
+
domain=domain,
|
83 |
+
turns=turns
|
84 |
+
)
|
85 |
+
dialogues.append(new_dlg)
|
86 |
+
|
87 |
+
if max_examples and len(dialogues) >= max_examples:
|
88 |
+
break
|
89 |
+
|
90 |
+
if self.config.debug:
|
91 |
+
print(f"[TaskmasterProcessor] Loaded {len(dialogues)} total dialogues from Taskmaster-1.")
|
92 |
+
return dialogues
|
93 |
+
|
94 |
+
def _extract_domain(self, scenario: str, turns: List[Dict[str, str]]) -> str:
|
95 |
+
"""
|
96 |
+
Combine scenario text + all turn texts to detect the domain more robustly.
|
97 |
+
"""
|
98 |
+
# 1) Combine scenario + conversation text
|
99 |
+
combined_text = scenario.lower()
|
100 |
+
for turn in turns:
|
101 |
+
text = turn.get('text', '').strip().lower()
|
102 |
+
combined_text += " " + text
|
103 |
+
|
104 |
+
# 2) Expanded domain patterns (edit or expand as you wish)
|
105 |
+
domain_patterns = {
|
106 |
+
'restaurant': r'\b(restaurant|dining|food|reservation|table|menu|cuisine|eat)\b',
|
107 |
+
'movie': r'\b(movie|cinema|film|ticket|showtime|theater)\b',
|
108 |
+
'ride_share': r'\b(ride|taxi|uber|lyft|car\s?service|pickup|dropoff)\b',
|
109 |
+
'coffee': r'\b(coffee|café|cafe|starbucks|espresso|latte|mocha|americano)\b',
|
110 |
+
'pizza': r'\b(pizza|delivery|order\s?food|pepperoni|topping|pizzeria)\b',
|
111 |
+
'auto': r'\b(car|vehicle|repair|maintenance|mechanic|oil\s?change)\b'
|
112 |
+
}
|
113 |
+
|
114 |
+
# 3) Return first matched domain or 'other'
|
115 |
+
for dom, pattern in domain_patterns.items():
|
116 |
+
if re.search(pattern, combined_text):
|
117 |
+
print(f"Matched domain: {dom}")
|
118 |
+
return dom
|
119 |
+
|
120 |
+
print("No domain match, returning 'other'")
|
121 |
+
return 'other'
|
122 |
+
|
123 |
+
def _process_utterances(self, utterances: List[Dict[str, Any]]) -> List[Dict[str, str]]:
|
124 |
+
"""Map speaker to user/assistant, store text."""
|
125 |
+
turns = []
|
126 |
+
for utt in utterances:
|
127 |
+
speaker = 'assistant' if utt.get('speaker') == 'ASSISTANT' else 'user'
|
128 |
+
text = utt.get('text', '').strip()
|
129 |
+
turns.append({
|
130 |
+
'speaker': speaker,
|
131 |
+
'text': text
|
132 |
+
})
|
133 |
+
return turns
|
134 |
+
|
135 |
+
def filter_and_convert(self, dialogues: List[TaskmasterDialogue]) -> List[Dict]:
|
136 |
+
"""
|
137 |
+
Filter out dialogues that don't meet min turns / min user words,
|
138 |
+
then convert them to final pipeline dict:
|
139 |
+
|
140 |
+
{
|
141 |
+
"dialogue_id": "...",
|
142 |
+
"domain": "...",
|
143 |
+
"turns": [
|
144 |
+
{"speaker": "user", "text": "..."},
|
145 |
+
...
|
146 |
+
]
|
147 |
+
}
|
148 |
+
"""
|
149 |
+
results = []
|
150 |
+
for dlg in dialogues:
|
151 |
+
if not dlg.validate():
|
152 |
+
continue
|
153 |
+
|
154 |
+
if len(dlg.turns) < self.config.min_turns:
|
155 |
+
continue
|
156 |
+
|
157 |
+
# Check user-turn min words
|
158 |
+
keep = True
|
159 |
+
for turn in dlg.turns:
|
160 |
+
if turn['speaker'] == 'user':
|
161 |
+
word_count = len(turn['text'].split())
|
162 |
+
if word_count < self.config.min_user_words:
|
163 |
+
keep = False
|
164 |
+
break
|
165 |
+
if not keep:
|
166 |
+
continue
|
167 |
+
|
168 |
+
pipeline_dlg = {
|
169 |
+
'dialogue_id': dlg.conversation_id,
|
170 |
+
'domain': dlg.domain,
|
171 |
+
'turns': dlg.turns # or you can refine further if needed
|
172 |
+
}
|
173 |
+
results.append(pipeline_dlg)
|
174 |
+
|
175 |
+
if self.config.debug:
|
176 |
+
print(f"[TaskmasterProcessor] Filtered down to {len(results)} dialogues.")
|
177 |
+
return results
|
prepare_data.py
CHANGED
@@ -3,10 +3,13 @@ import sys
|
|
3 |
import faiss
|
4 |
import json
|
5 |
import pickle
|
6 |
-
|
|
|
7 |
from tqdm.auto import tqdm
|
|
|
|
|
|
|
8 |
from chatbot_model import ChatbotConfig, EncoderModel
|
9 |
-
from environment_setup import EnvironmentSetup
|
10 |
from tf_data_pipeline import TFDataPipeline
|
11 |
from logger_config import config_logger
|
12 |
|
@@ -14,32 +17,23 @@ logger = config_logger(__name__)
|
|
14 |
|
15 |
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
16 |
|
17 |
-
def cleanup_test_indices(faiss_dir, test_prefix='test_'):
|
18 |
-
test_files = [f for f in os.listdir(faiss_dir) if f.startswith(test_prefix)]
|
19 |
-
for file in test_files:
|
20 |
-
file_path = os.path.join(faiss_dir, file)
|
21 |
-
os.remove(file_path)
|
22 |
-
logger.info(f"Removed test FAISS index file: {file_path}")
|
23 |
-
|
24 |
def main():
|
25 |
# Constants
|
26 |
-
MODELS_DIR = '
|
27 |
-
PROCESSED_DATA_DIR = 'processed_outputs'
|
28 |
-
CACHE_DIR = 'cache'
|
29 |
TOKENIZER_DIR = os.path.join(MODELS_DIR, 'tokenizer')
|
30 |
FAISS_INDICES_DIR = os.path.join(MODELS_DIR, 'faiss_indices')
|
31 |
-
TF_RECORD_DIR = 'training_data'
|
32 |
FAISS_INDEX_PRODUCTION_PATH = os.path.join(FAISS_INDICES_DIR, 'faiss_index_production.index')
|
33 |
-
|
34 |
-
ENVIRONMENT = 'production' # or 'test'
|
35 |
-
if ENVIRONMENT == 'test':
|
36 |
-
FAISS_INDEX_PATH = FAISS_INDEX_TEST_PATH
|
37 |
-
else:
|
38 |
-
FAISS_INDEX_PATH = FAISS_INDEX_PRODUCTION_PATH
|
39 |
-
JSON_TRAINING_DATA_PATH = os.path.join(PROCESSED_DATA_DIR, 'augmented_dialogues.json')
|
40 |
CACHE_FILE = os.path.join(CACHE_DIR, 'query_embeddings_cache.pkl')
|
41 |
-
TF_RECORD_PATH = os.path.join(TF_RECORD_DIR, '
|
42 |
-
|
|
|
|
|
|
|
|
|
43 |
|
44 |
# Ensure output directories exist
|
45 |
os.makedirs(MODELS_DIR, exist_ok=True)
|
@@ -49,58 +43,120 @@ def main():
|
|
49 |
os.makedirs(FAISS_INDICES_DIR, exist_ok=True)
|
50 |
os.makedirs(TF_RECORD_DIR, exist_ok=True)
|
51 |
|
52 |
-
# Initialize
|
53 |
-
|
54 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
55 |
|
56 |
-
#
|
57 |
try:
|
58 |
-
|
59 |
-
|
60 |
-
|
61 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
62 |
except Exception as e:
|
63 |
-
logger.error(f"Failed to load tokenizer: {e}")
|
64 |
sys.exit(1)
|
65 |
|
66 |
-
# Initialize encoder
|
67 |
try:
|
68 |
encoder = EncoderModel(config=config)
|
69 |
logger.info("EncoderModel initialized successfully.")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
70 |
encoder.pretrained.resize_token_embeddings(len(tokenizer))
|
71 |
logger.info(f"Token embeddings resized to: {len(tokenizer)}")
|
|
|
72 |
except Exception as e:
|
73 |
logger.error(f"Failed to initialize EncoderModel: {e}")
|
74 |
sys.exit(1)
|
75 |
|
76 |
# Load JSON dialogues
|
77 |
try:
|
78 |
-
|
79 |
-
|
|
|
|
|
|
|
|
|
80 |
except Exception as e:
|
81 |
logger.error(f"Failed to load dialogues: {e}")
|
82 |
sys.exit(1)
|
83 |
|
84 |
# Load or initialize query_embeddings_cache
|
85 |
-
|
86 |
-
|
|
|
87 |
with open(CACHE_FILE, 'rb') as f:
|
88 |
query_embeddings_cache = pickle.load(f)
|
89 |
logger.info(f"Loaded {len(query_embeddings_cache)} query embeddings from {CACHE_FILE}.")
|
90 |
-
|
91 |
-
|
92 |
-
|
93 |
-
|
94 |
-
logger.error(f"Failed to load or initialize query embeddings cache: {e}")
|
95 |
-
sys.exit(1)
|
96 |
|
97 |
# Initialize TFDataPipeline
|
98 |
try:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
99 |
data_pipeline = TFDataPipeline(
|
100 |
config=config,
|
101 |
tokenizer=tokenizer,
|
102 |
encoder=encoder,
|
103 |
-
index_file_path=
|
104 |
response_pool=[],
|
105 |
max_length=config.max_context_token_limit,
|
106 |
neg_samples=config.neg_samples,
|
@@ -114,48 +170,55 @@ def main():
|
|
114 |
logger.error(f"Failed to initialize TFDataPipeline: {e}")
|
115 |
sys.exit(1)
|
116 |
|
117 |
-
# Collect unique assistant responses from dialogues
|
118 |
try:
|
119 |
-
|
120 |
-
|
121 |
-
|
|
|
|
|
|
|
122 |
except Exception as e:
|
123 |
logger.error(f"Failed to collect responses: {e}")
|
124 |
sys.exit(1)
|
125 |
|
126 |
-
#
|
|
|
127 |
try:
|
128 |
-
|
129 |
-
|
130 |
-
|
131 |
-
|
132 |
-
|
133 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
134 |
|
135 |
-
# Save FAISS index and response pool
|
136 |
-
try:
|
137 |
-
logger.info(f"Saving FAISS index to {FAISS_INDEX_PATH}...")
|
138 |
-
faiss.write_index(data_pipeline.index, FAISS_INDEX_PATH)
|
139 |
-
logger.info("FAISS index saved successfully.")
|
140 |
-
|
141 |
-
response_pool_path = FAISS_INDEX_PATH.replace('.index', '_responses.json')
|
142 |
-
with open(response_pool_path, 'w', encoding='utf-8') as f:
|
143 |
-
json.dump(data_pipeline.response_pool, f, indent=2)
|
144 |
-
logger.info(f"Response pool saved to {response_pool_path}.")
|
145 |
except Exception as e:
|
146 |
-
logger.error(f"Failed to
|
147 |
sys.exit(1)
|
148 |
|
149 |
-
# Prepare and save training data as TFRecords
|
150 |
try:
|
151 |
-
|
152 |
-
|
153 |
-
|
|
|
|
|
|
|
154 |
except Exception as e:
|
155 |
logger.error(f"Failed during data preparation and saving: {e}")
|
156 |
sys.exit(1)
|
157 |
|
158 |
-
# Save query embeddings cache
|
159 |
try:
|
160 |
with open(CACHE_FILE, 'wb') as f:
|
161 |
pickle.dump(data_pipeline.query_embeddings_cache, f)
|
@@ -164,7 +227,7 @@ def main():
|
|
164 |
logger.error(f"Failed to save query embeddings cache: {e}")
|
165 |
sys.exit(1)
|
166 |
|
167 |
-
# Save Tokenizer
|
168 |
try:
|
169 |
tokenizer.save_pretrained(TOKENIZER_DIR)
|
170 |
logger.info(f"Tokenizer saved to {TOKENIZER_DIR}.")
|
@@ -173,6 +236,7 @@ def main():
|
|
173 |
sys.exit(1)
|
174 |
|
175 |
logger.info("Data preparation pipeline completed successfully.")
|
176 |
-
|
|
|
177 |
if __name__ == "__main__":
|
178 |
-
main()
|
|
|
3 |
import faiss
|
4 |
import json
|
5 |
import pickle
|
6 |
+
import tensorflow as tf
|
7 |
+
from transformers import AutoTokenizer, TFAutoModel
|
8 |
from tqdm.auto import tqdm
|
9 |
+
from pathlib import Path
|
10 |
+
|
11 |
+
# Your existing modules
|
12 |
from chatbot_model import ChatbotConfig, EncoderModel
|
|
|
13 |
from tf_data_pipeline import TFDataPipeline
|
14 |
from logger_config import config_logger
|
15 |
|
|
|
17 |
|
18 |
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
19 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
20 |
def main():
|
21 |
# Constants
|
22 |
+
MODELS_DIR = 'new_iteration/data_prep_iterative_models'
|
23 |
+
PROCESSED_DATA_DIR = 'new_iteration/processed_outputs'
|
24 |
+
CACHE_DIR = 'new_iteration/cache'
|
25 |
TOKENIZER_DIR = os.path.join(MODELS_DIR, 'tokenizer')
|
26 |
FAISS_INDICES_DIR = os.path.join(MODELS_DIR, 'faiss_indices')
|
27 |
+
TF_RECORD_DIR = 'new_iteration/training_data'
|
28 |
FAISS_INDEX_PRODUCTION_PATH = os.path.join(FAISS_INDICES_DIR, 'faiss_index_production.index')
|
29 |
+
JSON_TRAINING_DATA_PATH = os.path.join(PROCESSED_DATA_DIR, 'taskmaster_dialogues.json')
|
|
|
|
|
|
|
|
|
|
|
|
|
30 |
CACHE_FILE = os.path.join(CACHE_DIR, 'query_embeddings_cache.pkl')
|
31 |
+
TF_RECORD_PATH = os.path.join(TF_RECORD_DIR, 'training_data_3.tfrecord')
|
32 |
+
|
33 |
+
# Decide whether to load the **custom** fine-tuned model or just base DistilBERT.
|
34 |
+
# True for custom, False for base DistilBERT.
|
35 |
+
LOAD_CUSTOM_MODEL = True
|
36 |
+
NUM_NEG_SAMPLES = 10
|
37 |
|
38 |
# Ensure output directories exist
|
39 |
os.makedirs(MODELS_DIR, exist_ok=True)
|
|
|
43 |
os.makedirs(FAISS_INDICES_DIR, exist_ok=True)
|
44 |
os.makedirs(TF_RECORD_DIR, exist_ok=True)
|
45 |
|
46 |
+
# Initialize config
|
47 |
+
config_json = Path(MODELS_DIR) / "config.json"
|
48 |
+
if config_json.exists():
|
49 |
+
with open(config_json, "r", encoding="utf-8") as f:
|
50 |
+
config_dict = json.load(f)
|
51 |
+
config = ChatbotConfig.from_dict(config_dict)
|
52 |
+
logger.info(f"Loaded ChatbotConfig from {config_json}")
|
53 |
+
else:
|
54 |
+
config = ChatbotConfig()
|
55 |
+
logger.warning("No config.json found. Using default ChatbotConfig.")
|
56 |
+
|
57 |
+
config.neg_samples = NUM_NEG_SAMPLES
|
58 |
|
59 |
+
# Load or initialize tokenizer
|
60 |
try:
|
61 |
+
# If the directory has a valid tokenizer
|
62 |
+
if Path(TOKENIZER_DIR).exists() and list(Path(TOKENIZER_DIR).iterdir()):
|
63 |
+
logger.info(f"Loading tokenizer from {TOKENIZER_DIR}")
|
64 |
+
tokenizer = AutoTokenizer.from_pretrained(TOKENIZER_DIR)
|
65 |
+
else:
|
66 |
+
# Initialize from base DistilBERT
|
67 |
+
logger.info(f"Loading base tokenizer for {config.pretrained_model}")
|
68 |
+
tokenizer = AutoTokenizer.from_pretrained(config.pretrained_model)
|
69 |
+
|
70 |
+
# Save to disk
|
71 |
+
Path(TOKENIZER_DIR).mkdir(parents=True, exist_ok=True)
|
72 |
+
tokenizer.save_pretrained(TOKENIZER_DIR)
|
73 |
+
logger.info(f"New tokenizer saved to {TOKENIZER_DIR}")
|
74 |
except Exception as e:
|
75 |
+
logger.error(f"Failed to load or create tokenizer: {e}")
|
76 |
sys.exit(1)
|
77 |
|
78 |
+
# Initialize the encoder
|
79 |
try:
|
80 |
encoder = EncoderModel(config=config)
|
81 |
logger.info("EncoderModel initialized successfully.")
|
82 |
+
|
83 |
+
if LOAD_CUSTOM_MODEL:
|
84 |
+
# Load the DistilBERT submodule from 'shared_encoder'
|
85 |
+
shared_encoder_path = Path(MODELS_DIR) / "shared_encoder"
|
86 |
+
if shared_encoder_path.exists():
|
87 |
+
logger.info(f"Loading DistilBERT submodule from {shared_encoder_path}")
|
88 |
+
encoder.pretrained = TFAutoModel.from_pretrained(shared_encoder_path)
|
89 |
+
else:
|
90 |
+
logger.warning(f"No shared_encoder found at {shared_encoder_path}, using base DistilBERT instead.")
|
91 |
+
|
92 |
+
# Load top-level custom .weights.h5 (projection, dropout, etc.)
|
93 |
+
custom_weights_path = Path(MODELS_DIR) / "encoder_custom_weights.weights.h5"
|
94 |
+
if custom_weights_path.exists():
|
95 |
+
logger.info(f"Loading custom top-level weights from {custom_weights_path}")
|
96 |
+
# Build model layers with a dummy forward pass
|
97 |
+
dummy_input = tf.zeros((1, config.max_context_token_limit), dtype=tf.int32)
|
98 |
+
_ = encoder(dummy_input, training=False)
|
99 |
+
encoder.load_weights(str(custom_weights_path))
|
100 |
+
logger.info("Custom encoder weights loaded successfully.")
|
101 |
+
else:
|
102 |
+
logger.warning(f"Custom weights file not found at {custom_weights_path}. Using only submodule weights.")
|
103 |
+
else:
|
104 |
+
# Just base DistilBERT with special tokens resized
|
105 |
+
logger.info("Using the base DistilBERT without loading custom weights.")
|
106 |
+
|
107 |
+
# Resize token embeddings in case we added special tokens
|
108 |
encoder.pretrained.resize_token_embeddings(len(tokenizer))
|
109 |
logger.info(f"Token embeddings resized to: {len(tokenizer)}")
|
110 |
+
|
111 |
except Exception as e:
|
112 |
logger.error(f"Failed to initialize EncoderModel: {e}")
|
113 |
sys.exit(1)
|
114 |
|
115 |
# Load JSON dialogues
|
116 |
try:
|
117 |
+
if not Path(JSON_TRAINING_DATA_PATH).exists():
|
118 |
+
logger.warning(f"No dialogues found at {JSON_TRAINING_DATA_PATH}, skipping.")
|
119 |
+
dialogues = []
|
120 |
+
else:
|
121 |
+
dialogues = TFDataPipeline.load_json_training_data(JSON_TRAINING_DATA_PATH, debug_samples=None)
|
122 |
+
logger.info(f"Loaded {len(dialogues)} dialogues from {JSON_TRAINING_DATA_PATH}.")
|
123 |
except Exception as e:
|
124 |
logger.error(f"Failed to load dialogues: {e}")
|
125 |
sys.exit(1)
|
126 |
|
127 |
# Load or initialize query_embeddings_cache
|
128 |
+
query_embeddings_cache = {}
|
129 |
+
if os.path.exists(CACHE_FILE):
|
130 |
+
try:
|
131 |
with open(CACHE_FILE, 'rb') as f:
|
132 |
query_embeddings_cache = pickle.load(f)
|
133 |
logger.info(f"Loaded {len(query_embeddings_cache)} query embeddings from {CACHE_FILE}.")
|
134 |
+
except Exception as e:
|
135 |
+
logger.warning(f"Failed to load query embeddings cache: {e}")
|
136 |
+
else:
|
137 |
+
logger.info("No existing query embeddings cache found. Starting fresh.")
|
|
|
|
|
138 |
|
139 |
# Initialize TFDataPipeline
|
140 |
try:
|
141 |
+
# Determine if FAISS index should be loaded or initialized
|
142 |
+
if Path(FAISS_INDEX_PRODUCTION_PATH).exists():
|
143 |
+
# Load existing index
|
144 |
+
logger.info(f"Loading existing FAISS index from {FAISS_INDEX_PRODUCTION_PATH}...")
|
145 |
+
faiss_index = faiss.read_index(FAISS_INDEX_PRODUCTION_PATH)
|
146 |
+
logger.info("FAISS index loaded successfully.")
|
147 |
+
else:
|
148 |
+
# Initialize a new FAISS index
|
149 |
+
logger.info("No existing FAISS index found. Initializing a new index.")
|
150 |
+
dimension = config.embedding_dim # Ensure this matches your encoder's output
|
151 |
+
faiss_index = faiss.IndexFlatIP(dimension) # Using Inner Product for cosine similarity
|
152 |
+
logger.info(f"Initialized new FAISS index with dimension {dimension}.")
|
153 |
+
|
154 |
+
# Initialize TFDataPipeline with the FAISS index
|
155 |
data_pipeline = TFDataPipeline(
|
156 |
config=config,
|
157 |
tokenizer=tokenizer,
|
158 |
encoder=encoder,
|
159 |
+
index_file_path=FAISS_INDEX_PRODUCTION_PATH,
|
160 |
response_pool=[],
|
161 |
max_length=config.max_context_token_limit,
|
162 |
neg_samples=config.neg_samples,
|
|
|
170 |
logger.error(f"Failed to initialize TFDataPipeline: {e}")
|
171 |
sys.exit(1)
|
172 |
|
173 |
+
# 7) Collect unique assistant responses from dialogues
|
174 |
try:
|
175 |
+
if dialogues:
|
176 |
+
response_pool = data_pipeline.collect_responses_with_domain(dialogues)
|
177 |
+
data_pipeline.response_pool = response_pool
|
178 |
+
logger.info(f"Collected {len(response_pool)} unique assistant responses from dialogues.")
|
179 |
+
else:
|
180 |
+
logger.warning("No dialogues loaded. response_pool remains empty.")
|
181 |
except Exception as e:
|
182 |
logger.error(f"Failed to collect responses: {e}")
|
183 |
sys.exit(1)
|
184 |
|
185 |
+
# 8) Build the FAISS index with response embeddings
|
186 |
+
# Instead of manually computing embeddings, we use the pipeline method
|
187 |
try:
|
188 |
+
if data_pipeline.response_pool:
|
189 |
+
data_pipeline.build_text_to_domain_map()
|
190 |
+
logger.info("Computing and adding response embeddings to FAISS index using TFDataPipeline...")
|
191 |
+
data_pipeline.compute_and_index_response_embeddings()
|
192 |
+
logger.info("Response embeddings computed and added to FAISS index.")
|
193 |
+
|
194 |
+
# Save the updated FAISS index
|
195 |
+
data_pipeline.save_faiss_index(FAISS_INDEX_PRODUCTION_PATH)
|
196 |
+
|
197 |
+
# Also save the response pool JSON
|
198 |
+
response_pool_path = FAISS_INDEX_PRODUCTION_PATH.replace('.index', '_responses.json')
|
199 |
+
with open(response_pool_path, 'w', encoding='utf-8') as f:
|
200 |
+
json.dump(data_pipeline.response_pool, f, indent=2)
|
201 |
+
logger.info(f"Response pool saved to {response_pool_path}.")
|
202 |
+
else:
|
203 |
+
logger.warning("No responses to embed. Skipping FAISS indexing.")
|
204 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
205 |
except Exception as e:
|
206 |
+
logger.error(f"Failed to compute or add response embeddings: {e}")
|
207 |
sys.exit(1)
|
208 |
|
209 |
+
# 9) Prepare and save training data as TFRecords
|
210 |
try:
|
211 |
+
if dialogues:
|
212 |
+
logger.info("Starting data preparation and saving as TFRecord...")
|
213 |
+
data_pipeline.prepare_and_save_data(dialogues, TF_RECORD_PATH)
|
214 |
+
logger.info(f"Data saved as TFRecord at {TF_RECORD_PATH}.")
|
215 |
+
else:
|
216 |
+
logger.warning("No dialogues to build TFRecord from. Skipping TFRecord creation.")
|
217 |
except Exception as e:
|
218 |
logger.error(f"Failed during data preparation and saving: {e}")
|
219 |
sys.exit(1)
|
220 |
|
221 |
+
# 10) Save query embeddings cache
|
222 |
try:
|
223 |
with open(CACHE_FILE, 'wb') as f:
|
224 |
pickle.dump(data_pipeline.query_embeddings_cache, f)
|
|
|
227 |
logger.error(f"Failed to save query embeddings cache: {e}")
|
228 |
sys.exit(1)
|
229 |
|
230 |
+
# Save Tokenizer
|
231 |
try:
|
232 |
tokenizer.save_pretrained(TOKENIZER_DIR)
|
233 |
logger.info(f"Tokenizer saved to {TOKENIZER_DIR}.")
|
|
|
236 |
sys.exit(1)
|
237 |
|
238 |
logger.info("Data preparation pipeline completed successfully.")
|
239 |
+
|
240 |
+
|
241 |
if __name__ == "__main__":
|
242 |
+
main()
|
response_quality_checker.py
CHANGED
@@ -9,27 +9,41 @@ if TYPE_CHECKING:
|
|
9 |
from tf_data_pipeline import TFDataPipeline
|
10 |
|
11 |
class ResponseQualityChecker:
|
12 |
-
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
13 |
|
14 |
def __init__(
|
15 |
self,
|
16 |
data_pipeline: 'TFDataPipeline',
|
17 |
-
confidence_threshold: float = 0.
|
18 |
diversity_threshold: float = 0.15,
|
19 |
min_response_length: int = 5,
|
20 |
-
similarity_cap: float = 0.85
|
21 |
):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
22 |
self.confidence_threshold = confidence_threshold
|
23 |
self.diversity_threshold = diversity_threshold
|
24 |
self.min_response_length = min_response_length
|
25 |
self.similarity_cap = similarity_cap
|
26 |
-
self.data_pipeline = data_pipeline
|
27 |
|
28 |
-
#
|
29 |
self.thresholds = {
|
30 |
-
'relevance': 0.
|
31 |
-
'length_score': 0.
|
32 |
-
'score_gap': 0.
|
33 |
}
|
34 |
|
35 |
def check_response_quality(
|
@@ -38,14 +52,14 @@ class ResponseQualityChecker:
|
|
38 |
responses: List[Tuple[str, float]]
|
39 |
) -> Dict[str, Any]:
|
40 |
"""
|
41 |
-
Evaluate the quality of responses
|
42 |
|
43 |
Args:
|
44 |
-
query: The user's query
|
45 |
-
responses: List of (response_text, score)
|
46 |
|
47 |
Returns:
|
48 |
-
|
49 |
"""
|
50 |
if not responses:
|
51 |
return {
|
@@ -57,98 +71,282 @@ class ResponseQualityChecker:
|
|
57 |
'top_3_score_gap': 0.0
|
58 |
}
|
59 |
|
60 |
-
# Calculate
|
61 |
-
metrics = {
|
62 |
-
|
63 |
-
|
64 |
-
|
65 |
-
|
66 |
-
|
67 |
-
'top_score': responses[0][1],
|
68 |
-
'top_3_score_gap': self._calculate_score_gap([score for _, score in responses], top_n=3)
|
69 |
-
}
|
70 |
|
71 |
-
# Determine confidence
|
72 |
metrics['is_confident'] = self._determine_confidence(metrics)
|
73 |
-
|
74 |
logger.info(f"Quality metrics: {metrics}")
|
75 |
return metrics
|
76 |
|
77 |
def calculate_relevance(self, query: str, responses: List[Tuple[str, float]]) -> float:
|
78 |
-
"""
|
|
|
|
|
|
|
79 |
if not responses:
|
80 |
return 0.0
|
81 |
|
82 |
-
#
|
83 |
-
|
84 |
-
|
85 |
-
|
86 |
|
87 |
-
#
|
88 |
-
|
|
|
|
|
89 |
|
90 |
-
#
|
91 |
-
|
92 |
|
93 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
94 |
|
95 |
def calculate_diversity(self, responses: List[Tuple[str, float]]) -> float:
|
96 |
-
"""
|
97 |
-
|
98 |
-
|
|
|
|
|
|
|
99 |
|
100 |
-
|
101 |
-
|
102 |
-
if len(embeddings) < 2:
|
103 |
-
return 1.0
|
104 |
|
105 |
-
#
|
106 |
-
|
107 |
-
np.fill_diagonal(
|
108 |
|
109 |
-
#
|
110 |
-
|
111 |
|
112 |
-
#
|
113 |
-
|
114 |
-
num_pairs = len(
|
115 |
-
|
116 |
|
117 |
-
#
|
118 |
-
|
119 |
-
return diversity_score
|
120 |
|
121 |
def _determine_confidence(self, metrics: Dict[str, float]) -> bool:
|
122 |
-
"""
|
123 |
-
|
|
|
124 |
primary_conditions = [
|
125 |
metrics['top_score'] >= self.confidence_threshold,
|
126 |
metrics['response_diversity'] >= self.diversity_threshold,
|
127 |
metrics['response_length_score'] >= self.thresholds['length_score']
|
128 |
]
|
129 |
|
130 |
-
# Secondary conditions (majority must be met)
|
131 |
secondary_conditions = [
|
132 |
metrics['query_response_relevance'] >= self.thresholds['relevance'],
|
133 |
metrics['top_3_score_gap'] >= self.thresholds['score_gap'],
|
134 |
-
metrics['top_score'] >= (self.confidence_threshold
|
135 |
]
|
136 |
|
137 |
-
|
|
|
138 |
|
139 |
-
def
|
140 |
-
"""
|
141 |
-
|
|
|
|
|
|
|
|
|
|
|
142 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
143 |
if words < self.min_response_length:
|
144 |
-
return words / self.min_response_length
|
145 |
-
elif words >
|
146 |
-
return
|
147 |
return 1.0
|
148 |
|
149 |
def _calculate_score_gap(self, scores: List[float], top_n: int = 3) -> float:
|
150 |
-
"""
|
151 |
-
|
|
|
|
|
152 |
return 0.0
|
153 |
-
|
154 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
9 |
from tf_data_pipeline import TFDataPipeline
|
10 |
|
11 |
class ResponseQualityChecker:
|
12 |
+
"""
|
13 |
+
Enhanced quality checking that calculates:
|
14 |
+
- Relevance between query & responses
|
15 |
+
- Diversity among top responses
|
16 |
+
- Response length scoring
|
17 |
+
- Confidence determination based on multiple thresholds
|
18 |
+
"""
|
19 |
|
20 |
def __init__(
|
21 |
self,
|
22 |
data_pipeline: 'TFDataPipeline',
|
23 |
+
confidence_threshold: float = 0.45,
|
24 |
diversity_threshold: float = 0.15,
|
25 |
min_response_length: int = 5,
|
26 |
+
similarity_cap: float = 0.85,
|
27 |
):
|
28 |
+
"""
|
29 |
+
Args:
|
30 |
+
data_pipeline: Reference to TFDataPipeline for encoding
|
31 |
+
confidence_threshold: Minimum top_score for a 'confident' result
|
32 |
+
diversity_threshold: Minimum required diversity among top responses
|
33 |
+
min_response_length: Minimum words for a decent response
|
34 |
+
similarity_cap: Cap on pairwise similarity for diversity calc
|
35 |
+
"""
|
36 |
self.confidence_threshold = confidence_threshold
|
37 |
self.diversity_threshold = diversity_threshold
|
38 |
self.min_response_length = min_response_length
|
39 |
self.similarity_cap = similarity_cap
|
40 |
+
self.data_pipeline = data_pipeline
|
41 |
|
42 |
+
# Additional thresholds for more refined checks
|
43 |
self.thresholds = {
|
44 |
+
'relevance': 0.30, # Slightly relaxed
|
45 |
+
'length_score': 0.80, # Stricter length requirement
|
46 |
+
'score_gap': 0.05 # Gap between top scores
|
47 |
}
|
48 |
|
49 |
def check_response_quality(
|
|
|
52 |
responses: List[Tuple[str, float]]
|
53 |
) -> Dict[str, Any]:
|
54 |
"""
|
55 |
+
Evaluate the quality of a set of ranked responses for a given query.
|
56 |
|
57 |
Args:
|
58 |
+
query: The user's original query
|
59 |
+
responses: List of (response_text, score) sorted by descending score
|
60 |
|
61 |
Returns:
|
62 |
+
Dictionary of metrics, including 'is_confident' and others
|
63 |
"""
|
64 |
if not responses:
|
65 |
return {
|
|
|
71 |
'top_3_score_gap': 0.0
|
72 |
}
|
73 |
|
74 |
+
# 1) Calculate relevant metrics
|
75 |
+
metrics = {}
|
76 |
+
metrics['response_diversity'] = self.calculate_diversity(responses)
|
77 |
+
metrics['query_response_relevance'] = self.calculate_relevance(query, responses)
|
78 |
+
metrics['response_length_score'] = self._average_length_score(responses)
|
79 |
+
metrics['top_score'] = responses[0][1]
|
80 |
+
metrics['top_3_score_gap'] = self._calculate_score_gap([s for _, s in responses], top_n=3)
|
|
|
|
|
|
|
81 |
|
82 |
+
# 2) Determine confidence
|
83 |
metrics['is_confident'] = self._determine_confidence(metrics)
|
|
|
84 |
logger.info(f"Quality metrics: {metrics}")
|
85 |
return metrics
|
86 |
|
87 |
def calculate_relevance(self, query: str, responses: List[Tuple[str, float]]) -> float:
|
88 |
+
"""
|
89 |
+
Compute an overall 'relevance' metric between the query and the top responses.
|
90 |
+
Uses an exponential transform on the similarity to penalize weaker matches.
|
91 |
+
"""
|
92 |
if not responses:
|
93 |
return 0.0
|
94 |
|
95 |
+
# Encode query and responses
|
96 |
+
query_emb = self.data_pipeline.encode_query(query)
|
97 |
+
resp_texts = [r for r, _ in responses]
|
98 |
+
resp_embs = self.data_pipeline.encode_responses(resp_texts)
|
99 |
|
100 |
+
# Normalize embeddings
|
101 |
+
query_emb = query_emb / (np.linalg.norm(query_emb) + 1e-12)
|
102 |
+
resp_norms = np.linalg.norm(resp_embs, axis=1, keepdims=True) + 1e-12
|
103 |
+
resp_embs = resp_embs / resp_norms
|
104 |
|
105 |
+
# Cosine similarity
|
106 |
+
sims = cosine_similarity([query_emb], resp_embs)[0]
|
107 |
|
108 |
+
# Exponential transform: higher sims remain close to 1, lower sims drop quickly
|
109 |
+
sims = np.exp(sims - 1.0)
|
110 |
+
|
111 |
+
# Weighted average: give heavier weighting to higher-ranked items
|
112 |
+
weights = np.exp(-np.arange(len(sims)) / 2.0)
|
113 |
+
weighted_avg = np.average(sims, weights=weights)
|
114 |
+
return float(weighted_avg)
|
115 |
|
116 |
def calculate_diversity(self, responses: List[Tuple[str, float]]) -> float:
|
117 |
+
"""
|
118 |
+
Calculate how 'different' the top responses are from each other.
|
119 |
+
Diversity = 1 - avg_cosine_similarity (capped).
|
120 |
+
"""
|
121 |
+
if len(responses) < 2:
|
122 |
+
return 1.0 # Single response is trivially 'unique'
|
123 |
|
124 |
+
resp_texts = [r for r, _ in responses]
|
125 |
+
embs = self.data_pipeline.encode_responses(resp_texts)
|
|
|
|
|
126 |
|
127 |
+
# Pairwise similarity
|
128 |
+
sim_matrix = cosine_similarity(embs, embs)
|
129 |
+
np.fill_diagonal(sim_matrix, 0.0)
|
130 |
|
131 |
+
# Cap similarity to avoid outliers
|
132 |
+
sim_matrix = np.minimum(sim_matrix, self.similarity_cap)
|
133 |
|
134 |
+
# Mean off-diagonal similarity
|
135 |
+
sum_sims = np.sum(sim_matrix)
|
136 |
+
num_pairs = len(resp_texts) * (len(resp_texts) - 1)
|
137 |
+
avg_sim = sum_sims / num_pairs if num_pairs > 0 else 0.0
|
138 |
|
139 |
+
# Invert to get diversity
|
140 |
+
return 1.0 - avg_sim
|
|
|
141 |
|
142 |
def _determine_confidence(self, metrics: Dict[str, float]) -> bool:
|
143 |
+
"""
|
144 |
+
Decide if we're 'confident' based on multiple metric thresholds.
|
145 |
+
"""
|
146 |
primary_conditions = [
|
147 |
metrics['top_score'] >= self.confidence_threshold,
|
148 |
metrics['response_diversity'] >= self.diversity_threshold,
|
149 |
metrics['response_length_score'] >= self.thresholds['length_score']
|
150 |
]
|
151 |
|
|
|
152 |
secondary_conditions = [
|
153 |
metrics['query_response_relevance'] >= self.thresholds['relevance'],
|
154 |
metrics['top_3_score_gap'] >= self.thresholds['score_gap'],
|
155 |
+
metrics['top_score'] >= (self.confidence_threshold + 0.05) # Extra buffer
|
156 |
]
|
157 |
|
158 |
+
# Must pass all primary checks, and at least 2 of the 3 secondary
|
159 |
+
return all(primary_conditions) and (sum(secondary_conditions) >= 2)
|
160 |
|
161 |
+
def _average_length_score(self, responses: List[Tuple[str, float]]) -> float:
|
162 |
+
"""
|
163 |
+
Compute an average length score across all responses.
|
164 |
+
"""
|
165 |
+
length_scores = []
|
166 |
+
for response, _ in responses:
|
167 |
+
length_scores.append(self._length_score(response))
|
168 |
+
return float(np.mean(length_scores)) if length_scores else 0.0
|
169 |
|
170 |
+
def _length_score(self, text: str) -> float:
|
171 |
+
"""
|
172 |
+
Calculate how well the text meets our length requirement.
|
173 |
+
Scores 1.0 if text is >= min_response_length and not too long,
|
174 |
+
else it scales down.
|
175 |
+
"""
|
176 |
+
words = len(text.split())
|
177 |
if words < self.min_response_length:
|
178 |
+
return words / float(self.min_response_length)
|
179 |
+
elif words > 60:
|
180 |
+
return max(0.5, 60.0 / words) # Slight penalty for very long
|
181 |
return 1.0
|
182 |
|
183 |
def _calculate_score_gap(self, scores: List[float], top_n: int = 3) -> float:
|
184 |
+
"""
|
185 |
+
Calculate the average gap between consecutive scores in the top N.
|
186 |
+
"""
|
187 |
+
if len(scores) < 2:
|
188 |
return 0.0
|
189 |
+
top_n = min(len(scores), top_n)
|
190 |
+
gaps = []
|
191 |
+
for i in range(top_n - 1):
|
192 |
+
gaps.append(scores[i] - scores[i + 1])
|
193 |
+
return float(np.mean(gaps)) if gaps else 0.0
|
194 |
+
|
195 |
+
# import numpy as np
|
196 |
+
# from typing import List, Tuple, Dict, Any, TYPE_CHECKING
|
197 |
+
# from sklearn.metrics.pairwise import cosine_similarity
|
198 |
+
|
199 |
+
# from logger_config import config_logger
|
200 |
+
# logger = config_logger(__name__)
|
201 |
+
|
202 |
+
# if TYPE_CHECKING:
|
203 |
+
# from tf_data_pipeline import TFDataPipeline
|
204 |
+
|
205 |
+
# class ResponseQualityChecker:
|
206 |
+
# """Enhanced quality checking with dynamic thresholds."""
|
207 |
+
|
208 |
+
# def __init__(
|
209 |
+
# self,
|
210 |
+
# data_pipeline: 'TFDataPipeline',
|
211 |
+
# confidence_threshold: float = 0.4,
|
212 |
+
# diversity_threshold: float = 0.15,
|
213 |
+
# min_response_length: int = 5,
|
214 |
+
# similarity_cap: float = 0.85 # Renamed from max_similarity_ratio and used in diversity calc
|
215 |
+
# ):
|
216 |
+
# self.confidence_threshold = confidence_threshold
|
217 |
+
# self.diversity_threshold = diversity_threshold
|
218 |
+
# self.min_response_length = min_response_length
|
219 |
+
# self.similarity_cap = similarity_cap
|
220 |
+
# self.data_pipeline = data_pipeline # Reference to TFDataPipeline
|
221 |
+
|
222 |
+
# # Dynamic thresholds based on response patterns
|
223 |
+
# self.thresholds = {
|
224 |
+
# 'relevance': 0.35,
|
225 |
+
# 'length_score': 0.85,
|
226 |
+
# 'score_gap': 0.04
|
227 |
+
# }
|
228 |
+
|
229 |
+
# def check_response_quality(
|
230 |
+
# self,
|
231 |
+
# query: str,
|
232 |
+
# responses: List[Tuple[str, float]]
|
233 |
+
# ) -> Dict[str, Any]:
|
234 |
+
# """
|
235 |
+
# Evaluate the quality of responses based on various metrics.
|
236 |
+
|
237 |
+
# Args:
|
238 |
+
# query: The user's query
|
239 |
+
# responses: List of (response_text, score) tuples
|
240 |
+
|
241 |
+
# Returns:
|
242 |
+
# Dict containing quality metrics and confidence assessment
|
243 |
+
# """
|
244 |
+
# if not responses:
|
245 |
+
# return {
|
246 |
+
# 'response_diversity': 0.0,
|
247 |
+
# 'query_response_relevance': 0.0,
|
248 |
+
# 'is_confident': False,
|
249 |
+
# 'top_score': 0.0,
|
250 |
+
# 'response_length_score': 0.0,
|
251 |
+
# 'top_3_score_gap': 0.0
|
252 |
+
# }
|
253 |
+
|
254 |
+
# # Calculate core metrics
|
255 |
+
# metrics = {
|
256 |
+
# 'response_diversity': self.calculate_diversity(responses),
|
257 |
+
# 'query_response_relevance': self.calculate_relevance(query, responses),
|
258 |
+
# 'response_length_score': np.mean([
|
259 |
+
# self._calculate_length_score(response) for response, _ in responses
|
260 |
+
# ]),
|
261 |
+
# 'top_score': responses[0][1],
|
262 |
+
# 'top_3_score_gap': self._calculate_score_gap([score for _, score in responses], top_n=3)
|
263 |
+
# }
|
264 |
+
|
265 |
+
# # Determine confidence using thresholds
|
266 |
+
# metrics['is_confident'] = self._determine_confidence(metrics)
|
267 |
+
|
268 |
+
# logger.info(f"Quality metrics: {metrics}")
|
269 |
+
# return metrics
|
270 |
+
|
271 |
+
# def calculate_relevance(self, query: str, responses: List[Tuple[str, float]]) -> float:
|
272 |
+
# """Calculate relevance with stricter scoring."""
|
273 |
+
# if not responses:
|
274 |
+
# return 0.0
|
275 |
+
|
276 |
+
# query_embedding = self.data_pipeline.encode_query(query)
|
277 |
+
# response_texts = [resp for resp, _ in responses]
|
278 |
+
# response_embeddings = self.data_pipeline.encode_responses(response_texts)
|
279 |
+
|
280 |
+
# # Normalize embeddings
|
281 |
+
# query_embedding = query_embedding / np.linalg.norm(query_embedding)
|
282 |
+
# response_embeddings = response_embeddings / np.linalg.norm(response_embeddings, axis=1)[:, np.newaxis]
|
283 |
+
|
284 |
+
# # Compute similarities with exponential decay for far matches
|
285 |
+
# similarities = cosine_similarity([query_embedding], response_embeddings)[0]
|
286 |
+
# similarities = np.exp(similarities - 1) # Penalize lower similarities more strongly
|
287 |
+
|
288 |
+
# # Apply stronger position weighting
|
289 |
+
# weights = np.exp(-np.arange(len(similarities)) / 2)
|
290 |
+
|
291 |
+
# return float(np.average(similarities, weights=weights))
|
292 |
+
|
293 |
+
# def calculate_diversity(self, responses: List[Tuple[str, float]]) -> float:
|
294 |
+
# """Calculate diversity with length normalization and similarity capping."""
|
295 |
+
# if not responses:
|
296 |
+
# return 0.0
|
297 |
+
|
298 |
+
# response_texts = [resp for resp, _ in responses]
|
299 |
+
# embeddings = self.data_pipeline.encode_responses(response_texts)
|
300 |
+
# if len(embeddings) < 2:
|
301 |
+
# return 1.0
|
302 |
+
|
303 |
+
# # Calculate pairwise cosine similarities
|
304 |
+
# similarity_matrix = cosine_similarity(embeddings)
|
305 |
+
# np.fill_diagonal(similarity_matrix, 0) # Exclude self-similarity
|
306 |
+
|
307 |
+
# # Apply similarity cap
|
308 |
+
# similarity_matrix = np.minimum(similarity_matrix, self.similarity_cap)
|
309 |
+
|
310 |
+
# # Calculate average similarity
|
311 |
+
# sum_similarities = np.sum(similarity_matrix)
|
312 |
+
# num_pairs = len(embeddings) * (len(embeddings) - 1)
|
313 |
+
# avg_similarity = sum_similarities / num_pairs if num_pairs > 0 else 0.0
|
314 |
+
|
315 |
+
# # Diversity is inversely related to average similarity
|
316 |
+
# diversity_score = 1 - avg_similarity
|
317 |
+
# return diversity_score
|
318 |
+
|
319 |
+
# def _determine_confidence(self, metrics: Dict[str, float]) -> bool:
|
320 |
+
# """Determine confidence using primary and secondary conditions."""
|
321 |
+
# # Primary conditions (must all be met)
|
322 |
+
# primary_conditions = [
|
323 |
+
# metrics['top_score'] >= self.confidence_threshold,
|
324 |
+
# metrics['response_diversity'] >= self.diversity_threshold,
|
325 |
+
# metrics['response_length_score'] >= self.thresholds['length_score']
|
326 |
+
# ]
|
327 |
+
|
328 |
+
# # Secondary conditions (majority must be met)
|
329 |
+
# secondary_conditions = [
|
330 |
+
# metrics['query_response_relevance'] >= self.thresholds['relevance'],
|
331 |
+
# metrics['top_3_score_gap'] >= self.thresholds['score_gap'],
|
332 |
+
# metrics['top_score'] >= (self.confidence_threshold * 1.1) # Extra confidence boost
|
333 |
+
# ]
|
334 |
+
|
335 |
+
# return all(primary_conditions) and sum(secondary_conditions) >= 2
|
336 |
+
|
337 |
+
# def _calculate_length_score(self, response: str) -> float:
|
338 |
+
# """Calculate length score with penalty for very short or long responses."""
|
339 |
+
# words = len(response.split())
|
340 |
+
|
341 |
+
# if words < self.min_response_length:
|
342 |
+
# return words / self.min_response_length
|
343 |
+
# elif words > 50: # Penalty for very long responses
|
344 |
+
# return min(1.0, 50 / words)
|
345 |
+
# return 1.0
|
346 |
+
|
347 |
+
# def _calculate_score_gap(self, scores: List[float], top_n: int = 3) -> float:
|
348 |
+
# """Calculate average gap between top N scores."""
|
349 |
+
# if len(scores) < top_n + 1:
|
350 |
+
# return 0.0
|
351 |
+
# gaps = [scores[i] - scores[i + 1] for i in range(min(len(scores) - 1, top_n))]
|
352 |
+
# return np.mean(gaps)
|
tf_data_pipeline.py
CHANGED
@@ -8,11 +8,12 @@ import math
|
|
8 |
from tqdm import tqdm
|
9 |
import json
|
10 |
from pathlib import Path
|
11 |
-
from typing import Union, Optional, List, Tuple, Generator
|
12 |
from transformers import AutoTokenizer
|
13 |
from typing import List, Tuple, Generator
|
14 |
from transformers import AutoTokenizer
|
15 |
from gpu_monitor import GPUMemoryMonitor
|
|
|
16 |
|
17 |
from logger_config import config_logger
|
18 |
logger = config_logger(__name__)
|
@@ -27,7 +28,7 @@ class TFDataPipeline:
|
|
27 |
response_pool: List[str],
|
28 |
max_length: int,
|
29 |
query_embeddings_cache: dict,
|
30 |
-
neg_samples: int =
|
31 |
index_type: str = 'IndexFlatIP',
|
32 |
nlist: int = 100,
|
33 |
max_retries: int = 3
|
@@ -47,6 +48,10 @@ class TFDataPipeline:
|
|
47 |
self.max_batch_size = 16 if len(response_pool) < 100 else 64
|
48 |
self.memory_monitor = GPUMemoryMonitor()
|
49 |
self.max_retries = max_retries
|
|
|
|
|
|
|
|
|
50 |
|
51 |
if os.path.exists(index_file_path):
|
52 |
logger.info(f"Loading existing FAISS index from {index_file_path}...")
|
@@ -135,21 +140,49 @@ class TFDataPipeline:
|
|
135 |
|
136 |
logger.info(f"Loaded {len(dialogues)} dialogues.")
|
137 |
return dialogues
|
138 |
-
|
139 |
-
def
|
140 |
-
"""
|
141 |
-
|
|
|
|
|
|
|
|
|
|
|
142 |
for dialogue in tqdm(dialogues, desc="Processing Dialogues", unit="dialogue"):
|
|
|
|
|
143 |
turns = dialogue.get('turns', [])
|
144 |
for turn in turns:
|
145 |
speaker = turn.get('speaker')
|
146 |
text = turn.get('text', '').strip()
|
147 |
if speaker == 'assistant' and text:
|
148 |
-
# Ensure we don't exclude valid shorter responses
|
149 |
if len(text) <= self.max_length:
|
150 |
-
|
151 |
-
|
152 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
153 |
|
154 |
def _extract_pairs_from_dialogue(self, dialogue: dict) -> List[Tuple[str, str]]:
|
155 |
"""Extract query-response pairs from a dialogue."""
|
@@ -173,113 +206,101 @@ class TFDataPipeline:
|
|
173 |
|
174 |
def compute_and_index_response_embeddings(self):
|
175 |
"""
|
176 |
-
Computes embeddings for the response pool and adds them to the FAISS index
|
|
|
177 |
"""
|
178 |
logger.info("Computing embeddings for the response pool...")
|
179 |
|
180 |
-
#
|
181 |
-
|
182 |
-
|
183 |
-
|
184 |
-
|
185 |
-
# Tokenization
|
186 |
-
logger.info("Tokenizing responses...")
|
187 |
-
encoded_responses = self.tokenizer(
|
188 |
-
self.response_pool,
|
189 |
-
padding=True,
|
190 |
-
truncation=True,
|
191 |
-
max_length=self.max_length,
|
192 |
-
return_tensors='tf'
|
193 |
-
)
|
194 |
-
response_ids = encoded_responses['input_ids']
|
195 |
-
|
196 |
-
# Compute embeddings in batches with progress bar
|
197 |
-
batch_size = getattr(self, 'embedding_batch_size', 64) # Default to 64 if not set
|
198 |
-
total_responses = len(response_ids)
|
199 |
-
logger.info(f"Computing embeddings in batches of {batch_size}...")
|
200 |
embeddings = []
|
201 |
|
202 |
-
with tqdm(total=
|
203 |
-
for i in range(0,
|
204 |
-
|
205 |
-
|
206 |
-
|
207 |
-
|
208 |
-
|
209 |
-
|
210 |
-
|
211 |
-
|
212 |
-
|
213 |
-
|
214 |
-
|
215 |
-
|
216 |
-
|
217 |
-
|
218 |
-
|
219 |
-
|
220 |
-
|
221 |
-
|
222 |
-
|
223 |
-
|
224 |
-
|
225 |
-
self.index.add(batch)
|
226 |
-
pbar_index.update(len(batch))
|
227 |
-
|
228 |
-
logger.info("Response embeddings added to FAISS index.")
|
229 |
-
else:
|
230 |
-
logger.warning("No embeddings to add to FAISS index.")
|
231 |
-
|
232 |
-
# **Sanity Check:** Verify the number of embeddings in FAISS index
|
233 |
-
logger.info(f"Total embeddings in FAISS index after addition: {self.index.ntotal}")
|
234 |
|
235 |
def _find_hard_negatives_batch(self, queries: List[str], positives: List[str]) -> List[List[str]]:
|
236 |
-
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
237 |
retry_count = 0
|
238 |
total_responses = len(self.response_pool)
|
239 |
-
|
240 |
-
|
241 |
-
k = self.neg_samples + 0
|
242 |
|
243 |
while retry_count < self.max_retries:
|
244 |
try:
|
245 |
-
#
|
246 |
-
batch_size = 128 # Example sub-batch size; adjust as needed
|
247 |
query_embeddings = []
|
248 |
for i in range(0, len(queries), batch_size):
|
249 |
-
sub_queries = queries[i:i + batch_size]
|
250 |
-
|
251 |
-
|
252 |
-
|
253 |
-
|
254 |
-
query_embeddings.append(sub_embeddings)
|
255 |
-
query_embeddings = np.vstack(query_embeddings)
|
256 |
|
257 |
-
|
258 |
query_embeddings = np.ascontiguousarray(query_embeddings)
|
259 |
|
260 |
-
# Perform FAISS search
|
261 |
distances, indices = self.index.search(query_embeddings, k)
|
262 |
|
263 |
all_negatives = []
|
264 |
-
|
265 |
-
|
266 |
-
|
267 |
-
seen = {
|
|
|
|
|
|
|
268 |
|
|
|
269 |
for idx in query_indices:
|
270 |
-
if
|
271 |
-
|
272 |
-
|
273 |
-
|
274 |
-
|
275 |
-
|
|
|
276 |
break
|
277 |
|
278 |
-
# If not enough negatives
|
279 |
-
|
280 |
-
|
|
|
|
|
|
|
281 |
|
282 |
-
all_negatives.append(
|
283 |
|
284 |
return all_negatives
|
285 |
|
@@ -288,123 +309,236 @@ class TFDataPipeline:
|
|
288 |
logger.warning(f"Hard negative search attempt {retry_count} failed due to missing embeddings: {ke}")
|
289 |
if retry_count == self.max_retries:
|
290 |
logger.error("Max retries reached for hard negative search due to missing embeddings.")
|
291 |
-
return
|
292 |
-
# Perform memory cleanup
|
293 |
gc.collect()
|
294 |
if tf.config.list_physical_devices('GPU'):
|
295 |
tf.keras.backend.clear_session()
|
|
|
296 |
except Exception as e:
|
297 |
retry_count += 1
|
298 |
logger.warning(f"Hard negative search attempt {retry_count} failed: {e}")
|
299 |
if retry_count == self.max_retries:
|
300 |
logger.error("Max retries reached for hard negative search.")
|
301 |
-
return
|
302 |
-
# Perform memory cleanup
|
303 |
gc.collect()
|
304 |
if tf.config.list_physical_devices('GPU'):
|
305 |
tf.keras.backend.clear_session()
|
306 |
|
307 |
-
def
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
308 |
"""
|
309 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
310 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
311 |
Args:
|
312 |
-
query
|
313 |
-
context
|
314 |
|
315 |
Returns:
|
316 |
-
np.ndarray
|
317 |
"""
|
318 |
-
# Prepare
|
319 |
if context:
|
320 |
-
|
321 |
-
|
322 |
-
|
323 |
-
|
324 |
-
|
325 |
-
|
326 |
-
f" {
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
327 |
else:
|
328 |
-
|
329 |
-
|
330 |
-
|
|
|
|
|
|
|
331 |
encodings = self.tokenizer(
|
332 |
-
[
|
333 |
padding='max_length',
|
334 |
truncation=True,
|
335 |
max_length=self.max_length,
|
336 |
-
return_tensors='np' #
|
337 |
)
|
338 |
input_ids = encodings['input_ids']
|
339 |
|
340 |
-
#
|
341 |
max_id = np.max(input_ids)
|
342 |
-
|
343 |
-
|
344 |
-
|
345 |
-
logger.error(f"Token ID {max_id} exceeds the vocabulary size {new_vocab_size}.")
|
346 |
raise ValueError("Token ID exceeds vocabulary size.")
|
347 |
|
348 |
-
# Get embeddings from the
|
349 |
embeddings = self.encoder(input_ids, training=False).numpy()
|
350 |
-
|
351 |
-
|
352 |
-
|
353 |
-
|
354 |
-
return embeddings[0] # Return as a 1D array
|
355 |
|
356 |
-
def encode_responses(
|
|
|
|
|
|
|
|
|
357 |
"""
|
358 |
-
Encode
|
359 |
|
360 |
Args:
|
361 |
-
responses
|
362 |
-
context
|
363 |
|
364 |
Returns:
|
365 |
-
np.ndarray
|
366 |
"""
|
367 |
-
#
|
|
|
368 |
if context:
|
369 |
-
|
370 |
-
|
371 |
-
|
372 |
-
|
373 |
-
|
374 |
-
|
375 |
-
|
376 |
-
|
377 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
378 |
else:
|
379 |
-
|
380 |
-
|
381 |
-
|
|
|
382 |
]
|
383 |
-
|
384 |
-
# Tokenize
|
385 |
encodings = self.tokenizer(
|
386 |
-
|
387 |
padding='max_length',
|
388 |
truncation=True,
|
389 |
max_length=self.max_length,
|
390 |
-
return_tensors='np'
|
391 |
)
|
392 |
input_ids = encodings['input_ids']
|
393 |
-
|
394 |
-
#
|
395 |
max_id = np.max(input_ids)
|
396 |
-
|
397 |
-
|
398 |
-
|
399 |
-
logger.error(f"Token ID {max_id} exceeds the vocabulary size {new_vocab_size}.")
|
400 |
raise ValueError("Token ID exceeds vocabulary size.")
|
401 |
-
|
402 |
-
#
|
403 |
embeddings = self.encoder(input_ids, training=False).numpy()
|
404 |
-
|
405 |
-
|
406 |
-
faiss.normalize_L2(embeddings)
|
407 |
-
|
408 |
return embeddings.astype('float32')
|
409 |
|
410 |
def prepare_and_save_data(self, dialogues: List[dict], tf_record_path: str, batch_size: int = 32):
|
|
|
8 |
from tqdm import tqdm
|
9 |
import json
|
10 |
from pathlib import Path
|
11 |
+
from typing import Union, Optional, Dict, List, Tuple, Generator
|
12 |
from transformers import AutoTokenizer
|
13 |
from typing import List, Tuple, Generator
|
14 |
from transformers import AutoTokenizer
|
15 |
from gpu_monitor import GPUMemoryMonitor
|
16 |
+
import random
|
17 |
|
18 |
from logger_config import config_logger
|
19 |
logger = config_logger(__name__)
|
|
|
28 |
response_pool: List[str],
|
29 |
max_length: int,
|
30 |
query_embeddings_cache: dict,
|
31 |
+
neg_samples: int = 5,
|
32 |
index_type: str = 'IndexFlatIP',
|
33 |
nlist: int = 100,
|
34 |
max_retries: int = 3
|
|
|
48 |
self.max_batch_size = 16 if len(response_pool) < 100 else 64
|
49 |
self.memory_monitor = GPUMemoryMonitor()
|
50 |
self.max_retries = max_retries
|
51 |
+
|
52 |
+
# Build a quick text->domain map for O(1) domain lookups
|
53 |
+
self._text_domain_map = {}
|
54 |
+
self.build_text_to_domain_map()
|
55 |
|
56 |
if os.path.exists(index_file_path):
|
57 |
logger.info(f"Loading existing FAISS index from {index_file_path}...")
|
|
|
140 |
|
141 |
logger.info(f"Loaded {len(dialogues)} dialogues.")
|
142 |
return dialogues
|
143 |
+
|
144 |
+
def collect_responses_with_domain(self, dialogues: List[dict]) -> List[Dict[str, str]]:
|
145 |
+
"""
|
146 |
+
Extract unique assistant responses from dialogues, along with the domain.
|
147 |
+
Returns a list of dicts: [{'domain': str, 'text': str}, ...]
|
148 |
+
"""
|
149 |
+
response_set = set() # We'll store (domain, text) tuples to keep them unique
|
150 |
+
results = []
|
151 |
+
|
152 |
for dialogue in tqdm(dialogues, desc="Processing Dialogues", unit="dialogue"):
|
153 |
+
# domain is stored at the top level in your new JSON format
|
154 |
+
domain = dialogue.get('domain', 'other')
|
155 |
turns = dialogue.get('turns', [])
|
156 |
for turn in turns:
|
157 |
speaker = turn.get('speaker')
|
158 |
text = turn.get('text', '').strip()
|
159 |
if speaker == 'assistant' and text:
|
|
|
160 |
if len(text) <= self.max_length:
|
161 |
+
# Use a tuple as a "set" key to ensure uniqueness
|
162 |
+
key = (domain, text)
|
163 |
+
if key not in response_set:
|
164 |
+
response_set.add(key)
|
165 |
+
results.append({
|
166 |
+
"domain": domain,
|
167 |
+
"text": text
|
168 |
+
})
|
169 |
+
|
170 |
+
logger.info(f"Collected {len(results)} unique assistant responses from dialogues.")
|
171 |
+
return results
|
172 |
+
# def collect_responses(self, dialogues: List[dict]) -> List[str]:
|
173 |
+
# """Extract unique assistant responses from dialogues."""
|
174 |
+
# response_set = set()
|
175 |
+
# for dialogue in tqdm(dialogues, desc="Processing Dialogues", unit="dialogue"):
|
176 |
+
# turns = dialogue.get('turns', [])
|
177 |
+
# for turn in turns:
|
178 |
+
# speaker = turn.get('speaker')
|
179 |
+
# text = turn.get('text', '').strip()
|
180 |
+
# if speaker == 'assistant' and text:
|
181 |
+
# # Ensure we don't exclude valid shorter responses
|
182 |
+
# if len(text) <= self.max_length:
|
183 |
+
# response_set.add(text)
|
184 |
+
# logger.info(f"Collected {len(response_set)} unique assistant responses from dialogues.")
|
185 |
+
# return list(response_set)
|
186 |
|
187 |
def _extract_pairs_from_dialogue(self, dialogue: dict) -> List[Tuple[str, str]]:
|
188 |
"""Extract query-response pairs from a dialogue."""
|
|
|
206 |
|
207 |
def compute_and_index_response_embeddings(self):
|
208 |
"""
|
209 |
+
Computes embeddings for the response pool and adds them to the FAISS index.
|
210 |
+
self.response_pool is now List[Dict[str, str]] with keys "domain" and "text".
|
211 |
"""
|
212 |
logger.info("Computing embeddings for the response pool...")
|
213 |
|
214 |
+
# Extract just the assistant text
|
215 |
+
texts = [resp["text"] for resp in self.response_pool]
|
216 |
+
logger.debug(f"Total texts to embed: {len(texts)}")
|
217 |
+
|
218 |
+
batch_size = getattr(self, 'embedding_batch_size', 64)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
219 |
embeddings = []
|
220 |
|
221 |
+
with tqdm(total=len(texts), desc="Computing Embeddings", unit="response") as pbar:
|
222 |
+
for i in range(0, len(texts), batch_size):
|
223 |
+
batch_texts = texts[i:i+batch_size]
|
224 |
+
encodings = self.tokenizer(
|
225 |
+
batch_texts,
|
226 |
+
padding=True,
|
227 |
+
truncation=True,
|
228 |
+
max_length=self.max_length,
|
229 |
+
return_tensors='tf'
|
230 |
+
)
|
231 |
+
batch_embeds = self.encoder(encodings['input_ids'], training=False).numpy()
|
232 |
+
|
233 |
+
embeddings.append(batch_embeds)
|
234 |
+
pbar.update(len(batch_texts))
|
235 |
+
|
236 |
+
# Combine embeddings and add to FAISS
|
237 |
+
all_embeddings = np.vstack(embeddings).astype(np.float32)
|
238 |
+
logger.info(f"Adding {len(all_embeddings)} response embeddings to FAISS index...")
|
239 |
+
self.index.add(all_embeddings)
|
240 |
+
|
241 |
+
# For debugging or repeated usage, you might store them:
|
242 |
+
self.response_embeddings = all_embeddings
|
243 |
+
logger.info(f"FAISS index now has {self.index.ntotal} vectors.")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
244 |
|
245 |
def _find_hard_negatives_batch(self, queries: List[str], positives: List[str]) -> List[List[str]]:
|
246 |
+
"""
|
247 |
+
Find hard negatives for a batch of queries using FAISS search.
|
248 |
+
Falls back to random negatives if we run out of tries or can't find enough.
|
249 |
+
Uses domain-based fallback if possible.
|
250 |
+
"""
|
251 |
+
import random
|
252 |
+
import gc
|
253 |
+
|
254 |
retry_count = 0
|
255 |
total_responses = len(self.response_pool)
|
256 |
+
k = self.neg_samples # Number of negatives to retrieve from FAISS
|
257 |
+
batch_size = 128
|
|
|
258 |
|
259 |
while retry_count < self.max_retries:
|
260 |
try:
|
261 |
+
# 1) Build query embeddings from the cache
|
|
|
262 |
query_embeddings = []
|
263 |
for i in range(0, len(queries), batch_size):
|
264 |
+
sub_queries = queries[i : i + batch_size]
|
265 |
+
sub_embeds = [self.query_embeddings_cache[q] for q in sub_queries]
|
266 |
+
sub_embeds = np.vstack(sub_embeds).astype(np.float32)
|
267 |
+
faiss.normalize_L2(sub_embeds) # If not already normalized
|
268 |
+
query_embeddings.append(sub_embeds)
|
|
|
|
|
269 |
|
270 |
+
query_embeddings = np.vstack(query_embeddings)
|
271 |
query_embeddings = np.ascontiguousarray(query_embeddings)
|
272 |
|
273 |
+
# 2) Perform FAISS search
|
274 |
distances, indices = self.index.search(query_embeddings, k)
|
275 |
|
276 |
all_negatives = []
|
277 |
+
# For each query, find domain from the corresponding positive if possible
|
278 |
+
for query_indices, query_text, pos_text in zip(indices, queries, positives):
|
279 |
+
negative_list = []
|
280 |
+
seen = {pos_text.strip()}
|
281 |
+
|
282 |
+
# Attempt to detect the domain of the positive text
|
283 |
+
domain_of_positive = self._detect_domain_for_text(pos_text)
|
284 |
|
285 |
+
# Collect hard negatives from FAISS
|
286 |
for idx in query_indices:
|
287 |
+
if 0 <= idx < total_responses:
|
288 |
+
candidate_dict = self.response_pool[idx] # e.g. {domain, text}
|
289 |
+
candidate_text = candidate_dict["text"].strip()
|
290 |
+
if candidate_text and candidate_text not in seen:
|
291 |
+
seen.add(candidate_text)
|
292 |
+
negative_list.append(candidate_text)
|
293 |
+
if len(negative_list) >= self.neg_samples:
|
294 |
break
|
295 |
|
296 |
+
# If not enough negatives, fallback to random domain-based
|
297 |
+
if len(negative_list) < self.neg_samples:
|
298 |
+
needed = self.neg_samples - len(negative_list)
|
299 |
+
# Pass in domain_of_positive to your updated `_get_random_negatives(...)`
|
300 |
+
random_negatives = self._get_random_negatives(needed, seen, domain=domain_of_positive)
|
301 |
+
negative_list.extend(random_negatives)
|
302 |
|
303 |
+
all_negatives.append(negative_list)
|
304 |
|
305 |
return all_negatives
|
306 |
|
|
|
309 |
logger.warning(f"Hard negative search attempt {retry_count} failed due to missing embeddings: {ke}")
|
310 |
if retry_count == self.max_retries:
|
311 |
logger.error("Max retries reached for hard negative search due to missing embeddings.")
|
312 |
+
return self._fallback_negatives(queries, positives, reason="key_error")
|
|
|
313 |
gc.collect()
|
314 |
if tf.config.list_physical_devices('GPU'):
|
315 |
tf.keras.backend.clear_session()
|
316 |
+
|
317 |
except Exception as e:
|
318 |
retry_count += 1
|
319 |
logger.warning(f"Hard negative search attempt {retry_count} failed: {e}")
|
320 |
if retry_count == self.max_retries:
|
321 |
logger.error("Max retries reached for hard negative search.")
|
322 |
+
return self._fallback_negatives(queries, positives, reason="generic_error")
|
|
|
323 |
gc.collect()
|
324 |
if tf.config.list_physical_devices('GPU'):
|
325 |
tf.keras.backend.clear_session()
|
326 |
|
327 |
+
def _detect_domain_for_text(self, text: str) -> Optional[str]:
|
328 |
+
"""
|
329 |
+
O(1) domain detection by looking up text in our dictionary.
|
330 |
+
Returns the domain if found, else None.
|
331 |
+
"""
|
332 |
+
stripped_text = text.strip()
|
333 |
+
return self._text_domain_map.get(stripped_text, None)
|
334 |
+
|
335 |
+
def _get_random_negatives(self, needed: int, seen: set, domain: Optional[str] = None) -> List[str]:
|
336 |
+
"""
|
337 |
+
Return a list of 'needed' random negative texts from the same domain if possible,
|
338 |
+
otherwise fallback to all-domain.
|
339 |
+
"""
|
340 |
+
# 1) Filter response_pool for domain if provided
|
341 |
+
if domain:
|
342 |
+
domain_texts = [r["text"] for r in self.response_pool if r["domain"] == domain]
|
343 |
+
# fallback to entire set if insufficient domain_texts
|
344 |
+
if len(domain_texts) < needed * 2: # pick some threshold
|
345 |
+
domain_texts = [r["text"] for r in self.response_pool]
|
346 |
+
else:
|
347 |
+
domain_texts = [r["text"] for r in self.response_pool]
|
348 |
+
|
349 |
+
negatives = []
|
350 |
+
tries = 0
|
351 |
+
max_tries = needed * 10
|
352 |
+
while len(negatives) < needed and tries < max_tries:
|
353 |
+
tries += 1
|
354 |
+
candidate = random.choice(domain_texts).strip()
|
355 |
+
if candidate and candidate not in seen:
|
356 |
+
negatives.append(candidate)
|
357 |
+
seen.add(candidate)
|
358 |
+
|
359 |
+
# If still not enough, we do the best we can
|
360 |
+
if len(negatives) < needed:
|
361 |
+
logger.warning(f"Could not find enough domain-based random negatives; needed {needed}, got {len(negatives)}.")
|
362 |
+
|
363 |
+
return negatives
|
364 |
+
|
365 |
+
def _fallback_negatives(self, queries: List[str], positives: List[str], reason: str) -> List[List[str]]:
|
366 |
+
"""
|
367 |
+
Called if FAISS fails or embeddings are missing.
|
368 |
+
We use entirely random negatives for each query, ignoring FAISS,
|
369 |
+
but still attempt domain-based selection if possible.
|
370 |
"""
|
371 |
+
logger.error(f"Falling back to random negatives due to: {reason}")
|
372 |
+
all_negatives = []
|
373 |
+
|
374 |
+
for pos_text in positives:
|
375 |
+
# Build a 'seen' set with the positive
|
376 |
+
seen = {pos_text.strip()}
|
377 |
+
|
378 |
+
# Attempt to detect the domain of the positive text
|
379 |
+
domain_of_positive = self._detect_domain_for_text(pos_text)
|
380 |
+
|
381 |
+
# Use domain-based random negatives if available
|
382 |
+
negs = self._get_random_negatives(self.neg_samples, seen, domain=domain_of_positive)
|
383 |
+
all_negatives.append(negs)
|
384 |
|
385 |
+
return all_negatives
|
386 |
+
|
387 |
+
def build_text_to_domain_map(self):
|
388 |
+
"""
|
389 |
+
Build an O(1) lookup dict: text -> domain,
|
390 |
+
so we don't have to scan the entire self.response_pool each time.
|
391 |
+
"""
|
392 |
+
self._text_domain_map = {}
|
393 |
+
|
394 |
+
for item in self.response_pool:
|
395 |
+
# e.g., item = {"domain": "restaurant", "text": "some text..."}
|
396 |
+
stripped_text = item["text"].strip()
|
397 |
+
domain = item["domain"]
|
398 |
+
|
399 |
+
# If the same text appears multiple times with the same domain, no big deal.
|
400 |
+
# If it appears with a different domain, you can decide how to handle collisions.
|
401 |
+
if stripped_text in self._text_domain_map:
|
402 |
+
existing_domain = self._text_domain_map[stripped_text]
|
403 |
+
if existing_domain != domain:
|
404 |
+
# Log a warning or decide on a policy:
|
405 |
+
logger.warning(
|
406 |
+
f"Collision detected: text '{stripped_text}' found with domains "
|
407 |
+
f"'{existing_domain}' and '{domain}'. Keeping the first."
|
408 |
+
)
|
409 |
+
# By default, keep the first domain or overwrite. We'll skip overwriting:
|
410 |
+
continue
|
411 |
+
else:
|
412 |
+
# Insert into the dict
|
413 |
+
self._text_domain_map[stripped_text] = domain
|
414 |
+
|
415 |
+
logger.info(f"Built text->domain map with {len(self._text_domain_map)} unique text entries.")
|
416 |
+
|
417 |
+
def encode_query(
|
418 |
+
self,
|
419 |
+
query: str,
|
420 |
+
context: Optional[List[Tuple[str, str]]] = None
|
421 |
+
) -> np.ndarray:
|
422 |
+
"""
|
423 |
+
Encode a user query (and optional conversation context) into an embedding vector.
|
424 |
+
|
425 |
Args:
|
426 |
+
query: The user query.
|
427 |
+
context: Optional conversation history as a list of (user_text, assistant_text).
|
428 |
|
429 |
Returns:
|
430 |
+
np.ndarray of shape [embedding_dim], typically L2-normalized already.
|
431 |
"""
|
432 |
+
# 1) Prepare context (if any) by concatenating user/assistant pairs
|
433 |
if context:
|
434 |
+
# Take the last N turns
|
435 |
+
relevant_history = context[-self.config.max_context_turns:]
|
436 |
+
context_str_parts = []
|
437 |
+
for (u_text, a_text) in relevant_history:
|
438 |
+
context_str_parts.append(
|
439 |
+
f"{self.tokenizer.additional_special_tokens[self.tokenizer.additional_special_tokens.index('<USER>')]} {u_text} "
|
440 |
+
f"{self.tokenizer.additional_special_tokens[self.tokenizer.additional_special_tokens.index('<ASSISTANT>')]} {a_text}"
|
441 |
+
)
|
442 |
+
context_str = " ".join(context_str_parts)
|
443 |
+
|
444 |
+
# Append the user's new query
|
445 |
+
full_query = (
|
446 |
+
f"{context_str} "
|
447 |
+
f"{self.tokenizer.additional_special_tokens[self.tokenizer.additional_special_tokens.index('<USER>')]} {query}"
|
448 |
+
)
|
449 |
else:
|
450 |
+
# Just a single user turn
|
451 |
+
full_query = (
|
452 |
+
f"{self.tokenizer.additional_special_tokens[self.tokenizer.additional_special_tokens.index('<USER>')]} {query}"
|
453 |
+
)
|
454 |
+
|
455 |
+
# 2) Tokenize
|
456 |
encodings = self.tokenizer(
|
457 |
+
[full_query],
|
458 |
padding='max_length',
|
459 |
truncation=True,
|
460 |
max_length=self.max_length,
|
461 |
+
return_tensors='np' # to keep it compatible with FAISS
|
462 |
)
|
463 |
input_ids = encodings['input_ids']
|
464 |
|
465 |
+
# 3) Check for out-of-vocab IDs
|
466 |
max_id = np.max(input_ids)
|
467 |
+
vocab_size = len(self.tokenizer)
|
468 |
+
if max_id >= vocab_size:
|
469 |
+
logger.error(f"Token ID {max_id} exceeds tokenizer vocab size {vocab_size}.")
|
|
|
470 |
raise ValueError("Token ID exceeds vocabulary size.")
|
471 |
|
472 |
+
# 4) Get embeddings from the model
|
473 |
embeddings = self.encoder(input_ids, training=False).numpy()
|
474 |
+
# Typically your custom model already L2-normalizes the final embeddings.
|
475 |
+
|
476 |
+
# 5) Return the single embedding as 1D array
|
477 |
+
return embeddings[0]
|
|
|
478 |
|
479 |
+
def encode_responses(
|
480 |
+
self,
|
481 |
+
responses: List[str],
|
482 |
+
context: Optional[List[Tuple[str, str]]] = None
|
483 |
+
) -> np.ndarray:
|
484 |
"""
|
485 |
+
Encode multiple response texts into embedding vectors.
|
486 |
|
487 |
Args:
|
488 |
+
responses: List of raw assistant responses.
|
489 |
+
context: Optional conversation context (last N turns).
|
490 |
|
491 |
Returns:
|
492 |
+
np.ndarray of shape [num_responses, embedding_dim].
|
493 |
"""
|
494 |
+
# 1) If you want to incorporate context into response encoding
|
495 |
+
# Usually for retrieval we might skip this. But if you want it:
|
496 |
if context:
|
497 |
+
relevant_history = context[-self.config.max_context_turns:]
|
498 |
+
prepared = []
|
499 |
+
for resp in responses:
|
500 |
+
context_str_parts = []
|
501 |
+
for (u_text, a_text) in relevant_history:
|
502 |
+
context_str_parts.append(
|
503 |
+
f"{self.tokenizer.additional_special_tokens[self.tokenizer.additional_special_tokens.index('<USER>')]} {u_text} "
|
504 |
+
f"{self.tokenizer.additional_special_tokens[self.tokenizer.additional_special_tokens.index('<ASSISTANT>')]} {a_text}"
|
505 |
+
)
|
506 |
+
context_str = " ".join(context_str_parts)
|
507 |
+
|
508 |
+
# Now treat resp as an assistant turn
|
509 |
+
full_resp = (
|
510 |
+
f"{context_str} "
|
511 |
+
f"{self.tokenizer.additional_special_tokens[self.tokenizer.additional_special_tokens.index('<ASSISTANT>')]} {resp}"
|
512 |
+
)
|
513 |
+
prepared.append(full_resp)
|
514 |
else:
|
515 |
+
# By default, just mark each response as from the assistant
|
516 |
+
prepared = [
|
517 |
+
f"{self.tokenizer.additional_special_tokens[self.tokenizer.additional_special_tokens.index('<ASSISTANT>')]} {r}"
|
518 |
+
for r in responses
|
519 |
]
|
520 |
+
|
521 |
+
# 2) Tokenize
|
522 |
encodings = self.tokenizer(
|
523 |
+
prepared,
|
524 |
padding='max_length',
|
525 |
truncation=True,
|
526 |
max_length=self.max_length,
|
527 |
+
return_tensors='np'
|
528 |
)
|
529 |
input_ids = encodings['input_ids']
|
530 |
+
|
531 |
+
# 3) Check for out-of-vocab
|
532 |
max_id = np.max(input_ids)
|
533 |
+
vocab_size = len(self.tokenizer)
|
534 |
+
if max_id >= vocab_size:
|
535 |
+
logger.error(f"Token ID {max_id} exceeds tokenizer vocab size {vocab_size}.")
|
|
|
536 |
raise ValueError("Token ID exceeds vocabulary size.")
|
537 |
+
|
538 |
+
# 4) Model forward
|
539 |
embeddings = self.encoder(input_ids, training=False).numpy()
|
540 |
+
# Typically already L2-normalized if your final layer is normalized.
|
541 |
+
|
|
|
|
|
542 |
return embeddings.astype('float32')
|
543 |
|
544 |
def prepare_and_save_data(self, dialogues: List[dict], tf_record_path: str, batch_size: int = 32):
|
validate_model.py
CHANGED
@@ -1,16 +1,17 @@
|
|
1 |
import os
|
2 |
import json
|
|
|
3 |
from chatbot_model import ChatbotConfig, RetrievalChatbot
|
4 |
from response_quality_checker import ResponseQualityChecker
|
5 |
from chatbot_validator import ChatbotValidator
|
6 |
from plotter import Plotter
|
7 |
from environment_setup import EnvironmentSetup
|
8 |
-
|
9 |
from logger_config import config_logger
|
|
|
10 |
logger = config_logger(__name__)
|
11 |
|
12 |
def run_interactive_chat(chatbot, quality_checker):
|
13 |
-
"""Separate function for interactive chat loop"""
|
14 |
while True:
|
15 |
try:
|
16 |
user_input = input("You: ")
|
@@ -18,7 +19,7 @@ def run_interactive_chat(chatbot, quality_checker):
|
|
18 |
print("\nAssistant: Goodbye!")
|
19 |
break
|
20 |
|
21 |
-
if user_input.lower() in [
|
22 |
print("Assistant: Goodbye!")
|
23 |
break
|
24 |
|
@@ -26,69 +27,97 @@ def run_interactive_chat(chatbot, quality_checker):
|
|
26 |
query=user_input,
|
27 |
conversation_history=None,
|
28 |
quality_checker=quality_checker,
|
29 |
-
top_k=
|
30 |
)
|
31 |
|
32 |
print(f"Assistant: {response}")
|
33 |
|
34 |
-
if
|
|
|
35 |
print("\nAlternative responses:")
|
36 |
for resp, score in candidates[1:4]:
|
37 |
print(f"Score: {score:.4f} - {resp}")
|
38 |
else:
|
39 |
print("\n[Low Confidence]: Consider rephrasing your query for better assistance.")
|
40 |
|
41 |
-
# TODO:
|
42 |
def validate_chatbot():
|
43 |
# Initialize environment
|
44 |
env = EnvironmentSetup()
|
45 |
env.initialize()
|
46 |
|
47 |
-
MODEL_DIR =
|
48 |
-
FAISS_INDICES_DIR = os.path.join(MODEL_DIR,
|
49 |
-
FAISS_INDEX_PRODUCTION_PATH = os.path.join(FAISS_INDICES_DIR,
|
50 |
-
FAISS_INDEX_TEST_PATH = os.path.join(FAISS_INDICES_DIR,
|
51 |
-
|
52 |
-
|
53 |
-
ENVIRONMENT =
|
54 |
-
if ENVIRONMENT ==
|
55 |
FAISS_INDEX_PATH = FAISS_INDEX_TEST_PATH
|
56 |
-
RESPONSE_POOL_PATH =
|
57 |
else:
|
58 |
FAISS_INDEX_PATH = FAISS_INDEX_PRODUCTION_PATH
|
59 |
-
RESPONSE_POOL_PATH =
|
60 |
-
|
61 |
-
# Load config
|
62 |
-
|
63 |
-
|
64 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
65 |
try:
|
66 |
-
chatbot = RetrievalChatbot(
|
67 |
-
logger.info("RetrievalChatbot
|
68 |
except Exception as e:
|
69 |
-
logger.error(f"Failed to
|
70 |
return
|
71 |
|
72 |
-
#
|
73 |
if not os.path.exists(FAISS_INDEX_PATH) or not os.path.exists(RESPONSE_POOL_PATH):
|
74 |
logger.error("FAISS index or response pool file is missing.")
|
75 |
return
|
76 |
|
|
|
77 |
try:
|
|
|
78 |
chatbot.data_pipeline.load_faiss_index(FAISS_INDEX_PATH)
|
79 |
logger.info(f"FAISS index loaded from {FAISS_INDEX_PATH}.")
|
80 |
-
|
81 |
-
|
|
|
|
|
|
|
|
|
|
|
82 |
chatbot.data_pipeline.response_pool = json.load(f)
|
83 |
logger.info(f"Response pool loaded from {RESPONSE_POOL_PATH}.")
|
84 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
85 |
chatbot.data_pipeline.validate_faiss_index()
|
86 |
logger.info("FAISS index and response pool validated successfully.")
|
|
|
87 |
except Exception as e:
|
88 |
-
logger.error(f"Failed to load FAISS index: {e}")
|
89 |
return
|
90 |
|
91 |
-
#
|
92 |
quality_checker = ResponseQualityChecker(data_pipeline=chatbot.data_pipeline)
|
93 |
validator = ChatbotValidator(chatbot=chatbot, quality_checker=quality_checker)
|
94 |
logger.info("ResponseQualityChecker and ChatbotValidator initialized.")
|
@@ -101,17 +130,17 @@ def validate_chatbot():
|
|
101 |
logger.error(f"Validation process failed: {e}")
|
102 |
return
|
103 |
|
104 |
-
# Plot
|
105 |
-
try:
|
106 |
-
|
107 |
-
|
108 |
-
|
109 |
-
except Exception as e:
|
110 |
-
|
111 |
|
112 |
-
# Run interactive chat
|
113 |
-
logger.info("\nStarting interactive chat session...")
|
114 |
-
run_interactive_chat(chatbot, quality_checker)
|
115 |
|
116 |
-
if __name__ ==
|
117 |
-
validate_chatbot()
|
|
|
1 |
import os
|
2 |
import json
|
3 |
+
|
4 |
from chatbot_model import ChatbotConfig, RetrievalChatbot
|
5 |
from response_quality_checker import ResponseQualityChecker
|
6 |
from chatbot_validator import ChatbotValidator
|
7 |
from plotter import Plotter
|
8 |
from environment_setup import EnvironmentSetup
|
|
|
9 |
from logger_config import config_logger
|
10 |
+
|
11 |
logger = config_logger(__name__)
|
12 |
|
13 |
def run_interactive_chat(chatbot, quality_checker):
|
14 |
+
"""Separate function for interactive chat loop."""
|
15 |
while True:
|
16 |
try:
|
17 |
user_input = input("You: ")
|
|
|
19 |
print("\nAssistant: Goodbye!")
|
20 |
break
|
21 |
|
22 |
+
if user_input.lower() in ["quit", "exit", "bye"]:
|
23 |
print("Assistant: Goodbye!")
|
24 |
break
|
25 |
|
|
|
27 |
query=user_input,
|
28 |
conversation_history=None,
|
29 |
quality_checker=quality_checker,
|
30 |
+
top_k=10
|
31 |
)
|
32 |
|
33 |
print(f"Assistant: {response}")
|
34 |
|
35 |
+
# Show alternative responses if confident
|
36 |
+
if metrics.get("is_confident", False):
|
37 |
print("\nAlternative responses:")
|
38 |
for resp, score in candidates[1:4]:
|
39 |
print(f"Score: {score:.4f} - {resp}")
|
40 |
else:
|
41 |
print("\n[Low Confidence]: Consider rephrasing your query for better assistance.")
|
42 |
|
|
|
43 |
def validate_chatbot():
|
44 |
# Initialize environment
|
45 |
env = EnvironmentSetup()
|
46 |
env.initialize()
|
47 |
|
48 |
+
MODEL_DIR = "new_iteration/data_prep_iterative_models"
|
49 |
+
FAISS_INDICES_DIR = os.path.join(MODEL_DIR, "faiss_indices")
|
50 |
+
FAISS_INDEX_PRODUCTION_PATH = os.path.join(FAISS_INDICES_DIR, "faiss_index_production.index")
|
51 |
+
FAISS_INDEX_TEST_PATH = os.path.join(FAISS_INDICES_DIR, "faiss_index_test.index")
|
52 |
+
|
53 |
+
# Toggle 'production' or 'test' env
|
54 |
+
ENVIRONMENT = "production"
|
55 |
+
if ENVIRONMENT == "test":
|
56 |
FAISS_INDEX_PATH = FAISS_INDEX_TEST_PATH
|
57 |
+
RESPONSE_POOL_PATH = FAISS_INDEX_TEST_PATH.replace(".index", "_responses.json")
|
58 |
else:
|
59 |
FAISS_INDEX_PATH = FAISS_INDEX_PRODUCTION_PATH
|
60 |
+
RESPONSE_POOL_PATH = FAISS_INDEX_PRODUCTION_PATH.replace(".index", "_responses.json")
|
61 |
+
|
62 |
+
# Load the config
|
63 |
+
config_path = os.path.join(MODEL_DIR, "config.json")
|
64 |
+
if os.path.exists(config_path):
|
65 |
+
with open(config_path, "r", encoding="utf-8") as f:
|
66 |
+
config_dict = json.load(f)
|
67 |
+
config = ChatbotConfig.from_dict(config_dict)
|
68 |
+
logger.info(f"Loaded ChatbotConfig from {config_path}")
|
69 |
+
else:
|
70 |
+
config = ChatbotConfig()
|
71 |
+
logger.warning("No config.json found. Using default ChatbotConfig.")
|
72 |
+
|
73 |
+
# Load RetrievalChatbot in 'inference' mode using the classmethod
|
74 |
+
# This:
|
75 |
+
# - Loads shared_encoder submodule
|
76 |
+
# - Loads encoder_custom_weights.weights.h5
|
77 |
+
# - Loads tokenizer
|
78 |
+
# - Prepares the model for inference
|
79 |
try:
|
80 |
+
chatbot = RetrievalChatbot.load_model(load_dir=MODEL_DIR, mode="inference")
|
81 |
+
logger.info("RetrievalChatbot loaded in 'inference' mode successfully.")
|
82 |
except Exception as e:
|
83 |
+
logger.error(f"Failed to load RetrievalChatbot: {e}")
|
84 |
return
|
85 |
|
86 |
+
# Confirm FAISS index & response pool exist
|
87 |
if not os.path.exists(FAISS_INDEX_PATH) or not os.path.exists(RESPONSE_POOL_PATH):
|
88 |
logger.error("FAISS index or response pool file is missing.")
|
89 |
return
|
90 |
|
91 |
+
# Load specific FAISS index and response pool
|
92 |
try:
|
93 |
+
# Even though load_model might auto-load an index, we override here with the specific file
|
94 |
chatbot.data_pipeline.load_faiss_index(FAISS_INDEX_PATH)
|
95 |
logger.info(f"FAISS index loaded from {FAISS_INDEX_PATH}.")
|
96 |
+
|
97 |
+
print("FAISS dimensions:", chatbot.data_pipeline.index.d)
|
98 |
+
print("FAISS index type:", type(chatbot.data_pipeline.index))
|
99 |
+
print("FAISS index total vectors:", chatbot.data_pipeline.index.ntotal)
|
100 |
+
print("FAISS is_trained:", chatbot.data_pipeline.index.is_trained)
|
101 |
+
|
102 |
+
with open(RESPONSE_POOL_PATH, "r", encoding="utf-8") as f:
|
103 |
chatbot.data_pipeline.response_pool = json.load(f)
|
104 |
logger.info(f"Response pool loaded from {RESPONSE_POOL_PATH}.")
|
105 |
+
|
106 |
+
print("Sample from response pool (first 10):")
|
107 |
+
for i, response in enumerate(chatbot.data_pipeline.response_pool[:10]):
|
108 |
+
print(f"{i}: {response}")
|
109 |
+
|
110 |
+
print("\nTotal responses in pool:", len(chatbot.data_pipeline.response_pool))
|
111 |
+
|
112 |
+
# Validate dimension consistency
|
113 |
chatbot.data_pipeline.validate_faiss_index()
|
114 |
logger.info("FAISS index and response pool validated successfully.")
|
115 |
+
|
116 |
except Exception as e:
|
117 |
+
logger.error(f"Failed to load or validate FAISS index: {e}")
|
118 |
return
|
119 |
|
120 |
+
# Init QualityChecker and Validator
|
121 |
quality_checker = ResponseQualityChecker(data_pipeline=chatbot.data_pipeline)
|
122 |
validator = ChatbotValidator(chatbot=chatbot, quality_checker=quality_checker)
|
123 |
logger.info("ResponseQualityChecker and ChatbotValidator initialized.")
|
|
|
130 |
logger.error(f"Validation process failed: {e}")
|
131 |
return
|
132 |
|
133 |
+
# Plot metrics
|
134 |
+
# try:
|
135 |
+
# plotter = Plotter(save_dir=env.training_dirs["plots"])
|
136 |
+
# plotter.plot_validation_metrics(validation_metrics)
|
137 |
+
# logger.info("Validation metrics plotted successfully.")
|
138 |
+
# except Exception as e:
|
139 |
+
# logger.error(f"Failed to plot validation metrics: {e}")
|
140 |
|
141 |
+
# Run interactive chat loop
|
142 |
+
# logger.info("\nStarting interactive chat session...")
|
143 |
+
# run_interactive_chat(chatbot, quality_checker)
|
144 |
|
145 |
+
if __name__ == "__main__":
|
146 |
+
validate_chatbot()
|