JoeArmani
commited on
Commit
Β·
d7fc7a7
1
Parent(s):
c111c20
more structural updates
Browse files- .gitignore +2 -0
- chatbot_model.py +145 -171
- chatbot_validator.py +1 -1
- {data_augmentation β data_augmentation_code}/augmentation_processing_pipeline.py +0 -0
- {data_augmentation β data_augmentation_code}/back_translator.py +0 -0
- {data_augmentation β data_augmentation_code}/dialogue_augmenter.py +0 -0
- {data_augmentation β data_augmentation_code}/main.py +0 -0
- {data_augmentation β data_augmentation_code}/paraphraser.py +0 -0
- {data_augmentation β data_augmentation_code}/pipeline_config.py +0 -0
- {data_augmentation β data_augmentation_code}/quality_metrics.py +0 -0
- {data_augmentation β data_augmentation_code}/schema_guided_dialogue_processor.py +0 -0
- {data_augmentation β data_augmentation_code}/taskmaster_processor.py +0 -0
- validate_model.py β run_chatbot_validation.py +7 -7
- tf_data_pipeline.py +16 -16
.gitignore
CHANGED
|
@@ -187,3 +187,5 @@ new_iteration/cache/*
|
|
| 187 |
new_iteration/data_prep_iterative_models/*
|
| 188 |
new_iteration/training_data/*
|
| 189 |
new_iteration/processed_outputs/*
|
|
|
|
|
|
|
|
|
| 187 |
new_iteration/data_prep_iterative_models/*
|
| 188 |
new_iteration/training_data/*
|
| 189 |
new_iteration/processed_outputs/*
|
| 190 |
+
raw_datasets/*
|
| 191 |
+
|
chatbot_model.py
CHANGED
|
@@ -24,25 +24,25 @@ logger = config_logger(__name__)
|
|
| 24 |
|
| 25 |
@dataclass
|
| 26 |
class ChatbotConfig:
|
| 27 |
-
"""
|
| 28 |
max_context_token_limit: int = 512
|
| 29 |
embedding_dim: int = 768
|
| 30 |
encoder_units: int = 256
|
| 31 |
num_attention_heads: int = 8
|
| 32 |
dropout_rate: float = 0.2
|
| 33 |
l2_reg_weight: float = 0.001
|
| 34 |
-
learning_rate: float = 0.
|
| 35 |
min_text_length: int = 3
|
| 36 |
-
max_context_turns: int =
|
| 37 |
warmup_steps: int = 200
|
| 38 |
pretrained_model: str = 'distilbert-base-uncased'
|
| 39 |
cross_encoder_model: str = 'cross-encoder/ms-marco-MiniLM-L-12-v2'
|
|
|
|
| 40 |
dtype: str = 'float32'
|
| 41 |
freeze_embeddings: bool = False
|
| 42 |
embedding_batch_size: int = 64
|
| 43 |
search_batch_size: int = 64
|
| 44 |
max_batch_size: int = 64
|
| 45 |
-
neg_samples: int = 10
|
| 46 |
max_retries: int = 3
|
| 47 |
|
| 48 |
def to_dict(self) -> Dict:
|
|
@@ -57,7 +57,7 @@ class ChatbotConfig:
|
|
| 57 |
if k in cls.__dataclass_fields__})
|
| 58 |
|
| 59 |
class EncoderModel(tf.keras.Model):
|
| 60 |
-
"""Dual encoder model with pretrained embeddings."""
|
| 61 |
def __init__(
|
| 62 |
self,
|
| 63 |
config: ChatbotConfig,
|
|
@@ -71,7 +71,7 @@ class EncoderModel(tf.keras.Model):
|
|
| 71 |
self.pretrained = TFAutoModel.from_pretrained(config.pretrained_model)
|
| 72 |
self._freeze_layers()
|
| 73 |
|
| 74 |
-
# Add
|
| 75 |
self.pooler = tf.keras.layers.GlobalAveragePooling1D()
|
| 76 |
self.projection = tf.keras.layers.Dense(
|
| 77 |
config.embedding_dim,
|
|
@@ -86,7 +86,7 @@ class EncoderModel(tf.keras.Model):
|
|
| 86 |
)
|
| 87 |
|
| 88 |
def _freeze_layers(self):
|
| 89 |
-
"""Freeze layers of the pretrained model
|
| 90 |
if self.config.freeze_embeddings:
|
| 91 |
self.pretrained.trainable = False
|
| 92 |
logger.info("All pretrained layers frozen.")
|
|
@@ -95,29 +95,29 @@ class EncoderModel(tf.keras.Model):
|
|
| 95 |
for i, layer in enumerate(self.pretrained.layers):
|
| 96 |
if isinstance(layer, tf.keras.layers.Layer):
|
| 97 |
if hasattr(layer, 'trainable'):
|
| 98 |
-
# Freeze the first transformer block
|
| 99 |
if i < 1:
|
| 100 |
layer.trainable = False
|
| 101 |
logger.info(f"Layer {i} frozen.")
|
| 102 |
else:
|
| 103 |
layer.trainable = True
|
|
|
|
| 104 |
|
| 105 |
def call(self, inputs: tf.Tensor, training: bool = False) -> tf.Tensor:
|
| 106 |
"""Forward pass."""
|
| 107 |
# Get pretrained embeddings
|
| 108 |
pretrained_outputs = self.pretrained(inputs, training=training)
|
| 109 |
-
x = pretrained_outputs.last_hidden_state
|
| 110 |
|
| 111 |
# Apply pooling, projection, dropout, and normalization
|
| 112 |
-
x = self.pooler(x)
|
| 113 |
-
x = self.projection(x)
|
| 114 |
x = self.dropout(x, training=training)
|
| 115 |
-
x = self.normalize(x)
|
| 116 |
|
| 117 |
return x
|
| 118 |
|
| 119 |
def get_config(self) -> dict:
|
| 120 |
-
"""Return the config
|
| 121 |
config = super().get_config()
|
| 122 |
config.update({
|
| 123 |
"config": self.config.to_dict(),
|
|
@@ -126,7 +126,10 @@ class EncoderModel(tf.keras.Model):
|
|
| 126 |
return config
|
| 127 |
|
| 128 |
class RetrievalChatbot(DeviceAwareModel):
|
| 129 |
-
"""
|
|
|
|
|
|
|
|
|
|
| 130 |
def __init__(
|
| 131 |
self,
|
| 132 |
config: ChatbotConfig,
|
|
@@ -142,7 +145,7 @@ class RetrievalChatbot(DeviceAwareModel):
|
|
| 142 |
self.device = device or self._setup_default_device()
|
| 143 |
self.mode = mode.lower()
|
| 144 |
|
| 145 |
-
# Initialize reranker, summarizer, tokenizer,
|
| 146 |
self.reranker = reranker or self._initialize_reranker()
|
| 147 |
self.tokenizer = self._initialize_tokenizer()
|
| 148 |
self.encoder = self._initialize_encoder()
|
|
@@ -154,14 +157,9 @@ class RetrievalChatbot(DeviceAwareModel):
|
|
| 154 |
config=self.config,
|
| 155 |
tokenizer=self.tokenizer,
|
| 156 |
encoder=self.encoder,
|
| 157 |
-
index_file_path='new_iteration/data_prep_iterative_models/faiss_indices/faiss_index_production.index',
|
| 158 |
response_pool=[],
|
| 159 |
max_length=self.config.max_context_token_limit,
|
| 160 |
query_embeddings_cache={},
|
| 161 |
-
neg_samples=self.config.neg_samples,
|
| 162 |
-
index_type='IndexFlatIP',
|
| 163 |
-
nlist=100, # Not used with IndexFlatIP
|
| 164 |
-
max_retries=self.config.max_retries
|
| 165 |
)
|
| 166 |
|
| 167 |
# Collect unique responses from dialogues
|
|
@@ -197,7 +195,7 @@ class RetrievalChatbot(DeviceAwareModel):
|
|
| 197 |
"""Initialize the Summarizer."""
|
| 198 |
return Summarizer(
|
| 199 |
tokenizer=self.tokenizer,
|
| 200 |
-
model_name=
|
| 201 |
max_summary_length=self.config.max_context_token_limit // 4,
|
| 202 |
device=self.device,
|
| 203 |
max_summary_rounds=2
|
|
@@ -229,17 +227,18 @@ class RetrievalChatbot(DeviceAwareModel):
|
|
| 229 |
new_vocab_size = len(self.tokenizer)
|
| 230 |
encoder.pretrained.resize_token_embeddings(new_vocab_size)
|
| 231 |
logger.info(f"Token embeddings resized to: {new_vocab_size}")
|
|
|
|
| 232 |
return encoder
|
| 233 |
|
| 234 |
def _load_faiss_index_and_responses(self) -> None:
|
| 235 |
"""Load FAISS index and response pool for inference."""
|
| 236 |
try:
|
| 237 |
-
logger.info(f"Loading FAISS index from {self.data_pipeline.
|
| 238 |
-
self.data_pipeline.load_faiss_index(self.data_pipeline.
|
| 239 |
logger.info("FAISS index loaded successfully.")
|
| 240 |
|
| 241 |
-
# Load response pool
|
| 242 |
-
response_pool_path = self.data_pipeline.
|
| 243 |
if os.path.exists(response_pool_path):
|
| 244 |
with open(response_pool_path, 'r', encoding='utf-8') as f:
|
| 245 |
self.data_pipeline.response_pool = json.load(f)
|
|
@@ -263,29 +262,24 @@ class RetrievalChatbot(DeviceAwareModel):
|
|
| 263 |
"""
|
| 264 |
load_dir = Path(load_dir)
|
| 265 |
|
| 266 |
-
#
|
| 267 |
with open(load_dir / "config.json", "r") as f:
|
| 268 |
config = ChatbotConfig.from_dict(json.load(f))
|
| 269 |
|
| 270 |
-
#
|
| 271 |
chatbot = cls(config, mode=mode)
|
| 272 |
|
| 273 |
-
#
|
| 274 |
-
chatbot.encoder.pretrained = TFAutoModel.from_pretrained(
|
| 275 |
-
load_dir / "shared_encoder",
|
| 276 |
-
config=config
|
| 277 |
-
)
|
| 278 |
|
| 279 |
dummy_input = tf.zeros((1, config.max_context_token_limit), dtype=tf.int32)
|
| 280 |
_ = chatbot.encoder(dummy_input, training=False)
|
| 281 |
|
| 282 |
-
#
|
| 283 |
chatbot.tokenizer = AutoTokenizer.from_pretrained(load_dir / "tokenizer")
|
| 284 |
logger.info(f"Models and tokenizer loaded from {load_dir}")
|
| 285 |
|
| 286 |
-
|
| 287 |
-
|
| 288 |
-
# 5) Load the custom top layers' weights
|
| 289 |
custom_weights_path = load_dir / "encoder_custom_weights.weights.h5"
|
| 290 |
if custom_weights_path.exists():
|
| 291 |
chatbot.encoder.load_weights(str(custom_weights_path))
|
|
@@ -293,7 +287,7 @@ class RetrievalChatbot(DeviceAwareModel):
|
|
| 293 |
else:
|
| 294 |
logger.warning(f"No custom encoder weights found at {custom_weights_path}. The top-level projection layer won't have learned parameters.")
|
| 295 |
|
| 296 |
-
#
|
| 297 |
if mode == 'inference':
|
| 298 |
cls._prepare_model_for_inference(chatbot, load_dir)
|
| 299 |
|
|
@@ -301,7 +295,7 @@ class RetrievalChatbot(DeviceAwareModel):
|
|
| 301 |
|
| 302 |
@classmethod
|
| 303 |
def _prepare_model_for_inference(cls, chatbot: 'RetrievalChatbot', load_dir: Path) -> None:
|
| 304 |
-
"""
|
| 305 |
try:
|
| 306 |
# Load FAISS index
|
| 307 |
faiss_path = load_dir / 'faiss_indices/faiss_index_production.index'
|
|
@@ -332,7 +326,7 @@ class RetrievalChatbot(DeviceAwareModel):
|
|
| 332 |
raise
|
| 333 |
|
| 334 |
def save_models(self, save_dir: Union[str, Path]):
|
| 335 |
-
"""Save
|
| 336 |
save_dir = Path(save_dir)
|
| 337 |
save_dir.mkdir(parents=True, exist_ok=True)
|
| 338 |
|
|
@@ -340,21 +334,13 @@ class RetrievalChatbot(DeviceAwareModel):
|
|
| 340 |
with open(save_dir / "config.json", "w") as f:
|
| 341 |
json.dump(self.config.to_dict(), f, indent=2)
|
| 342 |
|
| 343 |
-
# Save the HF DistilBERT submodule
|
| 344 |
self.encoder.pretrained.save_pretrained(save_dir / "shared_encoder")
|
| 345 |
-
|
| 346 |
-
# ALSO save custom top-level layers' weights
|
| 347 |
self.encoder.save_weights(save_dir / "encoder_custom_weights.weights.h5")
|
| 348 |
-
|
| 349 |
-
# Save tokenizer
|
| 350 |
self.tokenizer.save_pretrained(save_dir / "tokenizer")
|
| 351 |
-
|
| 352 |
logger.info(f"Models and tokenizer saved to {save_dir}.")
|
| 353 |
|
| 354 |
-
def
|
| 355 |
-
return 1 / (1 + np.exp(-x))
|
| 356 |
-
|
| 357 |
-
def retrieve_responses_cross_encoder(
|
| 358 |
self,
|
| 359 |
query: str,
|
| 360 |
top_k: int = 10,
|
|
@@ -363,20 +349,20 @@ class RetrievalChatbot(DeviceAwareModel):
|
|
| 363 |
summarize_threshold: int = 512
|
| 364 |
) -> List[Tuple[str, float]]:
|
| 365 |
"""
|
| 366 |
-
Retrieve top-k responses
|
| 367 |
-
and cross-encoder re-ranking.
|
| 368 |
-
|
| 369 |
Args:
|
| 370 |
query: The user's input text.
|
| 371 |
-
top_k: Number of
|
| 372 |
-
reranker: CrossEncoderReranker for refined scoring
|
| 373 |
-
summarizer: Summarizer for long queries
|
| 374 |
-
summarize_threshold: Summarize if
|
| 375 |
-
|
| 376 |
Returns:
|
| 377 |
List of (response_text, final_score).
|
| 378 |
"""
|
| 379 |
-
|
|
|
|
|
|
|
|
|
|
| 380 |
if summarizer and len(query.split()) > summarize_threshold:
|
| 381 |
logger.info(f"Query is long ({len(query.split())} words). Summarizing.")
|
| 382 |
query = summarizer.summarize_text(query)
|
|
@@ -393,17 +379,17 @@ class RetrievalChatbot(DeviceAwareModel):
|
|
| 393 |
|
| 394 |
texts = [item[0] for item in faiss_candidates]
|
| 395 |
|
| 396 |
-
# Re-rank these boosted candidates
|
| 397 |
if not reranker:
|
| 398 |
reranker = CrossEncoderReranker(model_name=self.config.cross_encoder_model)
|
| 399 |
|
|
|
|
| 400 |
ce_logits = reranker.rerank(query, texts, max_length=256)
|
| 401 |
|
| 402 |
-
# Combine
|
| 403 |
final_candidates = []
|
| 404 |
for (resp_text, faiss_score), logit in zip(faiss_candidates, ce_logits):
|
| 405 |
-
ce_prob =
|
| 406 |
-
faiss_norm = (faiss_score + 1)/2.0
|
| 407 |
combined_score = 0.85 * ce_prob + 0.15 * faiss_norm
|
| 408 |
length_adjusted_score = self.length_adjust_score(resp_text, combined_score)
|
| 409 |
|
|
@@ -415,22 +401,22 @@ class RetrievalChatbot(DeviceAwareModel):
|
|
| 415 |
# Return top_k
|
| 416 |
return final_candidates[:top_k]
|
| 417 |
|
| 418 |
-
DOMAIN_KEYWORDS = {
|
| 419 |
-
'restaurant': ['restaurant', 'dining', 'food', 'dine', 'reservation', 'table', 'menu', 'cuisine', 'eat', 'place to eat', 'hungry', 'chef', 'dish', 'meal', 'brunch', 'bistro', 'buffet', 'catering', 'gourmet', 'fast food', 'fine dining', 'takeaway', 'delivery', 'restaurant booking'],
|
| 420 |
-
'movie': ['movie', 'cinema', 'film', 'ticket', 'showtime', 'showing', 'theater', 'flick', 'screening', 'film ticket', 'film show', 'blockbuster', 'premiere', 'trailer', 'director', 'actor', 'actress', 'plot', 'genre', 'screen', 'sequel', 'animation', 'documentary'],
|
| 421 |
-
'ride_share': ['ride', 'taxi', 'uber', 'lyft', 'car service', 'pickup', 'dropoff', 'driver', 'cab', 'hailing', 'rideshare', 'ride hailing', 'carpool', 'chauffeur', 'transit', 'transportation', 'hail ride'],
|
| 422 |
-
'coffee': ['coffee', 'cafΓ©', 'cafe', 'starbucks', 'espresso', 'latte', 'mocha', 'americano', 'barista', 'brew', 'cappuccino', 'macchiato', 'iced coffee', 'cold brew', 'espresso machine', 'coffee shop', 'tea', 'chai', 'java', 'bean', 'roast', 'decaf'],
|
| 423 |
-
'pizza': ['pizza', 'delivery', 'order food', 'pepperoni', 'topping', 'pizzeria', 'slice', 'pie', 'margherita', 'deep dish', 'thin crust', 'cheese', 'oven', 'tossed', 'sauce', 'garlic bread', 'calzone'],
|
| 424 |
-
'auto': ['car', 'vehicle', 'repair', 'maintenance', 'mechanic', 'oil change', 'garage', 'auto shop', 'tire', 'check engine', 'battery', 'transmission', 'brake', 'engine diagnostics', 'carwash', 'detail', 'alignment', 'exhaust', 'spark plug', 'dashboard'],
|
| 425 |
-
}
|
| 426 |
-
|
| 427 |
def extract_keywords(self, query: str) -> List[str]:
|
| 428 |
"""
|
| 429 |
Return any domain keywords present in the query (lowercased).
|
| 430 |
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 431 |
query_lower = query.lower()
|
| 432 |
found = set()
|
| 433 |
-
for domain, kw_list in
|
| 434 |
for kw in kw_list:
|
| 435 |
if kw in query_lower:
|
| 436 |
found.add(kw)
|
|
@@ -456,7 +442,7 @@ class RetrievalChatbot(DeviceAwareModel):
|
|
| 456 |
|
| 457 |
def detect_domain_from_query(self, query: str) -> str:
|
| 458 |
"""
|
| 459 |
-
Detect the domain of the query based on keywords.
|
| 460 |
"""
|
| 461 |
domain_patterns = {
|
| 462 |
'restaurant': r'\b(restaurant|restaurants?|dining|food|foods?|dine|reservation|reservations?|table|tables?|menu|menus?|cuisine|cuisines?|eat|eats?|place\s?to\s?eat|places\s?to\s?eat|hungry|chef|chefs?|dish|dishes?|meal|meals?|fork|forks?|knife|knives?|spoon|spoons?|brunch|bistro|buffet|buffets?|catering|caterings?|gourmet|fast\s?food|fine\s?dining|takeaway|takeaways?|delivery|deliveries|restaurant\s?booking)\b',
|
|
@@ -476,8 +462,7 @@ class RetrievalChatbot(DeviceAwareModel):
|
|
| 476 |
|
| 477 |
def is_numeric_response(self, text: str) -> bool:
|
| 478 |
"""
|
| 479 |
-
Return True if `text` is purely digits
|
| 480 |
-
with optional punctuation like '.' at the end.
|
| 481 |
"""
|
| 482 |
pattern = r'^[\s]*[\d]+([\s.,\d]+)*[\s]*$'
|
| 483 |
return bool(re.match(pattern, text.strip()))
|
|
@@ -486,18 +471,16 @@ class RetrievalChatbot(DeviceAwareModel):
|
|
| 486 |
self,
|
| 487 |
query: str,
|
| 488 |
domain: str = 'other',
|
| 489 |
-
top_k: int =
|
| 490 |
-
boost_factor: float = 1.
|
| 491 |
) -> List[Tuple[str, float]]:
|
| 492 |
"""
|
| 493 |
Retrieve top-k responses from the FAISS index (IndexFlatIP) given a user query.
|
| 494 |
-
|
| 495 |
Args:
|
| 496 |
query (str): The user input text.
|
| 497 |
-
domain (str
|
| 498 |
-
top_k (int
|
| 499 |
-
boost_factor (float, optional): Factor to boost scores for keyword matches.
|
| 500 |
-
|
| 501 |
Returns:
|
| 502 |
List[Tuple[str, float]]: List of (response_text, similarity) sorted by descending similarity.
|
| 503 |
"""
|
|
@@ -508,7 +491,7 @@ class RetrievalChatbot(DeviceAwareModel):
|
|
| 508 |
# Search the index
|
| 509 |
distances, indices = self.data_pipeline.index.search(q_emb_np, top_k * 10)
|
| 510 |
|
| 511 |
-
# IndexFlatIP: 'distances' are inner products (cosine similarities for normalized vectors)
|
| 512 |
candidates = []
|
| 513 |
for rank, idx in enumerate(indices[0]):
|
| 514 |
if idx < 0:
|
|
@@ -545,8 +528,7 @@ class RetrievalChatbot(DeviceAwareModel):
|
|
| 545 |
boosted = []
|
| 546 |
for (resp_text, resp_domain, score) in in_domain:
|
| 547 |
new_score = score
|
| 548 |
-
# If the domain is known AND the response text
|
| 549 |
-
# shares any query keywords, apply a small boost
|
| 550 |
if query_keywords and any(kw in resp_text.lower() for kw in query_keywords):
|
| 551 |
new_score *= boost_factor
|
| 552 |
|
|
@@ -558,7 +540,7 @@ class RetrievalChatbot(DeviceAwareModel):
|
|
| 558 |
# Sort boosted responses
|
| 559 |
boosted.sort(key=lambda x: x[1], reverse=True)
|
| 560 |
|
| 561 |
-
# Debug
|
| 562 |
# for resp, score in boosted[:100]:
|
| 563 |
# logger.debug(f"Candidate: '{resp}' with score {score}")
|
| 564 |
|
|
@@ -572,8 +554,7 @@ class RetrievalChatbot(DeviceAwareModel):
|
|
| 572 |
top_k: int = 10,
|
| 573 |
) -> Tuple[str, List[Tuple[str, float]], Dict[str, Any]]:
|
| 574 |
"""
|
| 575 |
-
|
| 576 |
-
if self.reranker is available.
|
| 577 |
"""
|
| 578 |
@self.run_on_device
|
| 579 |
def get_response(self_arg, query_arg):
|
|
@@ -581,7 +562,7 @@ class RetrievalChatbot(DeviceAwareModel):
|
|
| 581 |
conversation_str = self_arg._build_conversation_context(query_arg, conversation_history)
|
| 582 |
|
| 583 |
# Retrieve and re-rank
|
| 584 |
-
results = self_arg.
|
| 585 |
query=conversation_str,
|
| 586 |
top_k=top_k,
|
| 587 |
reranker=self_arg.reranker,
|
|
@@ -605,7 +586,9 @@ class RetrievalChatbot(DeviceAwareModel):
|
|
| 605 |
query: str,
|
| 606 |
conversation_history: Optional[List[Tuple[str, str]]]
|
| 607 |
) -> str:
|
| 608 |
-
"""
|
|
|
|
|
|
|
| 609 |
if not conversation_history:
|
| 610 |
return f"{self.tokenizer.additional_special_tokens[self.tokenizer.additional_special_tokens.index('<USER>')]} {query}"
|
| 611 |
|
|
@@ -636,12 +619,12 @@ class RetrievalChatbot(DeviceAwareModel):
|
|
| 636 |
) -> None:
|
| 637 |
"""
|
| 638 |
Train the retrieval model using a pre-prepared TFRecord dataset.
|
| 639 |
-
This method handles:
|
| 640 |
- Checkpoint loading/restoring
|
| 641 |
- LR scheduling
|
| 642 |
- Epoch/iteration tracking
|
| 643 |
-
-
|
| 644 |
-
-
|
|
|
|
| 645 |
"""
|
| 646 |
logger.info("Starting training with pre-prepared TFRecord dataset...")
|
| 647 |
|
|
@@ -673,7 +656,7 @@ class RetrievalChatbot(DeviceAwareModel):
|
|
| 673 |
steps_per_epoch = math.ceil(train_size / batch_size)
|
| 674 |
val_steps = math.ceil(val_size / batch_size)
|
| 675 |
total_steps = steps_per_epoch * epochs
|
| 676 |
-
buffer_size = max(1, total_pairs //
|
| 677 |
|
| 678 |
logger.info(f"Training pairs: {train_size}")
|
| 679 |
logger.info(f"Validation pairs: {val_size}")
|
|
@@ -695,7 +678,7 @@ class RetrievalChatbot(DeviceAwareModel):
|
|
| 695 |
self.optimizer = tf.keras.optimizers.Adam(learning_rate=tf.cast(peak_lr, tf.float32))
|
| 696 |
logger.info("Using fixed learning rate.")
|
| 697 |
|
| 698 |
-
#
|
| 699 |
dummy_input = tf.zeros((1, self.config.max_context_token_limit), dtype=tf.int32)
|
| 700 |
with tf.GradientTape() as tape:
|
| 701 |
dummy_output = self.encoder(dummy_input)
|
|
@@ -710,6 +693,7 @@ class RetrievalChatbot(DeviceAwareModel):
|
|
| 710 |
model=self.encoder
|
| 711 |
)
|
| 712 |
|
|
|
|
| 713 |
manager = tf.train.CheckpointManager(
|
| 714 |
checkpoint,
|
| 715 |
directory=checkpoint_dir,
|
|
@@ -717,18 +701,18 @@ class RetrievalChatbot(DeviceAwareModel):
|
|
| 717 |
checkpoint_name='ckpt'
|
| 718 |
)
|
| 719 |
|
| 720 |
-
# Restore from existing checkpoint if
|
| 721 |
latest_checkpoint = manager.latest_checkpoint
|
| 722 |
history_path = Path(checkpoint_dir) / 'training_history.json'
|
| 723 |
|
| 724 |
-
#
|
| 725 |
if not hasattr(self, 'history'):
|
| 726 |
self.history = {'train_loss': [], 'val_loss': [], 'learning_rate': []}
|
| 727 |
|
| 728 |
if latest_checkpoint and not test_mode:
|
| 729 |
-
#
|
| 730 |
-
logger.info(f"\nTrying to load checkpoint from: {latest_checkpoint}")
|
| 731 |
-
reader = tf.train.load_checkpoint(latest_checkpoint)
|
| 732 |
# shape_from_key = reader.get_variable_to_shape_map()
|
| 733 |
# dtype_from_key = reader.get_variable_to_dtype_map()
|
| 734 |
# logger.info("\nCheckpoint Variables:")
|
|
@@ -752,11 +736,11 @@ class RetrievalChatbot(DeviceAwareModel):
|
|
| 752 |
if initial_epoch == 0:
|
| 753 |
initial_epoch = ckpt_number
|
| 754 |
|
| 755 |
-
# Assign to checkpoint.epoch
|
| 756 |
checkpoint.epoch.assign(tf.cast(initial_epoch, tf.int32))
|
| 757 |
logger.info(f"Resuming from epoch {initial_epoch}")
|
| 758 |
|
| 759 |
-
#
|
| 760 |
if history_path.exists():
|
| 761 |
try:
|
| 762 |
with open(history_path, 'r') as f:
|
|
@@ -765,7 +749,10 @@ class RetrievalChatbot(DeviceAwareModel):
|
|
| 765 |
except Exception as e:
|
| 766 |
logger.warning(f"Could not load history, starting fresh: {e}")
|
| 767 |
|
| 768 |
-
#
|
|
|
|
|
|
|
|
|
|
| 769 |
self.save_models(Path(checkpoint_dir) / "pretrained_full_model")
|
| 770 |
logger.info(f"Manually saved custom weights after restore.")
|
| 771 |
else:
|
|
@@ -782,13 +769,13 @@ class RetrievalChatbot(DeviceAwareModel):
|
|
| 782 |
train_summary_writer = tf.summary.create_file_writer(train_log_dir)
|
| 783 |
val_summary_writer = tf.summary.create_file_writer(val_log_dir)
|
| 784 |
logger.info(f"TensorBoard logs will be saved in {log_dir}")
|
| 785 |
-
|
| 786 |
# Parse dataset
|
| 787 |
dataset = tf.data.TFRecordDataset(tfrecord_file_path)
|
| 788 |
-
|
| 789 |
-
#
|
| 790 |
if test_mode:
|
| 791 |
-
subset_size =
|
| 792 |
dataset = dataset.take(subset_size)
|
| 793 |
logger.info(f"TEST MODE: Using only {subset_size} examples")
|
| 794 |
# Recompute sizes, steps, epochs, etc., as needed
|
|
@@ -804,38 +791,36 @@ class RetrievalChatbot(DeviceAwareModel):
|
|
| 804 |
early_stopping_patience = 2
|
| 805 |
logger.info(f"New training pairs: {train_size}")
|
| 806 |
logger.info(f"New validation pairs: {val_size}")
|
| 807 |
-
|
| 808 |
dataset = dataset.map(
|
| 809 |
-
lambda x: parse_tfrecord_fn(x, self.config.max_context_token_limit, self.
|
| 810 |
num_parallel_calls=tf.data.AUTOTUNE
|
| 811 |
)
|
| 812 |
-
|
| 813 |
# Train/val split
|
| 814 |
train_dataset = dataset.take(train_size)
|
| 815 |
val_dataset = dataset.skip(train_size).take(val_size)
|
| 816 |
-
|
| 817 |
# Shuffle and batch
|
| 818 |
train_dataset = train_dataset.shuffle(buffer_size=buffer_size)
|
| 819 |
train_dataset = train_dataset.batch(batch_size, drop_remainder=True)
|
| 820 |
train_dataset = train_dataset.prefetch(tf.data.AUTOTUNE)
|
| 821 |
-
|
| 822 |
val_dataset = val_dataset.batch(batch_size, drop_remainder=False)
|
| 823 |
val_dataset = val_dataset.prefetch(tf.data.AUTOTUNE)
|
| 824 |
val_dataset = val_dataset.cache()
|
| 825 |
-
|
| 826 |
# Training loop
|
| 827 |
best_val_loss = float("inf")
|
| 828 |
epochs_no_improve = 0
|
| 829 |
-
|
| 830 |
for epoch in range(int(checkpoint.epoch.numpy()) + 1, epochs + 1):
|
| 831 |
checkpoint.epoch.assign(epoch)
|
| 832 |
logger.info(f"Starting Epoch {epoch}...")
|
| 833 |
-
|
| 834 |
-
# --- Training Phase ---
|
| 835 |
epoch_loss_avg = tf.keras.metrics.Mean(dtype=tf.float32)
|
| 836 |
batches_processed = 0
|
| 837 |
-
|
| 838 |
-
# Progress bar
|
| 839 |
try:
|
| 840 |
train_pbar = tqdm(
|
| 841 |
total=steps_per_epoch,
|
|
@@ -846,7 +831,8 @@ class RetrievalChatbot(DeviceAwareModel):
|
|
| 846 |
except ImportError:
|
| 847 |
train_pbar = None
|
| 848 |
is_tqdm_train = False
|
| 849 |
-
|
|
|
|
| 850 |
for q_batch, p_batch, n_batch in train_dataset:
|
| 851 |
loss, grad_norm, post_clip_norm = self.train_step(q_batch, p_batch, n_batch)
|
| 852 |
epoch_loss_avg(loss)
|
|
@@ -874,54 +860,54 @@ class RetrievalChatbot(DeviceAwareModel):
|
|
| 874 |
"lr": f"{current_lr:.2e}",
|
| 875 |
"batches": f"{batches_processed}/{steps_per_epoch}"
|
| 876 |
})
|
| 877 |
-
|
| 878 |
gc.collect()
|
| 879 |
-
|
| 880 |
# End the epoch early if we've processed all steps
|
| 881 |
if batches_processed >= steps_per_epoch:
|
| 882 |
break
|
| 883 |
-
|
| 884 |
if is_tqdm_train and train_pbar:
|
| 885 |
train_pbar.close()
|
| 886 |
-
|
| 887 |
-
# --- Validation
|
| 888 |
val_loss_avg = tf.keras.metrics.Mean(dtype=tf.float32)
|
| 889 |
val_batches_processed = 0
|
| 890 |
-
|
| 891 |
try:
|
| 892 |
val_pbar = tqdm(total=val_steps, desc="Validation", unit="batch")
|
| 893 |
is_tqdm_val = True
|
| 894 |
except ImportError:
|
| 895 |
val_pbar = None
|
| 896 |
is_tqdm_val = False
|
| 897 |
-
|
| 898 |
last_valid_val_loss = None
|
| 899 |
valid_batches = False
|
| 900 |
-
|
| 901 |
for q_batch, p_batch, n_batch in val_dataset:
|
| 902 |
# If batch is too small, skip
|
| 903 |
if tf.shape(q_batch)[0] < 2:
|
| 904 |
logger.warning(f"Skipping validation batch of size {tf.shape(q_batch)[0]}")
|
| 905 |
continue
|
| 906 |
-
|
| 907 |
valid_batches = True
|
| 908 |
val_loss = self.validation_step(q_batch, p_batch, n_batch)
|
| 909 |
val_loss_avg(val_loss)
|
| 910 |
last_valid_val_loss = val_loss
|
| 911 |
val_batches_processed += 1
|
| 912 |
-
|
| 913 |
if is_tqdm_val:
|
| 914 |
val_pbar.update(1)
|
| 915 |
val_pbar.set_postfix({
|
| 916 |
"val_loss": f"{val_loss.numpy():.4f}",
|
| 917 |
"batches": f"{val_batches_processed}/{val_steps}"
|
| 918 |
})
|
| 919 |
-
|
| 920 |
gc.collect()
|
| 921 |
-
|
| 922 |
if val_batches_processed >= val_steps:
|
| 923 |
break
|
| 924 |
-
|
| 925 |
if not valid_batches:
|
| 926 |
# If no valid batch is found, fallback
|
| 927 |
logger.warning("No valid validation batches in this epoch")
|
|
@@ -931,29 +917,29 @@ class RetrievalChatbot(DeviceAwareModel):
|
|
| 931 |
else:
|
| 932 |
val_loss = epoch_loss_avg.result()
|
| 933 |
val_loss_avg(val_loss)
|
| 934 |
-
|
| 935 |
if is_tqdm_val and val_pbar:
|
| 936 |
val_pbar.close()
|
| 937 |
-
|
| 938 |
# End of epoch: final stats
|
| 939 |
train_loss = epoch_loss_avg.result().numpy()
|
| 940 |
val_loss = val_loss_avg.result().numpy()
|
| 941 |
logger.info(f"Epoch {epoch} Complete: Train Loss={train_loss:.4f}, Val Loss={val_loss:.4f}")
|
| 942 |
-
|
| 943 |
# TensorBoard epoch logs
|
| 944 |
with train_summary_writer.as_default():
|
| 945 |
tf.summary.scalar("epoch_loss", train_loss, step=epoch)
|
| 946 |
with val_summary_writer.as_default():
|
| 947 |
tf.summary.scalar("val_loss", val_loss, step=epoch)
|
| 948 |
-
|
| 949 |
# Save checkpoint
|
| 950 |
manager.save()
|
| 951 |
-
|
| 952 |
-
#
|
| 953 |
model_save_path = Path(checkpoint_dir) / f"model_epoch_{epoch}"
|
| 954 |
self.save_models(model_save_path)
|
| 955 |
logger.info(f"Saved model for epoch {epoch} at {model_save_path}")
|
| 956 |
-
|
| 957 |
# Update local history
|
| 958 |
self.history['train_loss'].append(train_loss)
|
| 959 |
self.history['val_loss'].append(val_loss)
|
|
@@ -972,13 +958,12 @@ class RetrievalChatbot(DeviceAwareModel):
|
|
| 972 |
return obj
|
| 973 |
|
| 974 |
json_history = convert_to_py_floats(self.history)
|
| 975 |
-
|
| 976 |
# Save training history to file every epoch
|
| 977 |
-
# (Create or overwrite the file so we always have the latest.)
|
| 978 |
with open(history_path, 'w') as f:
|
| 979 |
json.dump(json_history, f)
|
| 980 |
logger.info(f"Saved training history to {history_path}")
|
| 981 |
-
|
| 982 |
# Early stopping
|
| 983 |
if val_loss < best_val_loss - min_delta:
|
| 984 |
best_val_loss = val_loss
|
|
@@ -990,7 +975,7 @@ class RetrievalChatbot(DeviceAwareModel):
|
|
| 990 |
if epochs_no_improve >= early_stopping_patience:
|
| 991 |
logger.info("Early stopping triggered.")
|
| 992 |
break
|
| 993 |
-
|
| 994 |
logger.info("Training completed!")
|
| 995 |
|
| 996 |
@tf.function
|
|
@@ -1004,37 +989,25 @@ class RetrievalChatbot(DeviceAwareModel):
|
|
| 1004 |
Single training step using queries, positives, and hard negatives.
|
| 1005 |
"""
|
| 1006 |
with tf.GradientTape() as tape:
|
| 1007 |
-
# Encode queries
|
| 1008 |
q_enc = self.encoder(q_batch, training=True) # [batch_size, embed_dim]
|
| 1009 |
-
|
| 1010 |
-
# Encode positives
|
| 1011 |
p_enc = self.encoder(p_batch, training=True) # [batch_size, embed_dim]
|
| 1012 |
-
|
| 1013 |
-
# Encode negatives
|
| 1014 |
-
# n_batch: [batch_size, neg_samples, max_length]
|
| 1015 |
shape = tf.shape(n_batch)
|
| 1016 |
bs = shape[0]
|
| 1017 |
neg_samples = shape[1]
|
| 1018 |
|
| 1019 |
-
# Flatten negatives to feed them in one pass:
|
| 1020 |
-
# => [batch_size * neg_samples, max_length]
|
| 1021 |
n_batch_flat = tf.reshape(n_batch, [bs * neg_samples, shape[2]])
|
| 1022 |
n_enc_flat = self.encoder(n_batch_flat, training=True) # [bs*neg_samples, embed_dim]
|
| 1023 |
|
| 1024 |
# Reshape back => [batch_size, neg_samples, embed_dim]
|
| 1025 |
n_enc = tf.reshape(n_enc_flat, [bs, neg_samples, -1])
|
| 1026 |
|
| 1027 |
-
# Combine the positive embedding and negative embeddings along dim=1
|
| 1028 |
-
#
|
| 1029 |
-
|
| 1030 |
-
combined_p_n = tf.concat(
|
| 1031 |
-
[tf.expand_dims(p_enc, axis=1), n_enc],
|
| 1032 |
-
axis=1
|
| 1033 |
-
) # [bs, (1+neg_samples), embed_dim]
|
| 1034 |
|
| 1035 |
-
#
|
| 1036 |
-
# We'll use `tf.einsum` to handle the batch dimension properly
|
| 1037 |
-
# dot_products => shape [batch_size, (1+neg_samples)]
|
| 1038 |
dot_products = tf.cast(tf.einsum('bd,bkd->bk', q_enc, combined_p_n), tf.float32)
|
| 1039 |
labels = tf.zeros([bs], dtype=tf.int32) # Keep labels as int32
|
| 1040 |
loss = tf.nn.sparse_softmax_cross_entropy_with_logits(
|
|
@@ -1043,14 +1016,13 @@ class RetrievalChatbot(DeviceAwareModel):
|
|
| 1043 |
)
|
| 1044 |
loss = tf.cast(tf.reduce_mean(loss), tf.float32)
|
| 1045 |
|
| 1046 |
-
# Calculate gradients
|
| 1047 |
gradients = tape.gradient(loss, self.encoder.trainable_variables)
|
| 1048 |
gradients_norm = tf.cast(tf.linalg.global_norm(gradients), tf.float32)
|
| 1049 |
max_grad_norm = tf.constant(1.5, dtype=tf.float32)
|
| 1050 |
gradients, _ = tf.clip_by_global_norm(gradients, max_grad_norm, gradients_norm)
|
| 1051 |
post_clip_norm = tf.cast(tf.linalg.global_norm(gradients), tf.float32)
|
| 1052 |
|
| 1053 |
-
# Apply gradients
|
| 1054 |
self.optimizer.apply_gradients(zip(gradients, self.encoder.trainable_variables))
|
| 1055 |
|
| 1056 |
return loss, gradients_norm, post_clip_norm
|
|
@@ -1064,6 +1036,7 @@ class RetrievalChatbot(DeviceAwareModel):
|
|
| 1064 |
) -> tf.Tensor:
|
| 1065 |
"""
|
| 1066 |
Single validation step using queries, positives, and hard negatives.
|
|
|
|
| 1067 |
"""
|
| 1068 |
q_enc = self.encoder(q_batch, training=False)
|
| 1069 |
p_enc = self.encoder(p_batch, training=False)
|
|
@@ -1082,7 +1055,7 @@ class RetrievalChatbot(DeviceAwareModel):
|
|
| 1082 |
)
|
| 1083 |
|
| 1084 |
dot_products = tf.cast(tf.einsum('bd,bkd->bk', q_enc, combined_p_n), tf.float32)
|
| 1085 |
-
labels = tf.zeros([bs], dtype=tf.int32)
|
| 1086 |
|
| 1087 |
loss = tf.nn.sparse_softmax_cross_entropy_with_logits(
|
| 1088 |
labels=labels,
|
|
@@ -1098,7 +1071,9 @@ class RetrievalChatbot(DeviceAwareModel):
|
|
| 1098 |
peak_lr: float,
|
| 1099 |
warmup_steps: int
|
| 1100 |
) -> tf.keras.optimizers.schedules.LearningRateSchedule:
|
| 1101 |
-
"""
|
|
|
|
|
|
|
| 1102 |
class CustomSchedule(tf.keras.optimizers.schedules.LearningRateSchedule):
|
| 1103 |
def __init__(
|
| 1104 |
self,
|
|
@@ -1110,11 +1085,11 @@ class RetrievalChatbot(DeviceAwareModel):
|
|
| 1110 |
self.total_steps = tf.cast(total_steps, tf.float32)
|
| 1111 |
self.peak_lr = tf.cast(peak_lr, tf.float32)
|
| 1112 |
|
| 1113 |
-
#
|
| 1114 |
adjusted_warmup_steps = min(warmup_steps, max(1, total_steps // 10))
|
| 1115 |
self.warmup_steps = tf.cast(adjusted_warmup_steps, tf.float32)
|
| 1116 |
|
| 1117 |
-
# Calculate
|
| 1118 |
self.initial_lr = tf.cast(self.peak_lr * 0.1, tf.float32)
|
| 1119 |
self.min_lr = tf.cast(self.peak_lr * 0.01, tf.float32)
|
| 1120 |
|
|
@@ -1128,21 +1103,20 @@ class RetrievalChatbot(DeviceAwareModel):
|
|
| 1128 |
def __call__(self, step):
|
| 1129 |
step = tf.cast(step, tf.float32)
|
| 1130 |
|
| 1131 |
-
# Warmup
|
| 1132 |
warmup_factor = tf.cast(tf.minimum(1.0, step / self.warmup_steps), tf.float32)
|
| 1133 |
warmup_lr = self.initial_lr + (self.peak_lr - self.initial_lr) * warmup_factor
|
| 1134 |
|
| 1135 |
-
# Decay
|
| 1136 |
decay_steps = tf.cast(tf.maximum(1.0, self.total_steps - self.warmup_steps), tf.float32)
|
| 1137 |
decay_factor = tf.cast((step - self.warmup_steps) / decay_steps, tf.float32)
|
| 1138 |
decay_factor = tf.cast(tf.minimum(tf.maximum(0.0, decay_factor), 1.0), tf.float32)
|
| 1139 |
cosine_decay = tf.cast(0.5 * (1.0 + tf.cos(tf.constant(math.pi, dtype=tf.float32) * decay_factor)), tf.float32)
|
| 1140 |
decay_lr = self.min_lr + (self.peak_lr - self.min_lr) * cosine_decay
|
| 1141 |
|
| 1142 |
-
# Choose between warmup and decay
|
| 1143 |
final_lr = tf.where(step < self.warmup_steps, warmup_lr, decay_lr)
|
| 1144 |
|
| 1145 |
-
# Ensure
|
| 1146 |
final_lr = tf.maximum(self.min_lr, final_lr)
|
| 1147 |
final_lr = tf.where(tf.math.is_finite(final_lr), final_lr, self.min_lr)
|
| 1148 |
|
|
|
|
| 24 |
|
| 25 |
@dataclass
|
| 26 |
class ChatbotConfig:
|
| 27 |
+
"""RetrievalChatbot Config"""
|
| 28 |
max_context_token_limit: int = 512
|
| 29 |
embedding_dim: int = 768
|
| 30 |
encoder_units: int = 256
|
| 31 |
num_attention_heads: int = 8
|
| 32 |
dropout_rate: float = 0.2
|
| 33 |
l2_reg_weight: float = 0.001
|
| 34 |
+
learning_rate: float = 0.0005
|
| 35 |
min_text_length: int = 3
|
| 36 |
+
max_context_turns: int = 20
|
| 37 |
warmup_steps: int = 200
|
| 38 |
pretrained_model: str = 'distilbert-base-uncased'
|
| 39 |
cross_encoder_model: str = 'cross-encoder/ms-marco-MiniLM-L-12-v2'
|
| 40 |
+
summarizer_model: str = 't5-small'
|
| 41 |
dtype: str = 'float32'
|
| 42 |
freeze_embeddings: bool = False
|
| 43 |
embedding_batch_size: int = 64
|
| 44 |
search_batch_size: int = 64
|
| 45 |
max_batch_size: int = 64
|
|
|
|
| 46 |
max_retries: int = 3
|
| 47 |
|
| 48 |
def to_dict(self) -> Dict:
|
|
|
|
| 57 |
if k in cls.__dataclass_fields__})
|
| 58 |
|
| 59 |
class EncoderModel(tf.keras.Model):
|
| 60 |
+
"""Dual encoder model with pretrained DistilBERT embeddings."""
|
| 61 |
def __init__(
|
| 62 |
self,
|
| 63 |
config: ChatbotConfig,
|
|
|
|
| 71 |
self.pretrained = TFAutoModel.from_pretrained(config.pretrained_model)
|
| 72 |
self._freeze_layers()
|
| 73 |
|
| 74 |
+
# Add Global Average Pooling, Projection, Dropout, and Normalization layers
|
| 75 |
self.pooler = tf.keras.layers.GlobalAveragePooling1D()
|
| 76 |
self.projection = tf.keras.layers.Dense(
|
| 77 |
config.embedding_dim,
|
|
|
|
| 86 |
)
|
| 87 |
|
| 88 |
def _freeze_layers(self):
|
| 89 |
+
"""Freeze n layers of the pretrained model"""
|
| 90 |
if self.config.freeze_embeddings:
|
| 91 |
self.pretrained.trainable = False
|
| 92 |
logger.info("All pretrained layers frozen.")
|
|
|
|
| 95 |
for i, layer in enumerate(self.pretrained.layers):
|
| 96 |
if isinstance(layer, tf.keras.layers.Layer):
|
| 97 |
if hasattr(layer, 'trainable'):
|
|
|
|
| 98 |
if i < 1:
|
| 99 |
layer.trainable = False
|
| 100 |
logger.info(f"Layer {i} frozen.")
|
| 101 |
else:
|
| 102 |
layer.trainable = True
|
| 103 |
+
logger.info(f"Layer {i} trainable.")
|
| 104 |
|
| 105 |
def call(self, inputs: tf.Tensor, training: bool = False) -> tf.Tensor:
|
| 106 |
"""Forward pass."""
|
| 107 |
# Get pretrained embeddings
|
| 108 |
pretrained_outputs = self.pretrained(inputs, training=training)
|
| 109 |
+
x = pretrained_outputs.last_hidden_state # Shape: [batch_size, seq_len, embedding_dim]
|
| 110 |
|
| 111 |
# Apply pooling, projection, dropout, and normalization
|
| 112 |
+
x = self.pooler(x) # Shape: [batch_size, 768]
|
| 113 |
+
x = self.projection(x) # Shape: [batch_size, 768]
|
| 114 |
x = self.dropout(x, training=training)
|
| 115 |
+
x = self.normalize(x) # Shape: [batch_size, 768]
|
| 116 |
|
| 117 |
return x
|
| 118 |
|
| 119 |
def get_config(self) -> dict:
|
| 120 |
+
"""Return the model config"""
|
| 121 |
config = super().get_config()
|
| 122 |
config.update({
|
| 123 |
"config": self.config.to_dict(),
|
|
|
|
| 126 |
return config
|
| 127 |
|
| 128 |
class RetrievalChatbot(DeviceAwareModel):
|
| 129 |
+
"""
|
| 130 |
+
Retrieval-based learning chatbot model.
|
| 131 |
+
Uses trained embeddings and FAISS for similarity search.
|
| 132 |
+
"""
|
| 133 |
def __init__(
|
| 134 |
self,
|
| 135 |
config: ChatbotConfig,
|
|
|
|
| 145 |
self.device = device or self._setup_default_device()
|
| 146 |
self.mode = mode.lower()
|
| 147 |
|
| 148 |
+
# Initialize reranker, summarizer, tokenizer, and encoder
|
| 149 |
self.reranker = reranker or self._initialize_reranker()
|
| 150 |
self.tokenizer = self._initialize_tokenizer()
|
| 151 |
self.encoder = self._initialize_encoder()
|
|
|
|
| 157 |
config=self.config,
|
| 158 |
tokenizer=self.tokenizer,
|
| 159 |
encoder=self.encoder,
|
|
|
|
| 160 |
response_pool=[],
|
| 161 |
max_length=self.config.max_context_token_limit,
|
| 162 |
query_embeddings_cache={},
|
|
|
|
|
|
|
|
|
|
|
|
|
| 163 |
)
|
| 164 |
|
| 165 |
# Collect unique responses from dialogues
|
|
|
|
| 195 |
"""Initialize the Summarizer."""
|
| 196 |
return Summarizer(
|
| 197 |
tokenizer=self.tokenizer,
|
| 198 |
+
model_name=self.config.summarizer_model,
|
| 199 |
max_summary_length=self.config.max_context_token_limit // 4,
|
| 200 |
device=self.device,
|
| 201 |
max_summary_rounds=2
|
|
|
|
| 227 |
new_vocab_size = len(self.tokenizer)
|
| 228 |
encoder.pretrained.resize_token_embeddings(new_vocab_size)
|
| 229 |
logger.info(f"Token embeddings resized to: {new_vocab_size}")
|
| 230 |
+
|
| 231 |
return encoder
|
| 232 |
|
| 233 |
def _load_faiss_index_and_responses(self) -> None:
|
| 234 |
"""Load FAISS index and response pool for inference."""
|
| 235 |
try:
|
| 236 |
+
logger.info(f"Loading FAISS index from {self.data_pipeline.faiss_index_file_path}...")
|
| 237 |
+
self.data_pipeline.load_faiss_index(self.data_pipeline.faiss_index_file_path)
|
| 238 |
logger.info("FAISS index loaded successfully.")
|
| 239 |
|
| 240 |
+
# Load response pool
|
| 241 |
+
response_pool_path = self.data_pipeline.faiss_index_file_path.replace('.index', '_responses.json')
|
| 242 |
if os.path.exists(response_pool_path):
|
| 243 |
with open(response_pool_path, 'r', encoding='utf-8') as f:
|
| 244 |
self.data_pipeline.response_pool = json.load(f)
|
|
|
|
| 262 |
"""
|
| 263 |
load_dir = Path(load_dir)
|
| 264 |
|
| 265 |
+
# Load config
|
| 266 |
with open(load_dir / "config.json", "r") as f:
|
| 267 |
config = ChatbotConfig.from_dict(json.load(f))
|
| 268 |
|
| 269 |
+
# Initialize chatbot
|
| 270 |
chatbot = cls(config, mode=mode)
|
| 271 |
|
| 272 |
+
# Load DistilBERT
|
| 273 |
+
chatbot.encoder.pretrained = TFAutoModel.from_pretrained(load_dir / "shared_encoder", config=config)
|
|
|
|
|
|
|
|
|
|
| 274 |
|
| 275 |
dummy_input = tf.zeros((1, config.max_context_token_limit), dtype=tf.int32)
|
| 276 |
_ = chatbot.encoder(dummy_input, training=False)
|
| 277 |
|
| 278 |
+
# Load tokenizer
|
| 279 |
chatbot.tokenizer = AutoTokenizer.from_pretrained(load_dir / "tokenizer")
|
| 280 |
logger.info(f"Models and tokenizer loaded from {load_dir}")
|
| 281 |
|
| 282 |
+
# Load the custom weights
|
|
|
|
|
|
|
| 283 |
custom_weights_path = load_dir / "encoder_custom_weights.weights.h5"
|
| 284 |
if custom_weights_path.exists():
|
| 285 |
chatbot.encoder.load_weights(str(custom_weights_path))
|
|
|
|
| 287 |
else:
|
| 288 |
logger.warning(f"No custom encoder weights found at {custom_weights_path}. The top-level projection layer won't have learned parameters.")
|
| 289 |
|
| 290 |
+
# Handle 'inference' mode: load FAISS, etc.
|
| 291 |
if mode == 'inference':
|
| 292 |
cls._prepare_model_for_inference(chatbot, load_dir)
|
| 293 |
|
|
|
|
| 295 |
|
| 296 |
@classmethod
|
| 297 |
def _prepare_model_for_inference(cls, chatbot: 'RetrievalChatbot', load_dir: Path) -> None:
|
| 298 |
+
"""Load inference components."""
|
| 299 |
try:
|
| 300 |
# Load FAISS index
|
| 301 |
faiss_path = load_dir / 'faiss_indices/faiss_index_production.index'
|
|
|
|
| 326 |
raise
|
| 327 |
|
| 328 |
def save_models(self, save_dir: Union[str, Path]):
|
| 329 |
+
"""Save model and config"""
|
| 330 |
save_dir = Path(save_dir)
|
| 331 |
save_dir.mkdir(parents=True, exist_ok=True)
|
| 332 |
|
|
|
|
| 334 |
with open(save_dir / "config.json", "w") as f:
|
| 335 |
json.dump(self.config.to_dict(), f, indent=2)
|
| 336 |
|
| 337 |
+
# Save the HF DistilBERT submodule, custom top-level layers, and tokenizer
|
| 338 |
self.encoder.pretrained.save_pretrained(save_dir / "shared_encoder")
|
|
|
|
|
|
|
| 339 |
self.encoder.save_weights(save_dir / "encoder_custom_weights.weights.h5")
|
|
|
|
|
|
|
| 340 |
self.tokenizer.save_pretrained(save_dir / "tokenizer")
|
|
|
|
| 341 |
logger.info(f"Models and tokenizer saved to {save_dir}.")
|
| 342 |
|
| 343 |
+
def retrieve_responses(
|
|
|
|
|
|
|
|
|
|
| 344 |
self,
|
| 345 |
query: str,
|
| 346 |
top_k: int = 10,
|
|
|
|
| 349 |
summarize_threshold: int = 512
|
| 350 |
) -> List[Tuple[str, float]]:
|
| 351 |
"""
|
| 352 |
+
Retrieve top-k responses using FAISS and cross-encoder re-ranking.
|
|
|
|
|
|
|
| 353 |
Args:
|
| 354 |
query: The user's input text.
|
| 355 |
+
top_k: Number of FAISS results to return
|
| 356 |
+
reranker: CrossEncoderReranker for refined scoring
|
| 357 |
+
summarizer: Summarizer for long queries
|
| 358 |
+
summarize_threshold: Summarize if conversation tokens > threshold.
|
|
|
|
| 359 |
Returns:
|
| 360 |
List of (response_text, final_score).
|
| 361 |
"""
|
| 362 |
+
def sigmoid(x: float) -> float:
|
| 363 |
+
return 1 / (1 + np.exp(-x))
|
| 364 |
+
|
| 365 |
+
# Query summarization
|
| 366 |
if summarizer and len(query.split()) > summarize_threshold:
|
| 367 |
logger.info(f"Query is long ({len(query.split())} words). Summarizing.")
|
| 368 |
query = summarizer.summarize_text(query)
|
|
|
|
| 379 |
|
| 380 |
texts = [item[0] for item in faiss_candidates]
|
| 381 |
|
|
|
|
| 382 |
if not reranker:
|
| 383 |
reranker = CrossEncoderReranker(model_name=self.config.cross_encoder_model)
|
| 384 |
|
| 385 |
+
# Re-rank the texts (candidates) from FAISS search using the cross-encoder
|
| 386 |
ce_logits = reranker.rerank(query, texts, max_length=256)
|
| 387 |
|
| 388 |
+
# Combine scores from FAISS and cross-encoder
|
| 389 |
final_candidates = []
|
| 390 |
for (resp_text, faiss_score), logit in zip(faiss_candidates, ce_logits):
|
| 391 |
+
ce_prob = sigmoid(logit) # now in range [0...1]
|
| 392 |
+
faiss_norm = (faiss_score + 1)/2.0 # now in range [0...1]
|
| 393 |
combined_score = 0.85 * ce_prob + 0.15 * faiss_norm
|
| 394 |
length_adjusted_score = self.length_adjust_score(resp_text, combined_score)
|
| 395 |
|
|
|
|
| 401 |
# Return top_k
|
| 402 |
return final_candidates[:top_k]
|
| 403 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 404 |
def extract_keywords(self, query: str) -> List[str]:
|
| 405 |
"""
|
| 406 |
Return any domain keywords present in the query (lowercased).
|
| 407 |
"""
|
| 408 |
+
domain_keywords = {
|
| 409 |
+
'restaurant': ['restaurant', 'dining', 'food', 'dine', 'reservation', 'table', 'menu', 'cuisine', 'eat', 'place to eat', 'hungry', 'chef', 'dish', 'meal', 'brunch', 'bistro', 'buffet', 'catering', 'gourmet', 'fast food', 'fine dining', 'takeaway', 'delivery', 'restaurant booking'],
|
| 410 |
+
'movie': ['movie', 'cinema', 'film', 'ticket', 'showtime', 'showing', 'theater', 'flick', 'screening', 'film ticket', 'film show', 'blockbuster', 'premiere', 'trailer', 'director', 'actor', 'actress', 'plot', 'genre', 'screen', 'sequel', 'animation', 'documentary'],
|
| 411 |
+
'ride_share': ['ride', 'taxi', 'uber', 'lyft', 'car service', 'pickup', 'dropoff', 'driver', 'cab', 'hailing', 'rideshare', 'ride hailing', 'carpool', 'chauffeur', 'transit', 'transportation', 'hail ride'],
|
| 412 |
+
'coffee': ['coffee', 'cafΓ©', 'cafe', 'starbucks', 'espresso', 'latte', 'mocha', 'americano', 'barista', 'brew', 'cappuccino', 'macchiato', 'iced coffee', 'cold brew', 'espresso machine', 'coffee shop', 'tea', 'chai', 'java', 'bean', 'roast', 'decaf'],
|
| 413 |
+
'pizza': ['pizza', 'delivery', 'order food', 'pepperoni', 'topping', 'pizzeria', 'slice', 'pie', 'margherita', 'deep dish', 'thin crust', 'cheese', 'oven', 'tossed', 'sauce', 'garlic bread', 'calzone'],
|
| 414 |
+
'auto': ['car', 'vehicle', 'repair', 'maintenance', 'mechanic', 'oil change', 'garage', 'auto shop', 'tire', 'check engine', 'battery', 'transmission', 'brake', 'engine diagnostics', 'carwash', 'detail', 'alignment', 'exhaust', 'spark plug', 'dashboard'],
|
| 415 |
+
}
|
| 416 |
+
|
| 417 |
query_lower = query.lower()
|
| 418 |
found = set()
|
| 419 |
+
for domain, kw_list in domain_keywords.items():
|
| 420 |
for kw in kw_list:
|
| 421 |
if kw in query_lower:
|
| 422 |
found.add(kw)
|
|
|
|
| 442 |
|
| 443 |
def detect_domain_from_query(self, query: str) -> str:
|
| 444 |
"""
|
| 445 |
+
Detect the domain of the query based on keywords. Used for boosting FAISS search.
|
| 446 |
"""
|
| 447 |
domain_patterns = {
|
| 448 |
'restaurant': r'\b(restaurant|restaurants?|dining|food|foods?|dine|reservation|reservations?|table|tables?|menu|menus?|cuisine|cuisines?|eat|eats?|place\s?to\s?eat|places\s?to\s?eat|hungry|chef|chefs?|dish|dishes?|meal|meals?|fork|forks?|knife|knives?|spoon|spoons?|brunch|bistro|buffet|buffets?|catering|caterings?|gourmet|fast\s?food|fine\s?dining|takeaway|takeaways?|delivery|deliveries|restaurant\s?booking)\b',
|
|
|
|
| 462 |
|
| 463 |
def is_numeric_response(self, text: str) -> bool:
|
| 464 |
"""
|
| 465 |
+
Return True if `text` is purely digits and/or spaces.
|
|
|
|
| 466 |
"""
|
| 467 |
pattern = r'^[\s]*[\d]+([\s.,\d]+)*[\s]*$'
|
| 468 |
return bool(re.match(pattern, text.strip()))
|
|
|
|
| 471 |
self,
|
| 472 |
query: str,
|
| 473 |
domain: str = 'other',
|
| 474 |
+
top_k: int = 10,
|
| 475 |
+
boost_factor: float = 1.15
|
| 476 |
) -> List[Tuple[str, float]]:
|
| 477 |
"""
|
| 478 |
Retrieve top-k responses from the FAISS index (IndexFlatIP) given a user query.
|
|
|
|
| 479 |
Args:
|
| 480 |
query (str): The user input text.
|
| 481 |
+
domain (str): The detected domain from possible domains: ['restaurant', 'movie', 'ride_share', 'coffee', 'pizza', 'auto', 'other']
|
| 482 |
+
top_k (int): Number of top results to return.
|
| 483 |
+
boost_factor (float, optional): Factor to boost scores for keyword matches.
|
|
|
|
| 484 |
Returns:
|
| 485 |
List[Tuple[str, float]]: List of (response_text, similarity) sorted by descending similarity.
|
| 486 |
"""
|
|
|
|
| 491 |
# Search the index
|
| 492 |
distances, indices = self.data_pipeline.index.search(q_emb_np, top_k * 10)
|
| 493 |
|
| 494 |
+
# IndexFlatIP: 'distances' are inner products (cosine similarities for normalized vectors).
|
| 495 |
candidates = []
|
| 496 |
for rank, idx in enumerate(indices[0]):
|
| 497 |
if idx < 0:
|
|
|
|
| 528 |
boosted = []
|
| 529 |
for (resp_text, resp_domain, score) in in_domain:
|
| 530 |
new_score = score
|
| 531 |
+
# If the domain is known AND the response text shares any query keywords, boost it
|
|
|
|
| 532 |
if query_keywords and any(kw in resp_text.lower() for kw in query_keywords):
|
| 533 |
new_score *= boost_factor
|
| 534 |
|
|
|
|
| 540 |
# Sort boosted responses
|
| 541 |
boosted.sort(key=lambda x: x[1], reverse=True)
|
| 542 |
|
| 543 |
+
# Debug logging (see FAISS responses)
|
| 544 |
# for resp, score in boosted[:100]:
|
| 545 |
# logger.debug(f"Candidate: '{resp}' with score {score}")
|
| 546 |
|
|
|
|
| 554 |
top_k: int = 10,
|
| 555 |
) -> Tuple[str, List[Tuple[str, float]], Dict[str, Any]]:
|
| 556 |
"""
|
| 557 |
+
Live chat with the chatbot. Uses same processing flow as validation, except for context handling and quality checking.
|
|
|
|
| 558 |
"""
|
| 559 |
@self.run_on_device
|
| 560 |
def get_response(self_arg, query_arg):
|
|
|
|
| 562 |
conversation_str = self_arg._build_conversation_context(query_arg, conversation_history)
|
| 563 |
|
| 564 |
# Retrieve and re-rank
|
| 565 |
+
results = self_arg.retrieve_responses(
|
| 566 |
query=conversation_str,
|
| 567 |
top_k=top_k,
|
| 568 |
reranker=self_arg.reranker,
|
|
|
|
| 586 |
query: str,
|
| 587 |
conversation_history: Optional[List[Tuple[str, str]]]
|
| 588 |
) -> str:
|
| 589 |
+
"""
|
| 590 |
+
Build conversation context string from conversation history.
|
| 591 |
+
"""
|
| 592 |
if not conversation_history:
|
| 593 |
return f"{self.tokenizer.additional_special_tokens[self.tokenizer.additional_special_tokens.index('<USER>')]} {query}"
|
| 594 |
|
|
|
|
| 619 |
) -> None:
|
| 620 |
"""
|
| 621 |
Train the retrieval model using a pre-prepared TFRecord dataset.
|
|
|
|
| 622 |
- Checkpoint loading/restoring
|
| 623 |
- LR scheduling
|
| 624 |
- Epoch/iteration tracking
|
| 625 |
+
- Training-history logging
|
| 626 |
+
- Early stopping
|
| 627 |
+
- Custom loss function (Contrastive loss with hard negative sampling))
|
| 628 |
"""
|
| 629 |
logger.info("Starting training with pre-prepared TFRecord dataset...")
|
| 630 |
|
|
|
|
| 656 |
steps_per_epoch = math.ceil(train_size / batch_size)
|
| 657 |
val_steps = math.ceil(val_size / batch_size)
|
| 658 |
total_steps = steps_per_epoch * epochs
|
| 659 |
+
buffer_size = max(1, total_pairs // 2) # 50% of the dataset for shuffling
|
| 660 |
|
| 661 |
logger.info(f"Training pairs: {train_size}")
|
| 662 |
logger.info(f"Validation pairs: {val_size}")
|
|
|
|
| 678 |
self.optimizer = tf.keras.optimizers.Adam(learning_rate=tf.cast(peak_lr, tf.float32))
|
| 679 |
logger.info("Using fixed learning rate.")
|
| 680 |
|
| 681 |
+
# Dummy step to force initialization
|
| 682 |
dummy_input = tf.zeros((1, self.config.max_context_token_limit), dtype=tf.int32)
|
| 683 |
with tf.GradientTape() as tape:
|
| 684 |
dummy_output = self.encoder(dummy_input)
|
|
|
|
| 693 |
model=self.encoder
|
| 694 |
)
|
| 695 |
|
| 696 |
+
# Create a CheckpointManager
|
| 697 |
manager = tf.train.CheckpointManager(
|
| 698 |
checkpoint,
|
| 699 |
directory=checkpoint_dir,
|
|
|
|
| 701 |
checkpoint_name='ckpt'
|
| 702 |
)
|
| 703 |
|
| 704 |
+
# Restore from existing checkpoint if one is provided
|
| 705 |
latest_checkpoint = manager.latest_checkpoint
|
| 706 |
history_path = Path(checkpoint_dir) / 'training_history.json'
|
| 707 |
|
| 708 |
+
# Log epoch losses across runs, including restore from checkpoint
|
| 709 |
if not hasattr(self, 'history'):
|
| 710 |
self.history = {'train_loss': [], 'val_loss': [], 'learning_rate': []}
|
| 711 |
|
| 712 |
if latest_checkpoint and not test_mode:
|
| 713 |
+
# Debug checkpoint loading
|
| 714 |
+
# logger.info(f"\nTrying to load checkpoint from: {latest_checkpoint}")
|
| 715 |
+
# reader = tf.train.load_checkpoint(latest_checkpoint)
|
| 716 |
# shape_from_key = reader.get_variable_to_shape_map()
|
| 717 |
# dtype_from_key = reader.get_variable_to_dtype_map()
|
| 718 |
# logger.info("\nCheckpoint Variables:")
|
|
|
|
| 736 |
if initial_epoch == 0:
|
| 737 |
initial_epoch = ckpt_number
|
| 738 |
|
| 739 |
+
# Assign to checkpoint.epoch for counting
|
| 740 |
checkpoint.epoch.assign(tf.cast(initial_epoch, tf.int32))
|
| 741 |
logger.info(f"Resuming from epoch {initial_epoch}")
|
| 742 |
|
| 743 |
+
# Load history from file:
|
| 744 |
if history_path.exists():
|
| 745 |
try:
|
| 746 |
with open(history_path, 'r') as f:
|
|
|
|
| 749 |
except Exception as e:
|
| 750 |
logger.warning(f"Could not load history, starting fresh: {e}")
|
| 751 |
|
| 752 |
+
# Save custom weights not being saved in the full model.
|
| 753 |
+
# This was a bugfix to extract weights from a checkpoint without retraining.
|
| 754 |
+
# Before updating save_models, only Distilbert weights were being saved (custom layers were missed).
|
| 755 |
+
# Not needed, also not harmful.
|
| 756 |
self.save_models(Path(checkpoint_dir) / "pretrained_full_model")
|
| 757 |
logger.info(f"Manually saved custom weights after restore.")
|
| 758 |
else:
|
|
|
|
| 769 |
train_summary_writer = tf.summary.create_file_writer(train_log_dir)
|
| 770 |
val_summary_writer = tf.summary.create_file_writer(val_log_dir)
|
| 771 |
logger.info(f"TensorBoard logs will be saved in {log_dir}")
|
| 772 |
+
|
| 773 |
# Parse dataset
|
| 774 |
dataset = tf.data.TFRecordDataset(tfrecord_file_path)
|
| 775 |
+
|
| 776 |
+
# Debug mode uses small subset. Useful for CPU debugging.
|
| 777 |
if test_mode:
|
| 778 |
+
subset_size = 200
|
| 779 |
dataset = dataset.take(subset_size)
|
| 780 |
logger.info(f"TEST MODE: Using only {subset_size} examples")
|
| 781 |
# Recompute sizes, steps, epochs, etc., as needed
|
|
|
|
| 791 |
early_stopping_patience = 2
|
| 792 |
logger.info(f"New training pairs: {train_size}")
|
| 793 |
logger.info(f"New validation pairs: {val_size}")
|
| 794 |
+
|
| 795 |
dataset = dataset.map(
|
| 796 |
+
lambda x: parse_tfrecord_fn(x, self.config.max_context_token_limit, self.data_pipeline.neg_samples),
|
| 797 |
num_parallel_calls=tf.data.AUTOTUNE
|
| 798 |
)
|
| 799 |
+
|
| 800 |
# Train/val split
|
| 801 |
train_dataset = dataset.take(train_size)
|
| 802 |
val_dataset = dataset.skip(train_size).take(val_size)
|
| 803 |
+
|
| 804 |
# Shuffle and batch
|
| 805 |
train_dataset = train_dataset.shuffle(buffer_size=buffer_size)
|
| 806 |
train_dataset = train_dataset.batch(batch_size, drop_remainder=True)
|
| 807 |
train_dataset = train_dataset.prefetch(tf.data.AUTOTUNE)
|
| 808 |
+
|
| 809 |
val_dataset = val_dataset.batch(batch_size, drop_remainder=False)
|
| 810 |
val_dataset = val_dataset.prefetch(tf.data.AUTOTUNE)
|
| 811 |
val_dataset = val_dataset.cache()
|
| 812 |
+
|
| 813 |
# Training loop
|
| 814 |
best_val_loss = float("inf")
|
| 815 |
epochs_no_improve = 0
|
| 816 |
+
|
| 817 |
for epoch in range(int(checkpoint.epoch.numpy()) + 1, epochs + 1):
|
| 818 |
checkpoint.epoch.assign(epoch)
|
| 819 |
logger.info(f"Starting Epoch {epoch}...")
|
| 820 |
+
|
|
|
|
| 821 |
epoch_loss_avg = tf.keras.metrics.Mean(dtype=tf.float32)
|
| 822 |
batches_processed = 0
|
| 823 |
+
|
|
|
|
| 824 |
try:
|
| 825 |
train_pbar = tqdm(
|
| 826 |
total=steps_per_epoch,
|
|
|
|
| 831 |
except ImportError:
|
| 832 |
train_pbar = None
|
| 833 |
is_tqdm_train = False
|
| 834 |
+
|
| 835 |
+
# --- Training ---
|
| 836 |
for q_batch, p_batch, n_batch in train_dataset:
|
| 837 |
loss, grad_norm, post_clip_norm = self.train_step(q_batch, p_batch, n_batch)
|
| 838 |
epoch_loss_avg(loss)
|
|
|
|
| 860 |
"lr": f"{current_lr:.2e}",
|
| 861 |
"batches": f"{batches_processed}/{steps_per_epoch}"
|
| 862 |
})
|
| 863 |
+
|
| 864 |
gc.collect()
|
| 865 |
+
|
| 866 |
# End the epoch early if we've processed all steps
|
| 867 |
if batches_processed >= steps_per_epoch:
|
| 868 |
break
|
| 869 |
+
|
| 870 |
if is_tqdm_train and train_pbar:
|
| 871 |
train_pbar.close()
|
| 872 |
+
|
| 873 |
+
# --- Validation ---
|
| 874 |
val_loss_avg = tf.keras.metrics.Mean(dtype=tf.float32)
|
| 875 |
val_batches_processed = 0
|
| 876 |
+
|
| 877 |
try:
|
| 878 |
val_pbar = tqdm(total=val_steps, desc="Validation", unit="batch")
|
| 879 |
is_tqdm_val = True
|
| 880 |
except ImportError:
|
| 881 |
val_pbar = None
|
| 882 |
is_tqdm_val = False
|
| 883 |
+
|
| 884 |
last_valid_val_loss = None
|
| 885 |
valid_batches = False
|
| 886 |
+
|
| 887 |
for q_batch, p_batch, n_batch in val_dataset:
|
| 888 |
# If batch is too small, skip
|
| 889 |
if tf.shape(q_batch)[0] < 2:
|
| 890 |
logger.warning(f"Skipping validation batch of size {tf.shape(q_batch)[0]}")
|
| 891 |
continue
|
| 892 |
+
|
| 893 |
valid_batches = True
|
| 894 |
val_loss = self.validation_step(q_batch, p_batch, n_batch)
|
| 895 |
val_loss_avg(val_loss)
|
| 896 |
last_valid_val_loss = val_loss
|
| 897 |
val_batches_processed += 1
|
| 898 |
+
|
| 899 |
if is_tqdm_val:
|
| 900 |
val_pbar.update(1)
|
| 901 |
val_pbar.set_postfix({
|
| 902 |
"val_loss": f"{val_loss.numpy():.4f}",
|
| 903 |
"batches": f"{val_batches_processed}/{val_steps}"
|
| 904 |
})
|
| 905 |
+
|
| 906 |
gc.collect()
|
| 907 |
+
|
| 908 |
if val_batches_processed >= val_steps:
|
| 909 |
break
|
| 910 |
+
|
| 911 |
if not valid_batches:
|
| 912 |
# If no valid batch is found, fallback
|
| 913 |
logger.warning("No valid validation batches in this epoch")
|
|
|
|
| 917 |
else:
|
| 918 |
val_loss = epoch_loss_avg.result()
|
| 919 |
val_loss_avg(val_loss)
|
| 920 |
+
|
| 921 |
if is_tqdm_val and val_pbar:
|
| 922 |
val_pbar.close()
|
| 923 |
+
|
| 924 |
# End of epoch: final stats
|
| 925 |
train_loss = epoch_loss_avg.result().numpy()
|
| 926 |
val_loss = val_loss_avg.result().numpy()
|
| 927 |
logger.info(f"Epoch {epoch} Complete: Train Loss={train_loss:.4f}, Val Loss={val_loss:.4f}")
|
| 928 |
+
|
| 929 |
# TensorBoard epoch logs
|
| 930 |
with train_summary_writer.as_default():
|
| 931 |
tf.summary.scalar("epoch_loss", train_loss, step=epoch)
|
| 932 |
with val_summary_writer.as_default():
|
| 933 |
tf.summary.scalar("val_loss", val_loss, step=epoch)
|
| 934 |
+
|
| 935 |
# Save checkpoint
|
| 936 |
manager.save()
|
| 937 |
+
|
| 938 |
+
# Save model for iterative testing/inference
|
| 939 |
model_save_path = Path(checkpoint_dir) / f"model_epoch_{epoch}"
|
| 940 |
self.save_models(model_save_path)
|
| 941 |
logger.info(f"Saved model for epoch {epoch} at {model_save_path}")
|
| 942 |
+
|
| 943 |
# Update local history
|
| 944 |
self.history['train_loss'].append(train_loss)
|
| 945 |
self.history['val_loss'].append(val_loss)
|
|
|
|
| 958 |
return obj
|
| 959 |
|
| 960 |
json_history = convert_to_py_floats(self.history)
|
| 961 |
+
|
| 962 |
# Save training history to file every epoch
|
|
|
|
| 963 |
with open(history_path, 'w') as f:
|
| 964 |
json.dump(json_history, f)
|
| 965 |
logger.info(f"Saved training history to {history_path}")
|
| 966 |
+
|
| 967 |
# Early stopping
|
| 968 |
if val_loss < best_val_loss - min_delta:
|
| 969 |
best_val_loss = val_loss
|
|
|
|
| 975 |
if epochs_no_improve >= early_stopping_patience:
|
| 976 |
logger.info("Early stopping triggered.")
|
| 977 |
break
|
| 978 |
+
|
| 979 |
logger.info("Training completed!")
|
| 980 |
|
| 981 |
@tf.function
|
|
|
|
| 989 |
Single training step using queries, positives, and hard negatives.
|
| 990 |
"""
|
| 991 |
with tf.GradientTape() as tape:
|
| 992 |
+
# Encode queries, positives, and negatives
|
| 993 |
q_enc = self.encoder(q_batch, training=True) # [batch_size, embed_dim]
|
|
|
|
|
|
|
| 994 |
p_enc = self.encoder(p_batch, training=True) # [batch_size, embed_dim]
|
|
|
|
|
|
|
|
|
|
| 995 |
shape = tf.shape(n_batch)
|
| 996 |
bs = shape[0]
|
| 997 |
neg_samples = shape[1]
|
| 998 |
|
| 999 |
+
# Flatten negatives to feed them in one pass: [batch_size * neg_samples, max_length]
|
|
|
|
| 1000 |
n_batch_flat = tf.reshape(n_batch, [bs * neg_samples, shape[2]])
|
| 1001 |
n_enc_flat = self.encoder(n_batch_flat, training=True) # [bs*neg_samples, embed_dim]
|
| 1002 |
|
| 1003 |
# Reshape back => [batch_size, neg_samples, embed_dim]
|
| 1004 |
n_enc = tf.reshape(n_enc_flat, [bs, neg_samples, -1])
|
| 1005 |
|
| 1006 |
+
# Combine the positive embedding and negative embeddings along dim=1: shape [batch_size, 1 + neg_samples, embed_dim]
|
| 1007 |
+
# Col 1 is the pos, subsequent cols are negatives
|
| 1008 |
+
combined_p_n = tf.concat([tf.expand_dims(p_enc, axis=1), n_enc], axis=1) # [bs, (1+neg_samples), embed_dim]
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1009 |
|
| 1010 |
+
# Compute scores: dot product of q_enc with each column in combined_p_n. `tf.einsum` handles the batch dimension
|
|
|
|
|
|
|
| 1011 |
dot_products = tf.cast(tf.einsum('bd,bkd->bk', q_enc, combined_p_n), tf.float32)
|
| 1012 |
labels = tf.zeros([bs], dtype=tf.int32) # Keep labels as int32
|
| 1013 |
loss = tf.nn.sparse_softmax_cross_entropy_with_logits(
|
|
|
|
| 1016 |
)
|
| 1017 |
loss = tf.cast(tf.reduce_mean(loss), tf.float32)
|
| 1018 |
|
| 1019 |
+
# Calculate gradients and clip
|
| 1020 |
gradients = tape.gradient(loss, self.encoder.trainable_variables)
|
| 1021 |
gradients_norm = tf.cast(tf.linalg.global_norm(gradients), tf.float32)
|
| 1022 |
max_grad_norm = tf.constant(1.5, dtype=tf.float32)
|
| 1023 |
gradients, _ = tf.clip_by_global_norm(gradients, max_grad_norm, gradients_norm)
|
| 1024 |
post_clip_norm = tf.cast(tf.linalg.global_norm(gradients), tf.float32)
|
| 1025 |
|
|
|
|
| 1026 |
self.optimizer.apply_gradients(zip(gradients, self.encoder.trainable_variables))
|
| 1027 |
|
| 1028 |
return loss, gradients_norm, post_clip_norm
|
|
|
|
| 1036 |
) -> tf.Tensor:
|
| 1037 |
"""
|
| 1038 |
Single validation step using queries, positives, and hard negatives.
|
| 1039 |
+
Same idea as train_step, but without gradient updates.
|
| 1040 |
"""
|
| 1041 |
q_enc = self.encoder(q_batch, training=False)
|
| 1042 |
p_enc = self.encoder(p_batch, training=False)
|
|
|
|
| 1055 |
)
|
| 1056 |
|
| 1057 |
dot_products = tf.cast(tf.einsum('bd,bkd->bk', q_enc, combined_p_n), tf.float32)
|
| 1058 |
+
labels = tf.zeros([bs], dtype=tf.int32)
|
| 1059 |
|
| 1060 |
loss = tf.nn.sparse_softmax_cross_entropy_with_logits(
|
| 1061 |
labels=labels,
|
|
|
|
| 1071 |
peak_lr: float,
|
| 1072 |
warmup_steps: int
|
| 1073 |
) -> tf.keras.optimizers.schedules.LearningRateSchedule:
|
| 1074 |
+
"""
|
| 1075 |
+
Custom learning rate schedule with warmup and cosine decay.
|
| 1076 |
+
"""
|
| 1077 |
class CustomSchedule(tf.keras.optimizers.schedules.LearningRateSchedule):
|
| 1078 |
def __init__(
|
| 1079 |
self,
|
|
|
|
| 1085 |
self.total_steps = tf.cast(total_steps, tf.float32)
|
| 1086 |
self.peak_lr = tf.cast(peak_lr, tf.float32)
|
| 1087 |
|
| 1088 |
+
# warmup_steps 10% of total_steps
|
| 1089 |
adjusted_warmup_steps = min(warmup_steps, max(1, total_steps // 10))
|
| 1090 |
self.warmup_steps = tf.cast(adjusted_warmup_steps, tf.float32)
|
| 1091 |
|
| 1092 |
+
# Calculate constants
|
| 1093 |
self.initial_lr = tf.cast(self.peak_lr * 0.1, tf.float32)
|
| 1094 |
self.min_lr = tf.cast(self.peak_lr * 0.01, tf.float32)
|
| 1095 |
|
|
|
|
| 1103 |
def __call__(self, step):
|
| 1104 |
step = tf.cast(step, tf.float32)
|
| 1105 |
|
| 1106 |
+
# Warmup
|
| 1107 |
warmup_factor = tf.cast(tf.minimum(1.0, step / self.warmup_steps), tf.float32)
|
| 1108 |
warmup_lr = self.initial_lr + (self.peak_lr - self.initial_lr) * warmup_factor
|
| 1109 |
|
| 1110 |
+
# Decay
|
| 1111 |
decay_steps = tf.cast(tf.maximum(1.0, self.total_steps - self.warmup_steps), tf.float32)
|
| 1112 |
decay_factor = tf.cast((step - self.warmup_steps) / decay_steps, tf.float32)
|
| 1113 |
decay_factor = tf.cast(tf.minimum(tf.maximum(0.0, decay_factor), 1.0), tf.float32)
|
| 1114 |
cosine_decay = tf.cast(0.5 * (1.0 + tf.cos(tf.constant(math.pi, dtype=tf.float32) * decay_factor)), tf.float32)
|
| 1115 |
decay_lr = self.min_lr + (self.peak_lr - self.min_lr) * cosine_decay
|
| 1116 |
|
|
|
|
| 1117 |
final_lr = tf.where(step < self.warmup_steps, warmup_lr, decay_lr)
|
| 1118 |
|
| 1119 |
+
# Ensure valid lr
|
| 1120 |
final_lr = tf.maximum(self.min_lr, final_lr)
|
| 1121 |
final_lr = tf.where(tf.math.is_finite(final_lr), final_lr, self.min_lr)
|
| 1122 |
|
chatbot_validator.py
CHANGED
|
@@ -113,7 +113,7 @@ class ChatbotValidator:
|
|
| 113 |
logger.info(f"\nTest Case {i}: {query}")
|
| 114 |
|
| 115 |
# Retrieve top_k responses, then evaluate with quality checker
|
| 116 |
-
responses = self.chatbot.
|
| 117 |
quality_metrics = self.quality_checker.check_response_quality(query, responses)
|
| 118 |
|
| 119 |
# Aggregate metrics and log
|
|
|
|
| 113 |
logger.info(f"\nTest Case {i}: {query}")
|
| 114 |
|
| 115 |
# Retrieve top_k responses, then evaluate with quality checker
|
| 116 |
+
responses = self.chatbot.retrieve_responses(query, top_k=top_k, reranker=reranker)
|
| 117 |
quality_metrics = self.quality_checker.check_response_quality(query, responses)
|
| 118 |
|
| 119 |
# Aggregate metrics and log
|
{data_augmentation β data_augmentation_code}/augmentation_processing_pipeline.py
RENAMED
|
File without changes
|
{data_augmentation β data_augmentation_code}/back_translator.py
RENAMED
|
File without changes
|
{data_augmentation β data_augmentation_code}/dialogue_augmenter.py
RENAMED
|
File without changes
|
{data_augmentation β data_augmentation_code}/main.py
RENAMED
|
File without changes
|
{data_augmentation β data_augmentation_code}/paraphraser.py
RENAMED
|
File without changes
|
{data_augmentation β data_augmentation_code}/pipeline_config.py
RENAMED
|
File without changes
|
{data_augmentation β data_augmentation_code}/quality_metrics.py
RENAMED
|
File without changes
|
{data_augmentation β data_augmentation_code}/schema_guided_dialogue_processor.py
RENAMED
|
File without changes
|
{data_augmentation β data_augmentation_code}/taskmaster_processor.py
RENAMED
|
File without changes
|
validate_model.py β run_chatbot_validation.py
RENAMED
|
@@ -39,7 +39,7 @@ def run_interactive_chat(chatbot, quality_checker):
|
|
| 39 |
else:
|
| 40 |
print("\n[Low Confidence]: Consider rephrasing your query for better assistance.")
|
| 41 |
|
| 42 |
-
def
|
| 43 |
# Initialize environment
|
| 44 |
env = EnvironmentSetup()
|
| 45 |
env.initialize()
|
|
@@ -86,15 +86,15 @@ def validate_chatbot():
|
|
| 86 |
try:
|
| 87 |
chatbot.data_pipeline.load_faiss_index(FAISS_INDEX_PATH)
|
| 88 |
logger.info(f"FAISS index loaded from {FAISS_INDEX_PATH}.")
|
| 89 |
-
logger.info("FAISS dimensions:
|
| 90 |
-
logger.info("FAISS index type:
|
| 91 |
-
logger.info("FAISS index total vectors:
|
| 92 |
-
logger.info("FAISS is_trained:
|
| 93 |
|
| 94 |
with open(RESPONSE_POOL_PATH, "r", encoding="utf-8") as f:
|
| 95 |
chatbot.data_pipeline.response_pool = json.load(f)
|
| 96 |
logger.info(f"Response pool loaded from {RESPONSE_POOL_PATH}.")
|
| 97 |
-
logger.info("\nTotal responses in pool:
|
| 98 |
|
| 99 |
# Validate dimension consistency
|
| 100 |
chatbot.data_pipeline.validate_faiss_index()
|
|
@@ -130,4 +130,4 @@ def validate_chatbot():
|
|
| 130 |
run_interactive_chat(chatbot, quality_checker)
|
| 131 |
|
| 132 |
if __name__ == "__main__":
|
| 133 |
-
|
|
|
|
| 39 |
else:
|
| 40 |
print("\n[Low Confidence]: Consider rephrasing your query for better assistance.")
|
| 41 |
|
| 42 |
+
def run_chatbot_validation():
|
| 43 |
# Initialize environment
|
| 44 |
env = EnvironmentSetup()
|
| 45 |
env.initialize()
|
|
|
|
| 86 |
try:
|
| 87 |
chatbot.data_pipeline.load_faiss_index(FAISS_INDEX_PATH)
|
| 88 |
logger.info(f"FAISS index loaded from {FAISS_INDEX_PATH}.")
|
| 89 |
+
logger.info(f"FAISS dimensions: {chatbot.data_pipeline.index.d}")
|
| 90 |
+
logger.info(f"FAISS index type: {type(chatbot.data_pipeline.index)}")
|
| 91 |
+
logger.info(f"FAISS index total vectors: {chatbot.data_pipeline.index.ntotal}")
|
| 92 |
+
logger.info(f"FAISS is_trained: {chatbot.data_pipeline.index.is_trained}")
|
| 93 |
|
| 94 |
with open(RESPONSE_POOL_PATH, "r", encoding="utf-8") as f:
|
| 95 |
chatbot.data_pipeline.response_pool = json.load(f)
|
| 96 |
logger.info(f"Response pool loaded from {RESPONSE_POOL_PATH}.")
|
| 97 |
+
logger.info(f"\nTotal responses in pool: {len(chatbot.data_pipeline.response_pool)}")
|
| 98 |
|
| 99 |
# Validate dimension consistency
|
| 100 |
chatbot.data_pipeline.validate_faiss_index()
|
|
|
|
| 130 |
run_interactive_chat(chatbot, quality_checker)
|
| 131 |
|
| 132 |
if __name__ == "__main__":
|
| 133 |
+
run_chatbot_validation()
|
tf_data_pipeline.py
CHANGED
|
@@ -24,19 +24,19 @@ class TFDataPipeline:
|
|
| 24 |
config,
|
| 25 |
tokenizer,
|
| 26 |
encoder,
|
| 27 |
-
index_file_path: str,
|
| 28 |
response_pool: List[str],
|
| 29 |
-
max_length: int,
|
| 30 |
query_embeddings_cache: dict,
|
| 31 |
-
|
|
|
|
| 32 |
index_type: str = 'IndexFlatIP',
|
|
|
|
| 33 |
nlist: int = 100,
|
| 34 |
max_retries: int = 3
|
| 35 |
):
|
| 36 |
self.config = config
|
| 37 |
self.tokenizer = tokenizer
|
| 38 |
self.encoder = encoder
|
| 39 |
-
self.
|
| 40 |
self.response_pool = response_pool
|
| 41 |
self.max_length = max_length
|
| 42 |
self.neg_samples = neg_samples
|
|
@@ -53,9 +53,9 @@ class TFDataPipeline:
|
|
| 53 |
self.build_text_to_domain_map()
|
| 54 |
|
| 55 |
# Initialize FAISS index
|
| 56 |
-
if os.path.exists(
|
| 57 |
-
logger.info(f"Loading existing FAISS index from {
|
| 58 |
-
self.index = faiss.read_index(
|
| 59 |
self.validate_faiss_index()
|
| 60 |
logger.info("FAISS index loaded and validated successfully.")
|
| 61 |
else:
|
|
@@ -83,18 +83,18 @@ class TFDataPipeline:
|
|
| 83 |
self.query_embeddings_cache[query] = hf[query][:]
|
| 84 |
logger.info(f"Embeddings cache loaded from {cache_file_path}.")
|
| 85 |
|
| 86 |
-
def save_faiss_index(self,
|
| 87 |
-
faiss.write_index(self.index,
|
| 88 |
-
logger.info(f"FAISS index saved to {
|
| 89 |
|
| 90 |
-
def load_faiss_index(self,
|
| 91 |
"""Load FAISS index from specified file path."""
|
| 92 |
-
if os.path.exists(
|
| 93 |
-
self.index = faiss.read_index(
|
| 94 |
-
logger.info(f"FAISS index loaded from {
|
| 95 |
else:
|
| 96 |
-
logger.error(f"FAISS index file not found at {
|
| 97 |
-
raise FileNotFoundError(f"FAISS index file not found at {
|
| 98 |
|
| 99 |
def validate_faiss_index(self):
|
| 100 |
"""Validates FAISS index dimensionality."""
|
|
|
|
| 24 |
config,
|
| 25 |
tokenizer,
|
| 26 |
encoder,
|
|
|
|
| 27 |
response_pool: List[str],
|
|
|
|
| 28 |
query_embeddings_cache: dict,
|
| 29 |
+
max_length: int = 512,
|
| 30 |
+
neg_samples: int = 10,
|
| 31 |
index_type: str = 'IndexFlatIP',
|
| 32 |
+
faiss_index_file_path: str = 'new_iteration/data_prep_iterative_models/faiss_indices/faiss_index_production.index',
|
| 33 |
nlist: int = 100,
|
| 34 |
max_retries: int = 3
|
| 35 |
):
|
| 36 |
self.config = config
|
| 37 |
self.tokenizer = tokenizer
|
| 38 |
self.encoder = encoder
|
| 39 |
+
self.faiss_index_file_path = faiss_index_file_path
|
| 40 |
self.response_pool = response_pool
|
| 41 |
self.max_length = max_length
|
| 42 |
self.neg_samples = neg_samples
|
|
|
|
| 53 |
self.build_text_to_domain_map()
|
| 54 |
|
| 55 |
# Initialize FAISS index
|
| 56 |
+
if os.path.exists(faiss_index_file_path):
|
| 57 |
+
logger.info(f"Loading existing FAISS index from {faiss_index_file_path}...")
|
| 58 |
+
self.index = faiss.read_index(faiss_index_file_path)
|
| 59 |
self.validate_faiss_index()
|
| 60 |
logger.info("FAISS index loaded and validated successfully.")
|
| 61 |
else:
|
|
|
|
| 83 |
self.query_embeddings_cache[query] = hf[query][:]
|
| 84 |
logger.info(f"Embeddings cache loaded from {cache_file_path}.")
|
| 85 |
|
| 86 |
+
def save_faiss_index(self, faiss_index_file_path: str):
|
| 87 |
+
faiss.write_index(self.index, faiss_index_file_path)
|
| 88 |
+
logger.info(f"FAISS index saved to {faiss_index_file_path}")
|
| 89 |
|
| 90 |
+
def load_faiss_index(self, faiss_index_file_path: str):
|
| 91 |
"""Load FAISS index from specified file path."""
|
| 92 |
+
if os.path.exists(faiss_index_file_path):
|
| 93 |
+
self.index = faiss.read_index(faiss_index_file_path)
|
| 94 |
+
logger.info(f"FAISS index loaded from {faiss_index_file_path}.")
|
| 95 |
else:
|
| 96 |
+
logger.error(f"FAISS index file not found at {faiss_index_file_path}.")
|
| 97 |
+
raise FileNotFoundError(f"FAISS index file not found at {faiss_index_file_path}.")
|
| 98 |
|
| 99 |
def validate_faiss_index(self):
|
| 100 |
"""Validates FAISS index dimensionality."""
|