JoeArmani commited on
Commit
74af405
·
1 Parent(s): 2183656

data processing pipeline

Browse files
Files changed (4) hide show
  1. chatbot_model.py +133 -637
  2. requirements.txt +1 -0
  3. run_data_preparer.py +182 -0
  4. tf_data_pipeline.py +734 -0
chatbot_model.py CHANGED
@@ -2,7 +2,7 @@ import time
2
  from transformers import TFAutoModel, AutoTokenizer
3
  import tensorflow as tf
4
  import numpy as np
5
- from typing import Generator, List, Tuple, Dict, Optional, Union, Any
6
  import math
7
  from dataclasses import dataclass
8
  import json
@@ -10,6 +10,7 @@ from pathlib import Path
10
  import datetime
11
  import faiss
12
  import gc
 
13
  from response_quality_checker import ResponseQualityChecker
14
  from cross_encoder_reranker import CrossEncoderReranker
15
  from conversation_summarizer import DeviceAwareModel, Summarizer
@@ -24,14 +25,12 @@ logger = config_logger(__name__)
24
  @dataclass
25
  class ChatbotConfig:
26
  """Configuration for the RetrievalChatbot."""
27
- vocab_size: int = 30526 # DistilBERT vocab size + special tokens
28
  max_context_token_limit: int = 512
29
  embedding_dim: int = 768
30
  encoder_units: int = 256
31
  num_attention_heads: int = 8
32
  dropout_rate: float = 0.2
33
  l2_reg_weight: float = 0.001
34
- margin: float = 0.3
35
  learning_rate: float = 0.001
36
  min_text_length: int = 3
37
  max_context_turns: int = 5
@@ -39,16 +38,19 @@ class ChatbotConfig:
39
  pretrained_model: str = 'distilbert-base-uncased'
40
  dtype: str = 'float32'
41
  freeze_embeddings: bool = False
42
- embedding_batch_size: int = 128
43
- # Additional configurations can be added here
 
 
 
44
 
45
- def to_dict(self) -> dict:
46
  """Convert config to dictionary."""
47
- return {k: str(v) if isinstance(v, Path) else v
48
  for k, v in self.__dict__.items()}
49
 
50
  @classmethod
51
- def from_dict(cls, config_dict: dict) -> 'ChatbotConfig':
52
  """Create config from dictionary."""
53
  return cls(**{k: v for k, v in config_dict.items()
54
  if k in cls.__dataclass_fields__})
@@ -59,24 +61,17 @@ class EncoderModel(tf.keras.Model):
59
  self,
60
  config: ChatbotConfig,
61
  name: str = "encoder",
62
- shared_weights: bool = False,
63
  **kwargs
64
  ):
65
  super().__init__(name=name, **kwargs)
66
  self.config = config
67
- self.shared_weights = shared_weights
68
 
69
  # Load pretrained model
70
  self.pretrained = TFAutoModel.from_pretrained(config.pretrained_model)
71
 
72
- # Freeze pretrained weights if specified
73
- self.pretrained.distilbert.embeddings.trainable = False
74
- for i, layer_module in enumerate(self.pretrained.distilbert.transformer.layer):
75
- if i < 1: # freeze first layer
76
- layer_module.trainable = False
77
- else:
78
- layer_module.trainable = True
79
-
80
  # Pooling layer (Global Average Pooling)
81
  self.pooler = tf.keras.layers.GlobalAveragePooling1D()
82
 
@@ -90,9 +85,27 @@ class EncoderModel(tf.keras.Model):
90
  # Dropout and normalization
91
  self.dropout = tf.keras.layers.Dropout(config.dropout_rate)
92
  self.normalize = tf.keras.layers.Lambda(
93
- lambda x: tf.nn.l2_normalize(x, axis=1)
 
94
  )
95
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
96
  def call(self, inputs: tf.Tensor, training: bool = False) -> tf.Tensor:
97
  """Forward pass."""
98
  # Get pretrained embeddings
@@ -112,46 +125,33 @@ class EncoderModel(tf.keras.Model):
112
  config = super().get_config()
113
  config.update({
114
  "config": self.config.to_dict(),
115
- "shared_weights": self.shared_weights,
116
  "name": self.name
117
  })
118
  return config
119
 
120
  class RetrievalChatbot(DeviceAwareModel):
121
  """Retrieval-based chatbot using pretrained embeddings and FAISS for similarity search."""
122
- def __init__(self, config: ChatbotConfig, dialogues: List[dict] = [], device: str = None,
123
- strategy=None, reranker: Optional[CrossEncoderReranker] = None,
124
- summarizer: Optional[Summarizer] = None
125
- ):
 
 
 
 
 
 
126
  self.config = config
127
  self.strategy = strategy
128
- self.setup_device(device)
129
-
130
- if reranker is None:
131
- logger.info("Creating default CrossEncoderReranker...")
132
- reranker = CrossEncoderReranker(model_name="cross-encoder/ms-marco-MiniLM-L-12-v2")
133
- self.reranker = reranker
134
-
135
- if summarizer is None:
136
- logger.info("Creating default Summarizer...")
137
- summarizer = Summarizer(device=self.device)
138
- self.summarizer = summarizer
139
-
140
- # Special tokens
141
- self.special_tokens = {
142
- "user": "<USER>",
143
- "assistant": "<ASSISTANT>",
144
- "context": "<CONTEXT>",
145
- "sep": "<SEP>"
146
- }
147
-
148
- # Initialize tokenizer and add special tokens
149
- self.tokenizer = AutoTokenizer.from_pretrained(config.pretrained_model)
150
- self.tokenizer.add_special_tokens(
151
- {'additional_special_tokens': list(self.special_tokens.values())}
152
- )
153
 
 
 
 
 
154
  self.memory_monitor = GPUMemoryMonitor()
 
 
155
  self.min_batch_size = 8
156
  self.max_batch_size = 128
157
  self.current_batch_size = 32
@@ -166,9 +166,62 @@ class RetrievalChatbot(DeviceAwareModel):
166
  "train_metrics": {},
167
  "val_metrics": {}
168
  }
169
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
170
  def build_models(self):
171
- """Initialize the shared encoder."""
172
  logger.info("Building encoder model...")
173
  tf.keras.backend.clear_session()
174
 
@@ -176,6 +229,7 @@ class RetrievalChatbot(DeviceAwareModel):
176
  self.encoder = EncoderModel(
177
  self.config,
178
  name="shared_encoder",
 
179
  )
180
 
181
  # Resize token embeddings after adding special tokens
@@ -183,31 +237,14 @@ class RetrievalChatbot(DeviceAwareModel):
183
  self.encoder.pretrained.resize_token_embeddings(new_vocab_size)
184
  logger.info(f"Token embeddings resized to: {new_vocab_size}")
185
 
186
- # Initialize FAISS index (moved here from __init__)
187
  self._initialize_faiss()
188
- # Compute embeddings after FAISS is initialized and moved
189
- self._compute_and_index_embeddings()
190
 
191
- # Try different ways to get embedding dimension
192
- try:
193
- # First try: from config
194
- embedding_dim = self.encoder.pretrained.config.dim
195
- logger.info("Got embedding dim from config")
196
- except AttributeError:
197
- try:
198
- # Second try: from word embeddings
199
- embedding_dim = self.encoder.pretrained.distilbert.embeddings.word_embeddings.embedding_dim
200
- logger.info("Got embedding dim from word embeddings")
201
- except AttributeError:
202
- try:
203
- # Third try: from embeddings module
204
- embedding_dim = self.encoder.pretrained.distilbert.embeddings.embedding_dim
205
- logger.info("Got embedding dim from embeddings module")
206
- except AttributeError:
207
- # Fallback to config value
208
- embedding_dim = self.config.embedding_dim
209
- logger.info("Using config embedding dim")
210
 
 
 
211
  vocab_size = len(self.tokenizer)
212
 
213
  logger.info(f"Encoder Embedding Dimension: {embedding_dim}")
@@ -217,29 +254,6 @@ class RetrievalChatbot(DeviceAwareModel):
217
  else:
218
  logger.error("Vocabulary size is less than embedding dimension.")
219
  raise ValueError("Vocabulary size is less than embedding dimension.")
220
-
221
- def _collect_responses(self, dialogues: List[dict]) -> Tuple[List[str], List[str]]:
222
- """Collect all unique responses from dialogues."""
223
- logger.info("Collecting responses from dialogues...")
224
-
225
- responses = []
226
- try:
227
- progress_bar = tqdm(dialogues, desc="Collecting assistant responses")
228
- except ImportError:
229
- progress_bar = dialogues
230
- logger.info("Progress bar disabled - continuing without visual progress")
231
-
232
- for dialogue in progress_bar:
233
- turns = dialogue.get('turns', [])
234
- for turn in turns:
235
- if turn.get('speaker') == 'assistant' and 'text' in turn:
236
- responses.append(turn['text'].strip())
237
-
238
- # Remove duplicates
239
- unique_responses = list(set(responses))
240
- logger.info(f"Found {len(unique_responses)} unique responses.")
241
-
242
- return responses, unique_responses
243
 
244
  def _adjust_batch_size(self) -> None:
245
  """Dynamically adjust batch size based on GPU memory usage."""
@@ -288,6 +302,7 @@ class RetrievalChatbot(DeviceAwareModel):
288
  logger.warning(f"Using CPU due to GPU initialization error: {e}")
289
 
290
  # TODO: figure out buf with faiss-gpu
 
291
  try:
292
  # Create appropriate index based on dataset size
293
  if len(self.unique_responses) < 1000:
@@ -860,33 +875,33 @@ class RetrievalChatbot(DeviceAwareModel):
860
  logger.info(f"Models and tokenizer loaded from {load_dir}.")
861
  return chatbot
862
 
863
- @staticmethod
864
- def load_training_data(data_path: Union[str, Path], debug_samples: Optional[int] = None) -> List[dict]:
865
- """
866
- Load training data from a JSON file.
867
 
868
- Args:
869
- data_path (Union[str, Path]): Path to the JSON file containing dialogues.
870
- debug_samples (Optional[int]): Number of samples to load for debugging.
871
 
872
- Returns:
873
- List[dict]: List of dialogue dictionaries.
874
- """
875
- logger.info(f"Loading training data from {data_path}...")
876
- data_path = Path(data_path)
877
- if not data_path.exists():
878
- logger.error(f"Data file {data_path} does not exist.")
879
- return []
880
 
881
- with open(data_path, 'r', encoding='utf-8') as f:
882
- dialogues = json.load(f)
883
 
884
- if debug_samples is not None:
885
- dialogues = dialogues[:debug_samples]
886
- logger.info(f"Debug mode: Limited to {debug_samples} dialogues")
887
 
888
- logger.info(f"Loaded {len(dialogues)} dialogues.")
889
- return dialogues
890
 
891
  def train_streaming(
892
  self,
@@ -1336,522 +1351,3 @@ class RetrievalChatbot(DeviceAwareModel):
1336
 
1337
  conversation_parts.append(f"{self.special_tokens['user']} {query}")
1338
  return "\n".join(conversation_parts)
1339
-
1340
- class TFDataPipeline:
1341
- def __init__(
1342
- self,
1343
- embedding_batch_size,
1344
- tokenizer,
1345
- encoder,
1346
- index,
1347
- response_pool,
1348
- max_length: int,
1349
- neg_samples: int,
1350
- ):
1351
- self.embedding_batch_size = embedding_batch_size
1352
- self.tokenizer = tokenizer
1353
- self.encoder = encoder
1354
- self.index = index # CPU version of the index
1355
- self.response_pool = response_pool
1356
- self.max_length = max_length
1357
- self.neg_samples = neg_samples
1358
- self.embedding_batch_size = 16 if len(response_pool) < 100 else 64
1359
- self.search_batch_size = 8 if len(response_pool) < 100 else 32
1360
- self.max_batch_size = 32 if len(response_pool) < 100 else 256
1361
- self.memory_monitor = GPUMemoryMonitor()
1362
- self.max_retries = 3
1363
-
1364
- # In-memory cache for embeddings
1365
- self.query_embeddings_cache = {}
1366
-
1367
- def _extract_pairs_from_dialogue(self, dialogue: dict) -> List[Tuple[str, str]]:
1368
- """Extract query-response pairs from a dialogue."""
1369
- pairs = []
1370
- turns = dialogue.get('turns', [])
1371
-
1372
- for i in range(len(turns) - 1):
1373
- current_turn = turns[i]
1374
- next_turn = turns[i+1]
1375
-
1376
- if (current_turn.get('speaker') == 'user' and
1377
- next_turn.get('speaker') == 'assistant' and
1378
- 'text' in current_turn and
1379
- 'text' in next_turn):
1380
-
1381
- query = current_turn['text'].strip()
1382
- positive = next_turn['text'].strip()
1383
- pairs.append((query, positive))
1384
-
1385
- return pairs
1386
-
1387
- def estimate_total_pairs(self, dialogues: List[dict]) -> int:
1388
- """Estimate total number of training pairs including hard negatives."""
1389
- base_pairs = sum(
1390
- len([
1391
- 1 for i in range(len(d.get('turns', [])) - 1)
1392
- if (d['turns'][i].get('speaker') == 'user' and
1393
- d['turns'][i+1].get('speaker') == 'assistant')
1394
- ])
1395
- for d in dialogues
1396
- )
1397
- # Account for hard negatives
1398
- return base_pairs * (1 + self.neg_samples)
1399
-
1400
- def _find_hard_negatives_batch(self, queries: List[str], positives: List[str]) -> List[List[str]]:
1401
- """Find hard negatives for a batch of queries with error handling and retries."""
1402
- retry_count = 0
1403
- total_responses = len(self.response_pool)
1404
-
1405
- while retry_count < self.max_retries:
1406
- try:
1407
- query_embeddings = np.vstack([
1408
- self.query_embeddings_cache[q] for q in queries
1409
- ]).astype(np.float32)
1410
-
1411
- query_embeddings = np.ascontiguousarray(query_embeddings)
1412
- faiss.normalize_L2(query_embeddings)
1413
-
1414
- k = 1 # TODO: try higher k for better results
1415
- #logger.debug(f"Searching with k={k} among {total_responses} responses")
1416
-
1417
- distances, indices = self.index.search(query_embeddings, k)
1418
-
1419
- all_negatives = []
1420
- for query_indices, query, positive in zip(indices, queries, positives):
1421
- negatives = []
1422
- positive_strip = positive.strip()
1423
- seen = {positive_strip}
1424
-
1425
- for idx in query_indices:
1426
- if idx >= 0 and idx < total_responses:
1427
- candidate = self.response_pool[idx].strip()
1428
- if candidate and candidate not in seen:
1429
- seen.add(candidate)
1430
- negatives.append(candidate)
1431
- if len(negatives) >= self.neg_samples:
1432
- break
1433
-
1434
- # Pad with a special empty negative if necessary
1435
- while len(negatives) < self.neg_samples:
1436
- negatives.append("<EMPTY_NEGATIVE>") # Use a special token
1437
-
1438
- all_negatives.append(negatives)
1439
-
1440
- return all_negatives
1441
-
1442
- except Exception as e:
1443
- retry_count += 1
1444
- logger.warning(f"Hard negative search attempt {retry_count} failed: {e}")
1445
- if retry_count == self.max_retries:
1446
- logger.error("Max retries reached for hard negative search")
1447
- return [["<EMPTY_NEGATIVE>"] * self.neg_samples for _ in queries] # Return empty negatives for all queries
1448
- gc.collect()
1449
- if tf.config.list_physical_devices('GPU'):
1450
- tf.keras.backend.clear_session()
1451
-
1452
- def _tokenize_negatives_tf(self, negatives):
1453
- """Tokenizes negatives using tf.py_function."""
1454
- # Handle the case where negatives is an empty tensor
1455
- if tf.size(negatives) == 0:
1456
- return tf.zeros([0, self.neg_samples, self.max_length], dtype=tf.int32)
1457
-
1458
- # Convert EagerTensor to a list of strings
1459
- negatives_list = []
1460
- for neg_list in negatives.numpy():
1461
- decoded_negs = [neg.decode("utf-8") for neg in neg_list if neg] # Filter out empty strings
1462
- negatives_list.append(decoded_negs)
1463
-
1464
- # Flatten the list of lists
1465
- flattened_negatives = [neg for sublist in negatives_list for neg in sublist]
1466
-
1467
- # Tokenize the flattened negatives
1468
- if flattened_negatives:
1469
- n_tokens = self.tokenizer(
1470
- flattened_negatives,
1471
- padding='max_length',
1472
- truncation=True,
1473
- max_length=self.max_length,
1474
- return_tensors='tf'
1475
- )
1476
- # Reshape the tokens
1477
- n_tokens_reshaped = tf.reshape(n_tokens['input_ids'], [-1, self.neg_samples, self.max_length])
1478
- return n_tokens_reshaped
1479
- else:
1480
- return tf.zeros([0, self.neg_samples, self.max_length], dtype=tf.int32)
1481
-
1482
- def _compute_embeddings(self, queries: List[str]) -> None:
1483
- """Computes and caches embeddings for new queries."""
1484
- new_queries = [q for q in queries if q not in self.query_embeddings_cache]
1485
- if not new_queries:
1486
- return # All queries already cached
1487
-
1488
- new_embeddings = []
1489
- for i in range(0, len(new_queries), self.embedding_batch_size):
1490
- batch_queries = new_queries[i:i + self.embedding_batch_size]
1491
-
1492
- encoded = self.tokenizer(
1493
- batch_queries,
1494
- padding=True,
1495
- truncation=True,
1496
- max_length=self.max_length,
1497
- return_tensors='tf'
1498
- )
1499
-
1500
- # Compute embeddings on CPU
1501
- with tf.device('/CPU:0'):
1502
- batch_embeddings = self.encoder(encoded['input_ids'], training=False).numpy()
1503
-
1504
- new_embeddings.extend(batch_embeddings)
1505
-
1506
- # Update cache with new embeddings
1507
- for query, emb in zip(new_queries, new_embeddings):
1508
- self.query_embeddings_cache[query] = emb
1509
-
1510
- def data_generator(self, dialogues: List[dict]) -> Generator[Tuple[str, str, List[str]], None, None]:
1511
- """
1512
- Generates training examples: (query, positive, hard_negatives).
1513
- Wrapped the outer loop with tqdm for progress tracking.
1514
- """
1515
- total_dialogues = len(dialogues)
1516
- logger.debug(f"Total dialogues to process: {total_dialogues}")
1517
-
1518
- # Initialize tqdm progress bar
1519
- with tqdm(total=total_dialogues, desc="Processing Dialogues", unit="dialogue") as pbar:
1520
- for dialogue in dialogues:
1521
- pairs = self._extract_pairs_from_dialogue(dialogue)
1522
- for query, positive in pairs:
1523
- # Ensure embeddings are computed, find hard negatives, etc.
1524
- self._compute_embeddings([query])
1525
- hard_negatives = self._find_hard_negatives_batch([query], [positive])[0]
1526
- yield (query, positive, hard_negatives)
1527
- pbar.update(1)
1528
-
1529
- def _prepare_batch(self, queries: tf.Tensor, positives: tf.Tensor, negatives: tf.Tensor) -> Tuple[tf.Tensor, tf.Tensor]:
1530
- """Prepares a batch of data for training."""
1531
-
1532
- # Convert EagerTensors to lists of strings
1533
- queries_list = [query.decode("utf-8") for query in queries.numpy()]
1534
- positives_list = [pos.decode("utf-8") for pos in positives.numpy()]
1535
-
1536
- # Tokenize queries and positives
1537
- q_tokens = self.tokenizer(queries_list, padding='max_length', truncation=True, max_length=self.max_length, return_tensors='tf')
1538
- p_tokens = self.tokenizer(positives_list, padding='max_length', truncation=True, max_length=self.max_length, return_tensors='tf')
1539
-
1540
- # Decode negatives and ensure they are lists of strings
1541
- negatives_list = []
1542
- for neg_list in negatives.numpy():
1543
- decoded_negs = [neg.decode("utf-8") for neg in neg_list if neg] # Filter out empty strings
1544
- negatives_list.append(decoded_negs)
1545
-
1546
- # Flatten negatives for tokenization if there are any valid negatives
1547
- flattened_negatives = [neg for sublist in negatives_list for neg in sublist if neg]
1548
-
1549
- # Tokenize negatives if there are any
1550
- n_tokens_reshaped = None
1551
- if flattened_negatives:
1552
- n_tokens = self.tokenizer(flattened_negatives, padding='max_length', truncation=True, max_length=self.max_length, return_tensors='tf')
1553
-
1554
- # Reshape n_tokens to match the expected shape based on the number of negatives per query
1555
- # This part may need adjustment if the number of negatives varies per query
1556
- n_tokens_reshaped = tf.reshape(n_tokens['input_ids'], [len(queries_list), -1, self.max_length])
1557
- else:
1558
- # Create a placeholder tensor for the case where there are no negatives
1559
- n_tokens_reshaped = tf.zeros([len(queries_list), 0, self.max_length], dtype=tf.int32)
1560
-
1561
- # Ensure n_tokens_reshaped has a consistent shape even when there are no negatives
1562
- # Adjust shape to [batch_size, num_neg_samples, max_length]
1563
- if n_tokens_reshaped.shape[1] != self.neg_samples:
1564
- # Pad or truncate the second dimension to match neg_samples
1565
- padding = tf.zeros([len(queries_list), tf.maximum(0, self.neg_samples - n_tokens_reshaped.shape[1]), self.max_length], dtype=tf.int32)
1566
- n_tokens_reshaped = tf.concat([n_tokens_reshaped, padding], axis=1)
1567
- n_tokens_reshaped = n_tokens_reshaped[:, :self.neg_samples, :]
1568
-
1569
- # Concatenate the positive and negative examples along the 'neg_samples' dimension
1570
- combined_p_n_tokens = tf.concat([tf.expand_dims(p_tokens['input_ids'], axis=1), n_tokens_reshaped], axis=1)
1571
-
1572
- return q_tokens['input_ids'], combined_p_n_tokens
1573
-
1574
- def get_tf_dataset(self, dialogues: List[dict], batch_size: int) -> tf.data.Dataset:
1575
- """
1576
- Creates a tf.data.Dataset for streaming training that yields
1577
- (input_ids_query, input_ids_positive, input_ids_negatives).
1578
- """
1579
- # 1) Start with a generator dataset
1580
- dataset = tf.data.Dataset.from_generator(
1581
- lambda: self.data_generator(dialogues),
1582
- output_signature=(
1583
- tf.TensorSpec(shape=(), dtype=tf.string), # Query (single string)
1584
- tf.TensorSpec(shape=(), dtype=tf.string), # Positive (single string)
1585
- tf.TensorSpec(shape=(None,), dtype=tf.string) # Hard Negatives (list of strings)
1586
- )
1587
- )
1588
-
1589
- # 2) Batch the raw strings
1590
- dataset = dataset.batch(batch_size)
1591
-
1592
- # 3) Now map them through a tokenize step (via py_function)
1593
- dataset = dataset.map(
1594
- lambda q, p, n: self._tokenize_triple(q, p, n),
1595
- num_parallel_calls=1 #tf.data.AUTOTUNE
1596
- )
1597
-
1598
- dataset = dataset.prefetch(tf.data.AUTOTUNE)
1599
- return dataset
1600
-
1601
- def _tokenize_triple(
1602
- self,
1603
- q: tf.Tensor,
1604
- p: tf.Tensor,
1605
- n: tf.Tensor
1606
- ) -> Tuple[tf.Tensor, tf.Tensor, tf.Tensor]:
1607
- """
1608
- Wraps a Python function via tf.py_function to convert tf.Tensors of strings
1609
- -> Python lists of strings -> HF tokenizer -> Tensors of IDs.
1610
-
1611
- q is shape [batch_size], p is shape [batch_size],
1612
- n is shape [batch_size, neg_samples] (i.e., each row is a list of negatives).
1613
- """
1614
- # Use tf.py_function with limited parallelism
1615
- q_ids, p_ids, n_ids = tf.py_function(
1616
- func=self._tokenize_triple_py,
1617
- inp=[q, p, n, tf.constant(self.max_length), tf.constant(self.neg_samples)],
1618
- Tout=[tf.int32, tf.int32, tf.int32]
1619
- )
1620
-
1621
- # Manually set shape information
1622
- q_ids.set_shape([None, self.max_length]) # [batch_size, max_length]
1623
- p_ids.set_shape([None, self.max_length]) # [batch_size, max_length]
1624
- n_ids.set_shape([None, self.neg_samples, self.max_length]) # [batch_size, neg_samples, max_length]
1625
-
1626
- return q_ids, p_ids, n_ids
1627
- # def _tokenize_triple(
1628
- # self,
1629
- # q: tf.Tensor,
1630
- # p: tf.Tensor,
1631
- # n: tf.Tensor
1632
- # ) -> Tuple[tf.Tensor, tf.Tensor, tf.Tensor]:
1633
- # """
1634
- # Wraps a Python function via tf.py_function to convert tf.Tensors of strings
1635
- # -> Python lists of strings -> HF tokenizer -> Tensors of IDs.
1636
-
1637
- # q is shape [batch_size], p is shape [batch_size],
1638
- # n is shape [batch_size, None] (i.e. each row is a variable number of negatives).
1639
- # """
1640
- # # Use tf.py_function
1641
- # # We pass in self.max_length as well, so we can do it in one shot.
1642
- # q_ids, p_ids, n_ids = tf.py_function(
1643
- # func=self._tokenize_triple_py,
1644
- # inp=[q, p, n, tf.constant(self.max_length), tf.constant(self.neg_samples)],
1645
- # Tout=[tf.int32, tf.int32, tf.int32]
1646
- # )
1647
-
1648
- # # We must manually set shape information so that TF data pipeline knows the dimensions
1649
- # q_ids.set_shape([None, self.max_length]) # [batch_size, max_length]
1650
- # p_ids.set_shape([None, self.max_length]) # [batch_size, max_length]
1651
- # n_ids.set_shape([None, self.neg_samples, self.max_length])
1652
- # # The negative dimension is set to `self.neg_samples` for consistency.
1653
-
1654
- # return q_ids, p_ids, n_ids
1655
-
1656
- def _tokenize_triple_py(
1657
- self,
1658
- q: tf.Tensor,
1659
- p: tf.Tensor,
1660
- n: tf.Tensor,
1661
- max_len: tf.Tensor,
1662
- neg_samples: tf.Tensor
1663
- ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
1664
- """
1665
- Python function that:
1666
- - Decodes each tf.string Tensor to a Python list of strings
1667
- - Calls the HF tokenizer
1668
- - Reshapes negatives
1669
- - Returns np.array of int32s for (q_ids, p_ids, n_ids).
1670
-
1671
- q: shape [batch_size], p: shape [batch_size]
1672
- n: shape [batch_size, neg_samples]
1673
- max_len: scalar int
1674
- neg_samples: scalar int
1675
- """
1676
- max_len = int(max_len.numpy()) # Convert to Python int
1677
- neg_samples = int(neg_samples.numpy())
1678
-
1679
- # 1) Convert Tensors -> Python lists of strings
1680
- q_list = [q_i.decode("utf-8") for q_i in q.numpy()] # shape [batch_size]
1681
- p_list = [p_i.decode("utf-8") for p_i in p.numpy()] # shape [batch_size]
1682
-
1683
- # shape [batch_size, neg_samples], decode each row
1684
- n_list = []
1685
- for row in n.numpy():
1686
- # row is shape [neg_samples], each is a tf.string
1687
- decoded = [neg.decode("utf-8") for neg in row]
1688
- n_list.append(decoded)
1689
-
1690
- # 2) Tokenize queries & positives
1691
- q_enc = self.tokenizer(
1692
- q_list,
1693
- padding="max_length",
1694
- truncation=True,
1695
- max_length=max_len,
1696
- return_tensors="np"
1697
- )
1698
- p_enc = self.tokenizer(
1699
- p_list,
1700
- padding="max_length",
1701
- truncation=True,
1702
- max_length=max_len,
1703
- return_tensors="np"
1704
- )
1705
-
1706
- # 3) Tokenize negatives
1707
- # Flatten [batch_size, neg_samples] -> single list
1708
- flattened_negatives = [neg for row in n_list for neg in row]
1709
- if len(flattened_negatives) == 0:
1710
- # No negatives at all: return a zero array
1711
- n_ids = np.zeros((len(q_list), neg_samples, max_len), dtype=np.int32)
1712
- else:
1713
- n_enc = self.tokenizer(
1714
- flattened_negatives,
1715
- padding="max_length",
1716
- truncation=True,
1717
- max_length=max_len,
1718
- return_tensors="np"
1719
- )
1720
- # shape [batch_size * neg_samples, max_len]
1721
- n_input_ids = n_enc["input_ids"]
1722
-
1723
- # We want to reshape to [batch_size, neg_samples, max_len]
1724
- # Handle cases where there might be fewer negatives
1725
- batch_size = len(q_list)
1726
- n_ids_list = []
1727
- for i in range(batch_size):
1728
- start_idx = i * neg_samples
1729
- end_idx = start_idx + neg_samples
1730
- row_negs = n_input_ids[start_idx:end_idx]
1731
-
1732
- # If fewer negatives, pad with zeros
1733
- if row_negs.shape[0] < neg_samples:
1734
- deficit = neg_samples - row_negs.shape[0]
1735
- pad_arr = np.zeros((deficit, max_len), dtype=np.int32)
1736
- row_negs = np.concatenate([row_negs, pad_arr], axis=0)
1737
-
1738
- n_ids_list.append(row_negs)
1739
-
1740
- # stack them -> shape [batch_size, neg_samples, max_len]
1741
- n_ids = np.stack(n_ids_list, axis=0)
1742
-
1743
- # 4) Return as np.int32 arrays
1744
- q_ids = q_enc["input_ids"].astype(np.int32) # shape [batch_size, max_len]
1745
- p_ids = p_enc["input_ids"].astype(np.int32) # shape [batch_size, max_len]
1746
- n_ids = n_ids.astype(np.int32) # shape [batch_size, neg_samples, max_len]
1747
-
1748
- return q_ids, p_ids, n_ids
1749
- # def _tokenize_triple_py(
1750
- # self,
1751
- # q: tf.Tensor,
1752
- # p: tf.Tensor,
1753
- # n: tf.Tensor,
1754
- # max_len: tf.Tensor,
1755
- # neg_samples: tf.Tensor
1756
- # ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
1757
- # """
1758
- # Python function that:
1759
- # - Decodes each tf.string Tensor to a Python list of strings
1760
- # - Calls the HF tokenizer
1761
- # - Reshapes negatives
1762
- # - Returns np.array of int32s for (q_ids, p_ids, n_ids).
1763
-
1764
- # q: shape [batch_size], p: shape [batch_size]
1765
- # n: shape [batch_size, None]
1766
- # max_len: scalar int
1767
- # neg_samples: scalar int
1768
- # """
1769
- # max_len = int(max_len.numpy()) # convert to python int
1770
- # neg_samples = int(neg_samples.numpy())
1771
-
1772
- # # 1) Convert Tensors -> Python lists of strings
1773
- # q_list = [q_i.decode("utf-8") for q_i in q.numpy()] # shape [batch_size]
1774
- # p_list = [p_i.decode("utf-8") for p_i in p.numpy()] # shape [batch_size]
1775
-
1776
- # # shape [batch_size, variable_negatives], decode each row
1777
- # n_list = []
1778
- # for row in n.numpy():
1779
- # # row is shape [N], each is a tf.string
1780
- # decoded = [neg.decode("utf-8") for neg in row]
1781
- # n_list.append(decoded)
1782
-
1783
- # # 2) Tokenize queries & positives
1784
- # q_enc = self.tokenizer(
1785
- # q_list,
1786
- # padding="max_length",
1787
- # truncation=True,
1788
- # max_length=max_len,
1789
- # return_tensors="np" # you can do return_tensors="tf", but "np" is often simpler here
1790
- # )
1791
- # p_enc = self.tokenizer(
1792
- # p_list,
1793
- # padding="max_length",
1794
- # truncation=True,
1795
- # max_length=max_len,
1796
- # return_tensors="np"
1797
- # )
1798
-
1799
- # # 3) Tokenize negatives
1800
- # # Flatten [batch_size, variable_negatives] -> single list
1801
- # flattened_negatives = [neg for row in n_list for neg in row]
1802
- # if len(flattened_negatives) == 0:
1803
- # # No negatives at all: return a zero array
1804
- # n_ids = np.zeros((len(q_list), neg_samples, max_len), dtype=np.int32)
1805
- # else:
1806
- # n_enc = self.tokenizer(
1807
- # flattened_negatives,
1808
- # padding="max_length",
1809
- # truncation=True,
1810
- # max_length=max_len,
1811
- # return_tensors="np"
1812
- # )
1813
- # # shape [batch_size * total_negatives, max_len]
1814
- # n_input_ids = n_enc["input_ids"]
1815
-
1816
- # # We want to reshape to [batch_size, neg_samples, max_len].
1817
- # # If each row truly has exactly `neg_samples` (or fewer), we can do:
1818
- # # n_input_ids = n_input_ids.reshape(len(q_list), neg_samples, max_len)
1819
- # # But if the rows have variable # of negatives, we must clamp or pad.
1820
- # # For simplicity, let's just "take first neg_samples" per row
1821
- # # and pad if fewer.
1822
-
1823
- # # We'll do it row by row:
1824
- # batch_size = len(q_list)
1825
- # row_offsets = 0
1826
- # n_ids_list = []
1827
- # for row_idx in range(batch_size):
1828
- # row_negs = n_list[row_idx]
1829
- # row_count = len(row_negs)
1830
-
1831
- # # slice from the flattened array
1832
- # row_slice = n_input_ids[row_offsets:row_offsets + row_count]
1833
- # row_offsets += row_count
1834
-
1835
- # # Now pick out up to neg_samples
1836
- # row_slice = row_slice[:neg_samples]
1837
-
1838
- # # If fewer than neg_samples, pad
1839
- # if row_slice.shape[0] < neg_samples:
1840
- # deficit = neg_samples - row_slice.shape[0]
1841
- # pad_arr = np.zeros((deficit, max_len), dtype=np.int32)
1842
- # row_slice = np.concatenate([row_slice, pad_arr], axis=0)
1843
-
1844
- # # row_slice is now shape [neg_samples, max_len]
1845
- # n_ids_list.append(row_slice)
1846
-
1847
- # # stack them -> shape [batch_size, neg_samples, max_len]
1848
- # n_ids = np.stack(n_ids_list, axis=0)
1849
-
1850
- # # 4) Return as np.int32 arrays (tokenizer should already return int32,
1851
- # # but we can cast to be sure)
1852
- # q_ids = q_enc["input_ids"].astype(np.int32) # shape [batch_size, max_len]
1853
- # p_ids = p_enc["input_ids"].astype(np.int32) # shape [batch_size, max_len]
1854
- # n_ids = n_ids.astype(np.int32) # shape [batch_size, neg_samples, max_len]
1855
-
1856
- # return q_ids, p_ids, n_ids
1857
-
 
2
  from transformers import TFAutoModel, AutoTokenizer
3
  import tensorflow as tf
4
  import numpy as np
5
+ from typing import List, Tuple, Dict, Optional, Union, Any
6
  import math
7
  from dataclasses import dataclass
8
  import json
 
10
  import datetime
11
  import faiss
12
  import gc
13
+ from tf_data_pipeline import TFDataPipeline
14
  from response_quality_checker import ResponseQualityChecker
15
  from cross_encoder_reranker import CrossEncoderReranker
16
  from conversation_summarizer import DeviceAwareModel, Summarizer
 
25
  @dataclass
26
  class ChatbotConfig:
27
  """Configuration for the RetrievalChatbot."""
 
28
  max_context_token_limit: int = 512
29
  embedding_dim: int = 768
30
  encoder_units: int = 256
31
  num_attention_heads: int = 8
32
  dropout_rate: float = 0.2
33
  l2_reg_weight: float = 0.001
 
34
  learning_rate: float = 0.001
35
  min_text_length: int = 3
36
  max_context_turns: int = 5
 
38
  pretrained_model: str = 'distilbert-base-uncased'
39
  dtype: str = 'float32'
40
  freeze_embeddings: bool = False
41
+ embedding_batch_size: int = 64
42
+ search_batch_size: int = 64
43
+ max_batch_size: int = 64
44
+ neg_samples: int = 3
45
+ max_retries: int = 3
46
 
47
+ def to_dict(self) -> Dict:
48
  """Convert config to dictionary."""
49
+ return {k: (str(v) if isinstance(v, Path) else v)
50
  for k, v in self.__dict__.items()}
51
 
52
  @classmethod
53
+ def from_dict(cls, config_dict: Dict) -> 'ChatbotConfig':
54
  """Create config from dictionary."""
55
  return cls(**{k: v for k, v in config_dict.items()
56
  if k in cls.__dataclass_fields__})
 
61
  self,
62
  config: ChatbotConfig,
63
  name: str = "encoder",
 
64
  **kwargs
65
  ):
66
  super().__init__(name=name, **kwargs)
67
  self.config = config
 
68
 
69
  # Load pretrained model
70
  self.pretrained = TFAutoModel.from_pretrained(config.pretrained_model)
71
 
72
+ # Freeze layers based on config
73
+ self._freeze_layers()
74
+
 
 
 
 
 
75
  # Pooling layer (Global Average Pooling)
76
  self.pooler = tf.keras.layers.GlobalAveragePooling1D()
77
 
 
85
  # Dropout and normalization
86
  self.dropout = tf.keras.layers.Dropout(config.dropout_rate)
87
  self.normalize = tf.keras.layers.Lambda(
88
+ lambda x: tf.nn.l2_normalize(x, axis=1),
89
+ name="l2_normalize"
90
  )
91
 
92
+ def _freeze_layers(self):
93
+ """Freeze layers of the pretrained model based on configuration."""
94
+ if self.config.freeze_embeddings:
95
+ self.pretrained.trainable = False
96
+ logger.info("All pretrained layers frozen.")
97
+ else:
98
+ # Freeze only the first 'n' transformer layers
99
+ for i, layer in enumerate(self.pretrained.layers):
100
+ if isinstance(layer, tf.keras.layers.Layer):
101
+ if hasattr(layer, 'trainable'):
102
+ # Freeze the first transformer block
103
+ if i < 1:
104
+ layer.trainable = False
105
+ logger.info(f"Layer {i} frozen.")
106
+ else:
107
+ layer.trainable = True
108
+
109
  def call(self, inputs: tf.Tensor, training: bool = False) -> tf.Tensor:
110
  """Forward pass."""
111
  # Get pretrained embeddings
 
125
  config = super().get_config()
126
  config.update({
127
  "config": self.config.to_dict(),
 
128
  "name": self.name
129
  })
130
  return config
131
 
132
  class RetrievalChatbot(DeviceAwareModel):
133
  """Retrieval-based chatbot using pretrained embeddings and FAISS for similarity search."""
134
+ def __init__(
135
+ self,
136
+ config: ChatbotConfig,
137
+ dialogues: List[dict] = [],
138
+ device: str = None,
139
+ strategy=None,
140
+ reranker: Optional[CrossEncoderReranker] = None,
141
+ summarizer: Optional[Summarizer] = None
142
+ ):
143
+ super().__init__()
144
  self.config = config
145
  self.strategy = strategy
146
+ self.device = device or self._setup_default_device()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
147
 
148
+ # Initialize reranker, summarizer, tokenizer, and memory monitor
149
+ self.reranker = reranker or self._initialize_reranker()
150
+ self.summarizer = summarizer or self._initialize_summarizer()
151
+ self.tokenizer = self._initialize_tokenizer()
152
  self.memory_monitor = GPUMemoryMonitor()
153
+
154
+ # Initialize models
155
  self.min_batch_size = 8
156
  self.max_batch_size = 128
157
  self.current_batch_size = 32
 
166
  "train_metrics": {},
167
  "val_metrics": {}
168
  }
169
+
170
+ def _setup_default_device(self) -> str:
171
+ """Set up default device if none is provided."""
172
+ if tf.config.list_physical_devices('GPU'):
173
+ return 'GPU'
174
+ else:
175
+ return 'CPU'
176
+
177
+ def _initialize_reranker(self) -> CrossEncoderReranker:
178
+ """Initialize the CrossEncoderReranker."""
179
+ logger.info("Initializing default CrossEncoderReranker...")
180
+ return CrossEncoderReranker(model_name="cross-encoder/ms-marco-MiniLM-L-12-v2")
181
+
182
+ def _initialize_summarizer(self) -> Summarizer:
183
+ """Initialize the Summarizer."""
184
+ logger.info("Initializing default Summarizer...")
185
+ return Summarizer(device=self.device)
186
+
187
+ def _initialize_tokenizer(self) -> AutoTokenizer:
188
+ """Initialize the tokenizer and add special tokens."""
189
+ logger.info("Initializing tokenizer and adding special tokens...")
190
+ tokenizer = AutoTokenizer.from_pretrained(self.config.pretrained_model)
191
+ special_tokens = {
192
+ "user": "<USER>",
193
+ "assistant": "<ASSISTANT>",
194
+ "context": "<CONTEXT>",
195
+ "sep": "<SEP>"
196
+ }
197
+ tokenizer.add_special_tokens(
198
+ {'additional_special_tokens': list(special_tokens.values())}
199
+ )
200
+ return tokenizer
201
+
202
+ def _collect_responses(self, dialogues: List[dict]) -> Tuple[List[str], List[str]]:
203
+ """
204
+ Collect unique responses from dialogues.
205
+ Returns:
206
+ response_pool: List of all possible responses.
207
+ unique_responses: List of unique responses.
208
+ """
209
+ logger.info("Collecting unique responses from dialogues...")
210
+ responses = set()
211
+ for dialogue in dialogues:
212
+ turns = dialogue.get('turns', [])
213
+ for turn in turns:
214
+ if turn.get('speaker') == 'assistant' and 'text' in turn:
215
+ response = turn['text'].strip()
216
+ if len(response) >= self.config.min_text_length:
217
+ responses.add(response)
218
+ response_pool = list(responses)
219
+ unique_responses = list(responses) # Assuming uniqueness
220
+ logger.info(f"Collected {len(response_pool)} unique responses.")
221
+ return response_pool, unique_responses
222
+
223
  def build_models(self):
224
+ """Initialize the shared encoder and FAISS index."""
225
  logger.info("Building encoder model...")
226
  tf.keras.backend.clear_session()
227
 
 
229
  self.encoder = EncoderModel(
230
  self.config,
231
  name="shared_encoder",
232
+ shared_weights=True # If weight sharing is intended
233
  )
234
 
235
  # Resize token embeddings after adding special tokens
 
237
  self.encoder.pretrained.resize_token_embeddings(new_vocab_size)
238
  logger.info(f"Token embeddings resized to: {new_vocab_size}")
239
 
240
+ # Initialize FAISS index
241
  self._initialize_faiss()
 
 
242
 
243
+ # Compute and index embeddings
244
+ self._compute_and_index_embeddings()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
245
 
246
+ # Retrieve embedding dimension from encoder
247
+ embedding_dim = self.config.embedding_dim
248
  vocab_size = len(self.tokenizer)
249
 
250
  logger.info(f"Encoder Embedding Dimension: {embedding_dim}")
 
254
  else:
255
  logger.error("Vocabulary size is less than embedding dimension.")
256
  raise ValueError("Vocabulary size is less than embedding dimension.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
257
 
258
  def _adjust_batch_size(self) -> None:
259
  """Dynamically adjust batch size based on GPU memory usage."""
 
302
  logger.warning(f"Using CPU due to GPU initialization error: {e}")
303
 
304
  # TODO: figure out buf with faiss-gpu
305
+ # TODO: consider IndexIVFFlat in the future (speed).
306
  try:
307
  # Create appropriate index based on dataset size
308
  if len(self.unique_responses) < 1000:
 
875
  logger.info(f"Models and tokenizer loaded from {load_dir}.")
876
  return chatbot
877
 
878
+ # @staticmethod
879
+ # def load_training_data(data_path: Union[str, Path], debug_samples: Optional[int] = None) -> List[dict]:
880
+ # """
881
+ # Load training data from a JSON file.
882
 
883
+ # Args:
884
+ # data_path (Union[str, Path]): Path to the JSON file containing dialogues.
885
+ # debug_samples (Optional[int]): Number of samples to load for debugging.
886
 
887
+ # Returns:
888
+ # List[dict]: List of dialogue dictionaries.
889
+ # """
890
+ # logger.info(f"Loading training data from {data_path}...")
891
+ # data_path = Path(data_path)
892
+ # if not data_path.exists():
893
+ # logger.error(f"Data file {data_path} does not exist.")
894
+ # return []
895
 
896
+ # with open(data_path, 'r', encoding='utf-8') as f:
897
+ # dialogues = json.load(f)
898
 
899
+ # if debug_samples is not None:
900
+ # dialogues = dialogues[:debug_samples]
901
+ # logger.info(f"Debug mode: Limited to {debug_samples} dialogues")
902
 
903
+ # logger.info(f"Loaded {len(dialogues)} dialogues.")
904
+ # return dialogues
905
 
906
  def train_streaming(
907
  self,
 
1351
 
1352
  conversation_parts.append(f"{self.special_tokens['user']} {query}")
1353
  return "\n".join(conversation_parts)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
requirements.txt CHANGED
@@ -1,4 +1,5 @@
1
  faiss-cpu>=1.7.0 # Required for Facebook AI Similarity Search
 
2
  ipython>=8.0.0 # For interactive Python
3
  loguru>=0.7.0 # Enhanced logging (optional but recommended)
4
  matplotlib>=3.5.0 # For validation plotting
 
1
  faiss-cpu>=1.7.0 # Required for Facebook AI Similarity Search
2
+ h5py>=3.1.0 # For saving and loading models
3
  ipython>=8.0.0 # For interactive Python
4
  loguru>=0.7.0 # Enhanced logging (optional but recommended)
5
  matplotlib>=3.5.0 # For validation plotting
run_data_preparer.py ADDED
@@ -0,0 +1,182 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ import faiss
4
+ import pickle
5
+ from transformers import AutoTokenizer
6
+ from tqdm.auto import tqdm
7
+ from chatbot_model import ChatbotConfig, EncoderModel
8
+ from environment_setup import EnvironmentSetup
9
+ from tf_data_pipeline import TFDataPipeline
10
+ from logger_config import config_logger
11
+
12
+ logger = config_logger(__name__)
13
+
14
+ os.environ["TOKENIZERS_PARALLELISM"] = "false"
15
+
16
+ def cleanup_test_indices(faiss_dir, test_prefix='test_'):
17
+ test_files = [f for f in os.listdir(faiss_dir) if f.startswith(test_prefix)]
18
+ for file in test_files:
19
+ file_path = os.path.join(faiss_dir, file)
20
+ os.remove(file_path)
21
+ logger.info(f"Removed test FAISS index file: {file_path}")
22
+
23
+ def main():
24
+ # Constants
25
+ MODELS_DIR = 'models'
26
+ PROCESSED_DATA_DIR = 'processed_outputs'
27
+ CACHE_DIR = 'cache'
28
+ TOKENIZER_DIR = os.path.join(MODELS_DIR, 'tokenizer')
29
+ FAISS_INDICES_DIR = os.path.join(MODELS_DIR, 'faiss_indices')
30
+ TF_RECORD_DIR = 'training_data'
31
+ FAISS_INDEX_PRODUCTION_PATH = os.path.join(FAISS_INDICES_DIR, 'faiss_index_production.index')
32
+ FAISS_INDEX_TEST_PATH = os.path.join(FAISS_INDICES_DIR, 'faiss_index_test.index')
33
+ ENVIRONMENT = 'test' # or 'production'
34
+ if ENVIRONMENT == 'test':
35
+ FAISS_INDEX_PATH = FAISS_INDEX_TEST_PATH
36
+ else:
37
+ FAISS_INDEX_PATH = FAISS_INDEX_PRODUCTION_PATH
38
+ JSON_TRAINING_DATA_PATH = os.path.join(PROCESSED_DATA_DIR, 'augmented_dialogues.json')
39
+ CACHE_FILE = os.path.join(CACHE_DIR, 'query_embeddings_cache.pkl')
40
+ TF_RECORD_PATH = os.path.join(TF_RECORD_DIR, 'training_data.tfrecord')
41
+ DEBUG_SAMPLES = None
42
+
43
+ # Ensure output directories exist
44
+ os.makedirs(MODELS_DIR, exist_ok=True)
45
+ os.makedirs(PROCESSED_DATA_DIR, exist_ok=True)
46
+ os.makedirs(CACHE_DIR, exist_ok=True)
47
+ os.makedirs(TOKENIZER_DIR, exist_ok=True)
48
+ os.makedirs(FAISS_INDICES_DIR, exist_ok=True)
49
+ os.makedirs(TF_RECORD_DIR, exist_ok=True)
50
+
51
+ # Initialize configuration
52
+ config = ChatbotConfig()
53
+ logger.info(f"Chatbot Configuration: {config}")
54
+
55
+ # Initialize tokenizer
56
+ try:
57
+ tokenizer = AutoTokenizer.from_pretrained(config.pretrained_model)
58
+ logger.info(f"Tokenizer '{config.pretrained_model}' loaded successfully.")
59
+ except Exception as e:
60
+ logger.error(f"Failed to load tokenizer: {e}")
61
+ sys.exit(1)
62
+
63
+ # Add special tokens
64
+ try:
65
+ tokenizer.add_special_tokens({'additional_special_tokens': ['<EMPTY_NEGATIVE>']})
66
+ logger.info("Added special tokens to tokenizer.")
67
+ except Exception as e:
68
+ logger.error(f"Failed to add special tokens: {e}")
69
+ sys.exit(1)
70
+
71
+ # Initialize encoder model
72
+ try:
73
+ encoder = EncoderModel(config=config)
74
+ logger.info("EncoderModel initialized successfully.")
75
+ except Exception as e:
76
+ logger.error(f"Failed to initialize EncoderModel: {e}")
77
+ sys.exit(1)
78
+
79
+ # Resize token embeddings in encoder to match tokenizer
80
+ try:
81
+ encoder.pretrained.resize_token_embeddings(len(tokenizer))
82
+ logger.info(f"Token embeddings resized to: {len(tokenizer)}")
83
+ except Exception as e:
84
+ logger.error(f"Failed to resize token embeddings: {e}")
85
+ sys.exit(1)
86
+
87
+ # Load JSON dialogues
88
+ try:
89
+ dialogues = TFDataPipeline.load_json_training_data(JSON_TRAINING_DATA_PATH, DEBUG_SAMPLES)
90
+ logger.info(f"Loaded {len(dialogues)} dialogues from {JSON_TRAINING_DATA_PATH}.")
91
+ except Exception as e:
92
+ logger.error(f"Failed to load dialogues: {e}")
93
+ sys.exit(1)
94
+
95
+ # Load or initialize query_embeddings_cache
96
+ try:
97
+ if os.path.exists(CACHE_FILE):
98
+ with open(CACHE_FILE, 'rb') as f:
99
+ query_embeddings_cache = pickle.load(f)
100
+ logger.info(f"Loaded {len(query_embeddings_cache)} query embeddings from {CACHE_FILE}.")
101
+ else:
102
+ query_embeddings_cache = {}
103
+ logger.info("Initialized empty query embeddings cache.")
104
+ except Exception as e:
105
+ logger.error(f"Failed to load or initialize query embeddings cache: {e}")
106
+ sys.exit(1)
107
+
108
+ # Initialize TFDataPipeline
109
+ try:
110
+ data_pipeline = TFDataPipeline(
111
+ config=config,
112
+ tokenizer=tokenizer,
113
+ encoder=encoder,
114
+ index_file_path=FAISS_INDEX_PATH,
115
+ response_pool=[],
116
+ max_length=config.max_context_token_limit,
117
+ neg_samples=config.neg_samples,
118
+ query_embeddings_cache=query_embeddings_cache,
119
+ max_retries=config.max_retries
120
+ )
121
+ logger.info("TFDataPipeline initialized successfully.")
122
+ except Exception as e:
123
+ logger.error(f"Failed to initialize TFDataPipeline: {e}")
124
+ sys.exit(1)
125
+
126
+ # Collect unique assistant responses from dialogues
127
+ try:
128
+ response_pool = data_pipeline.collect_responses(dialogues)
129
+ data_pipeline.response_pool = response_pool
130
+ logger.info(f"Collected {len(response_pool)} unique assistant responses from dialogues.")
131
+ except Exception as e:
132
+ logger.error(f"Failed to collect responses: {e}")
133
+ sys.exit(1)
134
+
135
+ # Compute and add response embeddings to FAISS index
136
+ try:
137
+ logger.info("Computing and adding response embeddings to FAISS index...")
138
+ data_pipeline._compute_and_index_response_embeddings()
139
+ logger.info("Response embeddings computed and added to FAISS index.")
140
+ except Exception as e:
141
+ logger.error(f"Failed to compute or add response embeddings: {e}")
142
+ sys.exit(1)
143
+
144
+ # Save FAISS index
145
+ try:
146
+ logger.info(f"Saving FAISS index to {FAISS_INDEX_PATH}...")
147
+ faiss.write_index(data_pipeline.index, FAISS_INDEX_PATH)
148
+ logger.info("FAISS index saved successfully.")
149
+ except Exception as e:
150
+ logger.error(f"Failed to save FAISS index: {e}")
151
+ sys.exit(1)
152
+
153
+ # Prepare and save training data as TFRecords
154
+ try:
155
+ logger.info("Starting data preparation and saving as TFRecord...")
156
+ data_pipeline.prepare_and_save_data(dialogues, TF_RECORD_PATH)
157
+ logger.info(f"Data saved as TFRecord at {TF_RECORD_PATH}.")
158
+ except Exception as e:
159
+ logger.error(f"Failed during data preparation and saving: {e}")
160
+ sys.exit(1)
161
+
162
+ # Save query embeddings cache
163
+ try:
164
+ with open(CACHE_FILE, 'wb') as f:
165
+ pickle.dump(data_pipeline.query_embeddings_cache, f)
166
+ logger.info(f"Saved {len(data_pipeline.query_embeddings_cache)} query embeddings to {CACHE_FILE}.")
167
+ except Exception as e:
168
+ logger.error(f"Failed to save query embeddings cache: {e}")
169
+ sys.exit(1)
170
+
171
+ # Save Tokenizer (including special tokens)
172
+ try:
173
+ tokenizer.save_pretrained(TOKENIZER_DIR)
174
+ logger.info(f"Tokenizer saved to {TOKENIZER_DIR}.")
175
+ except Exception as e:
176
+ logger.error(f"Failed to save tokenizer: {e}")
177
+ sys.exit(1)
178
+
179
+ logger.info("Data preparation pipeline completed successfully.")
180
+
181
+ if __name__ == "__main__":
182
+ main()
tf_data_pipeline.py ADDED
@@ -0,0 +1,734 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import gc
3
+ import numpy as np
4
+ import faiss
5
+ import tensorflow as tf
6
+ import h5py
7
+ from tqdm import tqdm
8
+ import json
9
+ from pathlib import Path
10
+ from typing import Union, Optional, List, Tuple, Generator
11
+ from transformers import AutoTokenizer
12
+ from typing import List, Tuple, Generator
13
+ from gpu_monitor import GPUMemoryMonitor
14
+
15
+ from logger_config import config_logger
16
+ logger = config_logger(__name__)
17
+
18
+ class TFDataPipeline:
19
+ def __init__(
20
+ self,
21
+ config,
22
+ tokenizer,
23
+ encoder,
24
+ index_file_path: str,
25
+ response_pool: List[str],
26
+ max_length: int,
27
+ query_embeddings_cache: dict,
28
+ neg_samples: int = 3,
29
+ index_type: str = 'IndexFlatIP',
30
+ nlist: int = 100,
31
+ max_retries: int = 3
32
+ ):
33
+ #self.embedding_batch_size = embedding_batch_size
34
+ self.config = config
35
+ self.tokenizer = tokenizer
36
+ self.encoder = encoder
37
+ self.index_file_path = index_file_path
38
+ self.response_pool = response_pool
39
+ self.max_length = max_length
40
+ self.neg_samples = neg_samples
41
+ self.query_embeddings_cache = query_embeddings_cache # In-memory cache for embeddings
42
+ self.index_type = index_type
43
+ self.nlist = nlist
44
+ self.embedding_batch_size = 16 if len(response_pool) < 100 else 64
45
+ self.search_batch_size = 16 if len(response_pool) < 100 else 64
46
+ self.max_batch_size = 16 if len(response_pool) < 100 else 64
47
+ self.memory_monitor = GPUMemoryMonitor()
48
+ self.max_retries = max_retries
49
+
50
+ if os.path.exists(index_file_path):
51
+ logger.info(f"Loading existing FAISS index from {index_file_path}...")
52
+ self.index = faiss.read_index(index_file_path)
53
+ self.validate_faiss_index()
54
+ logger.info("FAISS index loaded and validated successfully.")
55
+ else:
56
+ # Initialize FAISS index
57
+ dimension = self.encoder.config.embedding_dim
58
+ self.index = faiss.IndexFlatIP(dimension)
59
+ logger.info(f"Initialized FAISS IndexFlatIP with dimension {dimension}.")
60
+
61
+ if not self.index.is_trained:
62
+ # Train the index if it's not trained. # TODO: Replace 'dimension' with embedding size
63
+ dimension = self.query_embeddings_cache[next(iter(self.query_embeddings_cache))].shape[0]
64
+ self.index.train(np.array(list(self.query_embeddings_cache.values())).astype(np.float32))
65
+ self.index.add(np.array(list(self.query_embeddings_cache.values())).astype(np.float32))
66
+
67
+ def validate_faiss_index(self):
68
+ """Validates that the FAISS index has the correct dimensionality."""
69
+ expected_dim = self.encoder.config.embedding_dim
70
+ if self.index.d != expected_dim:
71
+ logger.error(f"FAISS index dimension {self.index.d} does not match encoder embedding dimension {expected_dim}.")
72
+ raise ValueError("FAISS index dimensionality mismatch.")
73
+ logger.info("FAISS index dimension validated successfully.")
74
+
75
+ def save_embeddings_cache_hdf5(self, cache_file_path: str):
76
+ """Save the embeddings cache to an HDF5 file."""
77
+ with h5py.File(cache_file_path, 'w') as hf:
78
+ for query, emb in self.query_embeddings_cache.items():
79
+ hf.create_dataset(query, data=emb)
80
+ logger.info(f"Embeddings cache saved to {cache_file_path}.")
81
+
82
+ def load_embeddings_cache_hdf5(self, cache_file_path: str):
83
+ """Load the embeddings cache from an HDF5 file."""
84
+ with h5py.File(cache_file_path, 'r') as hf:
85
+ for query in hf.keys():
86
+ self.query_embeddings_cache[query] = hf[query][:]
87
+ logger.info(f"Embeddings cache loaded from {cache_file_path}.")
88
+
89
+ def save_faiss_index(self, index_file_path: str):
90
+ faiss.write_index(self.index, index_file_path)
91
+ logger.info(f"FAISS index saved to {index_file_path}")
92
+
93
+ def load_faiss_index(self, index_file_path: str):
94
+ self.index = faiss.read_index(index_file_path)
95
+ logger.info(f"FAISS index loaded from {index_file_path}")
96
+
97
+ def save_tokenizer(self, tokenizer_dir: str):
98
+ self.tokenizer.save_pretrained(tokenizer_dir)
99
+ logger.info(f"Tokenizer saved to {tokenizer_dir}")
100
+
101
+ def load_tokenizer(self, tokenizer_dir: str):
102
+ self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_dir)
103
+ logger.info(f"Tokenizer loaded from {tokenizer_dir}")
104
+
105
+ def estimate_total_pairs(self, dialogues: List[dict]) -> int:
106
+ """Estimate total number of training pairs including hard negatives."""
107
+ base_pairs = sum(
108
+ len([
109
+ 1 for i in range(len(d.get('turns', [])) - 1)
110
+ if (d['turns'][i].get('speaker') == 'user' and
111
+ d['turns'][i+1].get('speaker') == 'assistant')
112
+ ])
113
+ for d in dialogues
114
+ )
115
+ # Account for hard negatives
116
+ return base_pairs * (1 + self.neg_samples)
117
+
118
+ @staticmethod
119
+ def load_json_training_data(data_path: Union[str, Path], debug_samples: Optional[int] = None) -> List[dict]:
120
+ """
121
+ Load training data from a JSON file.
122
+
123
+ Args:
124
+ data_path (Union[str, Path]): Path to the JSON file containing dialogues.
125
+ debug_samples (Optional[int]): Number of samples to load for debugging.
126
+
127
+ Returns:
128
+ List[dict]: List of dialogue dictionaries.
129
+ """
130
+ logger.info(f"Loading training data from {data_path}...")
131
+ data_path = Path(data_path)
132
+ if not data_path.exists():
133
+ logger.error(f"Data file {data_path} does not exist.")
134
+ return []
135
+
136
+ with open(data_path, 'r', encoding='utf-8') as f:
137
+ dialogues = json.load(f)
138
+
139
+ if debug_samples is not None:
140
+ dialogues = dialogues[:debug_samples]
141
+ logger.info(f"Debug mode: Limited to {debug_samples} dialogues")
142
+
143
+ logger.info(f"Loaded {len(dialogues)} dialogues.")
144
+ return dialogues
145
+
146
+ def collect_responses(self, dialogues: List[dict]) -> List[str]:
147
+ """Extract unique assistant responses from dialogues."""
148
+ response_set = set()
149
+ for dialogue in dialogues:
150
+ turns = dialogue.get('turns', [])
151
+ for turn in turns:
152
+ speaker = turn.get('speaker')
153
+ text = turn.get('text', '').strip()
154
+ if speaker == 'assistant' and text:
155
+ # Ensure we don't exclude valid shorter responses
156
+ if len(text) <= self.max_length:
157
+ response_set.add(text)
158
+ logger.info(f"Collected {len(response_set)} unique assistant responses from dialogues.")
159
+ return list(response_set)
160
+
161
+ def _extract_pairs_from_dialogue(self, dialogue: dict) -> List[Tuple[str, str]]:
162
+ """Extract query-response pairs from a dialogue."""
163
+ pairs = []
164
+ turns = dialogue.get('turns', [])
165
+
166
+ for i in range(len(turns) - 1):
167
+ current_turn = turns[i]
168
+ next_turn = turns[i+1]
169
+
170
+ if (current_turn.get('speaker') == 'user' and
171
+ next_turn.get('speaker') == 'assistant' and
172
+ 'text' in current_turn and
173
+ 'text' in next_turn):
174
+
175
+ query = current_turn['text'].strip()
176
+ positive = next_turn['text'].strip()
177
+ pairs.append((query, positive))
178
+
179
+ return pairs
180
+
181
+ def _compute_and_index_response_embeddings(self):
182
+ """
183
+ Computes embeddings for the response pool and adds them to the FAISS index.
184
+ """
185
+ logger.info("Computing embeddings for the response pool...")
186
+
187
+ # Log the contents and types of response_pool
188
+ for idx, response in enumerate(self.response_pool[:5], 1): # Log first 5 responses
189
+ logger.debug(f"Response {idx}: {response} (Type: {type(response)})")
190
+
191
+ # Ensure all responses are strings
192
+ if not all(isinstance(response, str) for response in self.response_pool):
193
+ logger.error("All elements in response_pool must be strings.")
194
+ raise ValueError("Invalid data type in response_pool.")
195
+
196
+ # Proceed with tokenization
197
+ encoded_responses = self.tokenizer(
198
+ self.response_pool,
199
+ padding=True,
200
+ truncation=True,
201
+ max_length=self.max_length,
202
+ return_tensors='tf'
203
+ )
204
+ response_ids = encoded_responses['input_ids']
205
+
206
+ # Compute embeddings in batches
207
+ batch_size = getattr(self, 'embedding_batch_size', 64) # Default to 64 if not set
208
+ embeddings = []
209
+ for i in range(0, len(response_ids), batch_size):
210
+ batch_ids = response_ids[i:i+batch_size]
211
+ # Compute embeddings
212
+ batch_embeddings = self.encoder(batch_ids, training=False).numpy()
213
+ # Normalize embeddings if using inner product or cosine similarity
214
+ faiss.normalize_L2(batch_embeddings)
215
+ embeddings.append(batch_embeddings)
216
+
217
+ if embeddings:
218
+ embeddings = np.vstack(embeddings).astype(np.float32)
219
+ # Add embeddings to FAISS index
220
+ logger.info(f"Adding {len(embeddings)} response embeddings to FAISS index...")
221
+ self.index.add(embeddings)
222
+ logger.info("Response embeddings added to FAISS index.")
223
+ else:
224
+ logger.warning("No embeddings to add to FAISS index.")
225
+
226
+ # **Sanity Check:** Verify the number of embeddings in FAISS index
227
+ logger.info(f"Total embeddings in FAISS index after addition: {self.index.ntotal}")
228
+
229
+ def _find_hard_negatives_batch(self, queries: List[str], positives: List[str]) -> List[List[str]]:
230
+ """Find hard negatives for a batch of queries with error handling and retries."""
231
+ retry_count = 0
232
+ total_responses = len(self.response_pool)
233
+
234
+ # Set k to be neg_samples + additional candidates to improve negative selection
235
+ k = self.neg_samples + 0
236
+
237
+ while retry_count < self.max_retries:
238
+ try:
239
+ # Compute embeddings in sub-batches to manage memory
240
+ batch_size = 128 # Example sub-batch size; adjust as needed
241
+ query_embeddings = []
242
+ for i in range(0, len(queries), batch_size):
243
+ sub_queries = queries[i:i + batch_size]
244
+ sub_embeddings = np.vstack([
245
+ self.query_embeddings_cache[q] for q in sub_queries
246
+ ]).astype(np.float32)
247
+ faiss.normalize_L2(sub_embeddings)
248
+ query_embeddings.append(sub_embeddings)
249
+ query_embeddings = np.vstack(query_embeddings)
250
+
251
+ # Ensure contiguous memory layout
252
+ query_embeddings = np.ascontiguousarray(query_embeddings)
253
+
254
+ # Perform FAISS search on CPU
255
+ distances, indices = self.index.search(query_embeddings, k)
256
+
257
+ all_negatives = []
258
+ for query_indices, query, positive in zip(indices, queries, positives):
259
+ negatives = []
260
+ positive_strip = positive.strip()
261
+ seen = {positive_strip}
262
+
263
+ for idx in query_indices:
264
+ if idx >= 0 and idx < total_responses:
265
+ candidate = self.response_pool[idx].strip()
266
+ if candidate and candidate not in seen:
267
+ seen.add(candidate)
268
+ negatives.append(candidate)
269
+ if len(negatives) >= self.neg_samples:
270
+ break
271
+
272
+ # If not enough negatives are found, pad with a special token
273
+ while len(negatives) < self.neg_samples:
274
+ negatives.append("<EMPTY_NEGATIVE>") # Use a special token
275
+
276
+ all_negatives.append(negatives)
277
+
278
+ return all_negatives
279
+
280
+ except KeyError as ke:
281
+ retry_count += 1
282
+ logger.warning(f"Hard negative search attempt {retry_count} failed due to missing embeddings: {ke}")
283
+ if retry_count == self.max_retries:
284
+ logger.error("Max retries reached for hard negative search due to missing embeddings.")
285
+ return [["<EMPTY_NEGATIVE>"] * self.neg_samples for _ in queries]
286
+ # Perform memory cleanup
287
+ gc.collect()
288
+ if tf.config.list_physical_devices('GPU'):
289
+ tf.keras.backend.clear_session()
290
+ except Exception as e:
291
+ retry_count += 1
292
+ logger.warning(f"Hard negative search attempt {retry_count} failed: {e}")
293
+ if retry_count == self.max_retries:
294
+ logger.error("Max retries reached for hard negative search.")
295
+ return [["<EMPTY_NEGATIVE>"] * self.neg_samples for _ in queries]
296
+ # Perform memory cleanup
297
+ gc.collect()
298
+ if tf.config.list_physical_devices('GPU'):
299
+ tf.keras.backend.clear_session()
300
+
301
+ def _tokenize_and_encode(self, queries: List[str], positives: List[str], negatives: List[List[str]]) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
302
+ """
303
+ Tokenize and encode the queries, positives, and negatives.
304
+ Returns:
305
+ query_ids: [batch_size, max_length]
306
+ positive_ids: [batch_size, max_length]
307
+ negative_ids: [batch_size, neg_samples, max_length]
308
+ """
309
+ # Tokenize queries
310
+ q_enc = self.tokenizer(
311
+ queries,
312
+ padding="max_length",
313
+ truncation=True,
314
+ max_length=self.max_length,
315
+ return_tensors="np"
316
+ )
317
+ # Tokenize positives
318
+ p_enc = self.tokenizer(
319
+ positives,
320
+ padding="max_length",
321
+ truncation=True,
322
+ max_length=self.max_length,
323
+ return_tensors="np"
324
+ )
325
+ # Tokenize negatives
326
+ # Flatten negatives
327
+ flattened_negatives = [neg for sublist in negatives for neg in sublist]
328
+ if len(flattened_negatives) == 0:
329
+ # No negatives at all: return a zero array
330
+ n_ids = np.zeros((len(queries), self.neg_samples, self.max_length), dtype=np.int32)
331
+ else:
332
+ n_enc = self.tokenizer(
333
+ flattened_negatives,
334
+ padding="max_length",
335
+ truncation=True,
336
+ max_length=self.max_length,
337
+ return_tensors="np"
338
+ )
339
+ n_input_ids = n_enc["input_ids"]
340
+
341
+ # Reshape to [batch_size, neg_samples, max_length]
342
+ batch_size = len(queries)
343
+ n_ids = n_input_ids.reshape(batch_size, self.neg_samples, self.max_length)
344
+
345
+ # Convert to int32
346
+ query_ids = q_enc["input_ids"].astype(np.int32)
347
+ positive_ids = p_enc["input_ids"].astype(np.int32)
348
+ negative_ids = n_ids.astype(np.int32)
349
+
350
+ return query_ids, positive_ids, negative_ids
351
+
352
+ def prepare_and_save_data(self, dialogues: List[dict], tfrecord_file_path: str, batch_size: int = 32):
353
+ """Processes dialogues in batches and saves to a TFRecord file."""
354
+ with tf.io.TFRecordWriter(tfrecord_file_path) as writer:
355
+ total_dialogues = len(dialogues)
356
+ logger.debug(f"Total dialogues to process: {total_dialogues}")
357
+
358
+ with tqdm(total=total_dialogues, desc="Processing Dialogues", unit="dialogue") as pbar:
359
+ for i in range(0, total_dialogues, batch_size):
360
+ batch_dialogues = dialogues[i:i+batch_size]
361
+ # Process each batch_dialogues
362
+ # Extract pairs, find negatives, tokenize, and serialize
363
+ # Example:
364
+ for dialogue in batch_dialogues:
365
+ pairs = self._extract_pairs_from_dialogue(dialogue)
366
+ queries = []
367
+ positives = []
368
+
369
+ for query, positive in pairs:
370
+ queries.append(query)
371
+ positives.append(positive)
372
+
373
+ if queries:
374
+ # **Compute and cache query embeddings before searching**
375
+ self._compute_embeddings(queries)
376
+
377
+ # Find hard negatives
378
+ hard_negatives = self._find_hard_negatives_batch(queries, positives)
379
+
380
+ for idx, negatives in enumerate(hard_negatives[:5]): # Log first 5 examples
381
+ logger.debug(f"Query: {queries[idx]}")
382
+ logger.debug(f"Positive: {positives[idx]}")
383
+ logger.debug(f"Hard Negatives: {negatives}")
384
+ # Tokenize and encode
385
+ query_ids, positive_ids, negative_ids = self._tokenize_and_encode(queries, positives, hard_negatives)
386
+
387
+ # Serialize each example and write to TFRecord
388
+ for q_id, p_id, n_id in zip(query_ids, positive_ids, negative_ids):
389
+ feature = {
390
+ 'query_ids': tf.train.Feature(int64_list=tf.train.Int64List(value=q_id)),
391
+ 'positive_ids': tf.train.Feature(int64_list=tf.train.Int64List(value=p_id)),
392
+ 'negative_ids': tf.train.Feature(int64_list=tf.train.Int64List(value=n_id.flatten())),
393
+ }
394
+ example = tf.train.Example(features=tf.train.Features(feature=feature))
395
+ writer.write(example.SerializeToString())
396
+
397
+ pbar.update(len(batch_dialogues))
398
+ logger.info(f"Data preparation complete. TFRecord saved at {tfrecord_file_path}")
399
+
400
+ def _tokenize_negatives_tf(self, negatives):
401
+ """Tokenizes negatives using tf.py_function."""
402
+ # Handle the case where negatives is an empty tensor
403
+ if tf.size(negatives) == 0:
404
+ return tf.zeros([0, self.neg_samples, self.max_length], dtype=tf.int32)
405
+
406
+ # Convert EagerTensor to a list of strings
407
+ negatives_list = []
408
+ for neg_list in negatives.numpy():
409
+ decoded_negs = [neg.decode("utf-8") for neg in neg_list if neg] # Filter out empty strings
410
+ negatives_list.append(decoded_negs)
411
+
412
+ # Flatten the list of lists
413
+ flattened_negatives = [neg for sublist in negatives_list for neg in sublist]
414
+
415
+ # Tokenize the flattened negatives
416
+ if flattened_negatives:
417
+ n_tokens = self.tokenizer(
418
+ flattened_negatives,
419
+ padding='max_length',
420
+ truncation=True,
421
+ max_length=self.max_length,
422
+ return_tensors='tf'
423
+ )
424
+ # Reshape the tokens
425
+ n_tokens_reshaped = tf.reshape(n_tokens['input_ids'], [-1, self.neg_samples, self.max_length])
426
+ return n_tokens_reshaped
427
+ else:
428
+ return tf.zeros([0, self.neg_samples, self.max_length], dtype=tf.int32)
429
+
430
+ def _compute_embeddings(self, queries: List[str]) -> None:
431
+ new_queries = [q for q in queries if q not in self.query_embeddings_cache]
432
+ if not new_queries:
433
+ return # All queries already cached
434
+
435
+ # Compute embeddings for new queries
436
+ new_embeddings = []
437
+ for i in range(0, len(new_queries), self.embedding_batch_size):
438
+ batch_queries = new_queries[i:i + self.embedding_batch_size]
439
+ encoded = self.tokenizer(
440
+ batch_queries,
441
+ padding=True,
442
+ truncation=True,
443
+ max_length=self.max_length,
444
+ return_tensors='tf'
445
+ )
446
+ batch_embeddings = self.encoder(encoded['input_ids'], training=False).numpy()
447
+ faiss.normalize_L2(batch_embeddings)
448
+ new_embeddings.extend(batch_embeddings)
449
+
450
+ # Update the cache
451
+ for query, emb in zip(new_queries, new_embeddings):
452
+ self.query_embeddings_cache[query] = emb
453
+
454
+ def data_generator(self, dialogues: List[dict]) -> Generator[Tuple[str, str, List[str]], None, None]:
455
+ """
456
+ Generates training examples: (query, positive, hard_negatives).
457
+ Wrapped the outer loop with tqdm for progress tracking.
458
+ """
459
+ total_dialogues = len(dialogues)
460
+ logger.debug(f"Total dialogues to process: {total_dialogues}")
461
+
462
+ # Initialize tqdm progress bar
463
+ with tqdm(total=total_dialogues, desc="Processing Dialogues", unit="dialogue") as pbar:
464
+ for dialogue in dialogues:
465
+ pairs = self._extract_pairs_from_dialogue(dialogue)
466
+ for query, positive in pairs:
467
+ # Ensure embeddings are computed, find hard negatives, etc.
468
+ self._compute_embeddings([query])
469
+ hard_negatives = self._find_hard_negatives_batch([query], [positive])[0]
470
+ yield (query, positive, hard_negatives)
471
+ pbar.update(1)
472
+
473
+ def _prepare_batch(self, queries: tf.Tensor, positives: tf.Tensor, negatives: tf.Tensor) -> Tuple[tf.Tensor, tf.Tensor]:
474
+ """Prepares a batch of data for training."""
475
+
476
+ # Convert EagerTensors to lists of strings
477
+ queries_list = [query.decode("utf-8") for query in queries.numpy()]
478
+ positives_list = [pos.decode("utf-8") for pos in positives.numpy()]
479
+
480
+ # Tokenize queries and positives
481
+ q_tokens = self.tokenizer(queries_list, padding='max_length', truncation=True, max_length=self.max_length, return_tensors='tf')
482
+ p_tokens = self.tokenizer(positives_list, padding='max_length', truncation=True, max_length=self.max_length, return_tensors='tf')
483
+
484
+ # Decode negatives and ensure they are lists of strings
485
+ negatives_list = []
486
+ for neg_list in negatives.numpy():
487
+ decoded_negs = [neg.decode("utf-8") for neg in neg_list if neg] # Filter out empty strings
488
+ negatives_list.append(decoded_negs)
489
+
490
+ # Flatten negatives for tokenization if there are any valid negatives
491
+ flattened_negatives = [neg for sublist in negatives_list for neg in sublist if neg]
492
+
493
+ # Tokenize negatives if there are any
494
+ n_tokens_reshaped = None
495
+ if flattened_negatives:
496
+ n_tokens = self.tokenizer(flattened_negatives, padding='max_length', truncation=True, max_length=self.max_length, return_tensors='tf')
497
+
498
+ # Reshape n_tokens to match the expected shape based on the number of negatives per query
499
+ # This part may need adjustment if the number of negatives varies per query
500
+ n_tokens_reshaped = tf.reshape(n_tokens['input_ids'], [len(queries_list), -1, self.max_length])
501
+ else:
502
+ # Create a placeholder tensor for the case where there are no negatives
503
+ n_tokens_reshaped = tf.zeros([len(queries_list), 0, self.max_length], dtype=tf.int32)
504
+
505
+ # Ensure n_tokens_reshaped has a consistent shape even when there are no negatives
506
+ # Adjust shape to [batch_size, num_neg_samples, max_length]
507
+ if n_tokens_reshaped.shape[1] != self.neg_samples:
508
+ # Pad or truncate the second dimension to match neg_samples
509
+ padding = tf.zeros([len(queries_list), tf.maximum(0, self.neg_samples - n_tokens_reshaped.shape[1]), self.max_length], dtype=tf.int32)
510
+ n_tokens_reshaped = tf.concat([n_tokens_reshaped, padding], axis=1)
511
+ n_tokens_reshaped = n_tokens_reshaped[:, :self.neg_samples, :]
512
+
513
+ # Concatenate the positive and negative examples along the 'neg_samples' dimension
514
+ combined_p_n_tokens = tf.concat([tf.expand_dims(p_tokens['input_ids'], axis=1), n_tokens_reshaped], axis=1)
515
+
516
+ return q_tokens['input_ids'], combined_p_n_tokens
517
+
518
+ def get_tf_dataset(self, dialogues: List[dict], batch_size: int) -> tf.data.Dataset:
519
+ """
520
+ Creates a tf.data.Dataset for streaming training that yields
521
+ (input_ids_query, input_ids_positive, input_ids_negatives).
522
+ """
523
+ # 1) Start with a generator dataset
524
+ dataset = tf.data.Dataset.from_generator(
525
+ lambda: self.data_generator(dialogues),
526
+ output_signature=(
527
+ tf.TensorSpec(shape=(), dtype=tf.string), # Query (single string)
528
+ tf.TensorSpec(shape=(), dtype=tf.string), # Positive (single string)
529
+ tf.TensorSpec(shape=(self.neg_samples,), dtype=tf.string) # Hard Negatives (list of strings)
530
+ )
531
+ )
532
+
533
+ # 2) Batch the raw strings
534
+ dataset = dataset.batch(batch_size, drop_remainder=True)
535
+
536
+ # 3) Map them through a tokenize step using `tf.py_function`
537
+ dataset = dataset.map(
538
+ lambda q, p, n: self._tokenize_triple(q, p, n),
539
+ num_parallel_calls=1 #tf.data.AUTOTUNE
540
+ )
541
+
542
+ dataset = dataset.prefetch(tf.data.AUTOTUNE)
543
+ return dataset
544
+ # def get_tf_dataset(self, dialogues: List[dict], batch_size: int) -> tf.data.Dataset:
545
+ # """
546
+ # Creates a tf.data.Dataset for streaming training that yields
547
+ # (input_ids_query, input_ids_positive, input_ids_negatives).
548
+ # """
549
+ # # 1) Start with a generator dataset
550
+ # dataset = tf.data.Dataset.from_generator(
551
+ # lambda: self.data_generator(dialogues),
552
+ # output_signature=(
553
+ # tf.TensorSpec(shape=(), dtype=tf.string), # Query (single string)
554
+ # tf.TensorSpec(shape=(), dtype=tf.string), # Positive (single string)
555
+ # tf.TensorSpec(shape=(None,), dtype=tf.string) # Hard Negatives (list of strings)
556
+ # )
557
+ # )
558
+
559
+ # # 2) Batch the raw strings
560
+ # dataset = dataset.batch(batch_size)
561
+
562
+ # # 3) Now map them through a tokenize step (via py_function)
563
+ # dataset = dataset.map(
564
+ # lambda q, p, n: self._tokenize_triple(q, p, n),
565
+ # num_parallel_calls=1 #tf.data.AUTOTUNE
566
+ # )
567
+
568
+ # dataset = dataset.prefetch(tf.data.AUTOTUNE)
569
+ # return dataset
570
+
571
+ def _tokenize_triple(
572
+ self,
573
+ q: tf.Tensor,
574
+ p: tf.Tensor,
575
+ n: tf.Tensor
576
+ ) -> Tuple[tf.Tensor, tf.Tensor, tf.Tensor]:
577
+ """
578
+ Wraps a Python function via tf.py_function to convert tf.Tensors of strings
579
+ -> Python lists of strings -> HF tokenizer -> Tensors of IDs.
580
+
581
+ q is shape [batch_size], p is shape [batch_size],
582
+ n is shape [batch_size, neg_samples] (i.e., each row is a list of negatives).
583
+ """
584
+ # Use tf.py_function with limited parallelism
585
+ q_ids, p_ids, n_ids = tf.py_function(
586
+ func=self._tokenize_triple_py,
587
+ inp=[q, p, n, tf.constant(self.max_length), tf.constant(self.neg_samples)],
588
+ Tout=[tf.int32, tf.int32, tf.int32]
589
+ )
590
+
591
+ # Manually set shape information
592
+ q_ids.set_shape([None, self.max_length]) # [batch_size, max_length]
593
+ p_ids.set_shape([None, self.max_length]) # [batch_size, max_length]
594
+ n_ids.set_shape([None, self.neg_samples, self.max_length]) # [batch_size, neg_samples, max_length]
595
+
596
+ return q_ids, p_ids, n_ids
597
+
598
+ def _tokenize_triple_py(
599
+ self,
600
+ q: tf.Tensor,
601
+ p: tf.Tensor,
602
+ n: tf.Tensor,
603
+ max_len: tf.Tensor,
604
+ neg_samples: tf.Tensor
605
+ ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
606
+ """
607
+ Python function that:
608
+ - Decodes each tf.string Tensor to a Python list of strings
609
+ - Calls the HF tokenizer
610
+ - Reshapes negatives
611
+ - Returns np.array of int32s for (q_ids, p_ids, n_ids).
612
+
613
+ q: shape [batch_size], p: shape [batch_size]
614
+ n: shape [batch_size, neg_samples]
615
+ max_len: scalar int
616
+ neg_samples: scalar int
617
+ """
618
+ max_len = int(max_len.numpy()) # Convert to Python int
619
+ neg_samples = int(neg_samples.numpy())
620
+
621
+ # 1) Convert Tensors -> Python lists of strings
622
+ q_list = [q_i.decode("utf-8") for q_i in q.numpy()] # shape [batch_size]
623
+ p_list = [p_i.decode("utf-8") for p_i in p.numpy()] # shape [batch_size]
624
+
625
+ # shape [batch_size, neg_samples], decode each row
626
+ n_list = []
627
+ for row in n.numpy():
628
+ # row is shape [neg_samples], each is a tf.string
629
+ decoded = [neg.decode("utf-8") for neg in row]
630
+ n_list.append(decoded)
631
+
632
+ # 2) Tokenize queries & positives
633
+ q_enc = self.tokenizer(
634
+ q_list,
635
+ padding="max_length",
636
+ truncation=True,
637
+ max_length=max_len,
638
+ return_tensors="np"
639
+ )
640
+ p_enc = self.tokenizer(
641
+ p_list,
642
+ padding="max_length",
643
+ truncation=True,
644
+ max_length=max_len,
645
+ return_tensors="np"
646
+ )
647
+
648
+ # 3) Tokenize negatives
649
+ # Flatten [batch_size, neg_samples] -> single list
650
+ flattened_negatives = [neg for row in n_list for neg in row]
651
+ if len(flattened_negatives) == 0:
652
+ # No negatives at all: return a zero array
653
+ n_ids = np.zeros((len(q_list), neg_samples, max_len), dtype=np.int32)
654
+ else:
655
+ n_enc = self.tokenizer(
656
+ flattened_negatives,
657
+ padding="max_length",
658
+ truncation=True,
659
+ max_length=max_len,
660
+ return_tensors="np"
661
+ )
662
+ # shape [batch_size * neg_samples, max_len]
663
+ n_input_ids = n_enc["input_ids"]
664
+
665
+ # We want to reshape to [batch_size, neg_samples, max_len]
666
+ # Handle cases where there might be fewer negatives
667
+ batch_size = len(q_list)
668
+ n_ids_list = []
669
+ for i in range(batch_size):
670
+ start_idx = i * neg_samples
671
+ end_idx = start_idx + neg_samples
672
+ row_negs = n_input_ids[start_idx:end_idx]
673
+
674
+ # If fewer negatives, pad with zeros
675
+ if row_negs.shape[0] < neg_samples:
676
+ deficit = neg_samples - row_negs.shape[0]
677
+ pad_arr = np.zeros((deficit, max_len), dtype=np.int32)
678
+ row_negs = np.concatenate([row_negs, pad_arr], axis=0)
679
+
680
+ n_ids_list.append(row_negs)
681
+
682
+ # stack them -> shape [batch_size, neg_samples, max_len]
683
+ n_ids = np.stack(n_ids_list, axis=0)
684
+
685
+ # 4) Return as np.int32 arrays
686
+ q_ids = q_enc["input_ids"].astype(np.int32) # shape [batch_size, max_len]
687
+ p_ids = p_enc["input_ids"].astype(np.int32) # shape [batch_size, max_len]
688
+ n_ids = n_ids.astype(np.int32) # shape [batch_size, neg_samples, max_len]
689
+
690
+ return q_ids, p_ids, n_ids
691
+
692
+
693
+
694
+ # def _find_hard_negatives_batch(self, queries: List[str], positives: List[str]) -> List[List[str]]:
695
+ # """Find hard negatives for a batch of queries with error handling and retries."""
696
+ # retry_count = 0
697
+ # total_responses = len(self.response_pool)
698
+
699
+ # while retry_count < self.max_retries:
700
+ # try:
701
+ # query_embeddings = np.vstack([
702
+ # self.query_embeddings_cache[q] for q in queries
703
+ # ]).astype(np.float32)
704
+
705
+ # query_embeddings = np.ascontiguousarray(query_embeddings)
706
+ # faiss.normalize_L2(query_embeddings)
707
+
708
+ # k = 1 # TODO: try higher k for better results
709
+ # #logger.debug(f"Searching with k={k} among {total_responses} responses")
710
+
711
+ # distances, indices = self.index.search(query_embeddings, k)
712
+
713
+ # all_negatives = []
714
+ # for query_indices, query, positive in zip(indices, queries, positives):
715
+ # negatives = []
716
+ # positive_strip = positive.strip()
717
+ # seen = {positive_strip}
718
+
719
+ # for idx in query_indices:
720
+ # if idx >= 0 and idx < total_responses:
721
+ # candidate = self.response_pool[idx].strip()
722
+ # if candidate and candidate not in seen:
723
+ # seen.add(candidate)
724
+ # negatives.append(candidate)
725
+ # if len(negatives) >= self.neg_samples:
726
+ # break
727
+
728
+ # # Pad with a special empty negative if necessary
729
+ # while len(negatives) < self.neg_samples:
730
+ # negatives.append("<EMPTY_NEGATIVE>") # Use a special token
731
+
732
+ # all_negatives.append(negatives)
733
+
734
+ # return all_negatives