JoeArmani
commited on
Commit
·
5b413d1
1
Parent(s):
ee0f664
training and inference updates
Browse files- chatbot_model.py +443 -995
- chatbot_validator.py +26 -26
- conversation_summarizer.py +11 -4
- environment_setup.py +0 -9
- run_data_preparer.py → prepare_data.py +14 -18
- response_quality_checker.py +40 -56
- test_trained_model.py +0 -0
- tf_data_pipeline.py +111 -327
- run_model_train.py → train_model.py +17 -51
- validate_model.py +117 -0
chatbot_model.py
CHANGED
@@ -1,7 +1,6 @@
|
|
1 |
-
import
|
2 |
from transformers import TFAutoModel, AutoTokenizer
|
3 |
import tensorflow as tf
|
4 |
-
import numpy as np
|
5 |
from typing import List, Tuple, Dict, Optional, Union, Any
|
6 |
import math
|
7 |
from dataclasses import dataclass
|
@@ -66,23 +65,17 @@ class EncoderModel(tf.keras.Model):
|
|
66 |
super().__init__(name=name, **kwargs)
|
67 |
self.config = config
|
68 |
|
69 |
-
# Load pretrained model
|
70 |
self.pretrained = TFAutoModel.from_pretrained(config.pretrained_model)
|
71 |
-
|
72 |
-
# Freeze layers based on config
|
73 |
self._freeze_layers()
|
74 |
|
75 |
-
# Pooling layer (Global Average Pooling)
|
76 |
self.pooler = tf.keras.layers.GlobalAveragePooling1D()
|
77 |
-
|
78 |
-
# Projection layer
|
79 |
self.projection = tf.keras.layers.Dense(
|
80 |
config.embedding_dim,
|
81 |
activation='tanh',
|
82 |
name="projection"
|
83 |
)
|
84 |
-
|
85 |
-
# Dropout and normalization
|
86 |
self.dropout = tf.keras.layers.Dropout(config.dropout_rate)
|
87 |
self.normalize = tf.keras.layers.Lambda(
|
88 |
lambda x: tf.nn.l2_normalize(x, axis=1),
|
@@ -110,13 +103,13 @@ class EncoderModel(tf.keras.Model):
|
|
110 |
"""Forward pass."""
|
111 |
# Get pretrained embeddings
|
112 |
pretrained_outputs = self.pretrained(inputs, training=training)
|
113 |
-
x = pretrained_outputs.last_hidden_state
|
114 |
|
115 |
# Apply pooling, projection, dropout, and normalization
|
116 |
-
x = self.pooler(x)
|
117 |
-
x = self.projection(x)
|
118 |
-
x = self.dropout(x, training=training)
|
119 |
-
x = self.normalize(x)
|
120 |
|
121 |
return x
|
122 |
|
@@ -134,12 +127,11 @@ class RetrievalChatbot(DeviceAwareModel):
|
|
134 |
def __init__(
|
135 |
self,
|
136 |
config: ChatbotConfig,
|
137 |
-
dialogues: List[dict] = [],
|
138 |
device: str = None,
|
139 |
strategy=None,
|
140 |
reranker: Optional[CrossEncoderReranker] = None,
|
141 |
summarizer: Optional[Summarizer] = None,
|
142 |
-
mode: str = '
|
143 |
):
|
144 |
super().__init__()
|
145 |
self.config = config
|
@@ -147,17 +139,37 @@ class RetrievalChatbot(DeviceAwareModel):
|
|
147 |
self.device = device or self._setup_default_device()
|
148 |
self.mode = mode.lower()
|
149 |
|
150 |
-
# Initialize reranker, summarizer, tokenizer, and memory monitor
|
151 |
self.reranker = reranker or self._initialize_reranker()
|
152 |
-
self.summarizer = summarizer or self._initialize_summarizer()
|
153 |
self.tokenizer = self._initialize_tokenizer()
|
|
|
|
|
154 |
self.memory_monitor = GPUMemoryMonitor()
|
155 |
|
156 |
-
#
|
157 |
-
|
158 |
-
|
159 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
160 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
161 |
# Initialize training history
|
162 |
self.history = {
|
163 |
"train_loss": [],
|
@@ -165,15 +177,7 @@ class RetrievalChatbot(DeviceAwareModel):
|
|
165 |
"train_metrics": {},
|
166 |
"val_metrics": {}
|
167 |
}
|
168 |
-
|
169 |
-
# Collect unique responses from dialogues
|
170 |
-
if self.mode == 'preparation':
|
171 |
-
# Collect unique responses from dialogues only in preparation mode
|
172 |
-
self.response_pool, self.unique_responses = self._collect_responses(dialogues)
|
173 |
-
else:
|
174 |
-
# In training mode, assume response_pool is handled via TFRecord
|
175 |
-
self.response_pool = []
|
176 |
-
self.unique_responses = []
|
177 |
|
178 |
def _setup_default_device(self) -> str:
|
179 |
"""Set up default device if none is provided."""
|
@@ -189,8 +193,13 @@ class RetrievalChatbot(DeviceAwareModel):
|
|
189 |
|
190 |
def _initialize_summarizer(self) -> Summarizer:
|
191 |
"""Initialize the Summarizer."""
|
192 |
-
|
193 |
-
|
|
|
|
|
|
|
|
|
|
|
194 |
|
195 |
def _initialize_tokenizer(self) -> AutoTokenizer:
|
196 |
"""Initialize the tokenizer and add special tokens."""
|
@@ -207,559 +216,127 @@ class RetrievalChatbot(DeviceAwareModel):
|
|
207 |
)
|
208 |
return tokenizer
|
209 |
|
210 |
-
def
|
211 |
-
"""
|
212 |
-
|
213 |
-
|
214 |
-
response_pool: List of all possible responses.
|
215 |
-
unique_responses: List of unique responses.
|
216 |
-
"""
|
217 |
-
logger.info("Collecting unique responses from dialogues...")
|
218 |
-
responses = set()
|
219 |
-
for dialogue in dialogues:
|
220 |
-
turns = dialogue.get('turns', [])
|
221 |
-
for turn in turns:
|
222 |
-
if turn.get('speaker') == 'assistant' and 'text' in turn:
|
223 |
-
response = turn['text'].strip()
|
224 |
-
if len(response) >= self.config.min_text_length:
|
225 |
-
responses.add(response)
|
226 |
-
response_pool = list(responses)
|
227 |
-
unique_responses = list(responses) # Assuming uniqueness
|
228 |
-
logger.info(f"Collected {len(response_pool)} unique responses.")
|
229 |
-
return response_pool, unique_responses
|
230 |
-
|
231 |
-
def build_models(self):
|
232 |
-
"""Initialize the shared encoder and FAISS index."""
|
233 |
-
logger.info("Building encoder model...")
|
234 |
-
tf.keras.backend.clear_session()
|
235 |
-
|
236 |
-
# Shared encoder for both queries and responses
|
237 |
-
self.encoder = EncoderModel(
|
238 |
self.config,
|
239 |
name="shared_encoder",
|
240 |
)
|
241 |
|
242 |
-
# Resize token embeddings after adding special tokens
|
243 |
new_vocab_size = len(self.tokenizer)
|
244 |
-
|
245 |
logger.info(f"Token embeddings resized to: {new_vocab_size}")
|
246 |
-
|
247 |
-
if self.mode == 'preparation':
|
248 |
-
# Initialize FAISS index only in preparation mode
|
249 |
-
self._initialize_faiss()
|
250 |
-
# Compute and index embeddings
|
251 |
-
self._compute_and_index_embeddings()
|
252 |
-
else:
|
253 |
-
# In training mode, skip FAISS indexing from dialogues
|
254 |
-
logger.info("Training mode: Skipping FAISS index initialization from dialogues.")
|
255 |
-
|
256 |
-
# Retrieve embedding dimension from encoder
|
257 |
-
embedding_dim = self.config.embedding_dim
|
258 |
-
vocab_size = len(self.tokenizer)
|
259 |
-
|
260 |
-
logger.info(f"Encoder Embedding Dimension: {embedding_dim}")
|
261 |
-
logger.info(f"Encoder Embedding Vocabulary Size: {vocab_size}")
|
262 |
-
if vocab_size >= embedding_dim:
|
263 |
-
logger.info("Encoder model built and embeddings resized successfully.")
|
264 |
-
else:
|
265 |
-
logger.error("Vocabulary size is less than embedding dimension.")
|
266 |
-
raise ValueError("Vocabulary size is less than embedding dimension.")
|
267 |
|
268 |
-
def
|
269 |
-
"""
|
270 |
-
if self.memory_monitor.should_reduce_batch_size():
|
271 |
-
new_size = max(self.min_batch_size, self.current_batch_size // 2)
|
272 |
-
if new_size != self.current_batch_size:
|
273 |
-
logger.info(f"Reducing batch size to {new_size} due to high memory usage")
|
274 |
-
self.current_batch_size = new_size
|
275 |
-
gc.collect()
|
276 |
-
if tf.config.list_physical_devices('GPU'):
|
277 |
-
tf.keras.backend.clear_session()
|
278 |
-
elif self.memory_monitor.can_increase_batch_size():
|
279 |
-
new_size = min(self.max_batch_size, self.current_batch_size * 2)
|
280 |
-
if new_size != self.current_batch_size:
|
281 |
-
logger.info(f"Increasing batch size to {new_size}")
|
282 |
-
self.current_batch_size = new_size
|
283 |
-
|
284 |
-
def _initialize_faiss(self):
|
285 |
-
"""Initialize FAISS with safe GPU handling and memory monitoring."""
|
286 |
-
logger.info("Initializing FAISS index...")
|
287 |
-
|
288 |
-
# Detect if we have GPU-enabled FAISS
|
289 |
-
self.faiss_gpu = False
|
290 |
-
self.gpu_resources = []
|
291 |
-
|
292 |
-
try:
|
293 |
-
if hasattr(faiss, 'get_num_gpus'):
|
294 |
-
ngpus = faiss.get_num_gpus()
|
295 |
-
if ngpus > 0:
|
296 |
-
# Configure GPU resources with memory limit
|
297 |
-
for i in range(ngpus):
|
298 |
-
res = faiss.StandardGpuResources()
|
299 |
-
# Set temp memory to 1/4 of total memory to avoid OOM
|
300 |
-
if self.memory_monitor.has_gpu:
|
301 |
-
stats = self.memory_monitor.get_memory_stats()
|
302 |
-
if stats:
|
303 |
-
temp_memory = int(stats.total * 0.25) # 25% of total memory
|
304 |
-
res.setTempMemory(temp_memory)
|
305 |
-
self.gpu_resources.append(res)
|
306 |
-
self.faiss_gpu = True
|
307 |
-
logger.info(f"FAISS GPU resources initialized on {ngpus} GPUs")
|
308 |
-
except Exception as e:
|
309 |
-
logger.warning(f"Using CPU due to GPU initialization error: {e}")
|
310 |
-
|
311 |
try:
|
312 |
-
|
313 |
-
|
314 |
-
|
315 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
316 |
else:
|
317 |
-
|
318 |
-
|
|
|
|
|
|
|
|
|
319 |
|
320 |
-
# Move to GPU(s) if available and needed
|
321 |
-
if self.faiss_gpu and self.gpu_resources:
|
322 |
-
try:
|
323 |
-
if len(self.gpu_resources) > 1:
|
324 |
-
self.index = faiss.index_cpu_to_gpus_list(self.index, self.gpu_resources)
|
325 |
-
logger.info("FAISS index distributed across multiple GPUs")
|
326 |
-
else:
|
327 |
-
self.index = faiss.index_cpu_to_gpu(self.gpu_resources[0], 0, self.index)
|
328 |
-
logger.info("FAISS index moved to single GPU")
|
329 |
-
except Exception as e:
|
330 |
-
logger.warning(f"Failed to move index to GPU: {e}. Falling back to CPU")
|
331 |
-
self.faiss_gpu = False
|
332 |
except Exception as e:
|
333 |
-
logger.error(f"
|
334 |
raise
|
335 |
-
|
336 |
-
|
337 |
-
|
338 |
-
responses: List[str],
|
339 |
-
batch_size: int = 64
|
340 |
-
) -> tf.Tensor:
|
341 |
-
"""
|
342 |
-
Encodes responses with more conservative memory management.
|
343 |
"""
|
344 |
-
|
345 |
-
logger.info("No responses to encode. Returning empty tensor.")
|
346 |
-
return tf.constant([], dtype=tf.float32)
|
347 |
-
|
348 |
-
all_embeddings = []
|
349 |
-
self.current_batch_size = batch_size
|
350 |
-
|
351 |
-
if self.memory_monitor.has_gpu:
|
352 |
-
batch_size = 128
|
353 |
-
|
354 |
-
total_processed = 0
|
355 |
|
356 |
-
|
357 |
-
|
358 |
-
|
359 |
-
|
360 |
-
|
361 |
-
|
362 |
-
self.current_batch_size = max(128, self.current_batch_size // 2)
|
363 |
-
logger.info(f"High GPU memory usage ({gpu_usage:.1%}), reducing batch size to {self.current_batch_size}")
|
364 |
-
gc.collect()
|
365 |
-
tf.keras.backend.clear_session()
|
366 |
-
|
367 |
-
# Get batch
|
368 |
-
end_idx = min(total_processed + self.current_batch_size, len(responses))
|
369 |
-
batch_texts = responses[total_processed:end_idx]
|
370 |
-
|
371 |
-
try:
|
372 |
-
# Tokenize
|
373 |
-
encodings = self.tokenizer(
|
374 |
-
batch_texts,
|
375 |
-
padding='max_length',
|
376 |
-
truncation=True,
|
377 |
-
max_length=self.config.max_context_token_limit,
|
378 |
-
return_tensors='tf'
|
379 |
-
)
|
380 |
-
|
381 |
-
# Encode
|
382 |
-
embeddings_batch = self.encoder(encodings['input_ids'], training=False)
|
383 |
-
|
384 |
-
# Cast to float32
|
385 |
-
if embeddings_batch.dtype != tf.float32:
|
386 |
-
embeddings_batch = tf.cast(embeddings_batch, tf.float32)
|
387 |
-
|
388 |
-
# Store
|
389 |
-
all_embeddings.append(embeddings_batch)
|
390 |
-
|
391 |
-
# Update progress
|
392 |
-
batch_processed = len(batch_texts)
|
393 |
-
total_processed += batch_processed
|
394 |
-
|
395 |
-
# Update progress bar
|
396 |
-
if self.memory_monitor.has_gpu:
|
397 |
-
gpu_usage = self.memory_monitor.get_memory_usage()
|
398 |
-
pbar.set_postfix({
|
399 |
-
'GPU mem': f'{gpu_usage:.1%}',
|
400 |
-
'batch_size': self.current_batch_size
|
401 |
-
})
|
402 |
-
pbar.update(batch_processed)
|
403 |
-
|
404 |
-
# Memory cleanup every 1000 samples
|
405 |
-
if total_processed % 1000 == 0:
|
406 |
-
gc.collect()
|
407 |
-
if tf.config.list_physical_devices('GPU'):
|
408 |
-
tf.keras.backend.clear_session()
|
409 |
-
|
410 |
-
except tf.errors.ResourceExhaustedError:
|
411 |
-
logger.warning("GPU memory exhausted during encoding, reducing batch size")
|
412 |
-
self.current_batch_size = max(8, self.current_batch_size // 2)
|
413 |
-
continue
|
414 |
-
|
415 |
-
except Exception as e:
|
416 |
-
logger.error(f"Error during encoding: {str(e)}")
|
417 |
-
raise
|
418 |
-
|
419 |
-
# Concatenate results
|
420 |
-
if not all_embeddings:
|
421 |
-
logger.info("No embeddings were encoded. Returning empty tensor.")
|
422 |
-
return tf.constant([], dtype=tf.float32)
|
423 |
-
|
424 |
-
if len(all_embeddings) == 1:
|
425 |
-
final_embeddings = all_embeddings[0]
|
426 |
-
else:
|
427 |
-
final_embeddings = tf.concat(all_embeddings, axis=0)
|
428 |
-
|
429 |
-
return final_embeddings
|
430 |
-
|
431 |
-
def _train_faiss_index(self, response_embeddings: np.ndarray) -> None:
|
432 |
-
"""Train FAISS index with better memory management and robust fallback mechanisms."""
|
433 |
-
if self.index.is_trained:
|
434 |
-
logger.info("Index already trained, skipping training phase")
|
435 |
-
return
|
436 |
-
|
437 |
-
logger.info("Starting FAISS index training...")
|
438 |
|
439 |
-
|
440 |
-
|
441 |
-
|
442 |
-
logger.info(f"Using {subset_size} samples for initial training attempt")
|
443 |
-
subset_idx = np.random.choice(len(response_embeddings), subset_size, replace=False)
|
444 |
-
training_embeddings = response_embeddings[subset_idx].copy() # Make a copy
|
445 |
-
|
446 |
-
# Ensure contiguous memory layout
|
447 |
-
training_embeddings = np.ascontiguousarray(training_embeddings)
|
448 |
-
|
449 |
-
# Force cleanup before training
|
450 |
-
gc.collect()
|
451 |
-
if tf.config.list_physical_devices('GPU'):
|
452 |
-
tf.keras.backend.clear_session()
|
453 |
-
|
454 |
-
# Verify data properties
|
455 |
-
logger.info(f"FAISS training data shape: {training_embeddings.shape}")
|
456 |
-
logger.info(f"FAISS training data dtype: {training_embeddings.dtype}")
|
457 |
-
|
458 |
-
logger.info("Starting initial training attempt...")
|
459 |
-
self.index.train(training_embeddings)
|
460 |
-
logger.info("Training completed successfully")
|
461 |
-
|
462 |
-
except (RuntimeError, Exception) as e:
|
463 |
-
logger.warning(f"Initial training attempt failed: {str(e)}")
|
464 |
-
logger.info("Attempting fallback strategy...")
|
465 |
-
|
466 |
-
try:
|
467 |
-
# Move to CPU for more stable training
|
468 |
-
if self.faiss_gpu:
|
469 |
-
logger.info("Moving index to CPU for fallback training")
|
470 |
-
cpu_index = faiss.index_gpu_to_cpu(self.index)
|
471 |
-
else:
|
472 |
-
cpu_index = self.index
|
473 |
-
|
474 |
-
# Create simpler index type if needed
|
475 |
-
if isinstance(cpu_index, faiss.IndexIVFFlat):
|
476 |
-
logger.info("Creating simpler FlatL2 index for fallback")
|
477 |
-
cpu_index = faiss.IndexFlatL2(self.config.embedding_dim)
|
478 |
-
|
479 |
-
# Use even smaller subset for CPU training
|
480 |
-
subset_size = min(2000, len(response_embeddings))
|
481 |
-
subset_idx = np.random.choice(len(response_embeddings), subset_size, replace=False)
|
482 |
-
fallback_embeddings = response_embeddings[subset_idx].copy()
|
483 |
-
|
484 |
-
# Ensure data is properly formatted
|
485 |
-
if not fallback_embeddings.flags['C_CONTIGUOUS']:
|
486 |
-
fallback_embeddings = np.ascontiguousarray(fallback_embeddings)
|
487 |
-
if fallback_embeddings.dtype != np.float32:
|
488 |
-
fallback_embeddings = fallback_embeddings.astype(np.float32)
|
489 |
-
|
490 |
-
# Train on CPU
|
491 |
-
logger.info("Training fallback index on CPU...")
|
492 |
-
cpu_index.train(fallback_embeddings)
|
493 |
-
|
494 |
-
# Move back to GPU if needed
|
495 |
-
if self.faiss_gpu:
|
496 |
-
logger.info("Moving trained index back to GPU...")
|
497 |
-
if len(self.gpu_resources) > 1:
|
498 |
-
self.index = faiss.index_cpu_to_gpus_list(cpu_index, self.gpu_resources)
|
499 |
-
else:
|
500 |
-
self.index = faiss.index_cpu_to_gpu(self.gpu_resources[0], 0, cpu_index)
|
501 |
-
else:
|
502 |
-
self.index = cpu_index
|
503 |
-
|
504 |
-
logger.info("Fallback training completed successfully")
|
505 |
-
|
506 |
-
except Exception as e2:
|
507 |
-
logger.error(f"Fallback training also failed: {str(e2)}")
|
508 |
-
logger.warning("Creating basic brute-force index as last resort")
|
509 |
-
|
510 |
-
try:
|
511 |
-
# Create basic brute-force index as last resort
|
512 |
-
dim = response_embeddings.shape[1]
|
513 |
-
basic_index = faiss.IndexFlatL2(dim)
|
514 |
-
|
515 |
-
if self.faiss_gpu:
|
516 |
-
if len(self.gpu_resources) > 1:
|
517 |
-
self.index = faiss.index_cpu_to_gpus_list(basic_index, self.gpu_resources)
|
518 |
-
else:
|
519 |
-
self.index = faiss.index_cpu_to_gpu(self.gpu_resources[0], 0, basic_index)
|
520 |
-
else:
|
521 |
-
self.index = basic_index
|
522 |
-
|
523 |
-
logger.info("Basic index created as fallback")
|
524 |
-
|
525 |
-
except Exception as e3:
|
526 |
-
logger.error(f"All training attempts failed: {str(e3)}")
|
527 |
-
raise RuntimeError("Unable to create working FAISS index")
|
528 |
-
|
529 |
-
def _add_vectors_to_index(self, response_embeddings: np.ndarray) -> None:
|
530 |
-
"""Add vectors to FAISS index with enhanced memory management."""
|
531 |
-
logger.info("Starting vector addition process...")
|
532 |
|
533 |
-
#
|
534 |
-
|
535 |
-
min_batch_size = 32
|
536 |
-
max_batch_size = 1024
|
537 |
|
538 |
-
|
539 |
-
|
540 |
-
|
|
|
|
|
541 |
|
542 |
-
|
543 |
-
|
544 |
-
|
545 |
-
if self.memory_monitor.has_gpu:
|
546 |
-
gpu_usage = self.memory_monitor.get_memory_usage()
|
547 |
-
#logger.info(f"GPU memory usage before batch: {gpu_usage:.1%}")
|
548 |
-
|
549 |
-
# Force cleanup if memory usage is high
|
550 |
-
if gpu_usage > 0.7: # Lower threshold to 70%
|
551 |
-
logger.info("High memory usage detected, forcing cleanup")
|
552 |
-
gc.collect()
|
553 |
-
tf.keras.backend.clear_session()
|
554 |
-
|
555 |
-
# Get batch
|
556 |
-
end_idx = min(total_added + initial_batch_size, len(response_embeddings))
|
557 |
-
batch = response_embeddings[total_added:end_idx]
|
558 |
-
|
559 |
-
# Add batch
|
560 |
-
self.index.add(batch)
|
561 |
-
|
562 |
-
# Update progress
|
563 |
-
batch_size = len(batch)
|
564 |
-
total_added += batch_size
|
565 |
-
|
566 |
-
# Memory cleanup every few batches
|
567 |
-
if total_added % (initial_batch_size * 5) == 0:
|
568 |
-
gc.collect()
|
569 |
-
if tf.config.list_physical_devices('GPU'):
|
570 |
-
tf.keras.backend.clear_session()
|
571 |
-
|
572 |
-
# Gradually increase batch size
|
573 |
-
if initial_batch_size < max_batch_size:
|
574 |
-
initial_batch_size = min(initial_batch_size + 25, max_batch_size)
|
575 |
-
|
576 |
-
except Exception as e:
|
577 |
-
logger.warning(f"Error adding batch: {str(e)}")
|
578 |
-
retry_count += 1
|
579 |
-
|
580 |
-
if retry_count > max_retries:
|
581 |
-
logger.error("Max retries exceeded.")
|
582 |
-
raise
|
583 |
-
|
584 |
-
# Reduce batch size
|
585 |
-
initial_batch_size = max(min_batch_size, initial_batch_size // 2)
|
586 |
-
logger.info(f"Reducing batch size to {initial_batch_size} and retrying...")
|
587 |
-
|
588 |
-
# Cleanup
|
589 |
-
gc.collect()
|
590 |
-
if tf.config.list_physical_devices('GPU'):
|
591 |
-
tf.keras.backend.clear_session()
|
592 |
-
|
593 |
-
time.sleep(1) # Brief pause before retry
|
594 |
|
595 |
-
|
596 |
-
|
597 |
-
|
598 |
-
"""CPU fallback with extra safeguards and progress tracking."""
|
599 |
-
logger.info(f"CPU Fallback: Adding {len(remaining_embeddings)} remaining vectors...")
|
600 |
|
|
|
|
|
|
|
|
|
|
|
601 |
try:
|
602 |
-
#
|
603 |
-
|
604 |
-
|
605 |
-
|
|
|
606 |
else:
|
607 |
-
|
608 |
-
|
609 |
-
# Add remaining vectors on CPU with very small batches
|
610 |
-
batch_size = 128
|
611 |
-
total_added = already_added
|
612 |
-
|
613 |
-
for i in range(0, len(remaining_embeddings), batch_size):
|
614 |
-
end_idx = min(i + batch_size, len(remaining_embeddings))
|
615 |
-
batch = remaining_embeddings[i:end_idx]
|
616 |
-
|
617 |
-
# Add batch
|
618 |
-
cpu_index.add(batch)
|
619 |
-
|
620 |
-
# Update progress
|
621 |
-
total_added += len(batch)
|
622 |
-
if i % (batch_size * 10) == 0:
|
623 |
-
logger.info(f"Added {total_added} vectors total "
|
624 |
-
f"({i}/{len(remaining_embeddings)} in current phase)")
|
625 |
-
|
626 |
-
# Periodic cleanup
|
627 |
-
if i % (batch_size * 20) == 0:
|
628 |
-
gc.collect()
|
629 |
|
630 |
-
#
|
631 |
-
|
632 |
-
|
633 |
-
|
634 |
-
|
635 |
-
|
636 |
-
self.index = faiss.index_cpu_to_gpu(self.gpu_resources[0], 0, cpu_index)
|
637 |
else:
|
638 |
-
|
639 |
-
|
640 |
-
logger.info("CPU fallback completed successfully")
|
641 |
-
|
642 |
-
except Exception as e:
|
643 |
-
logger.error(f"Error during CPU fallback: {str(e)}")
|
644 |
-
raise
|
645 |
-
|
646 |
-
def _compute_and_index_embeddings(self):
|
647 |
-
"""Compute embeddings and build FAISS index with simpler handling."""
|
648 |
-
logger.info("Computing embeddings and indexing with FAISS...")
|
649 |
-
|
650 |
-
try:
|
651 |
-
# Encode responses with memory monitoring
|
652 |
-
logger.info("Encoding unique responses")
|
653 |
-
response_embeddings = self.encode_responses(self.unique_responses)
|
654 |
-
response_embeddings = response_embeddings.numpy()
|
655 |
-
|
656 |
-
# Memory cleanup after encoding
|
657 |
-
gc.collect()
|
658 |
-
if tf.config.list_physical_devices('GPU'):
|
659 |
-
tf.keras.backend.clear_session()
|
660 |
-
|
661 |
-
# Ensure float32 and memory contiguous
|
662 |
-
response_embeddings = response_embeddings.astype('float32')
|
663 |
-
response_embeddings = np.ascontiguousarray(response_embeddings)
|
664 |
-
|
665 |
-
# Log memory state before normalization
|
666 |
-
if self.memory_monitor.has_gpu:
|
667 |
-
stats = self.memory_monitor.get_memory_stats()
|
668 |
-
if stats:
|
669 |
-
logger.info(f"GPU memory before normalization: {stats.used/1e9:.2f}GB used")
|
670 |
-
|
671 |
-
# Normalize embeddings
|
672 |
-
logger.info("Normalizing embeddings with FAISS")
|
673 |
-
faiss.normalize_L2(response_embeddings)
|
674 |
-
|
675 |
-
# Create and initialize simple FlatIP index
|
676 |
-
dim = response_embeddings.shape[1]
|
677 |
-
if self.faiss_gpu:
|
678 |
-
cpu_index = faiss.IndexFlatIP(dim)
|
679 |
-
if len(self.gpu_resources) > 1:
|
680 |
-
self.index = faiss.index_cpu_to_gpus_list(cpu_index, self.gpu_resources)
|
681 |
-
else:
|
682 |
-
self.index = faiss.index_cpu_to_gpu(self.gpu_resources[0], 0, cpu_index)
|
683 |
-
else:
|
684 |
-
self.index = faiss.IndexFlatIP(dim)
|
685 |
-
|
686 |
-
# Add vectors to index
|
687 |
-
self._add_vectors_to_index(response_embeddings)
|
688 |
-
|
689 |
-
# Store responses and embeddings
|
690 |
-
self.response_pool = self.unique_responses
|
691 |
-
self.response_embeddings = response_embeddings
|
692 |
-
|
693 |
-
# Final memory cleanup
|
694 |
-
gc.collect()
|
695 |
-
if tf.config.list_physical_devices('GPU'):
|
696 |
-
tf.keras.backend.clear_session()
|
697 |
-
|
698 |
-
# Log final state
|
699 |
-
logger.info(f"Successfully indexed {self.index.ntotal} responses")
|
700 |
-
if self.memory_monitor.has_gpu:
|
701 |
-
stats = self.memory_monitor.get_memory_stats()
|
702 |
-
if stats:
|
703 |
-
logger.info(f"Final GPU memory usage: {stats.used/1e9:.2f}GB used")
|
704 |
-
|
705 |
-
logger.info("Indexing completed successfully")
|
706 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
707 |
except Exception as e:
|
708 |
-
logger.error(f"Error
|
709 |
-
# Ensure cleanup even on error
|
710 |
-
gc.collect()
|
711 |
-
if tf.config.list_physical_devices('GPU'):
|
712 |
-
tf.keras.backend.clear_session()
|
713 |
raise
|
714 |
-
|
715 |
-
def
|
716 |
-
"""
|
717 |
-
|
718 |
-
|
719 |
-
return
|
720 |
-
|
721 |
-
indexed_size = self.index.ntotal
|
722 |
-
pool_size = len(self.response_pool)
|
723 |
-
logger.info(f"FAISS index size: {indexed_size}")
|
724 |
-
logger.info(f"Response pool size: {pool_size}")
|
725 |
-
if indexed_size != pool_size:
|
726 |
-
logger.warning("Mismatch between FAISS index size and response pool size.")
|
727 |
-
else:
|
728 |
-
logger.info("FAISS index correctly matches the response pool.")
|
729 |
-
|
730 |
-
def encode_query(self, query: str, context: Optional[List[Tuple[str, str]]] = None) -> tf.Tensor:
|
731 |
-
"""Encode a query with optional conversation context."""
|
732 |
-
# Prepare query with context
|
733 |
-
if context:
|
734 |
-
context_str = ' '.join([
|
735 |
-
f"{self.special_tokens['user']} {q} "
|
736 |
-
f"{self.special_tokens['assistant']} {r}"
|
737 |
-
for q, r in context[-self.config.max_context_turns:]
|
738 |
-
])
|
739 |
-
query = f"{context_str} {self.special_tokens['user']} {query}"
|
740 |
-
else:
|
741 |
-
query = f"{self.special_tokens['user']} {query}"
|
742 |
|
743 |
-
#
|
744 |
-
|
745 |
-
|
746 |
-
padding='max_length',
|
747 |
-
truncation=True,
|
748 |
-
max_length=self.config.max_context_token_limit,
|
749 |
-
return_tensors='tf'
|
750 |
-
)
|
751 |
-
input_ids = encodings['input_ids']
|
752 |
|
753 |
-
#
|
754 |
-
|
755 |
-
new_vocab_size = len(self.tokenizer)
|
756 |
|
757 |
-
|
758 |
-
|
759 |
-
raise ValueError("Token ID exceeds vocabulary size.")
|
760 |
|
761 |
-
|
762 |
-
return self.encoder(input_ids, training=False)
|
763 |
|
764 |
def retrieve_responses_cross_encoder(
|
765 |
self,
|
@@ -786,7 +363,7 @@ class RetrievalChatbot(DeviceAwareModel):
|
|
786 |
|
787 |
# 2) Dense retrieval
|
788 |
dense_topk = self.retrieve_responses_faiss(query, top_k=top_k) # [(resp, dense_score), ...]
|
789 |
-
|
790 |
if not dense_topk:
|
791 |
return []
|
792 |
|
@@ -800,75 +377,228 @@ class RetrievalChatbot(DeviceAwareModel):
|
|
800 |
combined.sort(key=lambda x: x[1], reverse=True)
|
801 |
|
802 |
return combined
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
803 |
|
804 |
def retrieve_responses_faiss(self, query: str, top_k: int = 5) -> List[Tuple[str, float]]:
|
805 |
"""Retrieve top-k responses using FAISS."""
|
806 |
-
if not hasattr(self, 'index') or self.index is None:
|
807 |
logger.warning("FAISS index not initialized. Cannot retrieve responses.")
|
808 |
return []
|
809 |
|
810 |
-
# Encode the query
|
811 |
-
q_emb = self.encode_query(query) #
|
812 |
q_emb_np = q_emb.numpy().astype('float32') # Ensure type match
|
813 |
|
814 |
# Normalize the query embedding for cosine similarity
|
815 |
faiss.normalize_L2(q_emb_np)
|
816 |
|
817 |
# Search the FAISS index
|
818 |
-
distances, indices = self.index.search(q_emb_np, top_k)
|
819 |
|
820 |
# Map indices to responses and distances to similarities
|
821 |
top_responses = []
|
822 |
for i, idx in enumerate(indices[0]):
|
823 |
-
if idx < len(self.response_pool):
|
824 |
-
top_responses.append((self.response_pool[idx], float(distances[0][i])))
|
825 |
else:
|
826 |
logger.warning(f"FAISS returned invalid index {idx}. Skipping.")
|
827 |
|
828 |
return top_responses
|
829 |
-
|
830 |
-
|
831 |
-
|
832 |
-
|
833 |
-
|
834 |
|
835 |
-
|
836 |
-
|
837 |
-
|
838 |
-
|
839 |
-
# Save models
|
840 |
-
self.encoder.pretrained.save_pretrained(save_dir / "shared_encoder")
|
841 |
|
842 |
-
|
843 |
-
|
844 |
|
845 |
-
|
846 |
-
|
847 |
-
@classmethod
|
848 |
-
def load_models(cls, load_dir: Union[str, Path]) -> 'RetrievalChatbot':
|
849 |
-
"""Load saved models and configuration."""
|
850 |
-
load_dir = Path(load_dir)
|
851 |
|
852 |
-
|
853 |
-
|
854 |
-
|
855 |
-
|
856 |
-
|
857 |
-
|
|
|
858 |
|
859 |
-
|
860 |
-
|
861 |
-
|
862 |
-
|
863 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
864 |
|
865 |
-
|
866 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
867 |
|
868 |
-
|
869 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
870 |
|
871 |
-
def
|
872 |
self,
|
873 |
tfrecord_file_path: str,
|
874 |
epochs: int = 20,
|
@@ -876,10 +606,12 @@ class RetrievalChatbot(DeviceAwareModel):
|
|
876 |
validation_split: float = 0.2,
|
877 |
checkpoint_dir: str = "checkpoints/",
|
878 |
use_lr_schedule: bool = True,
|
879 |
-
peak_lr: float =
|
880 |
warmup_steps_ratio: float = 0.1,
|
881 |
early_stopping_patience: int = 3,
|
882 |
min_delta: float = 1e-4,
|
|
|
|
|
883 |
) -> None:
|
884 |
"""Training using a pre-prepared TFRecord dataset."""
|
885 |
logger.info("Starting training with pre-prepared TFRecord dataset...")
|
@@ -908,8 +640,8 @@ class RetrievalChatbot(DeviceAwareModel):
|
|
908 |
negative_ids = tf.cast(parsed_features['negative_ids'], tf.int32)
|
909 |
negative_ids = tf.reshape(negative_ids, [neg_samples, max_length])
|
910 |
|
911 |
-
return query_ids, positive_ids, negative_ids
|
912 |
-
|
913 |
# Calculate total steps by counting the number of records in the TFRecord
|
914 |
raw_dataset = tf.data.TFRecordDataset(tfrecord_file_path)
|
915 |
total_pairs = sum(1 for _ in raw_dataset)
|
@@ -920,6 +652,7 @@ class RetrievalChatbot(DeviceAwareModel):
|
|
920 |
steps_per_epoch = math.ceil(train_size / batch_size)
|
921 |
val_steps = math.ceil(val_size / batch_size)
|
922 |
total_steps = steps_per_epoch * epochs
|
|
|
923 |
|
924 |
logger.info(f"Training pairs: {train_size}")
|
925 |
logger.info(f"Validation pairs: {val_size}")
|
@@ -942,9 +675,42 @@ class RetrievalChatbot(DeviceAwareModel):
|
|
942 |
logger.info("Using fixed learning rate.")
|
943 |
|
944 |
# Initialize checkpoint manager
|
945 |
-
checkpoint = tf.train.Checkpoint(
|
946 |
-
|
|
|
|
|
|
|
|
|
|
|
947 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
948 |
# Setup TensorBoard
|
949 |
log_dir = Path(checkpoint_dir) / "tensorboard_logs"
|
950 |
log_dir.mkdir(parents=True, exist_ok=True)
|
@@ -960,20 +726,47 @@ class RetrievalChatbot(DeviceAwareModel):
|
|
960 |
|
961 |
# Create the full dataset
|
962 |
dataset = tf.data.TFRecordDataset(tfrecord_file_path)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
963 |
dataset = dataset.map(parse_fn, num_parallel_calls=tf.data.AUTOTUNE)
|
964 |
-
|
965 |
-
|
966 |
-
dataset = dataset.prefetch(tf.data.AUTOTUNE)
|
967 |
-
|
968 |
-
# Split into training and validation
|
969 |
train_dataset = dataset.take(train_size)
|
970 |
val_dataset = dataset.skip(train_size).take(val_size)
|
971 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
972 |
# Training loop
|
973 |
best_val_loss = float("inf")
|
974 |
epochs_no_improve = 0
|
975 |
|
976 |
-
for epoch in range(1, epochs + 1):
|
977 |
# --- Training Phase ---
|
978 |
epoch_loss_avg = tf.keras.metrics.Mean()
|
979 |
batches_processed = 0
|
@@ -987,13 +780,28 @@ class RetrievalChatbot(DeviceAwareModel):
|
|
987 |
logger.info("Training progress bar disabled")
|
988 |
|
989 |
for q_batch, p_batch, n_batch in train_dataset:
|
990 |
-
loss = self.train_step(q_batch, p_batch, n_batch)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
991 |
epoch_loss_avg(loss)
|
992 |
batches_processed += 1
|
993 |
|
994 |
# Log to TensorBoard
|
995 |
with train_summary_writer.as_default():
|
996 |
-
|
|
|
|
|
|
|
997 |
|
998 |
# Update progress bar
|
999 |
if use_lr_schedule:
|
@@ -1005,6 +813,8 @@ class RetrievalChatbot(DeviceAwareModel):
|
|
1005 |
train_pbar.update(1)
|
1006 |
train_pbar.set_postfix({
|
1007 |
"loss": f"{loss.numpy():.4f}",
|
|
|
|
|
1008 |
"lr": f"{current_lr:.2e}",
|
1009 |
"batches": f"{batches_processed}/{steps_per_epoch}"
|
1010 |
})
|
@@ -1064,6 +874,11 @@ class RetrievalChatbot(DeviceAwareModel):
|
|
1064 |
|
1065 |
# Save checkpoint
|
1066 |
manager.save()
|
|
|
|
|
|
|
|
|
|
|
1067 |
|
1068 |
# Store metrics in history
|
1069 |
self.history['train_loss'].append(train_loss)
|
@@ -1074,8 +889,14 @@ class RetrievalChatbot(DeviceAwareModel):
|
|
1074 |
else:
|
1075 |
current_lr = float(self.optimizer.learning_rate.numpy())
|
1076 |
|
|
|
1077 |
self.history.setdefault('learning_rate', []).append(current_lr)
|
1078 |
|
|
|
|
|
|
|
|
|
|
|
1079 |
# Early stopping logic
|
1080 |
if val_loss < best_val_loss - min_delta:
|
1081 |
best_val_loss = val_loss
|
@@ -1144,10 +965,19 @@ class RetrievalChatbot(DeviceAwareModel):
|
|
1144 |
)
|
1145 |
loss = tf.reduce_mean(loss)
|
1146 |
|
1147 |
-
#
|
1148 |
gradients = tape.gradient(loss, self.encoder.trainable_variables)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1149 |
self.optimizer.apply_gradients(zip(gradients, self.encoder.trainable_variables))
|
1150 |
-
|
|
|
1151 |
|
1152 |
@tf.function
|
1153 |
def validation_step(
|
@@ -1185,316 +1015,6 @@ class RetrievalChatbot(DeviceAwareModel):
|
|
1185 |
loss = tf.reduce_mean(loss)
|
1186 |
|
1187 |
return loss
|
1188 |
-
# def train_streaming(
|
1189 |
-
# self,
|
1190 |
-
# dialogues: List[dict],
|
1191 |
-
# epochs: int = 20,
|
1192 |
-
# batch_size: int = 16,
|
1193 |
-
# validation_split: float = 0.2,
|
1194 |
-
# checkpoint_dir: str = "checkpoints/",
|
1195 |
-
# use_lr_schedule: bool = True,
|
1196 |
-
# peak_lr: float = 2e-5,
|
1197 |
-
# warmup_steps_ratio: float = 0.1,
|
1198 |
-
# early_stopping_patience: int = 3,
|
1199 |
-
# min_delta: float = 1e-4,
|
1200 |
-
# neg_samples: int = 1
|
1201 |
-
# ) -> None:
|
1202 |
-
# """Streaming training with tf.data pipeline."""
|
1203 |
-
# logger.info("Starting streaming training pipeline with tf.data...")
|
1204 |
-
|
1205 |
-
# # Initialize TFDataPipeline (replaces StreamingDataPipeline)
|
1206 |
-
# dataset_preparer = TFDataPipeline(
|
1207 |
-
# embedding_batch_size=self.config.embedding_batch_size,
|
1208 |
-
# tokenizer=self.tokenizer,
|
1209 |
-
# encoder=self.encoder,
|
1210 |
-
# index=self.index, # Pass CPU version of FAISS index
|
1211 |
-
# response_pool=self.response_pool,
|
1212 |
-
# max_length=self.config.max_context_token_limit,
|
1213 |
-
# neg_samples=neg_samples
|
1214 |
-
# )
|
1215 |
-
|
1216 |
-
# # Calculate total steps for learning rate schedule
|
1217 |
-
# total_pairs = dataset_preparer.estimate_total_pairs(dialogues)
|
1218 |
-
# train_size = int(total_pairs * (1 - validation_split))
|
1219 |
-
# val_size = int(total_pairs * validation_split)
|
1220 |
-
# steps_per_epoch = int(math.ceil(train_size / batch_size))
|
1221 |
-
# val_steps = int(math.ceil(val_size / batch_size))
|
1222 |
-
# total_steps = steps_per_epoch * epochs
|
1223 |
-
|
1224 |
-
# logger.info(f"Total pairs: {total_pairs}")
|
1225 |
-
# logger.info(f"Training pairs: {train_size}")
|
1226 |
-
# logger.info(f"Validation pairs: {val_size}")
|
1227 |
-
# logger.info(f"Steps per epoch: {steps_per_epoch}")
|
1228 |
-
# logger.info(f"Validation steps: {val_steps}")
|
1229 |
-
# logger.info(f"Total steps: {total_steps}")
|
1230 |
-
|
1231 |
-
# # Set up optimizer with learning rate schedule
|
1232 |
-
# if use_lr_schedule:
|
1233 |
-
# warmup_steps = int(total_steps * warmup_steps_ratio)
|
1234 |
-
# lr_schedule = self._get_lr_schedule(
|
1235 |
-
# total_steps=total_steps,
|
1236 |
-
# peak_lr=peak_lr,
|
1237 |
-
# warmup_steps=warmup_steps
|
1238 |
-
# )
|
1239 |
-
# self.optimizer = tf.keras.optimizers.Adam(learning_rate=lr_schedule)
|
1240 |
-
# logger.info("Using custom learning rate schedule.")
|
1241 |
-
# else:
|
1242 |
-
# self.optimizer = tf.keras.optimizers.Adam(learning_rate=peak_lr)
|
1243 |
-
# logger.info("Using fixed learning rate.")
|
1244 |
-
|
1245 |
-
# # Initialize checkpoint manager
|
1246 |
-
# checkpoint = tf.train.Checkpoint(optimizer=self.optimizer, model=self.encoder)
|
1247 |
-
# manager = tf.train.CheckpointManager(checkpoint, checkpoint_dir, max_to_keep=3)
|
1248 |
-
|
1249 |
-
# # Setup TensorBoard
|
1250 |
-
# log_dir = Path(checkpoint_dir) / "tensorboard_logs"
|
1251 |
-
# log_dir.mkdir(parents=True, exist_ok=True)
|
1252 |
-
# current_time = datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
|
1253 |
-
# train_log_dir = str(log_dir / f"train_{current_time}")
|
1254 |
-
# val_log_dir = str(log_dir / f"val_{current_time}")
|
1255 |
-
# train_summary_writer = tf.summary.create_file_writer(train_log_dir)
|
1256 |
-
# val_summary_writer = tf.summary.create_file_writer(val_log_dir)
|
1257 |
-
# logger.info(f"TensorBoard logs will be saved in {log_dir}")
|
1258 |
-
|
1259 |
-
# # Create training and validation datasets
|
1260 |
-
# train_dataset = dataset_preparer.get_tf_dataset(dialogues, batch_size).take(train_size)
|
1261 |
-
# val_dataset = dataset_preparer.get_tf_dataset(dialogues, batch_size).skip(train_size).take(val_size)
|
1262 |
-
|
1263 |
-
# # Training loop
|
1264 |
-
# best_val_loss = float("inf")
|
1265 |
-
# epochs_no_improve = 0
|
1266 |
-
|
1267 |
-
# for epoch in range(1, epochs + 1):
|
1268 |
-
# # --- Training Phase ---
|
1269 |
-
# epoch_loss_avg = tf.keras.metrics.Mean()
|
1270 |
-
# batches_processed = 0
|
1271 |
-
|
1272 |
-
# try:
|
1273 |
-
# train_pbar = tqdm(total=steps_per_epoch, desc=f"Training Epoch {epoch}", unit="batch")
|
1274 |
-
# is_tqdm_train = True
|
1275 |
-
# except ImportError:
|
1276 |
-
# train_pbar = None
|
1277 |
-
# is_tqdm_train = False
|
1278 |
-
# logger.info("Training progress bar disabled")
|
1279 |
-
|
1280 |
-
# for q_batch, p_batch, n_batch in train_dataset:
|
1281 |
-
# #p_batch = p_n_batch[:, 0, :] # Extract positive from (positive, negative) pair
|
1282 |
-
# loss = self.train_step(q_batch, p_batch, n_batch)
|
1283 |
-
# epoch_loss_avg(loss)
|
1284 |
-
# batches_processed += 1
|
1285 |
-
|
1286 |
-
# # Log to TensorBoard
|
1287 |
-
# with train_summary_writer.as_default():
|
1288 |
-
# tf.summary.scalar("loss", loss, step=(epoch - 1) * steps_per_epoch + batches_processed)
|
1289 |
-
|
1290 |
-
# # Update progress bar
|
1291 |
-
# if use_lr_schedule:
|
1292 |
-
# current_lr = float(lr_schedule(self.optimizer.iterations))
|
1293 |
-
# else:
|
1294 |
-
# current_lr = float(self.optimizer.learning_rate.numpy())
|
1295 |
-
|
1296 |
-
# if is_tqdm_train:
|
1297 |
-
# train_pbar.update(1)
|
1298 |
-
# train_pbar.set_postfix({
|
1299 |
-
# "loss": f"{loss.numpy():.4f}",
|
1300 |
-
# "lr": f"{current_lr:.2e}",
|
1301 |
-
# "batches": f"{batches_processed}/{steps_per_epoch}"
|
1302 |
-
# })
|
1303 |
-
|
1304 |
-
# # Memory cleanup
|
1305 |
-
# gc.collect()
|
1306 |
-
|
1307 |
-
# if batches_processed >= steps_per_epoch:
|
1308 |
-
# break
|
1309 |
-
|
1310 |
-
# if is_tqdm_train and train_pbar:
|
1311 |
-
# train_pbar.close()
|
1312 |
-
|
1313 |
-
# # --- Validation Phase ---
|
1314 |
-
# val_loss_avg = tf.keras.metrics.Mean()
|
1315 |
-
# val_batches_processed = 0
|
1316 |
-
|
1317 |
-
# try:
|
1318 |
-
# val_pbar = tqdm(total=val_steps, desc="Validation", unit="batch")
|
1319 |
-
# is_tqdm_val = True
|
1320 |
-
# except ImportError:
|
1321 |
-
# val_pbar = None
|
1322 |
-
# is_tqdm_val = False
|
1323 |
-
# logger.info("Validation progress bar disabled")
|
1324 |
-
|
1325 |
-
# for q_batch, p_batch, n_batch in val_dataset:
|
1326 |
-
# #p_batch = p_n_batch[:, 0, :] # Extract positive from (positive, negative) pair
|
1327 |
-
# val_loss = self.validation_step(q_batch, p_batch, n_batch)
|
1328 |
-
# val_loss_avg(val_loss)
|
1329 |
-
# val_batches_processed += 1
|
1330 |
-
|
1331 |
-
# if is_tqdm_val:
|
1332 |
-
# val_pbar.update(1)
|
1333 |
-
# val_pbar.set_postfix({
|
1334 |
-
# "val_loss": f"{val_loss.numpy():.4f}",
|
1335 |
-
# "batches": f"{val_batches_processed}/{val_steps}"
|
1336 |
-
# })
|
1337 |
-
|
1338 |
-
# # Memory cleanup
|
1339 |
-
# gc.collect()
|
1340 |
-
|
1341 |
-
|
1342 |
-
# if val_batches_processed >= val_steps:
|
1343 |
-
# break
|
1344 |
-
|
1345 |
-
# if is_tqdm_val and val_pbar:
|
1346 |
-
# val_pbar.close()
|
1347 |
-
|
1348 |
-
# # End of epoch: compute final epoch stats, log, and save checkpoint
|
1349 |
-
# train_loss = epoch_loss_avg.result().numpy()
|
1350 |
-
# val_loss = val_loss_avg.result().numpy()
|
1351 |
-
# logger.info(f"Epoch {epoch} Complete: Train Loss={train_loss:.4f}, Val Loss={val_loss:.4f}")
|
1352 |
-
|
1353 |
-
# # Log epoch metrics
|
1354 |
-
# with train_summary_writer.as_default():
|
1355 |
-
# tf.summary.scalar("epoch_loss", train_loss, step=epoch)
|
1356 |
-
# with val_summary_writer.as_default():
|
1357 |
-
# tf.summary.scalar("val_loss", val_loss, step=epoch)
|
1358 |
-
|
1359 |
-
# # Save checkpoint
|
1360 |
-
# manager.save()
|
1361 |
-
|
1362 |
-
# # Store metrics in history
|
1363 |
-
# self.history['train_loss'].append(train_loss)
|
1364 |
-
# self.history['val_loss'].append(val_loss)
|
1365 |
-
|
1366 |
-
# if use_lr_schedule:
|
1367 |
-
# current_lr = float(lr_schedule(self.optimizer.iterations))
|
1368 |
-
# else:
|
1369 |
-
# current_lr = float(self.optimizer.learning_rate.numpy())
|
1370 |
-
|
1371 |
-
# self.history.setdefault('learning_rate', []).append(current_lr)
|
1372 |
-
|
1373 |
-
# # Early stopping logic
|
1374 |
-
# if val_loss < best_val_loss - min_delta:
|
1375 |
-
# best_val_loss = val_loss
|
1376 |
-
# epochs_no_improve = 0
|
1377 |
-
# logger.info(f"Validation loss improved to {val_loss:.4f}. Reset patience.")
|
1378 |
-
# else:
|
1379 |
-
# epochs_no_improve += 1
|
1380 |
-
# logger.info(f"No improvement this epoch. Patience: {epochs_no_improve}/{early_stopping_patience}")
|
1381 |
-
# if epochs_no_improve >= early_stopping_patience:
|
1382 |
-
# logger.info("Early stopping triggered.")
|
1383 |
-
# break
|
1384 |
-
|
1385 |
-
# logger.info("Streaming training completed!")
|
1386 |
-
|
1387 |
-
|
1388 |
-
# @tf.function
|
1389 |
-
# def train_step(
|
1390 |
-
# self,
|
1391 |
-
# q_batch: tf.Tensor,
|
1392 |
-
# p_batch: tf.Tensor,
|
1393 |
-
# n_batch: tf.Tensor,
|
1394 |
-
# attention_mask: Optional[tf.Tensor] = None
|
1395 |
-
# ) -> tf.Tensor:
|
1396 |
-
# """
|
1397 |
-
# Single training step that uses queries, positives, and negatives in a
|
1398 |
-
# contrastive/InfoNCE style. The label is always 0 (the positive) vs.
|
1399 |
-
# the negative alternatives.
|
1400 |
-
# """
|
1401 |
-
# with tf.GradientTape() as tape:
|
1402 |
-
# # Encode queries
|
1403 |
-
# q_enc = self.encoder(q_batch, training=True) # [batch_size, embed_dim]
|
1404 |
-
|
1405 |
-
# # Encode positives
|
1406 |
-
# p_enc = self.encoder(p_batch, training=True) # [batch_size, embed_dim]
|
1407 |
-
|
1408 |
-
# # Encode negatives
|
1409 |
-
# # n_batch: [batch_size, neg_samples, max_length]
|
1410 |
-
# shape = tf.shape(n_batch)
|
1411 |
-
# bs = shape[0]
|
1412 |
-
# neg_samples = shape[1]
|
1413 |
-
|
1414 |
-
# # Flatten negatives to feed them in one pass:
|
1415 |
-
# # => [batch_size * neg_samples, max_length]
|
1416 |
-
# n_batch_flat = tf.reshape(n_batch, [bs * neg_samples, shape[2]])
|
1417 |
-
# n_enc_flat = self.encoder(n_batch_flat, training=True) # [bs*neg_samples, embed_dim]
|
1418 |
-
|
1419 |
-
# # Reshape back => [batch_size, neg_samples, embed_dim]
|
1420 |
-
# n_enc = tf.reshape(n_enc_flat, [bs, neg_samples, -1])
|
1421 |
-
|
1422 |
-
# # Combine the positive embedding and negative embeddings along dim=1
|
1423 |
-
# # => shape [batch_size, 1 + neg_samples, embed_dim]
|
1424 |
-
# # The first column is the positive; subsequent columns are negatives
|
1425 |
-
# combined_p_n = tf.concat(
|
1426 |
-
# [tf.expand_dims(p_enc, axis=1), n_enc],
|
1427 |
-
# axis=1
|
1428 |
-
# ) # [bs, (1+neg_samples), embed_dim]
|
1429 |
-
|
1430 |
-
# # Now compute scores: dot product of q_enc with each column in combined_p_n
|
1431 |
-
# # We'll use `tf.einsum` to handle the batch dimension properly
|
1432 |
-
# # dot_products => shape [batch_size, (1+neg_samples)]
|
1433 |
-
# dot_products = tf.einsum('bd,bkd->bk', q_enc, combined_p_n)
|
1434 |
-
|
1435 |
-
# # The label for each row is 0 (the first column is the correct/positive)
|
1436 |
-
# labels = tf.zeros([bs], dtype=tf.int32)
|
1437 |
-
|
1438 |
-
# # Cross-entropy over the [batch_size, 1+neg_samples] scores
|
1439 |
-
# loss = tf.nn.sparse_softmax_cross_entropy_with_logits(
|
1440 |
-
# labels=labels,
|
1441 |
-
# logits=dot_products
|
1442 |
-
# )
|
1443 |
-
# loss = tf.reduce_mean(loss)
|
1444 |
-
|
1445 |
-
# # If there's an attention_mask you want to apply (less common in this scenario),
|
1446 |
-
# # you could do something like:
|
1447 |
-
# if attention_mask is not None:
|
1448 |
-
# loss = loss * attention_mask
|
1449 |
-
# loss = tf.reduce_sum(loss) / tf.reduce_sum(attention_mask)
|
1450 |
-
|
1451 |
-
# # Apply gradients
|
1452 |
-
# gradients = tape.gradient(loss, self.encoder.trainable_variables)
|
1453 |
-
# self.optimizer.apply_gradients(zip(gradients, self.encoder.trainable_variables))
|
1454 |
-
# return loss
|
1455 |
-
|
1456 |
-
# @tf.function
|
1457 |
-
# def validation_step(
|
1458 |
-
# self,
|
1459 |
-
# q_batch: tf.Tensor,
|
1460 |
-
# p_batch: tf.Tensor,
|
1461 |
-
# n_batch: tf.Tensor,
|
1462 |
-
# attention_mask: Optional[tf.Tensor] = None
|
1463 |
-
# ) -> tf.Tensor:
|
1464 |
-
# """
|
1465 |
-
# Single validation step with queries, positives, and negatives.
|
1466 |
-
# Uses the same loss calculation as train_step, but `training=False`.
|
1467 |
-
# """
|
1468 |
-
# q_enc = self.encoder(q_batch, training=False)
|
1469 |
-
# p_enc = self.encoder(p_batch, training=False)
|
1470 |
-
|
1471 |
-
# shape = tf.shape(n_batch)
|
1472 |
-
# bs = shape[0]
|
1473 |
-
# neg_samples = shape[1]
|
1474 |
-
|
1475 |
-
# n_batch_flat = tf.reshape(n_batch, [bs * neg_samples, shape[2]])
|
1476 |
-
# n_enc_flat = self.encoder(n_batch_flat, training=False)
|
1477 |
-
# n_enc = tf.reshape(n_enc_flat, [bs, neg_samples, -1])
|
1478 |
-
|
1479 |
-
# combined_p_n = tf.concat(
|
1480 |
-
# [tf.expand_dims(p_enc, axis=1), n_enc],
|
1481 |
-
# axis=1
|
1482 |
-
# )
|
1483 |
-
|
1484 |
-
# dot_products = tf.einsum('bd,bkd->bk', q_enc, combined_p_n)
|
1485 |
-
# labels = tf.zeros([bs], dtype=tf.int32)
|
1486 |
-
|
1487 |
-
# loss = tf.nn.sparse_softmax_cross_entropy_with_logits(
|
1488 |
-
# labels=labels,
|
1489 |
-
# logits=dot_products
|
1490 |
-
# )
|
1491 |
-
# loss = tf.reduce_mean(loss)
|
1492 |
-
|
1493 |
-
# if attention_mask is not None:
|
1494 |
-
# loss = loss * attention_mask
|
1495 |
-
# loss = tf.reduce_sum(loss) / tf.reduce_sum(attention_mask)
|
1496 |
-
|
1497 |
-
# return loss
|
1498 |
|
1499 |
def _get_lr_schedule(
|
1500 |
self,
|
@@ -1561,75 +1081,3 @@ class RetrievalChatbot(DeviceAwareModel):
|
|
1561 |
}
|
1562 |
|
1563 |
return CustomSchedule(total_steps, peak_lr, warmup_steps)
|
1564 |
-
|
1565 |
-
def _cosine_similarity(self, emb1: np.ndarray, emb2: np.ndarray) -> np.ndarray:
|
1566 |
-
"""Compute cosine similarity between two numpy arrays."""
|
1567 |
-
normalized_emb1 = emb1 / np.linalg.norm(emb1, axis=1, keepdims=True)
|
1568 |
-
normalized_emb2 = emb2 / np.linalg.norm(emb2, axis=1, keepdims=True)
|
1569 |
-
return np.dot(normalized_emb1, normalized_emb2.T)
|
1570 |
-
|
1571 |
-
def chat(
|
1572 |
-
self,
|
1573 |
-
query: str,
|
1574 |
-
conversation_history: Optional[List[Tuple[str, str]]] = None,
|
1575 |
-
quality_checker: Optional['ResponseQualityChecker'] = None,
|
1576 |
-
top_k: int = 5,
|
1577 |
-
) -> Tuple[str, List[Tuple[str, float]], Dict[str, Any]]:
|
1578 |
-
"""
|
1579 |
-
Example chat method that always uses cross-encoder re-ranking
|
1580 |
-
if self.reranker is available.
|
1581 |
-
"""
|
1582 |
-
@self.run_on_device
|
1583 |
-
def get_response(self_arg, query_arg): # Add parameters that match decorator's expectations
|
1584 |
-
# 1) Build conversation context string
|
1585 |
-
conversation_str = self_arg._build_conversation_context(query_arg, conversation_history)
|
1586 |
-
|
1587 |
-
# 2) Retrieve + cross-encoder re-rank
|
1588 |
-
results = self_arg.retrieve_responses_cross_encoder(
|
1589 |
-
query=conversation_str,
|
1590 |
-
top_k=top_k,
|
1591 |
-
reranker=self_arg.reranker,
|
1592 |
-
summarizer=self_arg.summarizer,
|
1593 |
-
summarize_threshold=512
|
1594 |
-
)
|
1595 |
-
|
1596 |
-
# 3) Handle empty or confidence
|
1597 |
-
if not results:
|
1598 |
-
return (
|
1599 |
-
"I'm sorry, but I couldn't find a relevant response.",
|
1600 |
-
[],
|
1601 |
-
{}
|
1602 |
-
)
|
1603 |
-
|
1604 |
-
if quality_checker:
|
1605 |
-
metrics = quality_checker.check_response_quality(query_arg, results)
|
1606 |
-
if not metrics.get('is_confident', False):
|
1607 |
-
return (
|
1608 |
-
"I need more information to provide a good answer. Could you please clarify?",
|
1609 |
-
results,
|
1610 |
-
metrics
|
1611 |
-
)
|
1612 |
-
return results[0][0], results, metrics
|
1613 |
-
|
1614 |
-
return results[0][0], results, {}
|
1615 |
-
|
1616 |
-
return get_response(self, query)
|
1617 |
-
|
1618 |
-
def _build_conversation_context(
|
1619 |
-
self,
|
1620 |
-
query: str,
|
1621 |
-
conversation_history: Optional[List[Tuple[str, str]]]
|
1622 |
-
) -> str:
|
1623 |
-
"""Build conversation context with better memory management."""
|
1624 |
-
if not conversation_history:
|
1625 |
-
return f"{self.special_tokens['user']} {query}"
|
1626 |
-
|
1627 |
-
conversation_parts = []
|
1628 |
-
for user_txt, assistant_txt in conversation_history:
|
1629 |
-
conversation_parts.extend([
|
1630 |
-
f"{self.special_tokens['user']} {user_txt}",
|
1631 |
-
f"{self.special_tokens['assistant']} {assistant_txt}"
|
1632 |
-
])
|
1633 |
-
|
1634 |
-
conversation_parts.append(f"{self.special_tokens['user']} {query}")
|
1635 |
-
return "\n".join(conversation_parts)
|
|
|
1 |
+
import os
|
2 |
from transformers import TFAutoModel, AutoTokenizer
|
3 |
import tensorflow as tf
|
|
|
4 |
from typing import List, Tuple, Dict, Optional, Union, Any
|
5 |
import math
|
6 |
from dataclasses import dataclass
|
|
|
65 |
super().__init__(name=name, **kwargs)
|
66 |
self.config = config
|
67 |
|
68 |
+
# Load pretrained model and freeze layers based on config
|
69 |
self.pretrained = TFAutoModel.from_pretrained(config.pretrained_model)
|
|
|
|
|
70 |
self._freeze_layers()
|
71 |
|
72 |
+
# Add Pooling layer (Global Average Pooling), Projection layer, Dropout, and Normalization
|
73 |
self.pooler = tf.keras.layers.GlobalAveragePooling1D()
|
|
|
|
|
74 |
self.projection = tf.keras.layers.Dense(
|
75 |
config.embedding_dim,
|
76 |
activation='tanh',
|
77 |
name="projection"
|
78 |
)
|
|
|
|
|
79 |
self.dropout = tf.keras.layers.Dropout(config.dropout_rate)
|
80 |
self.normalize = tf.keras.layers.Lambda(
|
81 |
lambda x: tf.nn.l2_normalize(x, axis=1),
|
|
|
103 |
"""Forward pass."""
|
104 |
# Get pretrained embeddings
|
105 |
pretrained_outputs = self.pretrained(inputs, training=training)
|
106 |
+
x = pretrained_outputs.last_hidden_state # Shape: [batch_size, seq_len, embedding_dim]
|
107 |
|
108 |
# Apply pooling, projection, dropout, and normalization
|
109 |
+
x = self.pooler(x) # Shape: [batch_size, 768]
|
110 |
+
x = self.projection(x) # Shape: [batch_size, 768]
|
111 |
+
x = self.dropout(x, training=training)
|
112 |
+
x = self.normalize(x) # Shape: [batch_size, 768]
|
113 |
|
114 |
return x
|
115 |
|
|
|
127 |
def __init__(
|
128 |
self,
|
129 |
config: ChatbotConfig,
|
|
|
130 |
device: str = None,
|
131 |
strategy=None,
|
132 |
reranker: Optional[CrossEncoderReranker] = None,
|
133 |
summarizer: Optional[Summarizer] = None,
|
134 |
+
mode: str = 'training'
|
135 |
):
|
136 |
super().__init__()
|
137 |
self.config = config
|
|
|
139 |
self.device = device or self._setup_default_device()
|
140 |
self.mode = mode.lower()
|
141 |
|
142 |
+
# Initialize reranker, summarizer, tokenizer, encoder, and memory monitor
|
143 |
self.reranker = reranker or self._initialize_reranker()
|
|
|
144 |
self.tokenizer = self._initialize_tokenizer()
|
145 |
+
self.encoder = self._initialize_encoder()
|
146 |
+
self.summarizer = summarizer or self._initialize_summarizer()
|
147 |
self.memory_monitor = GPUMemoryMonitor()
|
148 |
|
149 |
+
# Initialize data pipeline
|
150 |
+
logger.info("Initializing TFDataPipeline.")
|
151 |
+
self.data_pipeline = TFDataPipeline(
|
152 |
+
config=self.config,
|
153 |
+
tokenizer=self.tokenizer,
|
154 |
+
encoder=self.encoder,
|
155 |
+
index_file_path='path/to/index', # Update as needed # TODO: Update this path
|
156 |
+
response_pool=[],
|
157 |
+
max_length=self.config.max_context_token_limit,
|
158 |
+
query_embeddings_cache={},
|
159 |
+
neg_samples=self.config.neg_samples,
|
160 |
+
index_type='IndexFlatIP',
|
161 |
+
nlist=100, # Not used with IndexFlatIP
|
162 |
+
max_retries=self.config.max_retries
|
163 |
+
)
|
164 |
|
165 |
+
# Collect unique responses from dialogues
|
166 |
+
if self.mode == 'inference':
|
167 |
+
logger.info("Mode set to 'inference'. Loading FAISS index and response pool.")
|
168 |
+
self._load_faiss_index_and_responses()
|
169 |
+
elif self.mode != 'training':
|
170 |
+
logger.error(f"Unsupported mode in RetrievalChatbot init: {self.mode}")
|
171 |
+
raise ValueError(f"Unsupported mode in RetrievalChatbot init: {self.mode}")
|
172 |
+
|
173 |
# Initialize training history
|
174 |
self.history = {
|
175 |
"train_loss": [],
|
|
|
177 |
"train_metrics": {},
|
178 |
"val_metrics": {}
|
179 |
}
|
180 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
181 |
|
182 |
def _setup_default_device(self) -> str:
|
183 |
"""Set up default device if none is provided."""
|
|
|
193 |
|
194 |
def _initialize_summarizer(self) -> Summarizer:
|
195 |
"""Initialize the Summarizer."""
|
196 |
+
return Summarizer(
|
197 |
+
tokenizer=self.tokenizer,
|
198 |
+
model_name="t5-small",
|
199 |
+
max_summary_length=self.config.max_context_token_limit // 4,
|
200 |
+
device=self.device,
|
201 |
+
max_summary_rounds=2
|
202 |
+
)
|
203 |
|
204 |
def _initialize_tokenizer(self) -> AutoTokenizer:
|
205 |
"""Initialize the tokenizer and add special tokens."""
|
|
|
216 |
)
|
217 |
return tokenizer
|
218 |
|
219 |
+
def _initialize_encoder(self) -> EncoderModel:
|
220 |
+
"""Initialize the EncoderModel and resize token embeddings."""
|
221 |
+
logger.info("Initializing encoder model...")
|
222 |
+
encoder = EncoderModel(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
223 |
self.config,
|
224 |
name="shared_encoder",
|
225 |
)
|
226 |
|
|
|
227 |
new_vocab_size = len(self.tokenizer)
|
228 |
+
encoder.pretrained.resize_token_embeddings(new_vocab_size)
|
229 |
logger.info(f"Token embeddings resized to: {new_vocab_size}")
|
230 |
+
return encoder
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
231 |
|
232 |
+
def _load_faiss_index_and_responses(self) -> None:
|
233 |
+
"""Load FAISS index and response pool for inference."""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
234 |
try:
|
235 |
+
logger.info(f"Loading FAISS index from {self.data_pipeline.index_file_path}...")
|
236 |
+
self.data_pipeline.load_faiss_index(self.data_pipeline.index_file_path)
|
237 |
+
logger.info("FAISS index loaded successfully.")
|
238 |
+
|
239 |
+
# Load response pool associated with the FAISS index
|
240 |
+
response_pool_path = self.data_pipeline.index_file_path.replace('.index', '_responses.json')
|
241 |
+
if os.path.exists(response_pool_path):
|
242 |
+
with open(response_pool_path, 'r', encoding='utf-8') as f:
|
243 |
+
self.data_pipeline.response_pool = json.load(f)
|
244 |
+
logger.info(f"Loaded {len(self.data_pipeline.response_pool)} responses from {response_pool_path}.")
|
245 |
else:
|
246 |
+
logger.error(f"Response pool file not found at {response_pool_path}.")
|
247 |
+
raise FileNotFoundError(f"Response pool file not found at {response_pool_path}.")
|
248 |
+
|
249 |
+
# Validate FAISS index and response pool
|
250 |
+
self.data_pipeline.validate_faiss_index()
|
251 |
+
logger.info("FAISS index and response pool validated successfully.")
|
252 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
253 |
except Exception as e:
|
254 |
+
logger.error(f"Failed to load FAISS index and response pool: {e}")
|
255 |
raise
|
256 |
+
|
257 |
+
@classmethod
|
258 |
+
def load_model(cls, load_dir: Union[str, Path], mode: str = 'training') -> 'RetrievalChatbot':
|
|
|
|
|
|
|
|
|
|
|
259 |
"""
|
260 |
+
Load saved models and configuration.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
261 |
|
262 |
+
Args:
|
263 |
+
load_dir (Union[str, Path]): Directory containing saved model files
|
264 |
+
mode (str): Either 'training' or 'inference'. In inference mode,
|
265 |
+
also loads FAISS index and response pool.
|
266 |
+
"""
|
267 |
+
load_dir = Path(load_dir)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
268 |
|
269 |
+
# Load config
|
270 |
+
with open(load_dir / "config.json", "r") as f:
|
271 |
+
config = ChatbotConfig.from_dict(json.load(f))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
272 |
|
273 |
+
# Initialize chatbot with appropriate mode
|
274 |
+
chatbot = cls(config, mode=mode)
|
|
|
|
|
275 |
|
276 |
+
# Load models
|
277 |
+
chatbot.encoder.pretrained = TFAutoModel.from_pretrained(
|
278 |
+
load_dir / "shared_encoder",
|
279 |
+
config=config
|
280 |
+
)
|
281 |
|
282 |
+
# Load tokenizer
|
283 |
+
chatbot.tokenizer = AutoTokenizer.from_pretrained(load_dir / "tokenizer")
|
284 |
+
logger.info(f"Models and tokenizer loaded from {load_dir}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
285 |
|
286 |
+
# If in inference mode, load additional components
|
287 |
+
if mode == 'inference':
|
288 |
+
cls._prepare_model_for_inference(chatbot, load_dir)
|
|
|
|
|
289 |
|
290 |
+
return chatbot
|
291 |
+
|
292 |
+
@classmethod
|
293 |
+
def _prepare_model_for_inference(cls, chatbot: 'RetrievalChatbot', load_dir: Path) -> None:
|
294 |
+
"""Internal method to load inference components."""
|
295 |
try:
|
296 |
+
# Load FAISS index
|
297 |
+
faiss_path = load_dir / 'faiss_index.bin'
|
298 |
+
if faiss_path.exists():
|
299 |
+
chatbot.index = faiss.read_index(str(faiss_path))
|
300 |
+
logger.info("FAISS index loaded successfully")
|
301 |
else:
|
302 |
+
raise FileNotFoundError(f"FAISS index not found at {faiss_path}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
303 |
|
304 |
+
# Load response pool
|
305 |
+
response_pool_path = load_dir / 'response_pool.json'
|
306 |
+
if response_pool_path.exists():
|
307 |
+
with open(response_pool_path, 'r') as f:
|
308 |
+
chatbot.response_pool = json.load(f)
|
309 |
+
logger.info(f"Loaded {len(chatbot.response_pool)} responses")
|
|
|
310 |
else:
|
311 |
+
raise FileNotFoundError(f"Response pool not found at {response_pool_path}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
312 |
|
313 |
+
# Verify dimensions match
|
314 |
+
if chatbot.index.d != chatbot.config.embedding_dim:
|
315 |
+
raise ValueError(
|
316 |
+
f"FAISS index dimension {chatbot.index.d} doesn't match "
|
317 |
+
f"model dimension {chatbot.config.embedding_dim}"
|
318 |
+
)
|
319 |
+
|
320 |
except Exception as e:
|
321 |
+
logger.error(f"Error loading inference components: {e}")
|
|
|
|
|
|
|
|
|
322 |
raise
|
323 |
+
|
324 |
+
def save_models(self, save_dir: Union[str, Path]):
|
325 |
+
"""Save models and configuration."""
|
326 |
+
save_dir = Path(save_dir)
|
327 |
+
save_dir.mkdir(parents=True, exist_ok=True)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
328 |
|
329 |
+
# Save config
|
330 |
+
with open(save_dir / "config.json", "w") as f:
|
331 |
+
json.dump(self.config.to_dict(), f, indent=2)
|
|
|
|
|
|
|
|
|
|
|
|
|
332 |
|
333 |
+
# Save models
|
334 |
+
self.encoder.pretrained.save_pretrained(save_dir / "shared_encoder")
|
|
|
335 |
|
336 |
+
# Save tokenizer
|
337 |
+
self.tokenizer.save_pretrained(save_dir / "tokenizer")
|
|
|
338 |
|
339 |
+
logger.info(f"Models and tokenizer saved to {save_dir}.")
|
|
|
340 |
|
341 |
def retrieve_responses_cross_encoder(
|
342 |
self,
|
|
|
363 |
|
364 |
# 2) Dense retrieval
|
365 |
dense_topk = self.retrieve_responses_faiss(query, top_k=top_k) # [(resp, dense_score), ...]
|
366 |
+
|
367 |
if not dense_topk:
|
368 |
return []
|
369 |
|
|
|
377 |
combined.sort(key=lambda x: x[1], reverse=True)
|
378 |
|
379 |
return combined
|
380 |
+
# def retrieve_responses_cross_encoder(
|
381 |
+
# self,
|
382 |
+
# query: str,
|
383 |
+
# top_k: int,
|
384 |
+
# reranker: Optional[CrossEncoderReranker] = None,
|
385 |
+
# summarizer: Optional[Summarizer] = None,
|
386 |
+
# summarize_threshold: int = 512 # Summarize over 512 tokens
|
387 |
+
# ) -> List[Tuple[str, float]]:
|
388 |
+
# """
|
389 |
+
# Retrieve top-k from FAISS, then re-rank them with a cross-encoder.
|
390 |
+
# Optionally summarize the user query if it's too long.
|
391 |
+
# """
|
392 |
+
# if reranker is None:
|
393 |
+
# reranker = self.reranker
|
394 |
+
# if summarizer is None:
|
395 |
+
# summarizer = self.summarizer
|
396 |
+
|
397 |
+
# # Optional summarization
|
398 |
+
# if summarizer and len(query.split()) > summarize_threshold:
|
399 |
+
# logger.info(f"Query is long. Summarizing before cross-encoder. Original length: {len(query.split())}")
|
400 |
+
# query = summarizer.summarize_text(query)
|
401 |
+
# logger.info(f"Summarized query: {query}")
|
402 |
+
|
403 |
+
# # 2) Dense retrieval
|
404 |
+
# dense_topk = self.retrieve_responses_faiss(query, top_k=top_k) # [(resp, dense_score), ...]
|
405 |
+
|
406 |
+
# if not dense_topk:
|
407 |
+
# return []
|
408 |
+
|
409 |
+
# # 3) Cross-encoder rerank
|
410 |
+
# candidate_texts = [pair[0] for pair in dense_topk]
|
411 |
+
# cross_scores = reranker.rerank(query, candidate_texts, max_length=256)
|
412 |
+
|
413 |
+
# # Combine
|
414 |
+
# combined = [(text, score) for (text, _), score in zip(dense_topk, cross_scores)]
|
415 |
+
# # Sort descending by cross-encoder score
|
416 |
+
# combined.sort(key=lambda x: x[1], reverse=True)
|
417 |
+
|
418 |
+
# return combined
|
419 |
|
420 |
def retrieve_responses_faiss(self, query: str, top_k: int = 5) -> List[Tuple[str, float]]:
|
421 |
"""Retrieve top-k responses using FAISS."""
|
422 |
+
if not hasattr(self.data_pipeline, 'index') or self.data_pipeline.index is None:
|
423 |
logger.warning("FAISS index not initialized. Cannot retrieve responses.")
|
424 |
return []
|
425 |
|
426 |
+
# Encode the query using TFDataPipeline's method
|
427 |
+
q_emb = self.data_pipeline.encode_query(query) # Ensure encode_query is within TFDataPipeline
|
428 |
q_emb_np = q_emb.numpy().astype('float32') # Ensure type match
|
429 |
|
430 |
# Normalize the query embedding for cosine similarity
|
431 |
faiss.normalize_L2(q_emb_np)
|
432 |
|
433 |
# Search the FAISS index
|
434 |
+
distances, indices = self.data_pipeline.index.search(q_emb_np, top_k)
|
435 |
|
436 |
# Map indices to responses and distances to similarities
|
437 |
top_responses = []
|
438 |
for i, idx in enumerate(indices[0]):
|
439 |
+
if idx < len(self.data_pipeline.response_pool):
|
440 |
+
top_responses.append((self.data_pipeline.response_pool[idx], float(distances[0][i])))
|
441 |
else:
|
442 |
logger.warning(f"FAISS returned invalid index {idx}. Skipping.")
|
443 |
|
444 |
return top_responses
|
445 |
+
# def retrieve_responses_faiss(self, query: str, top_k: int = 5) -> List[Tuple[str, float]]:
|
446 |
+
# """Retrieve top-k responses using FAISS."""
|
447 |
+
# if not hasattr(self, 'index') or self.index is None:
|
448 |
+
# logger.warning("FAISS index not initialized. Cannot retrieve responses.")
|
449 |
+
# return []
|
450 |
|
451 |
+
# # Encode the query
|
452 |
+
# q_emb = self.encode_query(query) # Shape: [1, embedding_dim]
|
453 |
+
# q_emb_np = q_emb.numpy().astype('float32') # Ensure type match
|
|
|
|
|
|
|
454 |
|
455 |
+
# # Normalize the query embedding for cosine similarity
|
456 |
+
# faiss.normalize_L2(q_emb_np)
|
457 |
|
458 |
+
# # Search the FAISS index
|
459 |
+
# distances, indices = self.index.search(q_emb_np, top_k)
|
|
|
|
|
|
|
|
|
460 |
|
461 |
+
# # Map indices to responses and distances to similarities
|
462 |
+
# top_responses = []
|
463 |
+
# for i, idx in enumerate(indices[0]):
|
464 |
+
# if idx < len(self.response_pool):
|
465 |
+
# top_responses.append((self.response_pool[idx], float(distances[0][i])))
|
466 |
+
# else:
|
467 |
+
# logger.warning(f"FAISS returned invalid index {idx}. Skipping.")
|
468 |
|
469 |
+
# return top_responses
|
470 |
+
|
471 |
+
def chat(
|
472 |
+
self,
|
473 |
+
query: str,
|
474 |
+
conversation_history: Optional[List[Tuple[str, str]]] = None,
|
475 |
+
quality_checker: Optional['ResponseQualityChecker'] = None,
|
476 |
+
top_k: int = 5,
|
477 |
+
) -> Tuple[str, List[Tuple[str, float]], Dict[str, Any]]:
|
478 |
+
"""
|
479 |
+
Example chat method that always uses cross-encoder re-ranking
|
480 |
+
if self.reranker is available.
|
481 |
+
"""
|
482 |
+
@self.run_on_device
|
483 |
+
def get_response(self_arg, query_arg):
|
484 |
+
# 1) Build conversation context string
|
485 |
+
conversation_str = self_arg._build_conversation_context(query_arg, conversation_history)
|
486 |
+
|
487 |
+
# 2) Retrieve + cross-encoder re-rank
|
488 |
+
results = self_arg.retrieve_responses_cross_encoder(
|
489 |
+
query=conversation_str,
|
490 |
+
top_k=top_k,
|
491 |
+
reranker=self_arg.reranker,
|
492 |
+
summarizer=self_arg.summarizer,
|
493 |
+
summarize_threshold=512
|
494 |
+
)
|
495 |
+
|
496 |
+
# 3) Handle empty or confidence
|
497 |
+
if not results:
|
498 |
+
return (
|
499 |
+
"I'm sorry, but I couldn't find a relevant response.",
|
500 |
+
[],
|
501 |
+
{}
|
502 |
+
)
|
503 |
+
|
504 |
+
if quality_checker:
|
505 |
+
metrics = quality_checker.check_response_quality(query_arg, results)
|
506 |
+
if not metrics.get('is_confident', False):
|
507 |
+
return (
|
508 |
+
"I need more information to provide a good answer. Could you please clarify?",
|
509 |
+
results,
|
510 |
+
metrics
|
511 |
+
)
|
512 |
+
return results[0][0], results, metrics
|
513 |
+
|
514 |
+
return results[0][0], results, {}
|
515 |
|
516 |
+
return get_response(self, query)
|
517 |
+
# def chat(
|
518 |
+
# self,
|
519 |
+
# query: str,
|
520 |
+
# conversation_history: Optional[List[Tuple[str, str]]] = None,
|
521 |
+
# quality_checker: Optional['ResponseQualityChecker'] = None,
|
522 |
+
# top_k: int = 5,
|
523 |
+
# ) -> Tuple[str, List[Tuple[str, float]], Dict[str, Any]]:
|
524 |
+
# """
|
525 |
+
# Example chat method that always uses cross-encoder re-ranking
|
526 |
+
# if self.reranker is available.
|
527 |
+
# """
|
528 |
+
# @self.run_on_device
|
529 |
+
# def get_response(self_arg, query_arg): # Add parameters that match decorator's expectations
|
530 |
+
# # 1) Build conversation context string
|
531 |
+
# conversation_str = self_arg._build_conversation_context(query_arg, conversation_history)
|
532 |
+
|
533 |
+
# # 2) Retrieve + cross-encoder re-rank
|
534 |
+
# results = self_arg.retrieve_responses_cross_encoder(
|
535 |
+
# query=conversation_str,
|
536 |
+
# top_k=top_k,
|
537 |
+
# reranker=self_arg.reranker,
|
538 |
+
# summarizer=self_arg.summarizer,
|
539 |
+
# summarize_threshold=512
|
540 |
+
# )
|
541 |
+
|
542 |
+
# # 3) Handle empty or confidence
|
543 |
+
# if not results:
|
544 |
+
# return (
|
545 |
+
# "I'm sorry, but I couldn't find a relevant response.",
|
546 |
+
# [],
|
547 |
+
# {}
|
548 |
+
# )
|
549 |
+
|
550 |
+
# if quality_checker:
|
551 |
+
# metrics = quality_checker.check_response_quality(query_arg, results)
|
552 |
+
# if not metrics.get('is_confident', False):
|
553 |
+
# return (
|
554 |
+
# "I need more information to provide a good answer. Could you please clarify?",
|
555 |
+
# results,
|
556 |
+
# metrics
|
557 |
+
# )
|
558 |
+
# return results[0][0], results, metrics
|
559 |
+
|
560 |
+
# return results[0][0], results, {}
|
561 |
|
562 |
+
# return get_response(self, query)
|
563 |
+
|
564 |
+
def _build_conversation_context(
|
565 |
+
self,
|
566 |
+
query: str,
|
567 |
+
conversation_history: Optional[List[Tuple[str, str]]]
|
568 |
+
) -> str:
|
569 |
+
"""Build conversation context with better memory management."""
|
570 |
+
if not conversation_history:
|
571 |
+
return f"{self.tokenizer.additional_special_tokens[self.tokenizer.additional_special_tokens.index('<USER>')]} {query}"
|
572 |
+
|
573 |
+
conversation_parts = []
|
574 |
+
for user_txt, assistant_txt in conversation_history:
|
575 |
+
conversation_parts.extend([
|
576 |
+
f"{self.tokenizer.additional_special_tokens[self.tokenizer.additional_special_tokens.index('<USER>')]} {user_txt}",
|
577 |
+
f"{self.tokenizer.additional_special_tokens[self.tokenizer.additional_special_tokens.index('<ASSISTANT>')]} {assistant_txt}"
|
578 |
+
])
|
579 |
+
|
580 |
+
conversation_parts.append(f"{self.tokenizer.additional_special_tokens[self.tokenizer.additional_special_tokens.index('<USER>')]} {query}")
|
581 |
+
return "\n".join(conversation_parts)
|
582 |
+
# def _build_conversation_context(
|
583 |
+
# self,
|
584 |
+
# query: str,
|
585 |
+
# conversation_history: Optional[List[Tuple[str, str]]]
|
586 |
+
# ) -> str:
|
587 |
+
# """Build conversation context with better memory management."""
|
588 |
+
# if not conversation_history:
|
589 |
+
# return f"{self.special_tokens['user']} {query}"
|
590 |
+
|
591 |
+
# conversation_parts = []
|
592 |
+
# for user_txt, assistant_txt in conversation_history:
|
593 |
+
# conversation_parts.extend([
|
594 |
+
# f"{self.special_tokens['user']} {user_txt}",
|
595 |
+
# f"{self.special_tokens['assistant']} {assistant_txt}"
|
596 |
+
# ])
|
597 |
+
|
598 |
+
# conversation_parts.append(f"{self.special_tokens['user']} {query}")
|
599 |
+
# return "\n".join(conversation_parts)
|
600 |
|
601 |
+
def train_model(
|
602 |
self,
|
603 |
tfrecord_file_path: str,
|
604 |
epochs: int = 20,
|
|
|
606 |
validation_split: float = 0.2,
|
607 |
checkpoint_dir: str = "checkpoints/",
|
608 |
use_lr_schedule: bool = True,
|
609 |
+
peak_lr: float = 1e-5,
|
610 |
warmup_steps_ratio: float = 0.1,
|
611 |
early_stopping_patience: int = 3,
|
612 |
min_delta: float = 1e-4,
|
613 |
+
test_mode: bool = False,
|
614 |
+
initial_epoch: int = 0
|
615 |
) -> None:
|
616 |
"""Training using a pre-prepared TFRecord dataset."""
|
617 |
logger.info("Starting training with pre-prepared TFRecord dataset...")
|
|
|
640 |
negative_ids = tf.cast(parsed_features['negative_ids'], tf.int32)
|
641 |
negative_ids = tf.reshape(negative_ids, [neg_samples, max_length])
|
642 |
|
643 |
+
return query_ids, positive_ids, negative_ids
|
644 |
+
|
645 |
# Calculate total steps by counting the number of records in the TFRecord
|
646 |
raw_dataset = tf.data.TFRecordDataset(tfrecord_file_path)
|
647 |
total_pairs = sum(1 for _ in raw_dataset)
|
|
|
652 |
steps_per_epoch = math.ceil(train_size / batch_size)
|
653 |
val_steps = math.ceil(val_size / batch_size)
|
654 |
total_steps = steps_per_epoch * epochs
|
655 |
+
buffer_size = total_pairs // 10 # 10% of the dataset
|
656 |
|
657 |
logger.info(f"Training pairs: {train_size}")
|
658 |
logger.info(f"Validation pairs: {val_size}")
|
|
|
675 |
logger.info("Using fixed learning rate.")
|
676 |
|
677 |
# Initialize checkpoint manager
|
678 |
+
checkpoint = tf.train.Checkpoint(
|
679 |
+
epoch=tf.Variable(0),
|
680 |
+
optimizer=self.optimizer,
|
681 |
+
model=self.encoder,
|
682 |
+
variables=self.encoder.variables
|
683 |
+
)
|
684 |
+
manager = tf.train.CheckpointManager(checkpoint, checkpoint_dir, max_to_keep=3, checkpoint_name='ckpt')
|
685 |
|
686 |
+
# Restore from checkpoint if available
|
687 |
+
latest_checkpoint = manager.latest_checkpoint
|
688 |
+
if latest_checkpoint:
|
689 |
+
history_path = Path(checkpoint_dir) / 'training_history.json'
|
690 |
+
if history_path.exists():
|
691 |
+
try:
|
692 |
+
with open(history_path, 'r') as f:
|
693 |
+
self.history = json.load(f)
|
694 |
+
logger.info(f"Loaded previous training history from {history_path}")
|
695 |
+
except Exception as e:
|
696 |
+
logger.warning(f"Could not load history, starting fresh: {e}")
|
697 |
+
self.history = {'train_loss': [], 'val_loss': [], 'learning_rate': []}
|
698 |
+
else:
|
699 |
+
self.history = {'train_loss': [], 'val_loss': [], 'learning_rate': []}
|
700 |
+
|
701 |
+
status = checkpoint.restore(latest_checkpoint)
|
702 |
+
status.expect_partial()
|
703 |
+
|
704 |
+
logger.info(f"Restored from checkpoint: {latest_checkpoint}")
|
705 |
+
# Get the checkpoint number to validate initial_epoch
|
706 |
+
ckpt_number = int(latest_checkpoint.split('ckpt-')[-1])
|
707 |
+
if initial_epoch == 0:
|
708 |
+
initial_epoch = ckpt_number
|
709 |
+
logger.info(f"Resuming from epoch {initial_epoch}")
|
710 |
+
else:
|
711 |
+
logger.info("Starting training from scratch")
|
712 |
+
initial_epoch = 0
|
713 |
+
|
714 |
# Setup TensorBoard
|
715 |
log_dir = Path(checkpoint_dir) / "tensorboard_logs"
|
716 |
log_dir.mkdir(parents=True, exist_ok=True)
|
|
|
726 |
|
727 |
# Create the full dataset
|
728 |
dataset = tf.data.TFRecordDataset(tfrecord_file_path)
|
729 |
+
|
730 |
+
# Test mode for debugging
|
731 |
+
if test_mode:
|
732 |
+
subset_size = 200
|
733 |
+
dataset = dataset.take(subset_size)
|
734 |
+
logger.info(f"TEST MODE: Using only {subset_size} examples")
|
735 |
+
# Recalculate sizes
|
736 |
+
total_pairs = subset_size
|
737 |
+
train_size = int(total_pairs * (1 - validation_split))
|
738 |
+
val_size = total_pairs - train_size
|
739 |
+
steps_per_epoch = math.ceil(train_size / batch_size)
|
740 |
+
val_steps = math.ceil(val_size / batch_size)
|
741 |
+
total_steps = steps_per_epoch * epochs
|
742 |
+
buffer_size = total_pairs // 10 # 10% of the dataset
|
743 |
+
epochs = min(epochs, 5) # Limit epochs in test mode
|
744 |
+
early_stopping_patience = 2
|
745 |
+
logger.info(f"New training pairs: {train_size}")
|
746 |
+
logger.info(f"New validation pairs: {val_size}")
|
747 |
+
|
748 |
dataset = dataset.map(parse_fn, num_parallel_calls=tf.data.AUTOTUNE)
|
749 |
+
|
750 |
+
# Split into training and validation sets
|
|
|
|
|
|
|
751 |
train_dataset = dataset.take(train_size)
|
752 |
val_dataset = dataset.skip(train_size).take(val_size)
|
753 |
|
754 |
+
# Shuffle the training data
|
755 |
+
train_dataset = train_dataset.shuffle(buffer_size=buffer_size)
|
756 |
+
|
757 |
+
# Batch both datasets
|
758 |
+
train_dataset = train_dataset.batch(batch_size, drop_remainder=True)
|
759 |
+
train_dataset = train_dataset.prefetch(tf.data.AUTOTUNE)
|
760 |
+
|
761 |
+
val_dataset = val_dataset.batch(batch_size, drop_remainder=True)
|
762 |
+
val_dataset = val_dataset.prefetch(tf.data.AUTOTUNE)
|
763 |
+
val_dataset = val_dataset.cache()
|
764 |
+
|
765 |
# Training loop
|
766 |
best_val_loss = float("inf")
|
767 |
epochs_no_improve = 0
|
768 |
|
769 |
+
for epoch in range(initial_epoch + 1, epochs + 1):
|
770 |
# --- Training Phase ---
|
771 |
epoch_loss_avg = tf.keras.metrics.Mean()
|
772 |
batches_processed = 0
|
|
|
780 |
logger.info("Training progress bar disabled")
|
781 |
|
782 |
for q_batch, p_batch, n_batch in train_dataset:
|
783 |
+
loss, grad_norm, post_clip_norm = self.train_step(q_batch, p_batch, n_batch)
|
784 |
+
|
785 |
+
# Check for gradient issues
|
786 |
+
grad_norm_value = float(grad_norm.numpy())
|
787 |
+
post_clip_value = float(post_clip_norm.numpy())
|
788 |
+
if grad_norm_value < 1e-7:
|
789 |
+
logger.warning(f"Potential vanishing gradient detected: norm = {grad_norm_value:.2e}")
|
790 |
+
elif grad_norm_value > 100:
|
791 |
+
logger.warning(f"Potential exploding gradient detected: norm = {grad_norm_value:.2e}")
|
792 |
+
|
793 |
+
if grad_norm_value != post_clip_value:
|
794 |
+
logger.info(f"Gradient clipped: {grad_norm_value:.2e} -> {post_clip_value:.2e}")
|
795 |
+
|
796 |
epoch_loss_avg(loss)
|
797 |
batches_processed += 1
|
798 |
|
799 |
# Log to TensorBoard
|
800 |
with train_summary_writer.as_default():
|
801 |
+
step = (epoch - 1) * steps_per_epoch + batches_processed
|
802 |
+
tf.summary.scalar("loss", loss, step=step)
|
803 |
+
tf.summary.scalar("gradient_norm_pre_clip", grad_norm, step=step)
|
804 |
+
tf.summary.scalar("gradient_norm_post_clip", post_clip_norm, step=step)
|
805 |
|
806 |
# Update progress bar
|
807 |
if use_lr_schedule:
|
|
|
813 |
train_pbar.update(1)
|
814 |
train_pbar.set_postfix({
|
815 |
"loss": f"{loss.numpy():.4f}",
|
816 |
+
"pre_clip": f"{grad_norm_value:.2e}",
|
817 |
+
"post_clip": f"{post_clip_value:.2e}",
|
818 |
"lr": f"{current_lr:.2e}",
|
819 |
"batches": f"{batches_processed}/{steps_per_epoch}"
|
820 |
})
|
|
|
874 |
|
875 |
# Save checkpoint
|
876 |
manager.save()
|
877 |
+
|
878 |
+
# Save model after each epoch for testing/inference
|
879 |
+
model_save_path = Path(checkpoint_dir) / f"model_epoch_{epoch}"
|
880 |
+
self.save_models(model_save_path)
|
881 |
+
logger.info(f"Saved model for epoch {epoch} at {model_save_path}")
|
882 |
|
883 |
# Store metrics in history
|
884 |
self.history['train_loss'].append(train_loss)
|
|
|
889 |
else:
|
890 |
current_lr = float(self.optimizer.learning_rate.numpy())
|
891 |
|
892 |
+
# Log learning rate
|
893 |
self.history.setdefault('learning_rate', []).append(current_lr)
|
894 |
|
895 |
+
# Save history to file
|
896 |
+
with open(history_path, 'w') as f:
|
897 |
+
json.dump(self.history, f)
|
898 |
+
logger.info(f"Saved training history to {history_path}")
|
899 |
+
|
900 |
# Early stopping logic
|
901 |
if val_loss < best_val_loss - min_delta:
|
902 |
best_val_loss = val_loss
|
|
|
965 |
)
|
966 |
loss = tf.reduce_mean(loss)
|
967 |
|
968 |
+
# Calculate gradients
|
969 |
gradients = tape.gradient(loss, self.encoder.trainable_variables)
|
970 |
+
gradients_norm = tf.linalg.global_norm(gradients)
|
971 |
+
|
972 |
+
# Clip gradients if norm exceeds threshold
|
973 |
+
max_grad_norm = 1.0
|
974 |
+
gradients, _ = tf.clip_by_global_norm(gradients, max_grad_norm, gradients_norm)
|
975 |
+
post_clip_norm = tf.linalg.global_norm(gradients)
|
976 |
+
|
977 |
+
# Apply gradients
|
978 |
self.optimizer.apply_gradients(zip(gradients, self.encoder.trainable_variables))
|
979 |
+
|
980 |
+
return loss, gradients_norm, post_clip_norm
|
981 |
|
982 |
@tf.function
|
983 |
def validation_step(
|
|
|
1015 |
loss = tf.reduce_mean(loss)
|
1016 |
|
1017 |
return loss
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1018 |
|
1019 |
def _get_lr_schedule(
|
1020 |
self,
|
|
|
1081 |
}
|
1082 |
|
1083 |
return CustomSchedule(total_steps, peak_lr, warmup_steps)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
chatbot_validator.py
CHANGED
@@ -1,23 +1,23 @@
|
|
1 |
from typing import Dict, List, Tuple, Any, Optional
|
2 |
import numpy as np
|
3 |
-
from logger_config import config_logger
|
4 |
|
|
|
5 |
logger = config_logger(__name__)
|
6 |
|
7 |
class ChatbotValidator:
|
8 |
"""Handles automated validation and performance analysis for the chatbot."""
|
9 |
-
|
10 |
def __init__(self, chatbot, quality_checker):
|
11 |
"""
|
12 |
Initialize the validator.
|
13 |
-
|
14 |
Args:
|
15 |
chatbot: RetrievalChatbot instance
|
16 |
quality_checker: ResponseQualityChecker instance
|
17 |
"""
|
18 |
self.chatbot = chatbot
|
19 |
self.quality_checker = quality_checker
|
20 |
-
|
21 |
# Domain-specific test queries aligned with Taskmaster-1 and Schema-Guided
|
22 |
self.domain_queries = {
|
23 |
'restaurant': [
|
@@ -59,50 +59,50 @@ class ChatbotValidator:
|
|
59 |
|
60 |
def run_validation(
|
61 |
self,
|
62 |
-
num_examples: int =
|
63 |
top_k: int = 10,
|
64 |
domains: Optional[List[str]] = None
|
65 |
) -> Dict[str, Any]:
|
66 |
"""
|
67 |
Run comprehensive validation across specified domains.
|
68 |
-
|
69 |
Args:
|
70 |
num_examples: Number of test queries per domain
|
71 |
top_k: Number of responses to retrieve for each query
|
72 |
domains: Optional list of specific domains to test
|
73 |
-
|
74 |
Returns:
|
75 |
Dict containing detailed validation metrics and domain-specific performance
|
76 |
"""
|
77 |
logger.info("\n=== Running Enhanced Automatic Validation ===")
|
78 |
-
|
79 |
# Select domains to test
|
80 |
test_domains = domains if domains else list(self.domain_queries.keys())
|
81 |
metrics_history = []
|
82 |
domain_metrics = {}
|
83 |
-
|
84 |
# Run validation for each domain
|
85 |
for domain in test_domains:
|
86 |
domain_metrics[domain] = []
|
87 |
queries = self.domain_queries[domain][:num_examples]
|
88 |
-
|
89 |
logger.info(f"\n=== Testing {domain.title()} Domain ===")
|
90 |
-
|
91 |
for i, query in enumerate(queries, 1):
|
92 |
logger.info(f"\nTest Case {i}:")
|
93 |
logger.info(f"Query: {query}")
|
94 |
-
|
95 |
# Get responses with increased top_k
|
96 |
responses = self.chatbot.retrieve_responses_cross_encoder(query, top_k=top_k)
|
97 |
-
|
98 |
-
# Enhanced quality checking with context
|
99 |
quality_metrics = self.quality_checker.check_response_quality(query, responses)
|
100 |
-
|
101 |
# Add domain info
|
102 |
quality_metrics['domain'] = domain
|
103 |
metrics_history.append(quality_metrics)
|
104 |
domain_metrics[domain].append(quality_metrics)
|
105 |
-
|
106 |
# Detailed logging
|
107 |
self._log_validation_results(query, responses, quality_metrics, i)
|
108 |
|
@@ -110,12 +110,12 @@ class ChatbotValidator:
|
|
110 |
aggregate_metrics = self._calculate_aggregate_metrics(metrics_history)
|
111 |
domain_analysis = self._analyze_domain_performance(domain_metrics)
|
112 |
confidence_analysis = self._analyze_confidence_distribution(metrics_history)
|
113 |
-
|
114 |
aggregate_metrics.update({
|
115 |
'domain_performance': domain_analysis,
|
116 |
'confidence_analysis': confidence_analysis
|
117 |
})
|
118 |
-
|
119 |
self._log_validation_summary(aggregate_metrics)
|
120 |
return aggregate_metrics
|
121 |
|
@@ -129,7 +129,7 @@ class ChatbotValidator:
|
|
129 |
'avg_length_score': np.mean([m.get('response_length_score', 0) for m in metrics_history]),
|
130 |
'avg_score_gap': np.mean([m.get('top_3_score_gap', 0) for m in metrics_history]),
|
131 |
'confidence_rate': np.mean([m.get('is_confident', False) for m in metrics_history]),
|
132 |
-
|
133 |
# Additional statistical metrics
|
134 |
'median_top_score': np.median([m.get('top_score', 0) for m in metrics_history]),
|
135 |
'score_std': np.std([m.get('top_score', 0) for m in metrics_history]),
|
@@ -141,7 +141,7 @@ class ChatbotValidator:
|
|
141 |
def _analyze_domain_performance(self, domain_metrics: Dict[str, List[Dict]]) -> Dict[str, Dict]:
|
142 |
"""Analyze performance by domain."""
|
143 |
domain_analysis = {}
|
144 |
-
|
145 |
for domain, metrics in domain_metrics.items():
|
146 |
domain_analysis[domain] = {
|
147 |
'confidence_rate': np.mean([m.get('is_confident', False) for m in metrics]),
|
@@ -150,13 +150,13 @@ class ChatbotValidator:
|
|
150 |
'avg_top_score': np.mean([m.get('top_score', 0) for m in metrics]),
|
151 |
'num_samples': len(metrics)
|
152 |
}
|
153 |
-
|
154 |
return domain_analysis
|
155 |
|
156 |
def _analyze_confidence_distribution(self, metrics_history: List[Dict]) -> Dict[str, float]:
|
157 |
"""Analyze the distribution of confidence scores."""
|
158 |
scores = [m.get('top_score', 0) for m in metrics_history]
|
159 |
-
|
160 |
return {
|
161 |
'percentile_25': np.percentile(scores, 25),
|
162 |
'percentile_50': np.percentile(scores, 50),
|
@@ -180,7 +180,7 @@ class ChatbotValidator:
|
|
180 |
for metric, value in metrics.items():
|
181 |
if isinstance(value, (int, float)):
|
182 |
logger.info(f" {metric}: {value:.4f}")
|
183 |
-
|
184 |
logger.info("\nTop Responses:")
|
185 |
for i, (response, score) in enumerate(responses[:3], 1):
|
186 |
logger.info(f"{i}. Score: {score:.4f}. Response: {response}")
|
@@ -190,18 +190,18 @@ class ChatbotValidator:
|
|
190 |
def _log_validation_summary(self, metrics: Dict[str, Any]):
|
191 |
"""Log comprehensive validation summary."""
|
192 |
logger.info("\n=== Validation Summary ===")
|
193 |
-
|
194 |
logger.info("\nOverall Metrics:")
|
195 |
for metric, value in metrics.items():
|
196 |
if isinstance(value, (int, float)):
|
197 |
logger.info(f"{metric}: {value:.4f}")
|
198 |
-
|
199 |
logger.info("\nDomain Performance:")
|
200 |
for domain, domain_metrics in metrics['domain_performance'].items():
|
201 |
logger.info(f"\n{domain.title()}:")
|
202 |
for metric, value in domain_metrics.items():
|
203 |
logger.info(f" {metric}: {value:.4f}")
|
204 |
-
|
205 |
logger.info("\nConfidence Distribution:")
|
206 |
for percentile, value in metrics['confidence_analysis'].items():
|
207 |
logger.info(f"{percentile}: {value:.4f}")
|
|
|
1 |
from typing import Dict, List, Tuple, Any, Optional
|
2 |
import numpy as np
|
|
|
3 |
|
4 |
+
from logger_config import config_logger
|
5 |
logger = config_logger(__name__)
|
6 |
|
7 |
class ChatbotValidator:
|
8 |
"""Handles automated validation and performance analysis for the chatbot."""
|
9 |
+
|
10 |
def __init__(self, chatbot, quality_checker):
|
11 |
"""
|
12 |
Initialize the validator.
|
13 |
+
|
14 |
Args:
|
15 |
chatbot: RetrievalChatbot instance
|
16 |
quality_checker: ResponseQualityChecker instance
|
17 |
"""
|
18 |
self.chatbot = chatbot
|
19 |
self.quality_checker = quality_checker
|
20 |
+
|
21 |
# Domain-specific test queries aligned with Taskmaster-1 and Schema-Guided
|
22 |
self.domain_queries = {
|
23 |
'restaurant': [
|
|
|
59 |
|
60 |
def run_validation(
|
61 |
self,
|
62 |
+
num_examples: int = 5,
|
63 |
top_k: int = 10,
|
64 |
domains: Optional[List[str]] = None
|
65 |
) -> Dict[str, Any]:
|
66 |
"""
|
67 |
Run comprehensive validation across specified domains.
|
68 |
+
|
69 |
Args:
|
70 |
num_examples: Number of test queries per domain
|
71 |
top_k: Number of responses to retrieve for each query
|
72 |
domains: Optional list of specific domains to test
|
73 |
+
|
74 |
Returns:
|
75 |
Dict containing detailed validation metrics and domain-specific performance
|
76 |
"""
|
77 |
logger.info("\n=== Running Enhanced Automatic Validation ===")
|
78 |
+
|
79 |
# Select domains to test
|
80 |
test_domains = domains if domains else list(self.domain_queries.keys())
|
81 |
metrics_history = []
|
82 |
domain_metrics = {}
|
83 |
+
|
84 |
# Run validation for each domain
|
85 |
for domain in test_domains:
|
86 |
domain_metrics[domain] = []
|
87 |
queries = self.domain_queries[domain][:num_examples]
|
88 |
+
|
89 |
logger.info(f"\n=== Testing {domain.title()} Domain ===")
|
90 |
+
|
91 |
for i, query in enumerate(queries, 1):
|
92 |
logger.info(f"\nTest Case {i}:")
|
93 |
logger.info(f"Query: {query}")
|
94 |
+
|
95 |
# Get responses with increased top_k
|
96 |
responses = self.chatbot.retrieve_responses_cross_encoder(query, top_k=top_k)
|
97 |
+
|
98 |
+
# Enhanced quality checking with context (assuming no context here)
|
99 |
quality_metrics = self.quality_checker.check_response_quality(query, responses)
|
100 |
+
|
101 |
# Add domain info
|
102 |
quality_metrics['domain'] = domain
|
103 |
metrics_history.append(quality_metrics)
|
104 |
domain_metrics[domain].append(quality_metrics)
|
105 |
+
|
106 |
# Detailed logging
|
107 |
self._log_validation_results(query, responses, quality_metrics, i)
|
108 |
|
|
|
110 |
aggregate_metrics = self._calculate_aggregate_metrics(metrics_history)
|
111 |
domain_analysis = self._analyze_domain_performance(domain_metrics)
|
112 |
confidence_analysis = self._analyze_confidence_distribution(metrics_history)
|
113 |
+
|
114 |
aggregate_metrics.update({
|
115 |
'domain_performance': domain_analysis,
|
116 |
'confidence_analysis': confidence_analysis
|
117 |
})
|
118 |
+
|
119 |
self._log_validation_summary(aggregate_metrics)
|
120 |
return aggregate_metrics
|
121 |
|
|
|
129 |
'avg_length_score': np.mean([m.get('response_length_score', 0) for m in metrics_history]),
|
130 |
'avg_score_gap': np.mean([m.get('top_3_score_gap', 0) for m in metrics_history]),
|
131 |
'confidence_rate': np.mean([m.get('is_confident', False) for m in metrics_history]),
|
132 |
+
|
133 |
# Additional statistical metrics
|
134 |
'median_top_score': np.median([m.get('top_score', 0) for m in metrics_history]),
|
135 |
'score_std': np.std([m.get('top_score', 0) for m in metrics_history]),
|
|
|
141 |
def _analyze_domain_performance(self, domain_metrics: Dict[str, List[Dict]]) -> Dict[str, Dict]:
|
142 |
"""Analyze performance by domain."""
|
143 |
domain_analysis = {}
|
144 |
+
|
145 |
for domain, metrics in domain_metrics.items():
|
146 |
domain_analysis[domain] = {
|
147 |
'confidence_rate': np.mean([m.get('is_confident', False) for m in metrics]),
|
|
|
150 |
'avg_top_score': np.mean([m.get('top_score', 0) for m in metrics]),
|
151 |
'num_samples': len(metrics)
|
152 |
}
|
153 |
+
|
154 |
return domain_analysis
|
155 |
|
156 |
def _analyze_confidence_distribution(self, metrics_history: List[Dict]) -> Dict[str, float]:
|
157 |
"""Analyze the distribution of confidence scores."""
|
158 |
scores = [m.get('top_score', 0) for m in metrics_history]
|
159 |
+
|
160 |
return {
|
161 |
'percentile_25': np.percentile(scores, 25),
|
162 |
'percentile_50': np.percentile(scores, 50),
|
|
|
180 |
for metric, value in metrics.items():
|
181 |
if isinstance(value, (int, float)):
|
182 |
logger.info(f" {metric}: {value:.4f}")
|
183 |
+
|
184 |
logger.info("\nTop Responses:")
|
185 |
for i, (response, score) in enumerate(responses[:3], 1):
|
186 |
logger.info(f"{i}. Score: {score:.4f}. Response: {response}")
|
|
|
190 |
def _log_validation_summary(self, metrics: Dict[str, Any]):
|
191 |
"""Log comprehensive validation summary."""
|
192 |
logger.info("\n=== Validation Summary ===")
|
193 |
+
|
194 |
logger.info("\nOverall Metrics:")
|
195 |
for metric, value in metrics.items():
|
196 |
if isinstance(value, (int, float)):
|
197 |
logger.info(f"{metric}: {value:.4f}")
|
198 |
+
|
199 |
logger.info("\nDomain Performance:")
|
200 |
for domain, domain_metrics in metrics['domain_performance'].items():
|
201 |
logger.info(f"\n{domain.title()}:")
|
202 |
for metric, value in domain_metrics.items():
|
203 |
logger.info(f" {metric}: {value:.4f}")
|
204 |
+
|
205 |
logger.info("\nConfidence Distribution:")
|
206 |
for percentile, value in metrics['confidence_analysis'].items():
|
207 |
logger.info(f"{percentile}: {value:.4f}")
|
conversation_summarizer.py
CHANGED
@@ -49,7 +49,15 @@ class Summarizer(DeviceAwareModel):
|
|
49 |
Handles long conversations by intelligent chunking and progressive summarization.
|
50 |
"""
|
51 |
|
52 |
-
def __init__(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
53 |
self.setup_device(device)
|
54 |
|
55 |
# Initialize model within strategy scope if using distribution
|
@@ -63,12 +71,11 @@ class Summarizer(DeviceAwareModel):
|
|
63 |
self.max_summary_rounds = max_summary_rounds
|
64 |
|
65 |
def _setup_model(self, model_name):
|
66 |
-
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
|
67 |
self.model = TFAutoModelForSeq2SeqLM.from_pretrained(model_name)
|
68 |
|
69 |
# Optimize model for inference
|
70 |
-
self.model.
|
71 |
-
self.model.
|
72 |
input_signature=[
|
73 |
{
|
74 |
'input_ids': tf.TensorSpec(shape=[None, None], dtype=tf.int32),
|
|
|
49 |
Handles long conversations by intelligent chunking and progressive summarization.
|
50 |
"""
|
51 |
|
52 |
+
def __init__(
|
53 |
+
self,
|
54 |
+
tokenizer: AutoTokenizer,
|
55 |
+
model_name="t5-small",
|
56 |
+
max_summary_length=128,
|
57 |
+
device=None,
|
58 |
+
max_summary_rounds=2
|
59 |
+
):
|
60 |
+
self.tokenizer = tokenizer # Injected tokenizer
|
61 |
self.setup_device(device)
|
62 |
|
63 |
# Initialize model within strategy scope if using distribution
|
|
|
71 |
self.max_summary_rounds = max_summary_rounds
|
72 |
|
73 |
def _setup_model(self, model_name):
|
|
|
74 |
self.model = TFAutoModelForSeq2SeqLM.from_pretrained(model_name)
|
75 |
|
76 |
# Optimize model for inference
|
77 |
+
self.model.generate = tf.function(
|
78 |
+
self.model.generate,
|
79 |
input_signature=[
|
80 |
{
|
81 |
'input_ids': tf.TensorSpec(shape=[None, None], dtype=tf.int32),
|
environment_setup.py
CHANGED
@@ -122,15 +122,6 @@ class EnvironmentSetup:
|
|
122 |
except (subprocess.SubprocessError, FileNotFoundError):
|
123 |
logger.warning("Could not detect specific GPU model")
|
124 |
|
125 |
-
# # Enable XLA
|
126 |
-
# tf.config.optimizer.set_jit(True)
|
127 |
-
# logger.info("XLA compilation enabled for Colab GPU")
|
128 |
-
|
129 |
-
# # Set mixed precision policy
|
130 |
-
# policy = tf.keras.mixed_precision.Policy('mixed_float16')
|
131 |
-
# tf.keras.mixed_precision.set_global_policy(policy)
|
132 |
-
# logger.info("Mixed precision training enabled (float16)")
|
133 |
-
|
134 |
strategy = tf.distribute.OneDeviceStrategy("/GPU:0")
|
135 |
return "GPU", strategy
|
136 |
|
|
|
122 |
except (subprocess.SubprocessError, FileNotFoundError):
|
123 |
logger.warning("Could not detect specific GPU model")
|
124 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
125 |
strategy = tf.distribute.OneDeviceStrategy("/GPU:0")
|
126 |
return "GPU", strategy
|
127 |
|
run_data_preparer.py → prepare_data.py
RENAMED
@@ -1,6 +1,7 @@
|
|
1 |
import os
|
2 |
import sys
|
3 |
import faiss
|
|
|
4 |
import pickle
|
5 |
from transformers import AutoTokenizer
|
6 |
from tqdm.auto import tqdm
|
@@ -52,36 +53,24 @@ def main():
|
|
52 |
config = ChatbotConfig()
|
53 |
logger.info(f"Chatbot Configuration: {config}")
|
54 |
|
55 |
-
# Initialize tokenizer
|
56 |
try:
|
57 |
tokenizer = AutoTokenizer.from_pretrained(config.pretrained_model)
|
58 |
logger.info(f"Tokenizer '{config.pretrained_model}' loaded successfully.")
|
59 |
-
except Exception as e:
|
60 |
-
logger.error(f"Failed to load tokenizer: {e}")
|
61 |
-
sys.exit(1)
|
62 |
-
|
63 |
-
# Add special tokens
|
64 |
-
try:
|
65 |
tokenizer.add_special_tokens({'additional_special_tokens': ['<EMPTY_NEGATIVE>']})
|
66 |
logger.info("Added special tokens to tokenizer.")
|
67 |
except Exception as e:
|
68 |
-
logger.error(f"Failed to
|
69 |
sys.exit(1)
|
70 |
|
71 |
-
# Initialize encoder model
|
72 |
try:
|
73 |
encoder = EncoderModel(config=config)
|
74 |
logger.info("EncoderModel initialized successfully.")
|
75 |
-
except Exception as e:
|
76 |
-
logger.error(f"Failed to initialize EncoderModel: {e}")
|
77 |
-
sys.exit(1)
|
78 |
-
|
79 |
-
# Resize token embeddings in encoder to match tokenizer
|
80 |
-
try:
|
81 |
encoder.pretrained.resize_token_embeddings(len(tokenizer))
|
82 |
logger.info(f"Token embeddings resized to: {len(tokenizer)}")
|
83 |
except Exception as e:
|
84 |
-
logger.error(f"Failed to
|
85 |
sys.exit(1)
|
86 |
|
87 |
# Load JSON dialogues
|
@@ -116,6 +105,8 @@ def main():
|
|
116 |
max_length=config.max_context_token_limit,
|
117 |
neg_samples=config.neg_samples,
|
118 |
query_embeddings_cache=query_embeddings_cache,
|
|
|
|
|
119 |
max_retries=config.max_retries
|
120 |
)
|
121 |
logger.info("TFDataPipeline initialized successfully.")
|
@@ -135,17 +126,22 @@ def main():
|
|
135 |
# Compute and add response embeddings to FAISS index
|
136 |
try:
|
137 |
logger.info("Computing and adding response embeddings to FAISS index...")
|
138 |
-
data_pipeline.
|
139 |
logger.info("Response embeddings computed and added to FAISS index.")
|
140 |
except Exception as e:
|
141 |
logger.error(f"Failed to compute or add response embeddings: {e}")
|
142 |
sys.exit(1)
|
143 |
|
144 |
-
# Save FAISS index
|
145 |
try:
|
146 |
logger.info(f"Saving FAISS index to {FAISS_INDEX_PATH}...")
|
147 |
faiss.write_index(data_pipeline.index, FAISS_INDEX_PATH)
|
148 |
logger.info("FAISS index saved successfully.")
|
|
|
|
|
|
|
|
|
|
|
149 |
except Exception as e:
|
150 |
logger.error(f"Failed to save FAISS index: {e}")
|
151 |
sys.exit(1)
|
|
|
1 |
import os
|
2 |
import sys
|
3 |
import faiss
|
4 |
+
import json
|
5 |
import pickle
|
6 |
from transformers import AutoTokenizer
|
7 |
from tqdm.auto import tqdm
|
|
|
53 |
config = ChatbotConfig()
|
54 |
logger.info(f"Chatbot Configuration: {config}")
|
55 |
|
56 |
+
# Initialize tokenizer and add special tokens
|
57 |
try:
|
58 |
tokenizer = AutoTokenizer.from_pretrained(config.pretrained_model)
|
59 |
logger.info(f"Tokenizer '{config.pretrained_model}' loaded successfully.")
|
|
|
|
|
|
|
|
|
|
|
|
|
60 |
tokenizer.add_special_tokens({'additional_special_tokens': ['<EMPTY_NEGATIVE>']})
|
61 |
logger.info("Added special tokens to tokenizer.")
|
62 |
except Exception as e:
|
63 |
+
logger.error(f"Failed to load tokenizer: {e}")
|
64 |
sys.exit(1)
|
65 |
|
66 |
+
# Initialize encoder model and resize token embeddings
|
67 |
try:
|
68 |
encoder = EncoderModel(config=config)
|
69 |
logger.info("EncoderModel initialized successfully.")
|
|
|
|
|
|
|
|
|
|
|
|
|
70 |
encoder.pretrained.resize_token_embeddings(len(tokenizer))
|
71 |
logger.info(f"Token embeddings resized to: {len(tokenizer)}")
|
72 |
except Exception as e:
|
73 |
+
logger.error(f"Failed to initialize EncoderModel: {e}")
|
74 |
sys.exit(1)
|
75 |
|
76 |
# Load JSON dialogues
|
|
|
105 |
max_length=config.max_context_token_limit,
|
106 |
neg_samples=config.neg_samples,
|
107 |
query_embeddings_cache=query_embeddings_cache,
|
108 |
+
index_type='IndexFlatIP',
|
109 |
+
nlist=100,
|
110 |
max_retries=config.max_retries
|
111 |
)
|
112 |
logger.info("TFDataPipeline initialized successfully.")
|
|
|
126 |
# Compute and add response embeddings to FAISS index
|
127 |
try:
|
128 |
logger.info("Computing and adding response embeddings to FAISS index...")
|
129 |
+
data_pipeline.compute_and_index_response_embeddings()
|
130 |
logger.info("Response embeddings computed and added to FAISS index.")
|
131 |
except Exception as e:
|
132 |
logger.error(f"Failed to compute or add response embeddings: {e}")
|
133 |
sys.exit(1)
|
134 |
|
135 |
+
# Save FAISS index and response pool
|
136 |
try:
|
137 |
logger.info(f"Saving FAISS index to {FAISS_INDEX_PATH}...")
|
138 |
faiss.write_index(data_pipeline.index, FAISS_INDEX_PATH)
|
139 |
logger.info("FAISS index saved successfully.")
|
140 |
+
|
141 |
+
response_pool_path = FAISS_INDEX_PATH.replace('.index', '_responses.json')
|
142 |
+
with open(response_pool_path, 'w', encoding='utf-8') as f:
|
143 |
+
json.dump(data_pipeline.response_pool, f, indent=2)
|
144 |
+
logger.info(f"Response pool saved to {response_pool_path}.")
|
145 |
except Exception as e:
|
146 |
logger.error(f"Failed to save FAISS index: {e}")
|
147 |
sys.exit(1)
|
response_quality_checker.py
CHANGED
@@ -6,14 +6,14 @@ from logger_config import config_logger
|
|
6 |
logger = config_logger(__name__)
|
7 |
|
8 |
if TYPE_CHECKING:
|
9 |
-
from
|
10 |
|
11 |
class ResponseQualityChecker:
|
12 |
"""Enhanced quality checking with dynamic thresholds."""
|
13 |
-
|
14 |
def __init__(
|
15 |
self,
|
16 |
-
|
17 |
confidence_threshold: float = 0.6,
|
18 |
diversity_threshold: float = 0.15,
|
19 |
min_response_length: int = 5,
|
@@ -23,15 +23,15 @@ class ResponseQualityChecker:
|
|
23 |
self.diversity_threshold = diversity_threshold
|
24 |
self.min_response_length = min_response_length
|
25 |
self.similarity_cap = similarity_cap
|
26 |
-
self.
|
27 |
-
|
28 |
# Dynamic thresholds based on response patterns
|
29 |
self.thresholds = {
|
30 |
'relevance': 0.35,
|
31 |
'length_score': 0.85,
|
32 |
'score_gap': 0.07
|
33 |
}
|
34 |
-
|
35 |
def check_response_quality(
|
36 |
self,
|
37 |
query: str,
|
@@ -39,11 +39,11 @@ class ResponseQualityChecker:
|
|
39 |
) -> Dict[str, Any]:
|
40 |
"""
|
41 |
Evaluate the quality of responses based on various metrics.
|
42 |
-
|
43 |
Args:
|
44 |
query: The user's query
|
45 |
responses: List of (response_text, score) tuples
|
46 |
-
|
47 |
Returns:
|
48 |
Dict containing quality metrics and confidence assessment
|
49 |
"""
|
@@ -56,7 +56,7 @@ class ResponseQualityChecker:
|
|
56 |
'response_length_score': 0.0,
|
57 |
'top_3_score_gap': 0.0
|
58 |
}
|
59 |
-
|
60 |
# Calculate core metrics
|
61 |
metrics = {
|
62 |
'response_diversity': self.calculate_diversity(responses),
|
@@ -67,10 +67,10 @@ class ResponseQualityChecker:
|
|
67 |
'top_score': responses[0][1],
|
68 |
'top_3_score_gap': self._calculate_score_gap([score for _, score in responses], top_n=3)
|
69 |
}
|
70 |
-
|
71 |
# Determine confidence using thresholds
|
72 |
metrics['is_confident'] = self._determine_confidence(metrics)
|
73 |
-
|
74 |
logger.info(f"Quality metrics: {metrics}")
|
75 |
return metrics
|
76 |
|
@@ -78,44 +78,45 @@ class ResponseQualityChecker:
|
|
78 |
"""Calculate relevance as weighted similarity between query and responses."""
|
79 |
if not responses:
|
80 |
return 0.0
|
81 |
-
|
82 |
# Get embeddings
|
83 |
-
query_embedding = self.encode_query(query)
|
84 |
-
|
85 |
-
|
86 |
-
|
|
|
87 |
similarities = cosine_similarity([query_embedding], response_embeddings)[0]
|
|
|
|
|
88 |
weights = np.array([1.0 / (i + 1) for i in range(len(similarities))])
|
89 |
-
|
90 |
return np.average(similarities, weights=weights)
|
91 |
|
92 |
def calculate_diversity(self, responses: List[Tuple[str, float]]) -> float:
|
93 |
"""Calculate diversity with length normalization and similarity capping."""
|
94 |
if not responses:
|
95 |
return 0.0
|
96 |
-
|
97 |
-
|
|
|
98 |
if len(embeddings) < 2:
|
99 |
return 1.0
|
100 |
-
|
101 |
-
# Calculate
|
102 |
similarity_matrix = cosine_similarity(embeddings)
|
|
|
|
|
|
|
103 |
similarity_matrix = np.minimum(similarity_matrix, self.similarity_cap)
|
104 |
-
|
105 |
-
#
|
106 |
-
|
107 |
-
|
108 |
-
length_ratios = length_ratios.reshape(len(responses), len(responses))
|
109 |
-
|
110 |
-
# Combine factors with weights
|
111 |
-
adjusted_similarity = (similarity_matrix * 0.7 + length_ratios * 0.3)
|
112 |
-
|
113 |
-
# Calculate final score
|
114 |
-
sum_similarities = np.sum(adjusted_similarity) - len(responses)
|
115 |
-
num_pairs = len(responses) * (len(responses) - 1)
|
116 |
avg_similarity = sum_similarities / num_pairs if num_pairs > 0 else 0.0
|
117 |
-
|
118 |
-
|
|
|
|
|
119 |
|
120 |
def _determine_confidence(self, metrics: Dict[str, float]) -> bool:
|
121 |
"""Determine confidence using primary and secondary conditions."""
|
@@ -125,20 +126,20 @@ class ResponseQualityChecker:
|
|
125 |
metrics['response_diversity'] >= self.diversity_threshold,
|
126 |
metrics['response_length_score'] >= self.thresholds['length_score']
|
127 |
]
|
128 |
-
|
129 |
# Secondary conditions (majority must be met)
|
130 |
secondary_conditions = [
|
131 |
metrics['query_response_relevance'] >= self.thresholds['relevance'],
|
132 |
metrics['top_3_score_gap'] >= self.thresholds['score_gap'],
|
133 |
metrics['top_score'] >= (self.confidence_threshold * 1.1) # Extra confidence boost
|
134 |
]
|
135 |
-
|
136 |
return all(primary_conditions) and sum(secondary_conditions) >= 2
|
137 |
|
138 |
def _calculate_length_score(self, response: str) -> float:
|
139 |
"""Calculate length score with penalty for very short or long responses."""
|
140 |
words = len(response.split())
|
141 |
-
|
142 |
if words < self.min_response_length:
|
143 |
return words / self.min_response_length
|
144 |
elif words > 50: # Penalty for very long responses
|
@@ -150,21 +151,4 @@ class ResponseQualityChecker:
|
|
150 |
if len(scores) < top_n + 1:
|
151 |
return 0.0
|
152 |
gaps = [scores[i] - scores[i + 1] for i in range(min(len(scores) - 1, top_n))]
|
153 |
-
return np.mean(gaps)
|
154 |
-
|
155 |
-
def encode_text(self, text: str) -> np.ndarray:
|
156 |
-
"""Encode response text to embedding."""
|
157 |
-
embedding_tensor = self.chatbot.encode_responses([text])
|
158 |
-
embedding = embedding_tensor.numpy()[0].astype('float32')
|
159 |
-
return self._normalize_embedding(embedding)
|
160 |
-
|
161 |
-
def encode_query(self, query: str) -> np.ndarray:
|
162 |
-
"""Encode query text to embedding."""
|
163 |
-
embedding_tensor = self.chatbot.encode_query(query)
|
164 |
-
embedding = embedding_tensor.numpy()[0].astype('float32')
|
165 |
-
return self._normalize_embedding(embedding)
|
166 |
-
|
167 |
-
def _normalize_embedding(self, embedding: np.ndarray) -> np.ndarray:
|
168 |
-
"""Normalize embedding vector."""
|
169 |
-
norm = np.linalg.norm(embedding)
|
170 |
-
return embedding / norm if norm > 0 else embedding
|
|
|
6 |
logger = config_logger(__name__)
|
7 |
|
8 |
if TYPE_CHECKING:
|
9 |
+
from tf_data_pipeline import TFDataPipeline
|
10 |
|
11 |
class ResponseQualityChecker:
|
12 |
"""Enhanced quality checking with dynamic thresholds."""
|
13 |
+
|
14 |
def __init__(
|
15 |
self,
|
16 |
+
data_pipeline: 'TFDataPipeline',
|
17 |
confidence_threshold: float = 0.6,
|
18 |
diversity_threshold: float = 0.15,
|
19 |
min_response_length: int = 5,
|
|
|
23 |
self.diversity_threshold = diversity_threshold
|
24 |
self.min_response_length = min_response_length
|
25 |
self.similarity_cap = similarity_cap
|
26 |
+
self.data_pipeline = data_pipeline # Reference to TFDataPipeline
|
27 |
+
|
28 |
# Dynamic thresholds based on response patterns
|
29 |
self.thresholds = {
|
30 |
'relevance': 0.35,
|
31 |
'length_score': 0.85,
|
32 |
'score_gap': 0.07
|
33 |
}
|
34 |
+
|
35 |
def check_response_quality(
|
36 |
self,
|
37 |
query: str,
|
|
|
39 |
) -> Dict[str, Any]:
|
40 |
"""
|
41 |
Evaluate the quality of responses based on various metrics.
|
42 |
+
|
43 |
Args:
|
44 |
query: The user's query
|
45 |
responses: List of (response_text, score) tuples
|
46 |
+
|
47 |
Returns:
|
48 |
Dict containing quality metrics and confidence assessment
|
49 |
"""
|
|
|
56 |
'response_length_score': 0.0,
|
57 |
'top_3_score_gap': 0.0
|
58 |
}
|
59 |
+
|
60 |
# Calculate core metrics
|
61 |
metrics = {
|
62 |
'response_diversity': self.calculate_diversity(responses),
|
|
|
67 |
'top_score': responses[0][1],
|
68 |
'top_3_score_gap': self._calculate_score_gap([score for _, score in responses], top_n=3)
|
69 |
}
|
70 |
+
|
71 |
# Determine confidence using thresholds
|
72 |
metrics['is_confident'] = self._determine_confidence(metrics)
|
73 |
+
|
74 |
logger.info(f"Quality metrics: {metrics}")
|
75 |
return metrics
|
76 |
|
|
|
78 |
"""Calculate relevance as weighted similarity between query and responses."""
|
79 |
if not responses:
|
80 |
return 0.0
|
81 |
+
|
82 |
# Get embeddings
|
83 |
+
query_embedding = self.data_pipeline.encode_query(query)
|
84 |
+
response_texts = [resp for resp, _ in responses]
|
85 |
+
response_embeddings = self.data_pipeline.encode_responses(response_texts)
|
86 |
+
|
87 |
+
# Compute similarities
|
88 |
similarities = cosine_similarity([query_embedding], response_embeddings)[0]
|
89 |
+
|
90 |
+
# Apply decreasing weights for later responses
|
91 |
weights = np.array([1.0 / (i + 1) for i in range(len(similarities))])
|
92 |
+
|
93 |
return np.average(similarities, weights=weights)
|
94 |
|
95 |
def calculate_diversity(self, responses: List[Tuple[str, float]]) -> float:
|
96 |
"""Calculate diversity with length normalization and similarity capping."""
|
97 |
if not responses:
|
98 |
return 0.0
|
99 |
+
|
100 |
+
response_texts = [resp for resp, _ in responses]
|
101 |
+
embeddings = self.data_pipeline.encode_responses(response_texts)
|
102 |
if len(embeddings) < 2:
|
103 |
return 1.0
|
104 |
+
|
105 |
+
# Calculate pairwise cosine similarities
|
106 |
similarity_matrix = cosine_similarity(embeddings)
|
107 |
+
np.fill_diagonal(similarity_matrix, 0) # Exclude self-similarity
|
108 |
+
|
109 |
+
# Apply similarity cap
|
110 |
similarity_matrix = np.minimum(similarity_matrix, self.similarity_cap)
|
111 |
+
|
112 |
+
# Calculate average similarity
|
113 |
+
sum_similarities = np.sum(similarity_matrix)
|
114 |
+
num_pairs = len(embeddings) * (len(embeddings) - 1)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
115 |
avg_similarity = sum_similarities / num_pairs if num_pairs > 0 else 0.0
|
116 |
+
|
117 |
+
# Diversity is inversely related to average similarity
|
118 |
+
diversity_score = 1 - avg_similarity
|
119 |
+
return diversity_score
|
120 |
|
121 |
def _determine_confidence(self, metrics: Dict[str, float]) -> bool:
|
122 |
"""Determine confidence using primary and secondary conditions."""
|
|
|
126 |
metrics['response_diversity'] >= self.diversity_threshold,
|
127 |
metrics['response_length_score'] >= self.thresholds['length_score']
|
128 |
]
|
129 |
+
|
130 |
# Secondary conditions (majority must be met)
|
131 |
secondary_conditions = [
|
132 |
metrics['query_response_relevance'] >= self.thresholds['relevance'],
|
133 |
metrics['top_3_score_gap'] >= self.thresholds['score_gap'],
|
134 |
metrics['top_score'] >= (self.confidence_threshold * 1.1) # Extra confidence boost
|
135 |
]
|
136 |
+
|
137 |
return all(primary_conditions) and sum(secondary_conditions) >= 2
|
138 |
|
139 |
def _calculate_length_score(self, response: str) -> float:
|
140 |
"""Calculate length score with penalty for very short or long responses."""
|
141 |
words = len(response.split())
|
142 |
+
|
143 |
if words < self.min_response_length:
|
144 |
return words / self.min_response_length
|
145 |
elif words > 50: # Penalty for very long responses
|
|
|
151 |
if len(scores) < top_n + 1:
|
152 |
return 0.0
|
153 |
gaps = [scores[i] - scores[i + 1] for i in range(min(len(scores) - 1, top_n))]
|
154 |
+
return np.mean(gaps)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
test_trained_model.py
DELETED
File without changes
|
tf_data_pipeline.py
CHANGED
@@ -11,6 +11,7 @@ from pathlib import Path
|
|
11 |
from typing import Union, Optional, List, Tuple, Generator
|
12 |
from transformers import AutoTokenizer
|
13 |
from typing import List, Tuple, Generator
|
|
|
14 |
from gpu_monitor import GPUMemoryMonitor
|
15 |
|
16 |
from logger_config import config_logger
|
@@ -31,7 +32,6 @@ class TFDataPipeline:
|
|
31 |
nlist: int = 100,
|
32 |
max_retries: int = 3
|
33 |
):
|
34 |
-
#self.embedding_batch_size = embedding_batch_size
|
35 |
self.config = config
|
36 |
self.tokenizer = tokenizer
|
37 |
self.encoder = encoder
|
@@ -64,14 +64,6 @@ class TFDataPipeline:
|
|
64 |
dimension = self.query_embeddings_cache[next(iter(self.query_embeddings_cache))].shape[0]
|
65 |
self.index.train(np.array(list(self.query_embeddings_cache.values())).astype(np.float32))
|
66 |
self.index.add(np.array(list(self.query_embeddings_cache.values())).astype(np.float32))
|
67 |
-
|
68 |
-
def validate_faiss_index(self):
|
69 |
-
"""Validates that the FAISS index has the correct dimensionality."""
|
70 |
-
expected_dim = self.encoder.config.embedding_dim
|
71 |
-
if self.index.d != expected_dim:
|
72 |
-
logger.error(f"FAISS index dimension {self.index.d} does not match encoder embedding dimension {expected_dim}.")
|
73 |
-
raise ValueError("FAISS index dimensionality mismatch.")
|
74 |
-
logger.info("FAISS index dimension validated successfully.")
|
75 |
|
76 |
def save_embeddings_cache_hdf5(self, cache_file_path: str):
|
77 |
"""Save the embeddings cache to an HDF5 file."""
|
@@ -92,8 +84,21 @@ class TFDataPipeline:
|
|
92 |
logger.info(f"FAISS index saved to {index_file_path}")
|
93 |
|
94 |
def load_faiss_index(self, index_file_path: str):
|
95 |
-
|
96 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
97 |
|
98 |
def save_tokenizer(self, tokenizer_dir: str):
|
99 |
self.tokenizer.save_pretrained(tokenizer_dir)
|
@@ -102,19 +107,6 @@ class TFDataPipeline:
|
|
102 |
def load_tokenizer(self, tokenizer_dir: str):
|
103 |
self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_dir)
|
104 |
logger.info(f"Tokenizer loaded from {tokenizer_dir}")
|
105 |
-
|
106 |
-
def estimate_total_pairs(self, dialogues: List[dict]) -> int:
|
107 |
-
"""Estimate total number of training pairs including hard negatives."""
|
108 |
-
base_pairs = sum(
|
109 |
-
len([
|
110 |
-
1 for i in range(len(d.get('turns', [])) - 1)
|
111 |
-
if (d['turns'][i].get('speaker') == 'user' and
|
112 |
-
d['turns'][i+1].get('speaker') == 'assistant')
|
113 |
-
])
|
114 |
-
for d in dialogues
|
115 |
-
)
|
116 |
-
# Account for hard negatives
|
117 |
-
return base_pairs * (1 + self.neg_samples)
|
118 |
|
119 |
@staticmethod
|
120 |
def load_json_training_data(data_path: Union[str, Path], debug_samples: Optional[int] = None) -> List[dict]:
|
@@ -179,7 +171,7 @@ class TFDataPipeline:
|
|
179 |
|
180 |
return pairs
|
181 |
|
182 |
-
def
|
183 |
"""
|
184 |
Computes embeddings for the response pool and adds them to the FAISS index with progress bars.
|
185 |
"""
|
@@ -239,49 +231,6 @@ class TFDataPipeline:
|
|
239 |
|
240 |
# **Sanity Check:** Verify the number of embeddings in FAISS index
|
241 |
logger.info(f"Total embeddings in FAISS index after addition: {self.index.ntotal}")
|
242 |
-
# def _compute_and_index_response_embeddings(self):
|
243 |
-
# """
|
244 |
-
# Computes embeddings for the response pool and adds them to the FAISS index.
|
245 |
-
# """
|
246 |
-
# logger.info("Computing embeddings for the response pool...")
|
247 |
-
|
248 |
-
# # Ensure all responses are strings
|
249 |
-
# if not all(isinstance(response, str) for response in self.response_pool):
|
250 |
-
# logger.error("All elements in response_pool must be strings.")
|
251 |
-
# raise ValueError("Invalid data type in response_pool.")
|
252 |
-
|
253 |
-
# # Proceed with tokenization
|
254 |
-
# encoded_responses = self.tokenizer(
|
255 |
-
# self.response_pool,
|
256 |
-
# padding=True,
|
257 |
-
# truncation=True,
|
258 |
-
# max_length=self.max_length,
|
259 |
-
# return_tensors='tf'
|
260 |
-
# )
|
261 |
-
# response_ids = encoded_responses['input_ids']
|
262 |
-
|
263 |
-
# # Compute embeddings in batches
|
264 |
-
# batch_size = getattr(self, 'embedding_batch_size', 64) # Default to 64 if not set
|
265 |
-
# embeddings = []
|
266 |
-
# for i in range(0, len(response_ids), batch_size):
|
267 |
-
# batch_ids = response_ids[i:i+batch_size]
|
268 |
-
# # Compute embeddings
|
269 |
-
# batch_embeddings = self.encoder(batch_ids, training=False).numpy()
|
270 |
-
# # Normalize embeddings if using inner product or cosine similarity
|
271 |
-
# faiss.normalize_L2(batch_embeddings)
|
272 |
-
# embeddings.append(batch_embeddings)
|
273 |
-
|
274 |
-
# if embeddings:
|
275 |
-
# embeddings = np.vstack(embeddings).astype(np.float32)
|
276 |
-
# # Add embeddings to FAISS index
|
277 |
-
# logger.info(f"Adding {len(embeddings)} response embeddings to FAISS index...")
|
278 |
-
# self.index.add(embeddings)
|
279 |
-
# logger.info("Response embeddings added to FAISS index.")
|
280 |
-
# else:
|
281 |
-
# logger.warning("No embeddings to add to FAISS index.")
|
282 |
-
|
283 |
-
# # **Sanity Check:** Verify the number of embeddings in FAISS index
|
284 |
-
# logger.info(f"Total embeddings in FAISS index after addition: {self.index.ntotal}")
|
285 |
|
286 |
def _find_hard_negatives_batch(self, queries: List[str], positives: List[str]) -> List[List[str]]:
|
287 |
"""Find hard negatives for a batch of queries with error handling and retries."""
|
@@ -355,58 +304,109 @@ class TFDataPipeline:
|
|
355 |
if tf.config.list_physical_devices('GPU'):
|
356 |
tf.keras.backend.clear_session()
|
357 |
|
358 |
-
def
|
359 |
"""
|
360 |
-
|
|
|
|
|
|
|
|
|
|
|
361 |
Returns:
|
362 |
-
|
363 |
-
positive_ids: [batch_size, max_length]
|
364 |
-
negative_ids: [batch_size, neg_samples, max_length]
|
365 |
"""
|
366 |
-
#
|
367 |
-
|
368 |
-
|
369 |
-
|
370 |
-
|
371 |
-
|
372 |
-
|
373 |
-
|
374 |
-
|
375 |
-
|
376 |
-
|
377 |
-
|
|
|
|
|
|
|
|
|
378 |
truncation=True,
|
379 |
max_length=self.max_length,
|
380 |
-
return_tensors=
|
381 |
)
|
382 |
-
|
383 |
-
|
384 |
-
|
385 |
-
|
386 |
-
|
387 |
-
|
388 |
-
|
389 |
-
|
390 |
-
|
391 |
-
|
392 |
-
|
393 |
-
|
394 |
-
|
395 |
-
|
396 |
-
|
397 |
-
|
398 |
-
|
399 |
-
batch_size = len(queries)
|
400 |
-
n_ids = n_input_ids.reshape(batch_size, self.neg_samples, self.max_length)
|
401 |
|
402 |
-
|
403 |
-
|
404 |
-
|
405 |
-
negative_ids = n_ids.astype(np.int32)
|
406 |
|
407 |
-
|
|
|
|
|
408 |
|
409 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
410 |
def prepare_and_save_data(self, dialogues: List[dict], tf_record_path: str, batch_size: int = 32):
|
411 |
"""
|
412 |
Processes dialogues in batches and saves to a TFRecord file using optimized batch tokenization and encoding.
|
@@ -522,83 +522,6 @@ class TFDataPipeline:
|
|
522 |
pbar.update(1)
|
523 |
|
524 |
logger.info(f"Data preparation complete. TFRecord saved.")
|
525 |
-
# def prepare_and_save_data(self, dialogues: List[dict], tfrecord_file_path: str, batch_size: int = 32):
|
526 |
-
# """Processes dialogues in batches and saves to a TFRecord file."""
|
527 |
-
# with tf.io.TFRecordWriter(tfrecord_file_path) as writer:
|
528 |
-
# total_dialogues = len(dialogues)
|
529 |
-
# logger.debug(f"Total dialogues to process: {total_dialogues}")
|
530 |
-
|
531 |
-
# with tqdm(total=total_dialogues, desc="Processing Dialogues", unit="dialogue") as pbar:
|
532 |
-
# for i in range(0, total_dialogues, batch_size):
|
533 |
-
# batch_dialogues = dialogues[i:i+batch_size]
|
534 |
-
# # Process each batch_dialogues
|
535 |
-
# # Extract pairs, find negatives, tokenize, and serialize
|
536 |
-
# # Example:
|
537 |
-
# for dialogue in batch_dialogues:
|
538 |
-
# pairs = self._extract_pairs_from_dialogue(dialogue)
|
539 |
-
# queries = []
|
540 |
-
# positives = []
|
541 |
-
|
542 |
-
# for query, positive in pairs:
|
543 |
-
# queries.append(query)
|
544 |
-
# positives.append(positive)
|
545 |
-
|
546 |
-
# if queries:
|
547 |
-
# # **Compute and cache query embeddings before searching**
|
548 |
-
# self._compute_embeddings(queries)
|
549 |
-
|
550 |
-
# # Find hard negatives
|
551 |
-
# hard_negatives = self._find_hard_negatives_batch(queries, positives)
|
552 |
-
|
553 |
-
# # for idx, negatives in enumerate(hard_negatives[:5]): # Log first 5 examples
|
554 |
-
# # logger.debug(f"Query: {queries[idx]}")
|
555 |
-
# # logger.debug(f"Positive: {positives[idx]}")
|
556 |
-
# # logger.debug(f"Hard Negatives: {negatives}")
|
557 |
-
# # Tokenize and encode
|
558 |
-
# query_ids, positive_ids, negative_ids = self._tokenize_and_encode(queries, positives, hard_negatives)
|
559 |
-
|
560 |
-
# # Serialize each example and write to TFRecord
|
561 |
-
# for q_id, p_id, n_id in zip(query_ids, positive_ids, negative_ids):
|
562 |
-
# feature = {
|
563 |
-
# 'query_ids': tf.train.Feature(int64_list=tf.train.Int64List(value=q_id)),
|
564 |
-
# 'positive_ids': tf.train.Feature(int64_list=tf.train.Int64List(value=p_id)),
|
565 |
-
# 'negative_ids': tf.train.Feature(int64_list=tf.train.Int64List(value=n_id.flatten())),
|
566 |
-
# }
|
567 |
-
# example = tf.train.Example(features=tf.train.Features(feature=feature))
|
568 |
-
# writer.write(example.SerializeToString())
|
569 |
-
|
570 |
-
# pbar.update(len(batch_dialogues))
|
571 |
-
# logger.info(f"Data preparation complete. TFRecord saved at {tfrecord_file_path}")
|
572 |
-
|
573 |
-
def _tokenize_negatives_tf(self, negatives):
|
574 |
-
"""Tokenizes negatives using tf.py_function."""
|
575 |
-
# Handle the case where negatives is an empty tensor
|
576 |
-
if tf.size(negatives) == 0:
|
577 |
-
return tf.zeros([0, self.neg_samples, self.max_length], dtype=tf.int32)
|
578 |
-
|
579 |
-
# Convert EagerTensor to a list of strings
|
580 |
-
negatives_list = []
|
581 |
-
for neg_list in negatives.numpy():
|
582 |
-
decoded_negs = [neg.decode("utf-8") for neg in neg_list if neg] # Filter out empty strings
|
583 |
-
negatives_list.append(decoded_negs)
|
584 |
-
|
585 |
-
# Flatten the list of lists
|
586 |
-
flattened_negatives = [neg for sublist in negatives_list for neg in sublist]
|
587 |
-
|
588 |
-
# Tokenize the flattened negatives
|
589 |
-
if flattened_negatives:
|
590 |
-
n_tokens = self.tokenizer(
|
591 |
-
flattened_negatives,
|
592 |
-
padding='max_length',
|
593 |
-
truncation=True,
|
594 |
-
max_length=self.max_length,
|
595 |
-
return_tensors='tf'
|
596 |
-
)
|
597 |
-
# Reshape the tokens
|
598 |
-
n_tokens_reshaped = tf.reshape(n_tokens['input_ids'], [-1, self.neg_samples, self.max_length])
|
599 |
-
return n_tokens_reshaped
|
600 |
-
else:
|
601 |
-
return tf.zeros([0, self.neg_samples, self.max_length], dtype=tf.int32)
|
602 |
|
603 |
def _compute_embeddings(self, queries: List[str]) -> None:
|
604 |
new_queries = [q for q in queries if q not in self.query_embeddings_cache]
|
@@ -642,51 +565,6 @@ class TFDataPipeline:
|
|
642 |
hard_negatives = self._find_hard_negatives_batch([query], [positive])[0]
|
643 |
yield (query, positive, hard_negatives)
|
644 |
pbar.update(1)
|
645 |
-
|
646 |
-
def _prepare_batch(self, queries: tf.Tensor, positives: tf.Tensor, negatives: tf.Tensor) -> Tuple[tf.Tensor, tf.Tensor]:
|
647 |
-
"""Prepares a batch of data for training."""
|
648 |
-
|
649 |
-
# Convert EagerTensors to lists of strings
|
650 |
-
queries_list = [query.decode("utf-8") for query in queries.numpy()]
|
651 |
-
positives_list = [pos.decode("utf-8") for pos in positives.numpy()]
|
652 |
-
|
653 |
-
# Tokenize queries and positives
|
654 |
-
q_tokens = self.tokenizer(queries_list, padding='max_length', truncation=True, max_length=self.max_length, return_tensors='tf')
|
655 |
-
p_tokens = self.tokenizer(positives_list, padding='max_length', truncation=True, max_length=self.max_length, return_tensors='tf')
|
656 |
-
|
657 |
-
# Decode negatives and ensure they are lists of strings
|
658 |
-
negatives_list = []
|
659 |
-
for neg_list in negatives.numpy():
|
660 |
-
decoded_negs = [neg.decode("utf-8") for neg in neg_list if neg] # Filter out empty strings
|
661 |
-
negatives_list.append(decoded_negs)
|
662 |
-
|
663 |
-
# Flatten negatives for tokenization if there are any valid negatives
|
664 |
-
flattened_negatives = [neg for sublist in negatives_list for neg in sublist if neg]
|
665 |
-
|
666 |
-
# Tokenize negatives if there are any
|
667 |
-
n_tokens_reshaped = None
|
668 |
-
if flattened_negatives:
|
669 |
-
n_tokens = self.tokenizer(flattened_negatives, padding='max_length', truncation=True, max_length=self.max_length, return_tensors='tf')
|
670 |
-
|
671 |
-
# Reshape n_tokens to match the expected shape based on the number of negatives per query
|
672 |
-
# This part may need adjustment if the number of negatives varies per query
|
673 |
-
n_tokens_reshaped = tf.reshape(n_tokens['input_ids'], [len(queries_list), -1, self.max_length])
|
674 |
-
else:
|
675 |
-
# Create a placeholder tensor for the case where there are no negatives
|
676 |
-
n_tokens_reshaped = tf.zeros([len(queries_list), 0, self.max_length], dtype=tf.int32)
|
677 |
-
|
678 |
-
# Ensure n_tokens_reshaped has a consistent shape even when there are no negatives
|
679 |
-
# Adjust shape to [batch_size, num_neg_samples, max_length]
|
680 |
-
if n_tokens_reshaped.shape[1] != self.neg_samples:
|
681 |
-
# Pad or truncate the second dimension to match neg_samples
|
682 |
-
padding = tf.zeros([len(queries_list), tf.maximum(0, self.neg_samples - n_tokens_reshaped.shape[1]), self.max_length], dtype=tf.int32)
|
683 |
-
n_tokens_reshaped = tf.concat([n_tokens_reshaped, padding], axis=1)
|
684 |
-
n_tokens_reshaped = n_tokens_reshaped[:, :self.neg_samples, :]
|
685 |
-
|
686 |
-
# Concatenate the positive and negative examples along the 'neg_samples' dimension
|
687 |
-
combined_p_n_tokens = tf.concat([tf.expand_dims(p_tokens['input_ids'], axis=1), n_tokens_reshaped], axis=1)
|
688 |
-
|
689 |
-
return q_tokens['input_ids'], combined_p_n_tokens
|
690 |
|
691 |
def get_tf_dataset(self, dialogues: List[dict], batch_size: int) -> tf.data.Dataset:
|
692 |
"""
|
@@ -714,32 +592,6 @@ class TFDataPipeline:
|
|
714 |
|
715 |
dataset = dataset.prefetch(tf.data.AUTOTUNE)
|
716 |
return dataset
|
717 |
-
# def get_tf_dataset(self, dialogues: List[dict], batch_size: int) -> tf.data.Dataset:
|
718 |
-
# """
|
719 |
-
# Creates a tf.data.Dataset for streaming training that yields
|
720 |
-
# (input_ids_query, input_ids_positive, input_ids_negatives).
|
721 |
-
# """
|
722 |
-
# # 1) Start with a generator dataset
|
723 |
-
# dataset = tf.data.Dataset.from_generator(
|
724 |
-
# lambda: self.data_generator(dialogues),
|
725 |
-
# output_signature=(
|
726 |
-
# tf.TensorSpec(shape=(), dtype=tf.string), # Query (single string)
|
727 |
-
# tf.TensorSpec(shape=(), dtype=tf.string), # Positive (single string)
|
728 |
-
# tf.TensorSpec(shape=(None,), dtype=tf.string) # Hard Negatives (list of strings)
|
729 |
-
# )
|
730 |
-
# )
|
731 |
-
|
732 |
-
# # 2) Batch the raw strings
|
733 |
-
# dataset = dataset.batch(batch_size)
|
734 |
-
|
735 |
-
# # 3) Now map them through a tokenize step (via py_function)
|
736 |
-
# dataset = dataset.map(
|
737 |
-
# lambda q, p, n: self._tokenize_triple(q, p, n),
|
738 |
-
# num_parallel_calls=1 #tf.data.AUTOTUNE
|
739 |
-
# )
|
740 |
-
|
741 |
-
# dataset = dataset.prefetch(tf.data.AUTOTUNE)
|
742 |
-
# return dataset
|
743 |
|
744 |
def _tokenize_triple(
|
745 |
self,
|
@@ -861,71 +713,3 @@ class TFDataPipeline:
|
|
861 |
n_ids = n_ids.astype(np.int32) # shape [batch_size, neg_samples, max_len]
|
862 |
|
863 |
return q_ids, p_ids, n_ids
|
864 |
-
|
865 |
-
# def parse_tfrecord_fn(example_proto, max_length, neg_samples):
|
866 |
-
# """
|
867 |
-
# Parses a single TFRecord example.
|
868 |
-
|
869 |
-
# Args:
|
870 |
-
# example_proto: A serialized TFRecord example.
|
871 |
-
# max_length: The maximum sequence length for tokenization.
|
872 |
-
# neg_samples: The number of hard negatives per query.
|
873 |
-
|
874 |
-
# Returns:
|
875 |
-
# A tuple of (query_ids, positive_ids, negative_ids).
|
876 |
-
# """
|
877 |
-
# feature_description = {
|
878 |
-
# 'query_ids': tf.io.FixedLenFeature([max_length], tf.int64),
|
879 |
-
# 'positive_ids': tf.io.FixedLenFeature([max_length], tf.int64),
|
880 |
-
# 'negative_ids': tf.io.FixedLenFeature([neg_samples * max_length], tf.int64),
|
881 |
-
# }
|
882 |
-
# parsed_features = tf.io.parse_single_example(example_proto, feature_description)
|
883 |
-
|
884 |
-
# query_ids = tf.cast(parsed_features['query_ids'], tf.int32)
|
885 |
-
# positive_ids = tf.cast(parsed_features['positive_ids'], tf.int32)
|
886 |
-
# negative_ids = tf.cast(parsed_features['negative_ids'], tf.int32)
|
887 |
-
# negative_ids = tf.reshape(negative_ids, [neg_samples, max_length])
|
888 |
-
|
889 |
-
# return query_ids, positive_ids, negative_ids
|
890 |
-
|
891 |
-
# def _find_hard_negatives_batch(self, queries: List[str], positives: List[str]) -> List[List[str]]:
|
892 |
-
# """Find hard negatives for a batch of queries with error handling and retries."""
|
893 |
-
# retry_count = 0
|
894 |
-
# total_responses = len(self.response_pool)
|
895 |
-
|
896 |
-
# while retry_count < self.max_retries:
|
897 |
-
# try:
|
898 |
-
# query_embeddings = np.vstack([
|
899 |
-
# self.query_embeddings_cache[q] for q in queries
|
900 |
-
# ]).astype(np.float32)
|
901 |
-
|
902 |
-
# query_embeddings = np.ascontiguousarray(query_embeddings)
|
903 |
-
# faiss.normalize_L2(query_embeddings)
|
904 |
-
|
905 |
-
# k = 1 # TODO: try higher k for better results
|
906 |
-
# #logger.debug(f"Searching with k={k} among {total_responses} responses")
|
907 |
-
|
908 |
-
# distances, indices = self.index.search(query_embeddings, k)
|
909 |
-
|
910 |
-
# all_negatives = []
|
911 |
-
# for query_indices, query, positive in zip(indices, queries, positives):
|
912 |
-
# negatives = []
|
913 |
-
# positive_strip = positive.strip()
|
914 |
-
# seen = {positive_strip}
|
915 |
-
|
916 |
-
# for idx in query_indices:
|
917 |
-
# if idx >= 0 and idx < total_responses:
|
918 |
-
# candidate = self.response_pool[idx].strip()
|
919 |
-
# if candidate and candidate not in seen:
|
920 |
-
# seen.add(candidate)
|
921 |
-
# negatives.append(candidate)
|
922 |
-
# if len(negatives) >= self.neg_samples:
|
923 |
-
# break
|
924 |
-
|
925 |
-
# # Pad with a special empty negative if necessary
|
926 |
-
# while len(negatives) < self.neg_samples:
|
927 |
-
# negatives.append("<EMPTY_NEGATIVE>") # Use a special token
|
928 |
-
|
929 |
-
# all_negatives.append(negatives)
|
930 |
-
|
931 |
-
# return all_negatives
|
|
|
11 |
from typing import Union, Optional, List, Tuple, Generator
|
12 |
from transformers import AutoTokenizer
|
13 |
from typing import List, Tuple, Generator
|
14 |
+
from transformers import AutoTokenizer
|
15 |
from gpu_monitor import GPUMemoryMonitor
|
16 |
|
17 |
from logger_config import config_logger
|
|
|
32 |
nlist: int = 100,
|
33 |
max_retries: int = 3
|
34 |
):
|
|
|
35 |
self.config = config
|
36 |
self.tokenizer = tokenizer
|
37 |
self.encoder = encoder
|
|
|
64 |
dimension = self.query_embeddings_cache[next(iter(self.query_embeddings_cache))].shape[0]
|
65 |
self.index.train(np.array(list(self.query_embeddings_cache.values())).astype(np.float32))
|
66 |
self.index.add(np.array(list(self.query_embeddings_cache.values())).astype(np.float32))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
67 |
|
68 |
def save_embeddings_cache_hdf5(self, cache_file_path: str):
|
69 |
"""Save the embeddings cache to an HDF5 file."""
|
|
|
84 |
logger.info(f"FAISS index saved to {index_file_path}")
|
85 |
|
86 |
def load_faiss_index(self, index_file_path: str):
|
87 |
+
"""Load the FAISS index from the specified file path."""
|
88 |
+
if os.path.exists(index_file_path):
|
89 |
+
self.index = faiss.read_index(index_file_path)
|
90 |
+
logger.info(f"FAISS index loaded from {index_file_path}.")
|
91 |
+
else:
|
92 |
+
logger.error(f"FAISS index file not found at {index_file_path}.")
|
93 |
+
raise FileNotFoundError(f"FAISS index file not found at {index_file_path}.")
|
94 |
+
|
95 |
+
def validate_faiss_index(self):
|
96 |
+
"""Validates that the FAISS index has the correct dimensionality."""
|
97 |
+
expected_dim = self.encoder.config.embedding_dim
|
98 |
+
if self.index.d != expected_dim:
|
99 |
+
logger.error(f"FAISS index dimension {self.index.d} does not match encoder embedding dimension {expected_dim}.")
|
100 |
+
raise ValueError("FAISS index dimensionality mismatch.")
|
101 |
+
logger.info("FAISS index dimension validated successfully.")
|
102 |
|
103 |
def save_tokenizer(self, tokenizer_dir: str):
|
104 |
self.tokenizer.save_pretrained(tokenizer_dir)
|
|
|
107 |
def load_tokenizer(self, tokenizer_dir: str):
|
108 |
self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_dir)
|
109 |
logger.info(f"Tokenizer loaded from {tokenizer_dir}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
110 |
|
111 |
@staticmethod
|
112 |
def load_json_training_data(data_path: Union[str, Path], debug_samples: Optional[int] = None) -> List[dict]:
|
|
|
171 |
|
172 |
return pairs
|
173 |
|
174 |
+
def compute_and_index_response_embeddings(self):
|
175 |
"""
|
176 |
Computes embeddings for the response pool and adds them to the FAISS index with progress bars.
|
177 |
"""
|
|
|
231 |
|
232 |
# **Sanity Check:** Verify the number of embeddings in FAISS index
|
233 |
logger.info(f"Total embeddings in FAISS index after addition: {self.index.ntotal}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
234 |
|
235 |
def _find_hard_negatives_batch(self, queries: List[str], positives: List[str]) -> List[List[str]]:
|
236 |
"""Find hard negatives for a batch of queries with error handling and retries."""
|
|
|
304 |
if tf.config.list_physical_devices('GPU'):
|
305 |
tf.keras.backend.clear_session()
|
306 |
|
307 |
+
def encode_query(self, query: str, context: Optional[List[Tuple[str, str]]] = None) -> np.ndarray:
|
308 |
"""
|
309 |
+
Encode a query with optional conversation context into an embedding vector.
|
310 |
+
|
311 |
+
Args:
|
312 |
+
query (str): The user query.
|
313 |
+
context (Optional[List[Tuple[str, str]]]): Optional conversation history as a list of (user, assistant) tuples.
|
314 |
+
|
315 |
Returns:
|
316 |
+
np.ndarray: The normalized embedding vector for the query.
|
|
|
|
|
317 |
"""
|
318 |
+
# Prepare query with context
|
319 |
+
if context:
|
320 |
+
context_str = ' '.join([
|
321 |
+
f"{self.tokenizer.additional_special_tokens[self.tokenizer.additional_special_tokens.index('<USER>')]} {q} "
|
322 |
+
f"{self.tokenizer.additional_special_tokens[self.tokenizer.additional_special_tokens.index('<ASSISTANT>')]} {r}"
|
323 |
+
for q, r in context[-self.config.max_context_turns:]
|
324 |
+
])
|
325 |
+
query = f"{context_str} {self.tokenizer.additional_special_tokens[self.tokenizer.additional_special_tokens.index('<USER>')]}" \
|
326 |
+
f" {query}"
|
327 |
+
else:
|
328 |
+
query = f"{self.tokenizer.additional_special_tokens[self.tokenizer.additional_special_tokens.index('<USER>')]} {query}"
|
329 |
+
|
330 |
+
# Tokenize and encode
|
331 |
+
encodings = self.tokenizer(
|
332 |
+
[query],
|
333 |
+
padding='max_length',
|
334 |
truncation=True,
|
335 |
max_length=self.max_length,
|
336 |
+
return_tensors='np' # Use NumPy arrays for compatibility with FAISS
|
337 |
)
|
338 |
+
input_ids = encodings['input_ids']
|
339 |
+
|
340 |
+
# Verify token IDs
|
341 |
+
max_id = np.max(input_ids)
|
342 |
+
new_vocab_size = len(self.tokenizer)
|
343 |
+
|
344 |
+
if max_id >= new_vocab_size:
|
345 |
+
logger.error(f"Token ID {max_id} exceeds the vocabulary size {new_vocab_size}.")
|
346 |
+
raise ValueError("Token ID exceeds vocabulary size.")
|
347 |
+
|
348 |
+
# Get embeddings from the shared encoder
|
349 |
+
embeddings = self.encoder(input_ids, training=False).numpy()
|
350 |
+
|
351 |
+
# Normalize embeddings for cosine similarity
|
352 |
+
faiss.normalize_L2(embeddings)
|
353 |
+
|
354 |
+
return embeddings[0] # Return as a 1D array
|
|
|
|
|
355 |
|
356 |
+
def encode_responses(self, responses: List[str], context: Optional[List[Tuple[str, str]]] = None) -> np.ndarray:
|
357 |
+
"""
|
358 |
+
Encode a list of responses into embedding vectors.
|
|
|
359 |
|
360 |
+
Args:
|
361 |
+
responses (List[str]): List of response texts.
|
362 |
+
context (Optional[List[Tuple[str, str]]]): Optional conversation history as a list of (user, assistant) tuples.
|
363 |
|
364 |
+
Returns:
|
365 |
+
np.ndarray: Array of normalized embedding vectors.
|
366 |
+
"""
|
367 |
+
# Prepare responses with context if provided
|
368 |
+
if context:
|
369 |
+
prepared_responses = []
|
370 |
+
for response in responses:
|
371 |
+
context_str = ' '.join([
|
372 |
+
f"{self.tokenizer.additional_special_tokens[self.tokenizer.additional_special_tokens.index('<USER>')]} {q} "
|
373 |
+
f"{self.tokenizer.additional_special_tokens[self.tokenizer.additional_special_tokens.index('<ASSISTANT>')]} {r}"
|
374 |
+
for q, r in context[-self.config.max_context_turns:]
|
375 |
+
])
|
376 |
+
full_response = f"{context_str} {self.tokenizer.additional_special_tokens[self.tokenizer.additional_special_tokens.index('<ASSISTANT>')]} {response}"
|
377 |
+
prepared_responses.append(full_response)
|
378 |
+
else:
|
379 |
+
prepared_responses = [
|
380 |
+
f"{self.tokenizer.additional_special_tokens[self.tokenizer.additional_special_tokens.index('<ASSISTANT>')]} {resp}"
|
381 |
+
for resp in responses
|
382 |
+
]
|
383 |
+
|
384 |
+
# Tokenize and encode
|
385 |
+
encodings = self.tokenizer(
|
386 |
+
prepared_responses,
|
387 |
+
padding='max_length',
|
388 |
+
truncation=True,
|
389 |
+
max_length=self.max_length,
|
390 |
+
return_tensors='np' # Use NumPy arrays for compatibility with FAISS
|
391 |
+
)
|
392 |
+
input_ids = encodings['input_ids']
|
393 |
+
|
394 |
+
# Verify token IDs
|
395 |
+
max_id = np.max(input_ids)
|
396 |
+
new_vocab_size = len(self.tokenizer)
|
397 |
+
|
398 |
+
if max_id >= new_vocab_size:
|
399 |
+
logger.error(f"Token ID {max_id} exceeds the vocabulary size {new_vocab_size}.")
|
400 |
+
raise ValueError("Token ID exceeds vocabulary size.")
|
401 |
+
|
402 |
+
# Get embeddings from the shared encoder
|
403 |
+
embeddings = self.encoder(input_ids, training=False).numpy()
|
404 |
+
|
405 |
+
# Normalize embeddings for cosine similarity
|
406 |
+
faiss.normalize_L2(embeddings)
|
407 |
+
|
408 |
+
return embeddings.astype('float32')
|
409 |
+
|
410 |
def prepare_and_save_data(self, dialogues: List[dict], tf_record_path: str, batch_size: int = 32):
|
411 |
"""
|
412 |
Processes dialogues in batches and saves to a TFRecord file using optimized batch tokenization and encoding.
|
|
|
522 |
pbar.update(1)
|
523 |
|
524 |
logger.info(f"Data preparation complete. TFRecord saved.")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
525 |
|
526 |
def _compute_embeddings(self, queries: List[str]) -> None:
|
527 |
new_queries = [q for q in queries if q not in self.query_embeddings_cache]
|
|
|
565 |
hard_negatives = self._find_hard_negatives_batch([query], [positive])[0]
|
566 |
yield (query, positive, hard_negatives)
|
567 |
pbar.update(1)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
568 |
|
569 |
def get_tf_dataset(self, dialogues: List[dict], batch_size: int) -> tf.data.Dataset:
|
570 |
"""
|
|
|
592 |
|
593 |
dataset = dataset.prefetch(tf.data.AUTOTUNE)
|
594 |
return dataset
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
595 |
|
596 |
def _tokenize_triple(
|
597 |
self,
|
|
|
713 |
n_ids = n_ids.astype(np.int32) # shape [batch_size, neg_samples, max_len]
|
714 |
|
715 |
return q_ids, p_ids, n_ids
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
run_model_train.py → train_model.py
RENAMED
@@ -1,36 +1,11 @@
|
|
1 |
import tensorflow as tf
|
2 |
from chatbot_model import RetrievalChatbot, ChatbotConfig
|
3 |
from environment_setup import EnvironmentSetup
|
4 |
-
from response_quality_checker import ResponseQualityChecker
|
5 |
-
from chatbot_validator import ChatbotValidator
|
6 |
from training_plotter import TrainingPlotter
|
7 |
|
8 |
-
# Configure logging
|
9 |
from logger_config import config_logger
|
10 |
logger = config_logger(__name__)
|
11 |
|
12 |
-
def run_interactive_chat(chatbot, quality_checker):
|
13 |
-
"""Separate function for interactive chat loop"""
|
14 |
-
while True:
|
15 |
-
user_input = input("You: ")
|
16 |
-
if user_input.lower() in ['quit', 'exit', 'bye']:
|
17 |
-
print("Assistant: Goodbye!")
|
18 |
-
break
|
19 |
-
|
20 |
-
response, candidates, metrics = chatbot.chat(
|
21 |
-
query=user_input,
|
22 |
-
conversation_history=None,
|
23 |
-
quality_checker=quality_checker,
|
24 |
-
top_k=5
|
25 |
-
)
|
26 |
-
|
27 |
-
print(f"Assistant: {response}")
|
28 |
-
|
29 |
-
if metrics.get('is_confident', False):
|
30 |
-
print("\nAlternative responses:")
|
31 |
-
for resp, score in candidates[1:4]:
|
32 |
-
print(f"Score: {score:.4f} - {resp}")
|
33 |
-
|
34 |
def inspect_tfrecord(tfrecord_file_path, num_examples=3):
|
35 |
def parse_example(example_proto):
|
36 |
feature_description = {
|
@@ -53,7 +28,7 @@ def inspect_tfrecord(tfrecord_file_path, num_examples=3):
|
|
53 |
def main():
|
54 |
|
55 |
# Quick test to inspect TFRecord
|
56 |
-
#inspect_tfrecord('training_data/training_data.tfrecord', num_examples=3)
|
57 |
|
58 |
# Initialize environment
|
59 |
tf.keras.backend.clear_session()
|
@@ -65,49 +40,40 @@ def main():
|
|
65 |
TF_RECORD_FILE_PATH = 'training_data/training_data.tfrecord'
|
66 |
|
67 |
# Optimize batch size for Colab
|
68 |
-
batch_size = env.optimize_batch_size(base_batch_size=16)
|
69 |
-
|
70 |
|
71 |
-
# Initialize
|
72 |
-
config = ChatbotConfig(
|
73 |
-
embedding_dim=768, # DistilBERT
|
74 |
-
max_context_token_limit=512,
|
75 |
-
freeze_embeddings=False,
|
76 |
-
)
|
77 |
|
78 |
# Initialize chatbot
|
79 |
-
#with env.strategy.scope():
|
80 |
chatbot = RetrievalChatbot(config, mode='training')
|
81 |
-
chatbot.build_models()
|
82 |
|
83 |
-
|
84 |
-
|
85 |
-
|
86 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
87 |
tfrecord_file_path=TF_RECORD_FILE_PATH,
|
88 |
epochs=EPOCHS,
|
89 |
batch_size=batch_size,
|
90 |
use_lr_schedule=True,
|
|
|
|
|
91 |
)
|
92 |
|
93 |
# Save final model
|
94 |
model_save_path = env.training_dirs['base'] / 'final_model'
|
95 |
chatbot.save_models(model_save_path)
|
96 |
|
97 |
-
# Run automatic validation
|
98 |
-
quality_checker = ResponseQualityChecker(chatbot=chatbot)
|
99 |
-
validator = ChatbotValidator(chatbot, quality_checker)
|
100 |
-
validation_metrics = validator.run_validation(num_examples=5)
|
101 |
-
logger.info(f"Validation Metrics: {validation_metrics}")
|
102 |
-
|
103 |
# Plot and save training history
|
104 |
plotter = TrainingPlotter(save_dir=env.training_dirs['plots'])
|
105 |
plotter.plot_training_history(chatbot.history)
|
106 |
-
plotter.plot_validation_metrics(validation_metrics)
|
107 |
-
|
108 |
-
# Run interactive chat
|
109 |
-
logger.info("\nStarting interactive chat session...")
|
110 |
-
run_interactive_chat(chatbot, quality_checker)
|
111 |
|
112 |
if __name__ == "__main__":
|
113 |
main()
|
|
|
1 |
import tensorflow as tf
|
2 |
from chatbot_model import RetrievalChatbot, ChatbotConfig
|
3 |
from environment_setup import EnvironmentSetup
|
|
|
|
|
4 |
from training_plotter import TrainingPlotter
|
5 |
|
|
|
6 |
from logger_config import config_logger
|
7 |
logger = config_logger(__name__)
|
8 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
9 |
def inspect_tfrecord(tfrecord_file_path, num_examples=3):
|
10 |
def parse_example(example_proto):
|
11 |
feature_description = {
|
|
|
28 |
def main():
|
29 |
|
30 |
# Quick test to inspect TFRecord
|
31 |
+
# inspect_tfrecord('training_data/training_data.tfrecord', num_examples=3)
|
32 |
|
33 |
# Initialize environment
|
34 |
tf.keras.backend.clear_session()
|
|
|
40 |
TF_RECORD_FILE_PATH = 'training_data/training_data.tfrecord'
|
41 |
|
42 |
# Optimize batch size for Colab
|
43 |
+
batch_size = 32 # env.optimize_batch_size(base_batch_size=16)
|
|
|
44 |
|
45 |
+
# Initialize config
|
46 |
+
config = ChatbotConfig()
|
|
|
|
|
|
|
|
|
47 |
|
48 |
# Initialize chatbot
|
|
|
49 |
chatbot = RetrievalChatbot(config, mode='training')
|
|
|
50 |
|
51 |
+
# Load from a checkpoint
|
52 |
+
checkpoint_dir = 'checkpoints/'
|
53 |
+
latest_checkpoint = tf.train.latest_checkpoint(checkpoint_dir)
|
54 |
+
initial_epoch = 0
|
55 |
+
if latest_checkpoint:
|
56 |
+
ckpt_number = int(latest_checkpoint.split('ckpt-')[-1])
|
57 |
+
initial_epoch = ckpt_number
|
58 |
+
logger.info(f"Found checkpoint {latest_checkpoint}, resuming from epoch {initial_epoch}")
|
59 |
+
|
60 |
+
# Train the model
|
61 |
+
chatbot.train_model(
|
62 |
tfrecord_file_path=TF_RECORD_FILE_PATH,
|
63 |
epochs=EPOCHS,
|
64 |
batch_size=batch_size,
|
65 |
use_lr_schedule=True,
|
66 |
+
test_mode=False,
|
67 |
+
initial_epoch=initial_epoch
|
68 |
)
|
69 |
|
70 |
# Save final model
|
71 |
model_save_path = env.training_dirs['base'] / 'final_model'
|
72 |
chatbot.save_models(model_save_path)
|
73 |
|
|
|
|
|
|
|
|
|
|
|
|
|
74 |
# Plot and save training history
|
75 |
plotter = TrainingPlotter(save_dir=env.training_dirs['plots'])
|
76 |
plotter.plot_training_history(chatbot.history)
|
|
|
|
|
|
|
|
|
|
|
77 |
|
78 |
if __name__ == "__main__":
|
79 |
main()
|
validate_model.py
ADDED
@@ -0,0 +1,117 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import json
|
3 |
+
from chatbot_model import ChatbotConfig, RetrievalChatbot
|
4 |
+
from response_quality_checker import ResponseQualityChecker
|
5 |
+
from chatbot_validator import ChatbotValidator
|
6 |
+
from training_plotter import TrainingPlotter
|
7 |
+
from environment_setup import EnvironmentSetup
|
8 |
+
|
9 |
+
from logger_config import config_logger
|
10 |
+
logger = config_logger(__name__)
|
11 |
+
|
12 |
+
def run_interactive_chat(chatbot, quality_checker):
|
13 |
+
"""Separate function for interactive chat loop"""
|
14 |
+
while True:
|
15 |
+
try:
|
16 |
+
user_input = input("You: ")
|
17 |
+
except (KeyboardInterrupt, EOFError):
|
18 |
+
print("\nAssistant: Goodbye!")
|
19 |
+
break
|
20 |
+
|
21 |
+
if user_input.lower() in ['quit', 'exit', 'bye']:
|
22 |
+
print("Assistant: Goodbye!")
|
23 |
+
break
|
24 |
+
|
25 |
+
response, candidates, metrics = chatbot.chat(
|
26 |
+
query=user_input,
|
27 |
+
conversation_history=None,
|
28 |
+
quality_checker=quality_checker,
|
29 |
+
top_k=5
|
30 |
+
)
|
31 |
+
|
32 |
+
print(f"Assistant: {response}")
|
33 |
+
|
34 |
+
if metrics.get('is_confident', False):
|
35 |
+
print("\nAlternative responses:")
|
36 |
+
for resp, score in candidates[1:4]:
|
37 |
+
print(f"Score: {score:.4f} - {resp}")
|
38 |
+
else:
|
39 |
+
print("\n[Low Confidence]: Consider rephrasing your query for better assistance.")
|
40 |
+
|
41 |
+
# TODO:
|
42 |
+
def validate_chatbot():
|
43 |
+
# Initialize environment
|
44 |
+
env = EnvironmentSetup()
|
45 |
+
env.initialize()
|
46 |
+
|
47 |
+
MODEL_DIR = 'models'
|
48 |
+
FAISS_INDICES_DIR = os.path.join(MODEL_DIR, 'faiss_indices')
|
49 |
+
FAISS_INDEX_PRODUCTION_PATH = os.path.join(FAISS_INDICES_DIR, 'faiss_index_production.index')
|
50 |
+
FAISS_INDEX_TEST_PATH = os.path.join(FAISS_INDICES_DIR, 'faiss_index_test.index')
|
51 |
+
RESPONSE_POOL_PRODUCTION_PATH = FAISS_INDEX_PRODUCTION_PATH.replace('.index', '_responses.json')
|
52 |
+
RESPONSE_POOL_TEST_PATH = FAISS_INDEX_TEST_PATH.replace('.index', '_responses.json')
|
53 |
+
ENVIRONMENT = 'production' # or 'test'
|
54 |
+
if ENVIRONMENT == 'test':
|
55 |
+
FAISS_INDEX_PATH = FAISS_INDEX_TEST_PATH
|
56 |
+
RESPONSE_POOL_PATH = RESPONSE_POOL_TEST_PATH
|
57 |
+
else:
|
58 |
+
FAISS_INDEX_PATH = FAISS_INDEX_PRODUCTION_PATH
|
59 |
+
RESPONSE_POOL_PATH = RESPONSE_POOL_PRODUCTION_PATH
|
60 |
+
|
61 |
+
# Load config
|
62 |
+
config = ChatbotConfig()
|
63 |
+
|
64 |
+
# Initialize RetrievalChatbot in 'inference' mode
|
65 |
+
try:
|
66 |
+
chatbot = RetrievalChatbot(config=config, mode='inference')
|
67 |
+
logger.info("RetrievalChatbot initialized in 'inference' mode.")
|
68 |
+
except Exception as e:
|
69 |
+
logger.error(f"Failed to initialize RetrievalChatbot: {e}")
|
70 |
+
return
|
71 |
+
|
72 |
+
# Ensure FAISS index and response pool are accessible, then load
|
73 |
+
if not os.path.exists(FAISS_INDEX_PATH) or not os.path.exists(RESPONSE_POOL_PATH):
|
74 |
+
logger.error("FAISS index or response pool file is missing.")
|
75 |
+
return
|
76 |
+
|
77 |
+
try:
|
78 |
+
chatbot.data_pipeline.load_faiss_index(FAISS_INDEX_PATH)
|
79 |
+
logger.info(f"FAISS index loaded from {FAISS_INDEX_PATH}.")
|
80 |
+
|
81 |
+
with open(RESPONSE_POOL_PATH, 'r', encoding='utf-8') as f:
|
82 |
+
chatbot.data_pipeline.response_pool = json.load(f)
|
83 |
+
logger.info(f"Response pool loaded from {RESPONSE_POOL_PATH}.")
|
84 |
+
|
85 |
+
chatbot.data_pipeline.validate_faiss_index()
|
86 |
+
logger.info("FAISS index and response pool validated successfully.")
|
87 |
+
except Exception as e:
|
88 |
+
logger.error(f"Failed to load FAISS index: {e}")
|
89 |
+
return
|
90 |
+
|
91 |
+
# Initialize ResponseQualityChecker and ChatbotValidator
|
92 |
+
quality_checker = ResponseQualityChecker(data_pipeline=chatbot.data_pipeline)
|
93 |
+
validator = ChatbotValidator(chatbot=chatbot, quality_checker=quality_checker)
|
94 |
+
logger.info("ResponseQualityChecker and ChatbotValidator initialized.")
|
95 |
+
|
96 |
+
# Run validation
|
97 |
+
try:
|
98 |
+
validation_metrics = validator.run_validation(num_examples=5)
|
99 |
+
logger.info(f"Validation Metrics: {validation_metrics}")
|
100 |
+
except Exception as e:
|
101 |
+
logger.error(f"Validation process failed: {e}")
|
102 |
+
return
|
103 |
+
|
104 |
+
# Plot validation_metrics
|
105 |
+
try:
|
106 |
+
plotter = TrainingPlotter(save_dir=env.training_dirs['plots'])
|
107 |
+
plotter.plot_validation_metrics(validation_metrics)
|
108 |
+
logger.info("Validation metrics plotted successfully.")
|
109 |
+
except Exception as e:
|
110 |
+
logger.error(f"Failed to plot validation metrics: {e}")
|
111 |
+
|
112 |
+
# Run interactive chat
|
113 |
+
logger.info("\nStarting interactive chat session...")
|
114 |
+
run_interactive_chat(chatbot, quality_checker)
|
115 |
+
|
116 |
+
if __name__ == '__main__':
|
117 |
+
validate_chatbot()
|