JoeArmani
commited on
Commit
Β·
d7fc7a7
1
Parent(s):
c111c20
more structural updates
Browse files- .gitignore +2 -0
- chatbot_model.py +145 -171
- chatbot_validator.py +1 -1
- {data_augmentation β data_augmentation_code}/augmentation_processing_pipeline.py +0 -0
- {data_augmentation β data_augmentation_code}/back_translator.py +0 -0
- {data_augmentation β data_augmentation_code}/dialogue_augmenter.py +0 -0
- {data_augmentation β data_augmentation_code}/main.py +0 -0
- {data_augmentation β data_augmentation_code}/paraphraser.py +0 -0
- {data_augmentation β data_augmentation_code}/pipeline_config.py +0 -0
- {data_augmentation β data_augmentation_code}/quality_metrics.py +0 -0
- {data_augmentation β data_augmentation_code}/schema_guided_dialogue_processor.py +0 -0
- {data_augmentation β data_augmentation_code}/taskmaster_processor.py +0 -0
- validate_model.py β run_chatbot_validation.py +7 -7
- tf_data_pipeline.py +16 -16
.gitignore
CHANGED
@@ -187,3 +187,5 @@ new_iteration/cache/*
|
|
187 |
new_iteration/data_prep_iterative_models/*
|
188 |
new_iteration/training_data/*
|
189 |
new_iteration/processed_outputs/*
|
|
|
|
|
|
187 |
new_iteration/data_prep_iterative_models/*
|
188 |
new_iteration/training_data/*
|
189 |
new_iteration/processed_outputs/*
|
190 |
+
raw_datasets/*
|
191 |
+
|
chatbot_model.py
CHANGED
@@ -24,25 +24,25 @@ logger = config_logger(__name__)
|
|
24 |
|
25 |
@dataclass
|
26 |
class ChatbotConfig:
|
27 |
-
"""
|
28 |
max_context_token_limit: int = 512
|
29 |
embedding_dim: int = 768
|
30 |
encoder_units: int = 256
|
31 |
num_attention_heads: int = 8
|
32 |
dropout_rate: float = 0.2
|
33 |
l2_reg_weight: float = 0.001
|
34 |
-
learning_rate: float = 0.
|
35 |
min_text_length: int = 3
|
36 |
-
max_context_turns: int =
|
37 |
warmup_steps: int = 200
|
38 |
pretrained_model: str = 'distilbert-base-uncased'
|
39 |
cross_encoder_model: str = 'cross-encoder/ms-marco-MiniLM-L-12-v2'
|
|
|
40 |
dtype: str = 'float32'
|
41 |
freeze_embeddings: bool = False
|
42 |
embedding_batch_size: int = 64
|
43 |
search_batch_size: int = 64
|
44 |
max_batch_size: int = 64
|
45 |
-
neg_samples: int = 10
|
46 |
max_retries: int = 3
|
47 |
|
48 |
def to_dict(self) -> Dict:
|
@@ -57,7 +57,7 @@ class ChatbotConfig:
|
|
57 |
if k in cls.__dataclass_fields__})
|
58 |
|
59 |
class EncoderModel(tf.keras.Model):
|
60 |
-
"""Dual encoder model with pretrained embeddings."""
|
61 |
def __init__(
|
62 |
self,
|
63 |
config: ChatbotConfig,
|
@@ -71,7 +71,7 @@ class EncoderModel(tf.keras.Model):
|
|
71 |
self.pretrained = TFAutoModel.from_pretrained(config.pretrained_model)
|
72 |
self._freeze_layers()
|
73 |
|
74 |
-
# Add
|
75 |
self.pooler = tf.keras.layers.GlobalAveragePooling1D()
|
76 |
self.projection = tf.keras.layers.Dense(
|
77 |
config.embedding_dim,
|
@@ -86,7 +86,7 @@ class EncoderModel(tf.keras.Model):
|
|
86 |
)
|
87 |
|
88 |
def _freeze_layers(self):
|
89 |
-
"""Freeze layers of the pretrained model
|
90 |
if self.config.freeze_embeddings:
|
91 |
self.pretrained.trainable = False
|
92 |
logger.info("All pretrained layers frozen.")
|
@@ -95,29 +95,29 @@ class EncoderModel(tf.keras.Model):
|
|
95 |
for i, layer in enumerate(self.pretrained.layers):
|
96 |
if isinstance(layer, tf.keras.layers.Layer):
|
97 |
if hasattr(layer, 'trainable'):
|
98 |
-
# Freeze the first transformer block
|
99 |
if i < 1:
|
100 |
layer.trainable = False
|
101 |
logger.info(f"Layer {i} frozen.")
|
102 |
else:
|
103 |
layer.trainable = True
|
|
|
104 |
|
105 |
def call(self, inputs: tf.Tensor, training: bool = False) -> tf.Tensor:
|
106 |
"""Forward pass."""
|
107 |
# Get pretrained embeddings
|
108 |
pretrained_outputs = self.pretrained(inputs, training=training)
|
109 |
-
x = pretrained_outputs.last_hidden_state
|
110 |
|
111 |
# Apply pooling, projection, dropout, and normalization
|
112 |
-
x = self.pooler(x)
|
113 |
-
x = self.projection(x)
|
114 |
x = self.dropout(x, training=training)
|
115 |
-
x = self.normalize(x)
|
116 |
|
117 |
return x
|
118 |
|
119 |
def get_config(self) -> dict:
|
120 |
-
"""Return the config
|
121 |
config = super().get_config()
|
122 |
config.update({
|
123 |
"config": self.config.to_dict(),
|
@@ -126,7 +126,10 @@ class EncoderModel(tf.keras.Model):
|
|
126 |
return config
|
127 |
|
128 |
class RetrievalChatbot(DeviceAwareModel):
|
129 |
-
"""
|
|
|
|
|
|
|
130 |
def __init__(
|
131 |
self,
|
132 |
config: ChatbotConfig,
|
@@ -142,7 +145,7 @@ class RetrievalChatbot(DeviceAwareModel):
|
|
142 |
self.device = device or self._setup_default_device()
|
143 |
self.mode = mode.lower()
|
144 |
|
145 |
-
# Initialize reranker, summarizer, tokenizer,
|
146 |
self.reranker = reranker or self._initialize_reranker()
|
147 |
self.tokenizer = self._initialize_tokenizer()
|
148 |
self.encoder = self._initialize_encoder()
|
@@ -154,14 +157,9 @@ class RetrievalChatbot(DeviceAwareModel):
|
|
154 |
config=self.config,
|
155 |
tokenizer=self.tokenizer,
|
156 |
encoder=self.encoder,
|
157 |
-
index_file_path='new_iteration/data_prep_iterative_models/faiss_indices/faiss_index_production.index',
|
158 |
response_pool=[],
|
159 |
max_length=self.config.max_context_token_limit,
|
160 |
query_embeddings_cache={},
|
161 |
-
neg_samples=self.config.neg_samples,
|
162 |
-
index_type='IndexFlatIP',
|
163 |
-
nlist=100, # Not used with IndexFlatIP
|
164 |
-
max_retries=self.config.max_retries
|
165 |
)
|
166 |
|
167 |
# Collect unique responses from dialogues
|
@@ -197,7 +195,7 @@ class RetrievalChatbot(DeviceAwareModel):
|
|
197 |
"""Initialize the Summarizer."""
|
198 |
return Summarizer(
|
199 |
tokenizer=self.tokenizer,
|
200 |
-
model_name=
|
201 |
max_summary_length=self.config.max_context_token_limit // 4,
|
202 |
device=self.device,
|
203 |
max_summary_rounds=2
|
@@ -229,17 +227,18 @@ class RetrievalChatbot(DeviceAwareModel):
|
|
229 |
new_vocab_size = len(self.tokenizer)
|
230 |
encoder.pretrained.resize_token_embeddings(new_vocab_size)
|
231 |
logger.info(f"Token embeddings resized to: {new_vocab_size}")
|
|
|
232 |
return encoder
|
233 |
|
234 |
def _load_faiss_index_and_responses(self) -> None:
|
235 |
"""Load FAISS index and response pool for inference."""
|
236 |
try:
|
237 |
-
logger.info(f"Loading FAISS index from {self.data_pipeline.
|
238 |
-
self.data_pipeline.load_faiss_index(self.data_pipeline.
|
239 |
logger.info("FAISS index loaded successfully.")
|
240 |
|
241 |
-
# Load response pool
|
242 |
-
response_pool_path = self.data_pipeline.
|
243 |
if os.path.exists(response_pool_path):
|
244 |
with open(response_pool_path, 'r', encoding='utf-8') as f:
|
245 |
self.data_pipeline.response_pool = json.load(f)
|
@@ -263,29 +262,24 @@ class RetrievalChatbot(DeviceAwareModel):
|
|
263 |
"""
|
264 |
load_dir = Path(load_dir)
|
265 |
|
266 |
-
#
|
267 |
with open(load_dir / "config.json", "r") as f:
|
268 |
config = ChatbotConfig.from_dict(json.load(f))
|
269 |
|
270 |
-
#
|
271 |
chatbot = cls(config, mode=mode)
|
272 |
|
273 |
-
#
|
274 |
-
chatbot.encoder.pretrained = TFAutoModel.from_pretrained(
|
275 |
-
load_dir / "shared_encoder",
|
276 |
-
config=config
|
277 |
-
)
|
278 |
|
279 |
dummy_input = tf.zeros((1, config.max_context_token_limit), dtype=tf.int32)
|
280 |
_ = chatbot.encoder(dummy_input, training=False)
|
281 |
|
282 |
-
#
|
283 |
chatbot.tokenizer = AutoTokenizer.from_pretrained(load_dir / "tokenizer")
|
284 |
logger.info(f"Models and tokenizer loaded from {load_dir}")
|
285 |
|
286 |
-
|
287 |
-
|
288 |
-
# 5) Load the custom top layers' weights
|
289 |
custom_weights_path = load_dir / "encoder_custom_weights.weights.h5"
|
290 |
if custom_weights_path.exists():
|
291 |
chatbot.encoder.load_weights(str(custom_weights_path))
|
@@ -293,7 +287,7 @@ class RetrievalChatbot(DeviceAwareModel):
|
|
293 |
else:
|
294 |
logger.warning(f"No custom encoder weights found at {custom_weights_path}. The top-level projection layer won't have learned parameters.")
|
295 |
|
296 |
-
#
|
297 |
if mode == 'inference':
|
298 |
cls._prepare_model_for_inference(chatbot, load_dir)
|
299 |
|
@@ -301,7 +295,7 @@ class RetrievalChatbot(DeviceAwareModel):
|
|
301 |
|
302 |
@classmethod
|
303 |
def _prepare_model_for_inference(cls, chatbot: 'RetrievalChatbot', load_dir: Path) -> None:
|
304 |
-
"""
|
305 |
try:
|
306 |
# Load FAISS index
|
307 |
faiss_path = load_dir / 'faiss_indices/faiss_index_production.index'
|
@@ -332,7 +326,7 @@ class RetrievalChatbot(DeviceAwareModel):
|
|
332 |
raise
|
333 |
|
334 |
def save_models(self, save_dir: Union[str, Path]):
|
335 |
-
"""Save
|
336 |
save_dir = Path(save_dir)
|
337 |
save_dir.mkdir(parents=True, exist_ok=True)
|
338 |
|
@@ -340,21 +334,13 @@ class RetrievalChatbot(DeviceAwareModel):
|
|
340 |
with open(save_dir / "config.json", "w") as f:
|
341 |
json.dump(self.config.to_dict(), f, indent=2)
|
342 |
|
343 |
-
# Save the HF DistilBERT submodule
|
344 |
self.encoder.pretrained.save_pretrained(save_dir / "shared_encoder")
|
345 |
-
|
346 |
-
# ALSO save custom top-level layers' weights
|
347 |
self.encoder.save_weights(save_dir / "encoder_custom_weights.weights.h5")
|
348 |
-
|
349 |
-
# Save tokenizer
|
350 |
self.tokenizer.save_pretrained(save_dir / "tokenizer")
|
351 |
-
|
352 |
logger.info(f"Models and tokenizer saved to {save_dir}.")
|
353 |
|
354 |
-
def
|
355 |
-
return 1 / (1 + np.exp(-x))
|
356 |
-
|
357 |
-
def retrieve_responses_cross_encoder(
|
358 |
self,
|
359 |
query: str,
|
360 |
top_k: int = 10,
|
@@ -363,20 +349,20 @@ class RetrievalChatbot(DeviceAwareModel):
|
|
363 |
summarize_threshold: int = 512
|
364 |
) -> List[Tuple[str, float]]:
|
365 |
"""
|
366 |
-
Retrieve top-k responses
|
367 |
-
and cross-encoder re-ranking.
|
368 |
-
|
369 |
Args:
|
370 |
query: The user's input text.
|
371 |
-
top_k: Number of
|
372 |
-
reranker: CrossEncoderReranker for refined scoring
|
373 |
-
summarizer: Summarizer for long queries
|
374 |
-
summarize_threshold: Summarize if
|
375 |
-
|
376 |
Returns:
|
377 |
List of (response_text, final_score).
|
378 |
"""
|
379 |
-
|
|
|
|
|
|
|
380 |
if summarizer and len(query.split()) > summarize_threshold:
|
381 |
logger.info(f"Query is long ({len(query.split())} words). Summarizing.")
|
382 |
query = summarizer.summarize_text(query)
|
@@ -393,17 +379,17 @@ class RetrievalChatbot(DeviceAwareModel):
|
|
393 |
|
394 |
texts = [item[0] for item in faiss_candidates]
|
395 |
|
396 |
-
# Re-rank these boosted candidates
|
397 |
if not reranker:
|
398 |
reranker = CrossEncoderReranker(model_name=self.config.cross_encoder_model)
|
399 |
|
|
|
400 |
ce_logits = reranker.rerank(query, texts, max_length=256)
|
401 |
|
402 |
-
# Combine
|
403 |
final_candidates = []
|
404 |
for (resp_text, faiss_score), logit in zip(faiss_candidates, ce_logits):
|
405 |
-
ce_prob =
|
406 |
-
faiss_norm = (faiss_score + 1)/2.0
|
407 |
combined_score = 0.85 * ce_prob + 0.15 * faiss_norm
|
408 |
length_adjusted_score = self.length_adjust_score(resp_text, combined_score)
|
409 |
|
@@ -415,22 +401,22 @@ class RetrievalChatbot(DeviceAwareModel):
|
|
415 |
# Return top_k
|
416 |
return final_candidates[:top_k]
|
417 |
|
418 |
-
DOMAIN_KEYWORDS = {
|
419 |
-
'restaurant': ['restaurant', 'dining', 'food', 'dine', 'reservation', 'table', 'menu', 'cuisine', 'eat', 'place to eat', 'hungry', 'chef', 'dish', 'meal', 'brunch', 'bistro', 'buffet', 'catering', 'gourmet', 'fast food', 'fine dining', 'takeaway', 'delivery', 'restaurant booking'],
|
420 |
-
'movie': ['movie', 'cinema', 'film', 'ticket', 'showtime', 'showing', 'theater', 'flick', 'screening', 'film ticket', 'film show', 'blockbuster', 'premiere', 'trailer', 'director', 'actor', 'actress', 'plot', 'genre', 'screen', 'sequel', 'animation', 'documentary'],
|
421 |
-
'ride_share': ['ride', 'taxi', 'uber', 'lyft', 'car service', 'pickup', 'dropoff', 'driver', 'cab', 'hailing', 'rideshare', 'ride hailing', 'carpool', 'chauffeur', 'transit', 'transportation', 'hail ride'],
|
422 |
-
'coffee': ['coffee', 'cafΓ©', 'cafe', 'starbucks', 'espresso', 'latte', 'mocha', 'americano', 'barista', 'brew', 'cappuccino', 'macchiato', 'iced coffee', 'cold brew', 'espresso machine', 'coffee shop', 'tea', 'chai', 'java', 'bean', 'roast', 'decaf'],
|
423 |
-
'pizza': ['pizza', 'delivery', 'order food', 'pepperoni', 'topping', 'pizzeria', 'slice', 'pie', 'margherita', 'deep dish', 'thin crust', 'cheese', 'oven', 'tossed', 'sauce', 'garlic bread', 'calzone'],
|
424 |
-
'auto': ['car', 'vehicle', 'repair', 'maintenance', 'mechanic', 'oil change', 'garage', 'auto shop', 'tire', 'check engine', 'battery', 'transmission', 'brake', 'engine diagnostics', 'carwash', 'detail', 'alignment', 'exhaust', 'spark plug', 'dashboard'],
|
425 |
-
}
|
426 |
-
|
427 |
def extract_keywords(self, query: str) -> List[str]:
|
428 |
"""
|
429 |
Return any domain keywords present in the query (lowercased).
|
430 |
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
431 |
query_lower = query.lower()
|
432 |
found = set()
|
433 |
-
for domain, kw_list in
|
434 |
for kw in kw_list:
|
435 |
if kw in query_lower:
|
436 |
found.add(kw)
|
@@ -456,7 +442,7 @@ class RetrievalChatbot(DeviceAwareModel):
|
|
456 |
|
457 |
def detect_domain_from_query(self, query: str) -> str:
|
458 |
"""
|
459 |
-
Detect the domain of the query based on keywords.
|
460 |
"""
|
461 |
domain_patterns = {
|
462 |
'restaurant': r'\b(restaurant|restaurants?|dining|food|foods?|dine|reservation|reservations?|table|tables?|menu|menus?|cuisine|cuisines?|eat|eats?|place\s?to\s?eat|places\s?to\s?eat|hungry|chef|chefs?|dish|dishes?|meal|meals?|fork|forks?|knife|knives?|spoon|spoons?|brunch|bistro|buffet|buffets?|catering|caterings?|gourmet|fast\s?food|fine\s?dining|takeaway|takeaways?|delivery|deliveries|restaurant\s?booking)\b',
|
@@ -476,8 +462,7 @@ class RetrievalChatbot(DeviceAwareModel):
|
|
476 |
|
477 |
def is_numeric_response(self, text: str) -> bool:
|
478 |
"""
|
479 |
-
Return True if `text` is purely digits
|
480 |
-
with optional punctuation like '.' at the end.
|
481 |
"""
|
482 |
pattern = r'^[\s]*[\d]+([\s.,\d]+)*[\s]*$'
|
483 |
return bool(re.match(pattern, text.strip()))
|
@@ -486,18 +471,16 @@ class RetrievalChatbot(DeviceAwareModel):
|
|
486 |
self,
|
487 |
query: str,
|
488 |
domain: str = 'other',
|
489 |
-
top_k: int =
|
490 |
-
boost_factor: float = 1.
|
491 |
) -> List[Tuple[str, float]]:
|
492 |
"""
|
493 |
Retrieve top-k responses from the FAISS index (IndexFlatIP) given a user query.
|
494 |
-
|
495 |
Args:
|
496 |
query (str): The user input text.
|
497 |
-
domain (str
|
498 |
-
top_k (int
|
499 |
-
boost_factor (float, optional): Factor to boost scores for keyword matches.
|
500 |
-
|
501 |
Returns:
|
502 |
List[Tuple[str, float]]: List of (response_text, similarity) sorted by descending similarity.
|
503 |
"""
|
@@ -508,7 +491,7 @@ class RetrievalChatbot(DeviceAwareModel):
|
|
508 |
# Search the index
|
509 |
distances, indices = self.data_pipeline.index.search(q_emb_np, top_k * 10)
|
510 |
|
511 |
-
# IndexFlatIP: 'distances' are inner products (cosine similarities for normalized vectors)
|
512 |
candidates = []
|
513 |
for rank, idx in enumerate(indices[0]):
|
514 |
if idx < 0:
|
@@ -545,8 +528,7 @@ class RetrievalChatbot(DeviceAwareModel):
|
|
545 |
boosted = []
|
546 |
for (resp_text, resp_domain, score) in in_domain:
|
547 |
new_score = score
|
548 |
-
# If the domain is known AND the response text
|
549 |
-
# shares any query keywords, apply a small boost
|
550 |
if query_keywords and any(kw in resp_text.lower() for kw in query_keywords):
|
551 |
new_score *= boost_factor
|
552 |
|
@@ -558,7 +540,7 @@ class RetrievalChatbot(DeviceAwareModel):
|
|
558 |
# Sort boosted responses
|
559 |
boosted.sort(key=lambda x: x[1], reverse=True)
|
560 |
|
561 |
-
# Debug
|
562 |
# for resp, score in boosted[:100]:
|
563 |
# logger.debug(f"Candidate: '{resp}' with score {score}")
|
564 |
|
@@ -572,8 +554,7 @@ class RetrievalChatbot(DeviceAwareModel):
|
|
572 |
top_k: int = 10,
|
573 |
) -> Tuple[str, List[Tuple[str, float]], Dict[str, Any]]:
|
574 |
"""
|
575 |
-
|
576 |
-
if self.reranker is available.
|
577 |
"""
|
578 |
@self.run_on_device
|
579 |
def get_response(self_arg, query_arg):
|
@@ -581,7 +562,7 @@ class RetrievalChatbot(DeviceAwareModel):
|
|
581 |
conversation_str = self_arg._build_conversation_context(query_arg, conversation_history)
|
582 |
|
583 |
# Retrieve and re-rank
|
584 |
-
results = self_arg.
|
585 |
query=conversation_str,
|
586 |
top_k=top_k,
|
587 |
reranker=self_arg.reranker,
|
@@ -605,7 +586,9 @@ class RetrievalChatbot(DeviceAwareModel):
|
|
605 |
query: str,
|
606 |
conversation_history: Optional[List[Tuple[str, str]]]
|
607 |
) -> str:
|
608 |
-
"""
|
|
|
|
|
609 |
if not conversation_history:
|
610 |
return f"{self.tokenizer.additional_special_tokens[self.tokenizer.additional_special_tokens.index('<USER>')]} {query}"
|
611 |
|
@@ -636,12 +619,12 @@ class RetrievalChatbot(DeviceAwareModel):
|
|
636 |
) -> None:
|
637 |
"""
|
638 |
Train the retrieval model using a pre-prepared TFRecord dataset.
|
639 |
-
This method handles:
|
640 |
- Checkpoint loading/restoring
|
641 |
- LR scheduling
|
642 |
- Epoch/iteration tracking
|
643 |
-
-
|
644 |
-
-
|
|
|
645 |
"""
|
646 |
logger.info("Starting training with pre-prepared TFRecord dataset...")
|
647 |
|
@@ -673,7 +656,7 @@ class RetrievalChatbot(DeviceAwareModel):
|
|
673 |
steps_per_epoch = math.ceil(train_size / batch_size)
|
674 |
val_steps = math.ceil(val_size / batch_size)
|
675 |
total_steps = steps_per_epoch * epochs
|
676 |
-
buffer_size = max(1, total_pairs //
|
677 |
|
678 |
logger.info(f"Training pairs: {train_size}")
|
679 |
logger.info(f"Validation pairs: {val_size}")
|
@@ -695,7 +678,7 @@ class RetrievalChatbot(DeviceAwareModel):
|
|
695 |
self.optimizer = tf.keras.optimizers.Adam(learning_rate=tf.cast(peak_lr, tf.float32))
|
696 |
logger.info("Using fixed learning rate.")
|
697 |
|
698 |
-
#
|
699 |
dummy_input = tf.zeros((1, self.config.max_context_token_limit), dtype=tf.int32)
|
700 |
with tf.GradientTape() as tape:
|
701 |
dummy_output = self.encoder(dummy_input)
|
@@ -710,6 +693,7 @@ class RetrievalChatbot(DeviceAwareModel):
|
|
710 |
model=self.encoder
|
711 |
)
|
712 |
|
|
|
713 |
manager = tf.train.CheckpointManager(
|
714 |
checkpoint,
|
715 |
directory=checkpoint_dir,
|
@@ -717,18 +701,18 @@ class RetrievalChatbot(DeviceAwareModel):
|
|
717 |
checkpoint_name='ckpt'
|
718 |
)
|
719 |
|
720 |
-
# Restore from existing checkpoint if
|
721 |
latest_checkpoint = manager.latest_checkpoint
|
722 |
history_path = Path(checkpoint_dir) / 'training_history.json'
|
723 |
|
724 |
-
#
|
725 |
if not hasattr(self, 'history'):
|
726 |
self.history = {'train_loss': [], 'val_loss': [], 'learning_rate': []}
|
727 |
|
728 |
if latest_checkpoint and not test_mode:
|
729 |
-
#
|
730 |
-
logger.info(f"\nTrying to load checkpoint from: {latest_checkpoint}")
|
731 |
-
reader = tf.train.load_checkpoint(latest_checkpoint)
|
732 |
# shape_from_key = reader.get_variable_to_shape_map()
|
733 |
# dtype_from_key = reader.get_variable_to_dtype_map()
|
734 |
# logger.info("\nCheckpoint Variables:")
|
@@ -752,11 +736,11 @@ class RetrievalChatbot(DeviceAwareModel):
|
|
752 |
if initial_epoch == 0:
|
753 |
initial_epoch = ckpt_number
|
754 |
|
755 |
-
# Assign to checkpoint.epoch
|
756 |
checkpoint.epoch.assign(tf.cast(initial_epoch, tf.int32))
|
757 |
logger.info(f"Resuming from epoch {initial_epoch}")
|
758 |
|
759 |
-
#
|
760 |
if history_path.exists():
|
761 |
try:
|
762 |
with open(history_path, 'r') as f:
|
@@ -765,7 +749,10 @@ class RetrievalChatbot(DeviceAwareModel):
|
|
765 |
except Exception as e:
|
766 |
logger.warning(f"Could not load history, starting fresh: {e}")
|
767 |
|
768 |
-
#
|
|
|
|
|
|
|
769 |
self.save_models(Path(checkpoint_dir) / "pretrained_full_model")
|
770 |
logger.info(f"Manually saved custom weights after restore.")
|
771 |
else:
|
@@ -782,13 +769,13 @@ class RetrievalChatbot(DeviceAwareModel):
|
|
782 |
train_summary_writer = tf.summary.create_file_writer(train_log_dir)
|
783 |
val_summary_writer = tf.summary.create_file_writer(val_log_dir)
|
784 |
logger.info(f"TensorBoard logs will be saved in {log_dir}")
|
785 |
-
|
786 |
# Parse dataset
|
787 |
dataset = tf.data.TFRecordDataset(tfrecord_file_path)
|
788 |
-
|
789 |
-
#
|
790 |
if test_mode:
|
791 |
-
subset_size =
|
792 |
dataset = dataset.take(subset_size)
|
793 |
logger.info(f"TEST MODE: Using only {subset_size} examples")
|
794 |
# Recompute sizes, steps, epochs, etc., as needed
|
@@ -804,38 +791,36 @@ class RetrievalChatbot(DeviceAwareModel):
|
|
804 |
early_stopping_patience = 2
|
805 |
logger.info(f"New training pairs: {train_size}")
|
806 |
logger.info(f"New validation pairs: {val_size}")
|
807 |
-
|
808 |
dataset = dataset.map(
|
809 |
-
lambda x: parse_tfrecord_fn(x, self.config.max_context_token_limit, self.
|
810 |
num_parallel_calls=tf.data.AUTOTUNE
|
811 |
)
|
812 |
-
|
813 |
# Train/val split
|
814 |
train_dataset = dataset.take(train_size)
|
815 |
val_dataset = dataset.skip(train_size).take(val_size)
|
816 |
-
|
817 |
# Shuffle and batch
|
818 |
train_dataset = train_dataset.shuffle(buffer_size=buffer_size)
|
819 |
train_dataset = train_dataset.batch(batch_size, drop_remainder=True)
|
820 |
train_dataset = train_dataset.prefetch(tf.data.AUTOTUNE)
|
821 |
-
|
822 |
val_dataset = val_dataset.batch(batch_size, drop_remainder=False)
|
823 |
val_dataset = val_dataset.prefetch(tf.data.AUTOTUNE)
|
824 |
val_dataset = val_dataset.cache()
|
825 |
-
|
826 |
# Training loop
|
827 |
best_val_loss = float("inf")
|
828 |
epochs_no_improve = 0
|
829 |
-
|
830 |
for epoch in range(int(checkpoint.epoch.numpy()) + 1, epochs + 1):
|
831 |
checkpoint.epoch.assign(epoch)
|
832 |
logger.info(f"Starting Epoch {epoch}...")
|
833 |
-
|
834 |
-
# --- Training Phase ---
|
835 |
epoch_loss_avg = tf.keras.metrics.Mean(dtype=tf.float32)
|
836 |
batches_processed = 0
|
837 |
-
|
838 |
-
# Progress bar
|
839 |
try:
|
840 |
train_pbar = tqdm(
|
841 |
total=steps_per_epoch,
|
@@ -846,7 +831,8 @@ class RetrievalChatbot(DeviceAwareModel):
|
|
846 |
except ImportError:
|
847 |
train_pbar = None
|
848 |
is_tqdm_train = False
|
849 |
-
|
|
|
850 |
for q_batch, p_batch, n_batch in train_dataset:
|
851 |
loss, grad_norm, post_clip_norm = self.train_step(q_batch, p_batch, n_batch)
|
852 |
epoch_loss_avg(loss)
|
@@ -874,54 +860,54 @@ class RetrievalChatbot(DeviceAwareModel):
|
|
874 |
"lr": f"{current_lr:.2e}",
|
875 |
"batches": f"{batches_processed}/{steps_per_epoch}"
|
876 |
})
|
877 |
-
|
878 |
gc.collect()
|
879 |
-
|
880 |
# End the epoch early if we've processed all steps
|
881 |
if batches_processed >= steps_per_epoch:
|
882 |
break
|
883 |
-
|
884 |
if is_tqdm_train and train_pbar:
|
885 |
train_pbar.close()
|
886 |
-
|
887 |
-
# --- Validation
|
888 |
val_loss_avg = tf.keras.metrics.Mean(dtype=tf.float32)
|
889 |
val_batches_processed = 0
|
890 |
-
|
891 |
try:
|
892 |
val_pbar = tqdm(total=val_steps, desc="Validation", unit="batch")
|
893 |
is_tqdm_val = True
|
894 |
except ImportError:
|
895 |
val_pbar = None
|
896 |
is_tqdm_val = False
|
897 |
-
|
898 |
last_valid_val_loss = None
|
899 |
valid_batches = False
|
900 |
-
|
901 |
for q_batch, p_batch, n_batch in val_dataset:
|
902 |
# If batch is too small, skip
|
903 |
if tf.shape(q_batch)[0] < 2:
|
904 |
logger.warning(f"Skipping validation batch of size {tf.shape(q_batch)[0]}")
|
905 |
continue
|
906 |
-
|
907 |
valid_batches = True
|
908 |
val_loss = self.validation_step(q_batch, p_batch, n_batch)
|
909 |
val_loss_avg(val_loss)
|
910 |
last_valid_val_loss = val_loss
|
911 |
val_batches_processed += 1
|
912 |
-
|
913 |
if is_tqdm_val:
|
914 |
val_pbar.update(1)
|
915 |
val_pbar.set_postfix({
|
916 |
"val_loss": f"{val_loss.numpy():.4f}",
|
917 |
"batches": f"{val_batches_processed}/{val_steps}"
|
918 |
})
|
919 |
-
|
920 |
gc.collect()
|
921 |
-
|
922 |
if val_batches_processed >= val_steps:
|
923 |
break
|
924 |
-
|
925 |
if not valid_batches:
|
926 |
# If no valid batch is found, fallback
|
927 |
logger.warning("No valid validation batches in this epoch")
|
@@ -931,29 +917,29 @@ class RetrievalChatbot(DeviceAwareModel):
|
|
931 |
else:
|
932 |
val_loss = epoch_loss_avg.result()
|
933 |
val_loss_avg(val_loss)
|
934 |
-
|
935 |
if is_tqdm_val and val_pbar:
|
936 |
val_pbar.close()
|
937 |
-
|
938 |
# End of epoch: final stats
|
939 |
train_loss = epoch_loss_avg.result().numpy()
|
940 |
val_loss = val_loss_avg.result().numpy()
|
941 |
logger.info(f"Epoch {epoch} Complete: Train Loss={train_loss:.4f}, Val Loss={val_loss:.4f}")
|
942 |
-
|
943 |
# TensorBoard epoch logs
|
944 |
with train_summary_writer.as_default():
|
945 |
tf.summary.scalar("epoch_loss", train_loss, step=epoch)
|
946 |
with val_summary_writer.as_default():
|
947 |
tf.summary.scalar("val_loss", val_loss, step=epoch)
|
948 |
-
|
949 |
# Save checkpoint
|
950 |
manager.save()
|
951 |
-
|
952 |
-
#
|
953 |
model_save_path = Path(checkpoint_dir) / f"model_epoch_{epoch}"
|
954 |
self.save_models(model_save_path)
|
955 |
logger.info(f"Saved model for epoch {epoch} at {model_save_path}")
|
956 |
-
|
957 |
# Update local history
|
958 |
self.history['train_loss'].append(train_loss)
|
959 |
self.history['val_loss'].append(val_loss)
|
@@ -972,13 +958,12 @@ class RetrievalChatbot(DeviceAwareModel):
|
|
972 |
return obj
|
973 |
|
974 |
json_history = convert_to_py_floats(self.history)
|
975 |
-
|
976 |
# Save training history to file every epoch
|
977 |
-
# (Create or overwrite the file so we always have the latest.)
|
978 |
with open(history_path, 'w') as f:
|
979 |
json.dump(json_history, f)
|
980 |
logger.info(f"Saved training history to {history_path}")
|
981 |
-
|
982 |
# Early stopping
|
983 |
if val_loss < best_val_loss - min_delta:
|
984 |
best_val_loss = val_loss
|
@@ -990,7 +975,7 @@ class RetrievalChatbot(DeviceAwareModel):
|
|
990 |
if epochs_no_improve >= early_stopping_patience:
|
991 |
logger.info("Early stopping triggered.")
|
992 |
break
|
993 |
-
|
994 |
logger.info("Training completed!")
|
995 |
|
996 |
@tf.function
|
@@ -1004,37 +989,25 @@ class RetrievalChatbot(DeviceAwareModel):
|
|
1004 |
Single training step using queries, positives, and hard negatives.
|
1005 |
"""
|
1006 |
with tf.GradientTape() as tape:
|
1007 |
-
# Encode queries
|
1008 |
q_enc = self.encoder(q_batch, training=True) # [batch_size, embed_dim]
|
1009 |
-
|
1010 |
-
# Encode positives
|
1011 |
p_enc = self.encoder(p_batch, training=True) # [batch_size, embed_dim]
|
1012 |
-
|
1013 |
-
# Encode negatives
|
1014 |
-
# n_batch: [batch_size, neg_samples, max_length]
|
1015 |
shape = tf.shape(n_batch)
|
1016 |
bs = shape[0]
|
1017 |
neg_samples = shape[1]
|
1018 |
|
1019 |
-
# Flatten negatives to feed them in one pass:
|
1020 |
-
# => [batch_size * neg_samples, max_length]
|
1021 |
n_batch_flat = tf.reshape(n_batch, [bs * neg_samples, shape[2]])
|
1022 |
n_enc_flat = self.encoder(n_batch_flat, training=True) # [bs*neg_samples, embed_dim]
|
1023 |
|
1024 |
# Reshape back => [batch_size, neg_samples, embed_dim]
|
1025 |
n_enc = tf.reshape(n_enc_flat, [bs, neg_samples, -1])
|
1026 |
|
1027 |
-
# Combine the positive embedding and negative embeddings along dim=1
|
1028 |
-
#
|
1029 |
-
|
1030 |
-
combined_p_n = tf.concat(
|
1031 |
-
[tf.expand_dims(p_enc, axis=1), n_enc],
|
1032 |
-
axis=1
|
1033 |
-
) # [bs, (1+neg_samples), embed_dim]
|
1034 |
|
1035 |
-
#
|
1036 |
-
# We'll use `tf.einsum` to handle the batch dimension properly
|
1037 |
-
# dot_products => shape [batch_size, (1+neg_samples)]
|
1038 |
dot_products = tf.cast(tf.einsum('bd,bkd->bk', q_enc, combined_p_n), tf.float32)
|
1039 |
labels = tf.zeros([bs], dtype=tf.int32) # Keep labels as int32
|
1040 |
loss = tf.nn.sparse_softmax_cross_entropy_with_logits(
|
@@ -1043,14 +1016,13 @@ class RetrievalChatbot(DeviceAwareModel):
|
|
1043 |
)
|
1044 |
loss = tf.cast(tf.reduce_mean(loss), tf.float32)
|
1045 |
|
1046 |
-
# Calculate gradients
|
1047 |
gradients = tape.gradient(loss, self.encoder.trainable_variables)
|
1048 |
gradients_norm = tf.cast(tf.linalg.global_norm(gradients), tf.float32)
|
1049 |
max_grad_norm = tf.constant(1.5, dtype=tf.float32)
|
1050 |
gradients, _ = tf.clip_by_global_norm(gradients, max_grad_norm, gradients_norm)
|
1051 |
post_clip_norm = tf.cast(tf.linalg.global_norm(gradients), tf.float32)
|
1052 |
|
1053 |
-
# Apply gradients
|
1054 |
self.optimizer.apply_gradients(zip(gradients, self.encoder.trainable_variables))
|
1055 |
|
1056 |
return loss, gradients_norm, post_clip_norm
|
@@ -1064,6 +1036,7 @@ class RetrievalChatbot(DeviceAwareModel):
|
|
1064 |
) -> tf.Tensor:
|
1065 |
"""
|
1066 |
Single validation step using queries, positives, and hard negatives.
|
|
|
1067 |
"""
|
1068 |
q_enc = self.encoder(q_batch, training=False)
|
1069 |
p_enc = self.encoder(p_batch, training=False)
|
@@ -1082,7 +1055,7 @@ class RetrievalChatbot(DeviceAwareModel):
|
|
1082 |
)
|
1083 |
|
1084 |
dot_products = tf.cast(tf.einsum('bd,bkd->bk', q_enc, combined_p_n), tf.float32)
|
1085 |
-
labels = tf.zeros([bs], dtype=tf.int32)
|
1086 |
|
1087 |
loss = tf.nn.sparse_softmax_cross_entropy_with_logits(
|
1088 |
labels=labels,
|
@@ -1098,7 +1071,9 @@ class RetrievalChatbot(DeviceAwareModel):
|
|
1098 |
peak_lr: float,
|
1099 |
warmup_steps: int
|
1100 |
) -> tf.keras.optimizers.schedules.LearningRateSchedule:
|
1101 |
-
"""
|
|
|
|
|
1102 |
class CustomSchedule(tf.keras.optimizers.schedules.LearningRateSchedule):
|
1103 |
def __init__(
|
1104 |
self,
|
@@ -1110,11 +1085,11 @@ class RetrievalChatbot(DeviceAwareModel):
|
|
1110 |
self.total_steps = tf.cast(total_steps, tf.float32)
|
1111 |
self.peak_lr = tf.cast(peak_lr, tf.float32)
|
1112 |
|
1113 |
-
#
|
1114 |
adjusted_warmup_steps = min(warmup_steps, max(1, total_steps // 10))
|
1115 |
self.warmup_steps = tf.cast(adjusted_warmup_steps, tf.float32)
|
1116 |
|
1117 |
-
# Calculate
|
1118 |
self.initial_lr = tf.cast(self.peak_lr * 0.1, tf.float32)
|
1119 |
self.min_lr = tf.cast(self.peak_lr * 0.01, tf.float32)
|
1120 |
|
@@ -1128,21 +1103,20 @@ class RetrievalChatbot(DeviceAwareModel):
|
|
1128 |
def __call__(self, step):
|
1129 |
step = tf.cast(step, tf.float32)
|
1130 |
|
1131 |
-
# Warmup
|
1132 |
warmup_factor = tf.cast(tf.minimum(1.0, step / self.warmup_steps), tf.float32)
|
1133 |
warmup_lr = self.initial_lr + (self.peak_lr - self.initial_lr) * warmup_factor
|
1134 |
|
1135 |
-
# Decay
|
1136 |
decay_steps = tf.cast(tf.maximum(1.0, self.total_steps - self.warmup_steps), tf.float32)
|
1137 |
decay_factor = tf.cast((step - self.warmup_steps) / decay_steps, tf.float32)
|
1138 |
decay_factor = tf.cast(tf.minimum(tf.maximum(0.0, decay_factor), 1.0), tf.float32)
|
1139 |
cosine_decay = tf.cast(0.5 * (1.0 + tf.cos(tf.constant(math.pi, dtype=tf.float32) * decay_factor)), tf.float32)
|
1140 |
decay_lr = self.min_lr + (self.peak_lr - self.min_lr) * cosine_decay
|
1141 |
|
1142 |
-
# Choose between warmup and decay
|
1143 |
final_lr = tf.where(step < self.warmup_steps, warmup_lr, decay_lr)
|
1144 |
|
1145 |
-
# Ensure
|
1146 |
final_lr = tf.maximum(self.min_lr, final_lr)
|
1147 |
final_lr = tf.where(tf.math.is_finite(final_lr), final_lr, self.min_lr)
|
1148 |
|
|
|
24 |
|
25 |
@dataclass
|
26 |
class ChatbotConfig:
|
27 |
+
"""RetrievalChatbot Config"""
|
28 |
max_context_token_limit: int = 512
|
29 |
embedding_dim: int = 768
|
30 |
encoder_units: int = 256
|
31 |
num_attention_heads: int = 8
|
32 |
dropout_rate: float = 0.2
|
33 |
l2_reg_weight: float = 0.001
|
34 |
+
learning_rate: float = 0.0005
|
35 |
min_text_length: int = 3
|
36 |
+
max_context_turns: int = 20
|
37 |
warmup_steps: int = 200
|
38 |
pretrained_model: str = 'distilbert-base-uncased'
|
39 |
cross_encoder_model: str = 'cross-encoder/ms-marco-MiniLM-L-12-v2'
|
40 |
+
summarizer_model: str = 't5-small'
|
41 |
dtype: str = 'float32'
|
42 |
freeze_embeddings: bool = False
|
43 |
embedding_batch_size: int = 64
|
44 |
search_batch_size: int = 64
|
45 |
max_batch_size: int = 64
|
|
|
46 |
max_retries: int = 3
|
47 |
|
48 |
def to_dict(self) -> Dict:
|
|
|
57 |
if k in cls.__dataclass_fields__})
|
58 |
|
59 |
class EncoderModel(tf.keras.Model):
|
60 |
+
"""Dual encoder model with pretrained DistilBERT embeddings."""
|
61 |
def __init__(
|
62 |
self,
|
63 |
config: ChatbotConfig,
|
|
|
71 |
self.pretrained = TFAutoModel.from_pretrained(config.pretrained_model)
|
72 |
self._freeze_layers()
|
73 |
|
74 |
+
# Add Global Average Pooling, Projection, Dropout, and Normalization layers
|
75 |
self.pooler = tf.keras.layers.GlobalAveragePooling1D()
|
76 |
self.projection = tf.keras.layers.Dense(
|
77 |
config.embedding_dim,
|
|
|
86 |
)
|
87 |
|
88 |
def _freeze_layers(self):
|
89 |
+
"""Freeze n layers of the pretrained model"""
|
90 |
if self.config.freeze_embeddings:
|
91 |
self.pretrained.trainable = False
|
92 |
logger.info("All pretrained layers frozen.")
|
|
|
95 |
for i, layer in enumerate(self.pretrained.layers):
|
96 |
if isinstance(layer, tf.keras.layers.Layer):
|
97 |
if hasattr(layer, 'trainable'):
|
|
|
98 |
if i < 1:
|
99 |
layer.trainable = False
|
100 |
logger.info(f"Layer {i} frozen.")
|
101 |
else:
|
102 |
layer.trainable = True
|
103 |
+
logger.info(f"Layer {i} trainable.")
|
104 |
|
105 |
def call(self, inputs: tf.Tensor, training: bool = False) -> tf.Tensor:
|
106 |
"""Forward pass."""
|
107 |
# Get pretrained embeddings
|
108 |
pretrained_outputs = self.pretrained(inputs, training=training)
|
109 |
+
x = pretrained_outputs.last_hidden_state # Shape: [batch_size, seq_len, embedding_dim]
|
110 |
|
111 |
# Apply pooling, projection, dropout, and normalization
|
112 |
+
x = self.pooler(x) # Shape: [batch_size, 768]
|
113 |
+
x = self.projection(x) # Shape: [batch_size, 768]
|
114 |
x = self.dropout(x, training=training)
|
115 |
+
x = self.normalize(x) # Shape: [batch_size, 768]
|
116 |
|
117 |
return x
|
118 |
|
119 |
def get_config(self) -> dict:
|
120 |
+
"""Return the model config"""
|
121 |
config = super().get_config()
|
122 |
config.update({
|
123 |
"config": self.config.to_dict(),
|
|
|
126 |
return config
|
127 |
|
128 |
class RetrievalChatbot(DeviceAwareModel):
|
129 |
+
"""
|
130 |
+
Retrieval-based learning chatbot model.
|
131 |
+
Uses trained embeddings and FAISS for similarity search.
|
132 |
+
"""
|
133 |
def __init__(
|
134 |
self,
|
135 |
config: ChatbotConfig,
|
|
|
145 |
self.device = device or self._setup_default_device()
|
146 |
self.mode = mode.lower()
|
147 |
|
148 |
+
# Initialize reranker, summarizer, tokenizer, and encoder
|
149 |
self.reranker = reranker or self._initialize_reranker()
|
150 |
self.tokenizer = self._initialize_tokenizer()
|
151 |
self.encoder = self._initialize_encoder()
|
|
|
157 |
config=self.config,
|
158 |
tokenizer=self.tokenizer,
|
159 |
encoder=self.encoder,
|
|
|
160 |
response_pool=[],
|
161 |
max_length=self.config.max_context_token_limit,
|
162 |
query_embeddings_cache={},
|
|
|
|
|
|
|
|
|
163 |
)
|
164 |
|
165 |
# Collect unique responses from dialogues
|
|
|
195 |
"""Initialize the Summarizer."""
|
196 |
return Summarizer(
|
197 |
tokenizer=self.tokenizer,
|
198 |
+
model_name=self.config.summarizer_model,
|
199 |
max_summary_length=self.config.max_context_token_limit // 4,
|
200 |
device=self.device,
|
201 |
max_summary_rounds=2
|
|
|
227 |
new_vocab_size = len(self.tokenizer)
|
228 |
encoder.pretrained.resize_token_embeddings(new_vocab_size)
|
229 |
logger.info(f"Token embeddings resized to: {new_vocab_size}")
|
230 |
+
|
231 |
return encoder
|
232 |
|
233 |
def _load_faiss_index_and_responses(self) -> None:
|
234 |
"""Load FAISS index and response pool for inference."""
|
235 |
try:
|
236 |
+
logger.info(f"Loading FAISS index from {self.data_pipeline.faiss_index_file_path}...")
|
237 |
+
self.data_pipeline.load_faiss_index(self.data_pipeline.faiss_index_file_path)
|
238 |
logger.info("FAISS index loaded successfully.")
|
239 |
|
240 |
+
# Load response pool
|
241 |
+
response_pool_path = self.data_pipeline.faiss_index_file_path.replace('.index', '_responses.json')
|
242 |
if os.path.exists(response_pool_path):
|
243 |
with open(response_pool_path, 'r', encoding='utf-8') as f:
|
244 |
self.data_pipeline.response_pool = json.load(f)
|
|
|
262 |
"""
|
263 |
load_dir = Path(load_dir)
|
264 |
|
265 |
+
# Load config
|
266 |
with open(load_dir / "config.json", "r") as f:
|
267 |
config = ChatbotConfig.from_dict(json.load(f))
|
268 |
|
269 |
+
# Initialize chatbot
|
270 |
chatbot = cls(config, mode=mode)
|
271 |
|
272 |
+
# Load DistilBERT
|
273 |
+
chatbot.encoder.pretrained = TFAutoModel.from_pretrained(load_dir / "shared_encoder", config=config)
|
|
|
|
|
|
|
274 |
|
275 |
dummy_input = tf.zeros((1, config.max_context_token_limit), dtype=tf.int32)
|
276 |
_ = chatbot.encoder(dummy_input, training=False)
|
277 |
|
278 |
+
# Load tokenizer
|
279 |
chatbot.tokenizer = AutoTokenizer.from_pretrained(load_dir / "tokenizer")
|
280 |
logger.info(f"Models and tokenizer loaded from {load_dir}")
|
281 |
|
282 |
+
# Load the custom weights
|
|
|
|
|
283 |
custom_weights_path = load_dir / "encoder_custom_weights.weights.h5"
|
284 |
if custom_weights_path.exists():
|
285 |
chatbot.encoder.load_weights(str(custom_weights_path))
|
|
|
287 |
else:
|
288 |
logger.warning(f"No custom encoder weights found at {custom_weights_path}. The top-level projection layer won't have learned parameters.")
|
289 |
|
290 |
+
# Handle 'inference' mode: load FAISS, etc.
|
291 |
if mode == 'inference':
|
292 |
cls._prepare_model_for_inference(chatbot, load_dir)
|
293 |
|
|
|
295 |
|
296 |
@classmethod
|
297 |
def _prepare_model_for_inference(cls, chatbot: 'RetrievalChatbot', load_dir: Path) -> None:
|
298 |
+
"""Load inference components."""
|
299 |
try:
|
300 |
# Load FAISS index
|
301 |
faiss_path = load_dir / 'faiss_indices/faiss_index_production.index'
|
|
|
326 |
raise
|
327 |
|
328 |
def save_models(self, save_dir: Union[str, Path]):
|
329 |
+
"""Save model and config"""
|
330 |
save_dir = Path(save_dir)
|
331 |
save_dir.mkdir(parents=True, exist_ok=True)
|
332 |
|
|
|
334 |
with open(save_dir / "config.json", "w") as f:
|
335 |
json.dump(self.config.to_dict(), f, indent=2)
|
336 |
|
337 |
+
# Save the HF DistilBERT submodule, custom top-level layers, and tokenizer
|
338 |
self.encoder.pretrained.save_pretrained(save_dir / "shared_encoder")
|
|
|
|
|
339 |
self.encoder.save_weights(save_dir / "encoder_custom_weights.weights.h5")
|
|
|
|
|
340 |
self.tokenizer.save_pretrained(save_dir / "tokenizer")
|
|
|
341 |
logger.info(f"Models and tokenizer saved to {save_dir}.")
|
342 |
|
343 |
+
def retrieve_responses(
|
|
|
|
|
|
|
344 |
self,
|
345 |
query: str,
|
346 |
top_k: int = 10,
|
|
|
349 |
summarize_threshold: int = 512
|
350 |
) -> List[Tuple[str, float]]:
|
351 |
"""
|
352 |
+
Retrieve top-k responses using FAISS and cross-encoder re-ranking.
|
|
|
|
|
353 |
Args:
|
354 |
query: The user's input text.
|
355 |
+
top_k: Number of FAISS results to return
|
356 |
+
reranker: CrossEncoderReranker for refined scoring
|
357 |
+
summarizer: Summarizer for long queries
|
358 |
+
summarize_threshold: Summarize if conversation tokens > threshold.
|
|
|
359 |
Returns:
|
360 |
List of (response_text, final_score).
|
361 |
"""
|
362 |
+
def sigmoid(x: float) -> float:
|
363 |
+
return 1 / (1 + np.exp(-x))
|
364 |
+
|
365 |
+
# Query summarization
|
366 |
if summarizer and len(query.split()) > summarize_threshold:
|
367 |
logger.info(f"Query is long ({len(query.split())} words). Summarizing.")
|
368 |
query = summarizer.summarize_text(query)
|
|
|
379 |
|
380 |
texts = [item[0] for item in faiss_candidates]
|
381 |
|
|
|
382 |
if not reranker:
|
383 |
reranker = CrossEncoderReranker(model_name=self.config.cross_encoder_model)
|
384 |
|
385 |
+
# Re-rank the texts (candidates) from FAISS search using the cross-encoder
|
386 |
ce_logits = reranker.rerank(query, texts, max_length=256)
|
387 |
|
388 |
+
# Combine scores from FAISS and cross-encoder
|
389 |
final_candidates = []
|
390 |
for (resp_text, faiss_score), logit in zip(faiss_candidates, ce_logits):
|
391 |
+
ce_prob = sigmoid(logit) # now in range [0...1]
|
392 |
+
faiss_norm = (faiss_score + 1)/2.0 # now in range [0...1]
|
393 |
combined_score = 0.85 * ce_prob + 0.15 * faiss_norm
|
394 |
length_adjusted_score = self.length_adjust_score(resp_text, combined_score)
|
395 |
|
|
|
401 |
# Return top_k
|
402 |
return final_candidates[:top_k]
|
403 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
404 |
def extract_keywords(self, query: str) -> List[str]:
|
405 |
"""
|
406 |
Return any domain keywords present in the query (lowercased).
|
407 |
"""
|
408 |
+
domain_keywords = {
|
409 |
+
'restaurant': ['restaurant', 'dining', 'food', 'dine', 'reservation', 'table', 'menu', 'cuisine', 'eat', 'place to eat', 'hungry', 'chef', 'dish', 'meal', 'brunch', 'bistro', 'buffet', 'catering', 'gourmet', 'fast food', 'fine dining', 'takeaway', 'delivery', 'restaurant booking'],
|
410 |
+
'movie': ['movie', 'cinema', 'film', 'ticket', 'showtime', 'showing', 'theater', 'flick', 'screening', 'film ticket', 'film show', 'blockbuster', 'premiere', 'trailer', 'director', 'actor', 'actress', 'plot', 'genre', 'screen', 'sequel', 'animation', 'documentary'],
|
411 |
+
'ride_share': ['ride', 'taxi', 'uber', 'lyft', 'car service', 'pickup', 'dropoff', 'driver', 'cab', 'hailing', 'rideshare', 'ride hailing', 'carpool', 'chauffeur', 'transit', 'transportation', 'hail ride'],
|
412 |
+
'coffee': ['coffee', 'cafΓ©', 'cafe', 'starbucks', 'espresso', 'latte', 'mocha', 'americano', 'barista', 'brew', 'cappuccino', 'macchiato', 'iced coffee', 'cold brew', 'espresso machine', 'coffee shop', 'tea', 'chai', 'java', 'bean', 'roast', 'decaf'],
|
413 |
+
'pizza': ['pizza', 'delivery', 'order food', 'pepperoni', 'topping', 'pizzeria', 'slice', 'pie', 'margherita', 'deep dish', 'thin crust', 'cheese', 'oven', 'tossed', 'sauce', 'garlic bread', 'calzone'],
|
414 |
+
'auto': ['car', 'vehicle', 'repair', 'maintenance', 'mechanic', 'oil change', 'garage', 'auto shop', 'tire', 'check engine', 'battery', 'transmission', 'brake', 'engine diagnostics', 'carwash', 'detail', 'alignment', 'exhaust', 'spark plug', 'dashboard'],
|
415 |
+
}
|
416 |
+
|
417 |
query_lower = query.lower()
|
418 |
found = set()
|
419 |
+
for domain, kw_list in domain_keywords.items():
|
420 |
for kw in kw_list:
|
421 |
if kw in query_lower:
|
422 |
found.add(kw)
|
|
|
442 |
|
443 |
def detect_domain_from_query(self, query: str) -> str:
|
444 |
"""
|
445 |
+
Detect the domain of the query based on keywords. Used for boosting FAISS search.
|
446 |
"""
|
447 |
domain_patterns = {
|
448 |
'restaurant': r'\b(restaurant|restaurants?|dining|food|foods?|dine|reservation|reservations?|table|tables?|menu|menus?|cuisine|cuisines?|eat|eats?|place\s?to\s?eat|places\s?to\s?eat|hungry|chef|chefs?|dish|dishes?|meal|meals?|fork|forks?|knife|knives?|spoon|spoons?|brunch|bistro|buffet|buffets?|catering|caterings?|gourmet|fast\s?food|fine\s?dining|takeaway|takeaways?|delivery|deliveries|restaurant\s?booking)\b',
|
|
|
462 |
|
463 |
def is_numeric_response(self, text: str) -> bool:
|
464 |
"""
|
465 |
+
Return True if `text` is purely digits and/or spaces.
|
|
|
466 |
"""
|
467 |
pattern = r'^[\s]*[\d]+([\s.,\d]+)*[\s]*$'
|
468 |
return bool(re.match(pattern, text.strip()))
|
|
|
471 |
self,
|
472 |
query: str,
|
473 |
domain: str = 'other',
|
474 |
+
top_k: int = 10,
|
475 |
+
boost_factor: float = 1.15
|
476 |
) -> List[Tuple[str, float]]:
|
477 |
"""
|
478 |
Retrieve top-k responses from the FAISS index (IndexFlatIP) given a user query.
|
|
|
479 |
Args:
|
480 |
query (str): The user input text.
|
481 |
+
domain (str): The detected domain from possible domains: ['restaurant', 'movie', 'ride_share', 'coffee', 'pizza', 'auto', 'other']
|
482 |
+
top_k (int): Number of top results to return.
|
483 |
+
boost_factor (float, optional): Factor to boost scores for keyword matches.
|
|
|
484 |
Returns:
|
485 |
List[Tuple[str, float]]: List of (response_text, similarity) sorted by descending similarity.
|
486 |
"""
|
|
|
491 |
# Search the index
|
492 |
distances, indices = self.data_pipeline.index.search(q_emb_np, top_k * 10)
|
493 |
|
494 |
+
# IndexFlatIP: 'distances' are inner products (cosine similarities for normalized vectors).
|
495 |
candidates = []
|
496 |
for rank, idx in enumerate(indices[0]):
|
497 |
if idx < 0:
|
|
|
528 |
boosted = []
|
529 |
for (resp_text, resp_domain, score) in in_domain:
|
530 |
new_score = score
|
531 |
+
# If the domain is known AND the response text shares any query keywords, boost it
|
|
|
532 |
if query_keywords and any(kw in resp_text.lower() for kw in query_keywords):
|
533 |
new_score *= boost_factor
|
534 |
|
|
|
540 |
# Sort boosted responses
|
541 |
boosted.sort(key=lambda x: x[1], reverse=True)
|
542 |
|
543 |
+
# Debug logging (see FAISS responses)
|
544 |
# for resp, score in boosted[:100]:
|
545 |
# logger.debug(f"Candidate: '{resp}' with score {score}")
|
546 |
|
|
|
554 |
top_k: int = 10,
|
555 |
) -> Tuple[str, List[Tuple[str, float]], Dict[str, Any]]:
|
556 |
"""
|
557 |
+
Live chat with the chatbot. Uses same processing flow as validation, except for context handling and quality checking.
|
|
|
558 |
"""
|
559 |
@self.run_on_device
|
560 |
def get_response(self_arg, query_arg):
|
|
|
562 |
conversation_str = self_arg._build_conversation_context(query_arg, conversation_history)
|
563 |
|
564 |
# Retrieve and re-rank
|
565 |
+
results = self_arg.retrieve_responses(
|
566 |
query=conversation_str,
|
567 |
top_k=top_k,
|
568 |
reranker=self_arg.reranker,
|
|
|
586 |
query: str,
|
587 |
conversation_history: Optional[List[Tuple[str, str]]]
|
588 |
) -> str:
|
589 |
+
"""
|
590 |
+
Build conversation context string from conversation history.
|
591 |
+
"""
|
592 |
if not conversation_history:
|
593 |
return f"{self.tokenizer.additional_special_tokens[self.tokenizer.additional_special_tokens.index('<USER>')]} {query}"
|
594 |
|
|
|
619 |
) -> None:
|
620 |
"""
|
621 |
Train the retrieval model using a pre-prepared TFRecord dataset.
|
|
|
622 |
- Checkpoint loading/restoring
|
623 |
- LR scheduling
|
624 |
- Epoch/iteration tracking
|
625 |
+
- Training-history logging
|
626 |
+
- Early stopping
|
627 |
+
- Custom loss function (Contrastive loss with hard negative sampling))
|
628 |
"""
|
629 |
logger.info("Starting training with pre-prepared TFRecord dataset...")
|
630 |
|
|
|
656 |
steps_per_epoch = math.ceil(train_size / batch_size)
|
657 |
val_steps = math.ceil(val_size / batch_size)
|
658 |
total_steps = steps_per_epoch * epochs
|
659 |
+
buffer_size = max(1, total_pairs // 2) # 50% of the dataset for shuffling
|
660 |
|
661 |
logger.info(f"Training pairs: {train_size}")
|
662 |
logger.info(f"Validation pairs: {val_size}")
|
|
|
678 |
self.optimizer = tf.keras.optimizers.Adam(learning_rate=tf.cast(peak_lr, tf.float32))
|
679 |
logger.info("Using fixed learning rate.")
|
680 |
|
681 |
+
# Dummy step to force initialization
|
682 |
dummy_input = tf.zeros((1, self.config.max_context_token_limit), dtype=tf.int32)
|
683 |
with tf.GradientTape() as tape:
|
684 |
dummy_output = self.encoder(dummy_input)
|
|
|
693 |
model=self.encoder
|
694 |
)
|
695 |
|
696 |
+
# Create a CheckpointManager
|
697 |
manager = tf.train.CheckpointManager(
|
698 |
checkpoint,
|
699 |
directory=checkpoint_dir,
|
|
|
701 |
checkpoint_name='ckpt'
|
702 |
)
|
703 |
|
704 |
+
# Restore from existing checkpoint if one is provided
|
705 |
latest_checkpoint = manager.latest_checkpoint
|
706 |
history_path = Path(checkpoint_dir) / 'training_history.json'
|
707 |
|
708 |
+
# Log epoch losses across runs, including restore from checkpoint
|
709 |
if not hasattr(self, 'history'):
|
710 |
self.history = {'train_loss': [], 'val_loss': [], 'learning_rate': []}
|
711 |
|
712 |
if latest_checkpoint and not test_mode:
|
713 |
+
# Debug checkpoint loading
|
714 |
+
# logger.info(f"\nTrying to load checkpoint from: {latest_checkpoint}")
|
715 |
+
# reader = tf.train.load_checkpoint(latest_checkpoint)
|
716 |
# shape_from_key = reader.get_variable_to_shape_map()
|
717 |
# dtype_from_key = reader.get_variable_to_dtype_map()
|
718 |
# logger.info("\nCheckpoint Variables:")
|
|
|
736 |
if initial_epoch == 0:
|
737 |
initial_epoch = ckpt_number
|
738 |
|
739 |
+
# Assign to checkpoint.epoch for counting
|
740 |
checkpoint.epoch.assign(tf.cast(initial_epoch, tf.int32))
|
741 |
logger.info(f"Resuming from epoch {initial_epoch}")
|
742 |
|
743 |
+
# Load history from file:
|
744 |
if history_path.exists():
|
745 |
try:
|
746 |
with open(history_path, 'r') as f:
|
|
|
749 |
except Exception as e:
|
750 |
logger.warning(f"Could not load history, starting fresh: {e}")
|
751 |
|
752 |
+
# Save custom weights not being saved in the full model.
|
753 |
+
# This was a bugfix to extract weights from a checkpoint without retraining.
|
754 |
+
# Before updating save_models, only Distilbert weights were being saved (custom layers were missed).
|
755 |
+
# Not needed, also not harmful.
|
756 |
self.save_models(Path(checkpoint_dir) / "pretrained_full_model")
|
757 |
logger.info(f"Manually saved custom weights after restore.")
|
758 |
else:
|
|
|
769 |
train_summary_writer = tf.summary.create_file_writer(train_log_dir)
|
770 |
val_summary_writer = tf.summary.create_file_writer(val_log_dir)
|
771 |
logger.info(f"TensorBoard logs will be saved in {log_dir}")
|
772 |
+
|
773 |
# Parse dataset
|
774 |
dataset = tf.data.TFRecordDataset(tfrecord_file_path)
|
775 |
+
|
776 |
+
# Debug mode uses small subset. Useful for CPU debugging.
|
777 |
if test_mode:
|
778 |
+
subset_size = 200
|
779 |
dataset = dataset.take(subset_size)
|
780 |
logger.info(f"TEST MODE: Using only {subset_size} examples")
|
781 |
# Recompute sizes, steps, epochs, etc., as needed
|
|
|
791 |
early_stopping_patience = 2
|
792 |
logger.info(f"New training pairs: {train_size}")
|
793 |
logger.info(f"New validation pairs: {val_size}")
|
794 |
+
|
795 |
dataset = dataset.map(
|
796 |
+
lambda x: parse_tfrecord_fn(x, self.config.max_context_token_limit, self.data_pipeline.neg_samples),
|
797 |
num_parallel_calls=tf.data.AUTOTUNE
|
798 |
)
|
799 |
+
|
800 |
# Train/val split
|
801 |
train_dataset = dataset.take(train_size)
|
802 |
val_dataset = dataset.skip(train_size).take(val_size)
|
803 |
+
|
804 |
# Shuffle and batch
|
805 |
train_dataset = train_dataset.shuffle(buffer_size=buffer_size)
|
806 |
train_dataset = train_dataset.batch(batch_size, drop_remainder=True)
|
807 |
train_dataset = train_dataset.prefetch(tf.data.AUTOTUNE)
|
808 |
+
|
809 |
val_dataset = val_dataset.batch(batch_size, drop_remainder=False)
|
810 |
val_dataset = val_dataset.prefetch(tf.data.AUTOTUNE)
|
811 |
val_dataset = val_dataset.cache()
|
812 |
+
|
813 |
# Training loop
|
814 |
best_val_loss = float("inf")
|
815 |
epochs_no_improve = 0
|
816 |
+
|
817 |
for epoch in range(int(checkpoint.epoch.numpy()) + 1, epochs + 1):
|
818 |
checkpoint.epoch.assign(epoch)
|
819 |
logger.info(f"Starting Epoch {epoch}...")
|
820 |
+
|
|
|
821 |
epoch_loss_avg = tf.keras.metrics.Mean(dtype=tf.float32)
|
822 |
batches_processed = 0
|
823 |
+
|
|
|
824 |
try:
|
825 |
train_pbar = tqdm(
|
826 |
total=steps_per_epoch,
|
|
|
831 |
except ImportError:
|
832 |
train_pbar = None
|
833 |
is_tqdm_train = False
|
834 |
+
|
835 |
+
# --- Training ---
|
836 |
for q_batch, p_batch, n_batch in train_dataset:
|
837 |
loss, grad_norm, post_clip_norm = self.train_step(q_batch, p_batch, n_batch)
|
838 |
epoch_loss_avg(loss)
|
|
|
860 |
"lr": f"{current_lr:.2e}",
|
861 |
"batches": f"{batches_processed}/{steps_per_epoch}"
|
862 |
})
|
863 |
+
|
864 |
gc.collect()
|
865 |
+
|
866 |
# End the epoch early if we've processed all steps
|
867 |
if batches_processed >= steps_per_epoch:
|
868 |
break
|
869 |
+
|
870 |
if is_tqdm_train and train_pbar:
|
871 |
train_pbar.close()
|
872 |
+
|
873 |
+
# --- Validation ---
|
874 |
val_loss_avg = tf.keras.metrics.Mean(dtype=tf.float32)
|
875 |
val_batches_processed = 0
|
876 |
+
|
877 |
try:
|
878 |
val_pbar = tqdm(total=val_steps, desc="Validation", unit="batch")
|
879 |
is_tqdm_val = True
|
880 |
except ImportError:
|
881 |
val_pbar = None
|
882 |
is_tqdm_val = False
|
883 |
+
|
884 |
last_valid_val_loss = None
|
885 |
valid_batches = False
|
886 |
+
|
887 |
for q_batch, p_batch, n_batch in val_dataset:
|
888 |
# If batch is too small, skip
|
889 |
if tf.shape(q_batch)[0] < 2:
|
890 |
logger.warning(f"Skipping validation batch of size {tf.shape(q_batch)[0]}")
|
891 |
continue
|
892 |
+
|
893 |
valid_batches = True
|
894 |
val_loss = self.validation_step(q_batch, p_batch, n_batch)
|
895 |
val_loss_avg(val_loss)
|
896 |
last_valid_val_loss = val_loss
|
897 |
val_batches_processed += 1
|
898 |
+
|
899 |
if is_tqdm_val:
|
900 |
val_pbar.update(1)
|
901 |
val_pbar.set_postfix({
|
902 |
"val_loss": f"{val_loss.numpy():.4f}",
|
903 |
"batches": f"{val_batches_processed}/{val_steps}"
|
904 |
})
|
905 |
+
|
906 |
gc.collect()
|
907 |
+
|
908 |
if val_batches_processed >= val_steps:
|
909 |
break
|
910 |
+
|
911 |
if not valid_batches:
|
912 |
# If no valid batch is found, fallback
|
913 |
logger.warning("No valid validation batches in this epoch")
|
|
|
917 |
else:
|
918 |
val_loss = epoch_loss_avg.result()
|
919 |
val_loss_avg(val_loss)
|
920 |
+
|
921 |
if is_tqdm_val and val_pbar:
|
922 |
val_pbar.close()
|
923 |
+
|
924 |
# End of epoch: final stats
|
925 |
train_loss = epoch_loss_avg.result().numpy()
|
926 |
val_loss = val_loss_avg.result().numpy()
|
927 |
logger.info(f"Epoch {epoch} Complete: Train Loss={train_loss:.4f}, Val Loss={val_loss:.4f}")
|
928 |
+
|
929 |
# TensorBoard epoch logs
|
930 |
with train_summary_writer.as_default():
|
931 |
tf.summary.scalar("epoch_loss", train_loss, step=epoch)
|
932 |
with val_summary_writer.as_default():
|
933 |
tf.summary.scalar("val_loss", val_loss, step=epoch)
|
934 |
+
|
935 |
# Save checkpoint
|
936 |
manager.save()
|
937 |
+
|
938 |
+
# Save model for iterative testing/inference
|
939 |
model_save_path = Path(checkpoint_dir) / f"model_epoch_{epoch}"
|
940 |
self.save_models(model_save_path)
|
941 |
logger.info(f"Saved model for epoch {epoch} at {model_save_path}")
|
942 |
+
|
943 |
# Update local history
|
944 |
self.history['train_loss'].append(train_loss)
|
945 |
self.history['val_loss'].append(val_loss)
|
|
|
958 |
return obj
|
959 |
|
960 |
json_history = convert_to_py_floats(self.history)
|
961 |
+
|
962 |
# Save training history to file every epoch
|
|
|
963 |
with open(history_path, 'w') as f:
|
964 |
json.dump(json_history, f)
|
965 |
logger.info(f"Saved training history to {history_path}")
|
966 |
+
|
967 |
# Early stopping
|
968 |
if val_loss < best_val_loss - min_delta:
|
969 |
best_val_loss = val_loss
|
|
|
975 |
if epochs_no_improve >= early_stopping_patience:
|
976 |
logger.info("Early stopping triggered.")
|
977 |
break
|
978 |
+
|
979 |
logger.info("Training completed!")
|
980 |
|
981 |
@tf.function
|
|
|
989 |
Single training step using queries, positives, and hard negatives.
|
990 |
"""
|
991 |
with tf.GradientTape() as tape:
|
992 |
+
# Encode queries, positives, and negatives
|
993 |
q_enc = self.encoder(q_batch, training=True) # [batch_size, embed_dim]
|
|
|
|
|
994 |
p_enc = self.encoder(p_batch, training=True) # [batch_size, embed_dim]
|
|
|
|
|
|
|
995 |
shape = tf.shape(n_batch)
|
996 |
bs = shape[0]
|
997 |
neg_samples = shape[1]
|
998 |
|
999 |
+
# Flatten negatives to feed them in one pass: [batch_size * neg_samples, max_length]
|
|
|
1000 |
n_batch_flat = tf.reshape(n_batch, [bs * neg_samples, shape[2]])
|
1001 |
n_enc_flat = self.encoder(n_batch_flat, training=True) # [bs*neg_samples, embed_dim]
|
1002 |
|
1003 |
# Reshape back => [batch_size, neg_samples, embed_dim]
|
1004 |
n_enc = tf.reshape(n_enc_flat, [bs, neg_samples, -1])
|
1005 |
|
1006 |
+
# Combine the positive embedding and negative embeddings along dim=1: shape [batch_size, 1 + neg_samples, embed_dim]
|
1007 |
+
# Col 1 is the pos, subsequent cols are negatives
|
1008 |
+
combined_p_n = tf.concat([tf.expand_dims(p_enc, axis=1), n_enc], axis=1) # [bs, (1+neg_samples), embed_dim]
|
|
|
|
|
|
|
|
|
1009 |
|
1010 |
+
# Compute scores: dot product of q_enc with each column in combined_p_n. `tf.einsum` handles the batch dimension
|
|
|
|
|
1011 |
dot_products = tf.cast(tf.einsum('bd,bkd->bk', q_enc, combined_p_n), tf.float32)
|
1012 |
labels = tf.zeros([bs], dtype=tf.int32) # Keep labels as int32
|
1013 |
loss = tf.nn.sparse_softmax_cross_entropy_with_logits(
|
|
|
1016 |
)
|
1017 |
loss = tf.cast(tf.reduce_mean(loss), tf.float32)
|
1018 |
|
1019 |
+
# Calculate gradients and clip
|
1020 |
gradients = tape.gradient(loss, self.encoder.trainable_variables)
|
1021 |
gradients_norm = tf.cast(tf.linalg.global_norm(gradients), tf.float32)
|
1022 |
max_grad_norm = tf.constant(1.5, dtype=tf.float32)
|
1023 |
gradients, _ = tf.clip_by_global_norm(gradients, max_grad_norm, gradients_norm)
|
1024 |
post_clip_norm = tf.cast(tf.linalg.global_norm(gradients), tf.float32)
|
1025 |
|
|
|
1026 |
self.optimizer.apply_gradients(zip(gradients, self.encoder.trainable_variables))
|
1027 |
|
1028 |
return loss, gradients_norm, post_clip_norm
|
|
|
1036 |
) -> tf.Tensor:
|
1037 |
"""
|
1038 |
Single validation step using queries, positives, and hard negatives.
|
1039 |
+
Same idea as train_step, but without gradient updates.
|
1040 |
"""
|
1041 |
q_enc = self.encoder(q_batch, training=False)
|
1042 |
p_enc = self.encoder(p_batch, training=False)
|
|
|
1055 |
)
|
1056 |
|
1057 |
dot_products = tf.cast(tf.einsum('bd,bkd->bk', q_enc, combined_p_n), tf.float32)
|
1058 |
+
labels = tf.zeros([bs], dtype=tf.int32)
|
1059 |
|
1060 |
loss = tf.nn.sparse_softmax_cross_entropy_with_logits(
|
1061 |
labels=labels,
|
|
|
1071 |
peak_lr: float,
|
1072 |
warmup_steps: int
|
1073 |
) -> tf.keras.optimizers.schedules.LearningRateSchedule:
|
1074 |
+
"""
|
1075 |
+
Custom learning rate schedule with warmup and cosine decay.
|
1076 |
+
"""
|
1077 |
class CustomSchedule(tf.keras.optimizers.schedules.LearningRateSchedule):
|
1078 |
def __init__(
|
1079 |
self,
|
|
|
1085 |
self.total_steps = tf.cast(total_steps, tf.float32)
|
1086 |
self.peak_lr = tf.cast(peak_lr, tf.float32)
|
1087 |
|
1088 |
+
# warmup_steps 10% of total_steps
|
1089 |
adjusted_warmup_steps = min(warmup_steps, max(1, total_steps // 10))
|
1090 |
self.warmup_steps = tf.cast(adjusted_warmup_steps, tf.float32)
|
1091 |
|
1092 |
+
# Calculate constants
|
1093 |
self.initial_lr = tf.cast(self.peak_lr * 0.1, tf.float32)
|
1094 |
self.min_lr = tf.cast(self.peak_lr * 0.01, tf.float32)
|
1095 |
|
|
|
1103 |
def __call__(self, step):
|
1104 |
step = tf.cast(step, tf.float32)
|
1105 |
|
1106 |
+
# Warmup
|
1107 |
warmup_factor = tf.cast(tf.minimum(1.0, step / self.warmup_steps), tf.float32)
|
1108 |
warmup_lr = self.initial_lr + (self.peak_lr - self.initial_lr) * warmup_factor
|
1109 |
|
1110 |
+
# Decay
|
1111 |
decay_steps = tf.cast(tf.maximum(1.0, self.total_steps - self.warmup_steps), tf.float32)
|
1112 |
decay_factor = tf.cast((step - self.warmup_steps) / decay_steps, tf.float32)
|
1113 |
decay_factor = tf.cast(tf.minimum(tf.maximum(0.0, decay_factor), 1.0), tf.float32)
|
1114 |
cosine_decay = tf.cast(0.5 * (1.0 + tf.cos(tf.constant(math.pi, dtype=tf.float32) * decay_factor)), tf.float32)
|
1115 |
decay_lr = self.min_lr + (self.peak_lr - self.min_lr) * cosine_decay
|
1116 |
|
|
|
1117 |
final_lr = tf.where(step < self.warmup_steps, warmup_lr, decay_lr)
|
1118 |
|
1119 |
+
# Ensure valid lr
|
1120 |
final_lr = tf.maximum(self.min_lr, final_lr)
|
1121 |
final_lr = tf.where(tf.math.is_finite(final_lr), final_lr, self.min_lr)
|
1122 |
|
chatbot_validator.py
CHANGED
@@ -113,7 +113,7 @@ class ChatbotValidator:
|
|
113 |
logger.info(f"\nTest Case {i}: {query}")
|
114 |
|
115 |
# Retrieve top_k responses, then evaluate with quality checker
|
116 |
-
responses = self.chatbot.
|
117 |
quality_metrics = self.quality_checker.check_response_quality(query, responses)
|
118 |
|
119 |
# Aggregate metrics and log
|
|
|
113 |
logger.info(f"\nTest Case {i}: {query}")
|
114 |
|
115 |
# Retrieve top_k responses, then evaluate with quality checker
|
116 |
+
responses = self.chatbot.retrieve_responses(query, top_k=top_k, reranker=reranker)
|
117 |
quality_metrics = self.quality_checker.check_response_quality(query, responses)
|
118 |
|
119 |
# Aggregate metrics and log
|
{data_augmentation β data_augmentation_code}/augmentation_processing_pipeline.py
RENAMED
File without changes
|
{data_augmentation β data_augmentation_code}/back_translator.py
RENAMED
File without changes
|
{data_augmentation β data_augmentation_code}/dialogue_augmenter.py
RENAMED
File without changes
|
{data_augmentation β data_augmentation_code}/main.py
RENAMED
File without changes
|
{data_augmentation β data_augmentation_code}/paraphraser.py
RENAMED
File without changes
|
{data_augmentation β data_augmentation_code}/pipeline_config.py
RENAMED
File without changes
|
{data_augmentation β data_augmentation_code}/quality_metrics.py
RENAMED
File without changes
|
{data_augmentation β data_augmentation_code}/schema_guided_dialogue_processor.py
RENAMED
File without changes
|
{data_augmentation β data_augmentation_code}/taskmaster_processor.py
RENAMED
File without changes
|
validate_model.py β run_chatbot_validation.py
RENAMED
@@ -39,7 +39,7 @@ def run_interactive_chat(chatbot, quality_checker):
|
|
39 |
else:
|
40 |
print("\n[Low Confidence]: Consider rephrasing your query for better assistance.")
|
41 |
|
42 |
-
def
|
43 |
# Initialize environment
|
44 |
env = EnvironmentSetup()
|
45 |
env.initialize()
|
@@ -86,15 +86,15 @@ def validate_chatbot():
|
|
86 |
try:
|
87 |
chatbot.data_pipeline.load_faiss_index(FAISS_INDEX_PATH)
|
88 |
logger.info(f"FAISS index loaded from {FAISS_INDEX_PATH}.")
|
89 |
-
logger.info("FAISS dimensions:
|
90 |
-
logger.info("FAISS index type:
|
91 |
-
logger.info("FAISS index total vectors:
|
92 |
-
logger.info("FAISS is_trained:
|
93 |
|
94 |
with open(RESPONSE_POOL_PATH, "r", encoding="utf-8") as f:
|
95 |
chatbot.data_pipeline.response_pool = json.load(f)
|
96 |
logger.info(f"Response pool loaded from {RESPONSE_POOL_PATH}.")
|
97 |
-
logger.info("\nTotal responses in pool:
|
98 |
|
99 |
# Validate dimension consistency
|
100 |
chatbot.data_pipeline.validate_faiss_index()
|
@@ -130,4 +130,4 @@ def validate_chatbot():
|
|
130 |
run_interactive_chat(chatbot, quality_checker)
|
131 |
|
132 |
if __name__ == "__main__":
|
133 |
-
|
|
|
39 |
else:
|
40 |
print("\n[Low Confidence]: Consider rephrasing your query for better assistance.")
|
41 |
|
42 |
+
def run_chatbot_validation():
|
43 |
# Initialize environment
|
44 |
env = EnvironmentSetup()
|
45 |
env.initialize()
|
|
|
86 |
try:
|
87 |
chatbot.data_pipeline.load_faiss_index(FAISS_INDEX_PATH)
|
88 |
logger.info(f"FAISS index loaded from {FAISS_INDEX_PATH}.")
|
89 |
+
logger.info(f"FAISS dimensions: {chatbot.data_pipeline.index.d}")
|
90 |
+
logger.info(f"FAISS index type: {type(chatbot.data_pipeline.index)}")
|
91 |
+
logger.info(f"FAISS index total vectors: {chatbot.data_pipeline.index.ntotal}")
|
92 |
+
logger.info(f"FAISS is_trained: {chatbot.data_pipeline.index.is_trained}")
|
93 |
|
94 |
with open(RESPONSE_POOL_PATH, "r", encoding="utf-8") as f:
|
95 |
chatbot.data_pipeline.response_pool = json.load(f)
|
96 |
logger.info(f"Response pool loaded from {RESPONSE_POOL_PATH}.")
|
97 |
+
logger.info(f"\nTotal responses in pool: {len(chatbot.data_pipeline.response_pool)}")
|
98 |
|
99 |
# Validate dimension consistency
|
100 |
chatbot.data_pipeline.validate_faiss_index()
|
|
|
130 |
run_interactive_chat(chatbot, quality_checker)
|
131 |
|
132 |
if __name__ == "__main__":
|
133 |
+
run_chatbot_validation()
|
tf_data_pipeline.py
CHANGED
@@ -24,19 +24,19 @@ class TFDataPipeline:
|
|
24 |
config,
|
25 |
tokenizer,
|
26 |
encoder,
|
27 |
-
index_file_path: str,
|
28 |
response_pool: List[str],
|
29 |
-
max_length: int,
|
30 |
query_embeddings_cache: dict,
|
31 |
-
|
|
|
32 |
index_type: str = 'IndexFlatIP',
|
|
|
33 |
nlist: int = 100,
|
34 |
max_retries: int = 3
|
35 |
):
|
36 |
self.config = config
|
37 |
self.tokenizer = tokenizer
|
38 |
self.encoder = encoder
|
39 |
-
self.
|
40 |
self.response_pool = response_pool
|
41 |
self.max_length = max_length
|
42 |
self.neg_samples = neg_samples
|
@@ -53,9 +53,9 @@ class TFDataPipeline:
|
|
53 |
self.build_text_to_domain_map()
|
54 |
|
55 |
# Initialize FAISS index
|
56 |
-
if os.path.exists(
|
57 |
-
logger.info(f"Loading existing FAISS index from {
|
58 |
-
self.index = faiss.read_index(
|
59 |
self.validate_faiss_index()
|
60 |
logger.info("FAISS index loaded and validated successfully.")
|
61 |
else:
|
@@ -83,18 +83,18 @@ class TFDataPipeline:
|
|
83 |
self.query_embeddings_cache[query] = hf[query][:]
|
84 |
logger.info(f"Embeddings cache loaded from {cache_file_path}.")
|
85 |
|
86 |
-
def save_faiss_index(self,
|
87 |
-
faiss.write_index(self.index,
|
88 |
-
logger.info(f"FAISS index saved to {
|
89 |
|
90 |
-
def load_faiss_index(self,
|
91 |
"""Load FAISS index from specified file path."""
|
92 |
-
if os.path.exists(
|
93 |
-
self.index = faiss.read_index(
|
94 |
-
logger.info(f"FAISS index loaded from {
|
95 |
else:
|
96 |
-
logger.error(f"FAISS index file not found at {
|
97 |
-
raise FileNotFoundError(f"FAISS index file not found at {
|
98 |
|
99 |
def validate_faiss_index(self):
|
100 |
"""Validates FAISS index dimensionality."""
|
|
|
24 |
config,
|
25 |
tokenizer,
|
26 |
encoder,
|
|
|
27 |
response_pool: List[str],
|
|
|
28 |
query_embeddings_cache: dict,
|
29 |
+
max_length: int = 512,
|
30 |
+
neg_samples: int = 10,
|
31 |
index_type: str = 'IndexFlatIP',
|
32 |
+
faiss_index_file_path: str = 'new_iteration/data_prep_iterative_models/faiss_indices/faiss_index_production.index',
|
33 |
nlist: int = 100,
|
34 |
max_retries: int = 3
|
35 |
):
|
36 |
self.config = config
|
37 |
self.tokenizer = tokenizer
|
38 |
self.encoder = encoder
|
39 |
+
self.faiss_index_file_path = faiss_index_file_path
|
40 |
self.response_pool = response_pool
|
41 |
self.max_length = max_length
|
42 |
self.neg_samples = neg_samples
|
|
|
53 |
self.build_text_to_domain_map()
|
54 |
|
55 |
# Initialize FAISS index
|
56 |
+
if os.path.exists(faiss_index_file_path):
|
57 |
+
logger.info(f"Loading existing FAISS index from {faiss_index_file_path}...")
|
58 |
+
self.index = faiss.read_index(faiss_index_file_path)
|
59 |
self.validate_faiss_index()
|
60 |
logger.info("FAISS index loaded and validated successfully.")
|
61 |
else:
|
|
|
83 |
self.query_embeddings_cache[query] = hf[query][:]
|
84 |
logger.info(f"Embeddings cache loaded from {cache_file_path}.")
|
85 |
|
86 |
+
def save_faiss_index(self, faiss_index_file_path: str):
|
87 |
+
faiss.write_index(self.index, faiss_index_file_path)
|
88 |
+
logger.info(f"FAISS index saved to {faiss_index_file_path}")
|
89 |
|
90 |
+
def load_faiss_index(self, faiss_index_file_path: str):
|
91 |
"""Load FAISS index from specified file path."""
|
92 |
+
if os.path.exists(faiss_index_file_path):
|
93 |
+
self.index = faiss.read_index(faiss_index_file_path)
|
94 |
+
logger.info(f"FAISS index loaded from {faiss_index_file_path}.")
|
95 |
else:
|
96 |
+
logger.error(f"FAISS index file not found at {faiss_index_file_path}.")
|
97 |
+
raise FileNotFoundError(f"FAISS index file not found at {faiss_index_file_path}.")
|
98 |
|
99 |
def validate_faiss_index(self):
|
100 |
"""Validates FAISS index dimensionality."""
|