JoeArmani
commited on
Commit
·
9b5daff
1
Parent(s):
f5346f7
improve split processes
Browse files- chatbot_model.py +86 -94
- run_data_preparer.py +1 -1
- run_model_train.py +31 -10
- tf_data_pipeline.py +230 -57
chatbot_model.py
CHANGED
@@ -138,12 +138,14 @@ class RetrievalChatbot(DeviceAwareModel):
|
|
138 |
device: str = None,
|
139 |
strategy=None,
|
140 |
reranker: Optional[CrossEncoderReranker] = None,
|
141 |
-
summarizer: Optional[Summarizer] = None
|
|
|
142 |
):
|
143 |
super().__init__()
|
144 |
self.config = config
|
145 |
self.strategy = strategy
|
146 |
self.device = device or self._setup_default_device()
|
|
|
147 |
|
148 |
# Initialize reranker, summarizer, tokenizer, and memory monitor
|
149 |
self.reranker = reranker or self._initialize_reranker()
|
@@ -151,13 +153,10 @@ class RetrievalChatbot(DeviceAwareModel):
|
|
151 |
self.tokenizer = self._initialize_tokenizer()
|
152 |
self.memory_monitor = GPUMemoryMonitor()
|
153 |
|
154 |
-
# Initialize models
|
155 |
-
self.min_batch_size = 8
|
156 |
-
self.max_batch_size = 128
|
157 |
-
self.current_batch_size = 32
|
158 |
-
|
159 |
-
# Collect unique responses from dialogues
|
160 |
-
self.response_pool, self.unique_responses = self._collect_responses(dialogues)
|
161 |
|
162 |
# Initialize training history
|
163 |
self.history = {
|
@@ -166,6 +165,15 @@ class RetrievalChatbot(DeviceAwareModel):
|
|
166 |
"train_metrics": {},
|
167 |
"val_metrics": {}
|
168 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
169 |
|
170 |
def _setup_default_device(self) -> str:
|
171 |
"""Set up default device if none is provided."""
|
@@ -236,11 +244,14 @@ class RetrievalChatbot(DeviceAwareModel):
|
|
236 |
self.encoder.pretrained.resize_token_embeddings(new_vocab_size)
|
237 |
logger.info(f"Token embeddings resized to: {new_vocab_size}")
|
238 |
|
239 |
-
|
240 |
-
|
241 |
-
|
242 |
-
|
243 |
-
|
|
|
|
|
|
|
244 |
|
245 |
# Retrieve embedding dimension from encoder
|
246 |
embedding_dim = self.config.embedding_dim
|
@@ -271,10 +282,10 @@ class RetrievalChatbot(DeviceAwareModel):
|
|
271 |
self.current_batch_size = new_size
|
272 |
|
273 |
def _initialize_faiss(self):
|
274 |
-
"""Initialize FAISS with
|
275 |
logger.info("Initializing FAISS index...")
|
276 |
|
277 |
-
#
|
278 |
self.faiss_gpu = False
|
279 |
self.gpu_resources = []
|
280 |
|
@@ -294,53 +305,30 @@ class RetrievalChatbot(DeviceAwareModel):
|
|
294 |
self.gpu_resources.append(res)
|
295 |
self.faiss_gpu = True
|
296 |
logger.info(f"FAISS GPU resources initialized on {ngpus} GPUs")
|
297 |
-
else:
|
298 |
-
logger.info("Using CPU-only FAISS build")
|
299 |
-
|
300 |
except Exception as e:
|
301 |
logger.warning(f"Using CPU due to GPU initialization error: {e}")
|
302 |
|
303 |
-
# TODO: figure out buf with faiss-gpu
|
304 |
-
# TODO: consider IndexIVFFlat in the future (speed).
|
305 |
try:
|
306 |
# Create appropriate index based on dataset size
|
307 |
if len(self.unique_responses) < 1000:
|
308 |
logger.info("Small dataset detected, using simple FlatIP index")
|
309 |
self.index = faiss.IndexFlatIP(self.config.embedding_dim)
|
310 |
else:
|
311 |
-
#
|
312 |
-
# nlist = min(
|
313 |
-
# 25, # max clusters
|
314 |
-
# max(int(math.sqrt(len(self.unique_responses))), 1) # min 1 cluster
|
315 |
-
# )
|
316 |
-
# logger.info(f"Using IVF index with {nlist} clusters")
|
317 |
-
|
318 |
-
# quantizer = faiss.IndexFlatIP(self.config.embedding_dim)
|
319 |
-
# self.index = faiss.IndexIVFFlat(
|
320 |
-
# quantizer,
|
321 |
-
# self.config.embedding_dim,
|
322 |
-
# nlist,
|
323 |
-
# faiss.METRIC_INNER_PRODUCT
|
324 |
-
# )
|
325 |
self.index = faiss.IndexFlatIP(self.config.embedding_dim)
|
326 |
|
327 |
-
#
|
328 |
-
|
329 |
-
|
330 |
-
|
331 |
-
|
332 |
-
|
333 |
-
|
334 |
-
|
335 |
-
|
336 |
-
|
337 |
-
|
338 |
-
|
339 |
-
|
340 |
-
# # Set search parameters for IVF index
|
341 |
-
# if isinstance(self.index, faiss.IndexIVFFlat):
|
342 |
-
# self.index.nprobe = min(10, nlist)
|
343 |
-
|
344 |
except Exception as e:
|
345 |
logger.error(f"Error initializing FAISS: {e}")
|
346 |
raise
|
@@ -353,21 +341,16 @@ class RetrievalChatbot(DeviceAwareModel):
|
|
353 |
"""
|
354 |
Encodes responses with more conservative memory management.
|
355 |
"""
|
|
|
|
|
|
|
|
|
356 |
all_embeddings = []
|
357 |
self.current_batch_size = batch_size
|
358 |
|
359 |
if self.memory_monitor.has_gpu:
|
360 |
batch_size = 128
|
361 |
|
362 |
-
# Memory stats
|
363 |
-
# if self.memory_monitor.has_gpu:
|
364 |
-
# initial_stats = self.memory_monitor.get_memory_stats()
|
365 |
-
# if initial_stats:
|
366 |
-
# logger.info("Initial GPU memory state:")
|
367 |
-
# logger.info(f"Total: {initial_stats.total / 1e9:.2f}GB")
|
368 |
-
# logger.info(f"Used: {initial_stats.used / 1e9:.2f}GB")
|
369 |
-
# logger.info(f"Free: {initial_stats.free / 1e9:.2f}GB")
|
370 |
-
|
371 |
total_processed = 0
|
372 |
|
373 |
with tqdm(total=len(responses), desc="Encoding responses") as pbar:
|
@@ -434,7 +417,10 @@ class RetrievalChatbot(DeviceAwareModel):
|
|
434 |
raise
|
435 |
|
436 |
# Concatenate results
|
437 |
-
|
|
|
|
|
|
|
438 |
if len(all_embeddings) == 1:
|
439 |
final_embeddings = all_embeddings[0]
|
440 |
else:
|
@@ -727,7 +713,11 @@ class RetrievalChatbot(DeviceAwareModel):
|
|
727 |
raise
|
728 |
|
729 |
def verify_faiss_index(self):
|
730 |
-
"""Verify that FAISS index matches the response pool."""
|
|
|
|
|
|
|
|
|
731 |
indexed_size = self.index.ntotal
|
732 |
pool_size = len(self.response_pool)
|
733 |
logger.info(f"FAISS index size: {indexed_size}")
|
@@ -813,6 +803,10 @@ class RetrievalChatbot(DeviceAwareModel):
|
|
813 |
|
814 |
def retrieve_responses_faiss(self, query: str, top_k: int = 5) -> List[Tuple[str, float]]:
|
815 |
"""Retrieve top-k responses using FAISS."""
|
|
|
|
|
|
|
|
|
816 |
# Encode the query
|
817 |
q_emb = self.encode_query(query) # Shape: [1, embedding_dim]
|
818 |
q_emb_np = q_emb.numpy().astype('float32') # Ensure type match
|
@@ -874,32 +868,6 @@ class RetrievalChatbot(DeviceAwareModel):
|
|
874 |
logger.info(f"Models and tokenizer loaded from {load_dir}.")
|
875 |
return chatbot
|
876 |
|
877 |
-
def parse_tfrecord_fn(example_proto, max_length, neg_samples):
|
878 |
-
"""
|
879 |
-
Parses a single TFRecord example.
|
880 |
-
|
881 |
-
Args:
|
882 |
-
example_proto: A serialized TFRecord example.
|
883 |
-
max_length: The maximum sequence length for tokenization.
|
884 |
-
neg_samples: The number of hard negatives per query.
|
885 |
-
|
886 |
-
Returns:
|
887 |
-
A tuple of (query_ids, positive_ids, negative_ids).
|
888 |
-
"""
|
889 |
-
feature_description = {
|
890 |
-
'query_ids': tf.io.FixedLenFeature([max_length], tf.int64),
|
891 |
-
'positive_ids': tf.io.FixedLenFeature([max_length], tf.int64),
|
892 |
-
'negative_ids': tf.io.FixedLenFeature([neg_samples * max_length], tf.int64),
|
893 |
-
}
|
894 |
-
parsed_features = tf.io.parse_single_example(example_proto, feature_description)
|
895 |
-
|
896 |
-
query_ids = tf.cast(parsed_features['query_ids'], tf.int32)
|
897 |
-
positive_ids = tf.cast(parsed_features['positive_ids'], tf.int32)
|
898 |
-
negative_ids = tf.cast(parsed_features['negative_ids'], tf.int32)
|
899 |
-
negative_ids = tf.reshape(negative_ids, [neg_samples, max_length])
|
900 |
-
|
901 |
-
return query_ids, positive_ids, negative_ids
|
902 |
-
|
903 |
def train_streaming(
|
904 |
self,
|
905 |
tfrecord_file_path: str,
|
@@ -915,10 +883,34 @@ class RetrievalChatbot(DeviceAwareModel):
|
|
915 |
) -> None:
|
916 |
"""Training using a pre-prepared TFRecord dataset."""
|
917 |
logger.info("Starting training with pre-prepared TFRecord dataset...")
|
918 |
-
|
919 |
-
|
920 |
-
|
921 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
922 |
raw_dataset = tf.data.TFRecordDataset(tfrecord_file_path)
|
923 |
total_pairs = sum(1 for _ in raw_dataset)
|
924 |
logger.info(f"Total pairs in TFRecord: {total_pairs}")
|
@@ -964,12 +956,12 @@ class RetrievalChatbot(DeviceAwareModel):
|
|
964 |
logger.info(f"TensorBoard logs will be saved in {log_dir}")
|
965 |
|
966 |
# Define the parsing function with the appropriate max_length and neg_samples
|
967 |
-
parse_fn = lambda x:
|
968 |
|
969 |
# Create the full dataset
|
970 |
dataset = tf.data.TFRecordDataset(tfrecord_file_path)
|
971 |
dataset = dataset.map(parse_fn, num_parallel_calls=tf.data.AUTOTUNE)
|
972 |
-
dataset = dataset.shuffle(buffer_size=10000) # Adjust buffer size as needed
|
973 |
dataset = dataset.batch(batch_size, drop_remainder=True)
|
974 |
dataset = dataset.prefetch(tf.data.AUTOTUNE)
|
975 |
|
|
|
138 |
device: str = None,
|
139 |
strategy=None,
|
140 |
reranker: Optional[CrossEncoderReranker] = None,
|
141 |
+
summarizer: Optional[Summarizer] = None,
|
142 |
+
mode: str = 'preparation'
|
143 |
):
|
144 |
super().__init__()
|
145 |
self.config = config
|
146 |
self.strategy = strategy
|
147 |
self.device = device or self._setup_default_device()
|
148 |
+
self.mode = mode.lower()
|
149 |
|
150 |
# Initialize reranker, summarizer, tokenizer, and memory monitor
|
151 |
self.reranker = reranker or self._initialize_reranker()
|
|
|
153 |
self.tokenizer = self._initialize_tokenizer()
|
154 |
self.memory_monitor = GPUMemoryMonitor()
|
155 |
|
156 |
+
# # Initialize models
|
157 |
+
# self.min_batch_size = 8
|
158 |
+
# self.max_batch_size = 128
|
159 |
+
# self.current_batch_size = 32
|
|
|
|
|
|
|
160 |
|
161 |
# Initialize training history
|
162 |
self.history = {
|
|
|
165 |
"train_metrics": {},
|
166 |
"val_metrics": {}
|
167 |
}
|
168 |
+
|
169 |
+
# Collect unique responses from dialogues
|
170 |
+
if self.mode == 'preparation':
|
171 |
+
# Collect unique responses from dialogues only in preparation mode
|
172 |
+
self.response_pool, self.unique_responses = self._collect_responses(dialogues)
|
173 |
+
else:
|
174 |
+
# In training mode, assume response_pool is handled via TFRecord
|
175 |
+
self.response_pool = []
|
176 |
+
self.unique_responses = []
|
177 |
|
178 |
def _setup_default_device(self) -> str:
|
179 |
"""Set up default device if none is provided."""
|
|
|
244 |
self.encoder.pretrained.resize_token_embeddings(new_vocab_size)
|
245 |
logger.info(f"Token embeddings resized to: {new_vocab_size}")
|
246 |
|
247 |
+
if self.mode == 'preparation':
|
248 |
+
# Initialize FAISS index only in preparation mode
|
249 |
+
self._initialize_faiss()
|
250 |
+
# Compute and index embeddings
|
251 |
+
self._compute_and_index_embeddings()
|
252 |
+
else:
|
253 |
+
# In training mode, skip FAISS indexing from dialogues
|
254 |
+
logger.info("Training mode: Skipping FAISS index initialization from dialogues.")
|
255 |
|
256 |
# Retrieve embedding dimension from encoder
|
257 |
embedding_dim = self.config.embedding_dim
|
|
|
282 |
self.current_batch_size = new_size
|
283 |
|
284 |
def _initialize_faiss(self):
|
285 |
+
"""Initialize FAISS with safe GPU handling and memory monitoring."""
|
286 |
logger.info("Initializing FAISS index...")
|
287 |
|
288 |
+
# Detect if we have GPU-enabled FAISS
|
289 |
self.faiss_gpu = False
|
290 |
self.gpu_resources = []
|
291 |
|
|
|
305 |
self.gpu_resources.append(res)
|
306 |
self.faiss_gpu = True
|
307 |
logger.info(f"FAISS GPU resources initialized on {ngpus} GPUs")
|
|
|
|
|
|
|
308 |
except Exception as e:
|
309 |
logger.warning(f"Using CPU due to GPU initialization error: {e}")
|
310 |
|
|
|
|
|
311 |
try:
|
312 |
# Create appropriate index based on dataset size
|
313 |
if len(self.unique_responses) < 1000:
|
314 |
logger.info("Small dataset detected, using simple FlatIP index")
|
315 |
self.index = faiss.IndexFlatIP(self.config.embedding_dim)
|
316 |
else:
|
317 |
+
# For larger datasets, consider using more efficient indices like IVF
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
318 |
self.index = faiss.IndexFlatIP(self.config.embedding_dim)
|
319 |
|
320 |
+
# Move to GPU(s) if available and needed
|
321 |
+
if self.faiss_gpu and self.gpu_resources:
|
322 |
+
try:
|
323 |
+
if len(self.gpu_resources) > 1:
|
324 |
+
self.index = faiss.index_cpu_to_gpus_list(self.index, self.gpu_resources)
|
325 |
+
logger.info("FAISS index distributed across multiple GPUs")
|
326 |
+
else:
|
327 |
+
self.index = faiss.index_cpu_to_gpu(self.gpu_resources[0], 0, self.index)
|
328 |
+
logger.info("FAISS index moved to single GPU")
|
329 |
+
except Exception as e:
|
330 |
+
logger.warning(f"Failed to move index to GPU: {e}. Falling back to CPU")
|
331 |
+
self.faiss_gpu = False
|
|
|
|
|
|
|
|
|
|
|
332 |
except Exception as e:
|
333 |
logger.error(f"Error initializing FAISS: {e}")
|
334 |
raise
|
|
|
341 |
"""
|
342 |
Encodes responses with more conservative memory management.
|
343 |
"""
|
344 |
+
if not responses:
|
345 |
+
logger.info("No responses to encode. Returning empty tensor.")
|
346 |
+
return tf.constant([], dtype=tf.float32)
|
347 |
+
|
348 |
all_embeddings = []
|
349 |
self.current_batch_size = batch_size
|
350 |
|
351 |
if self.memory_monitor.has_gpu:
|
352 |
batch_size = 128
|
353 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
354 |
total_processed = 0
|
355 |
|
356 |
with tqdm(total=len(responses), desc="Encoding responses") as pbar:
|
|
|
417 |
raise
|
418 |
|
419 |
# Concatenate results
|
420 |
+
if not all_embeddings:
|
421 |
+
logger.info("No embeddings were encoded. Returning empty tensor.")
|
422 |
+
return tf.constant([], dtype=tf.float32)
|
423 |
+
|
424 |
if len(all_embeddings) == 1:
|
425 |
final_embeddings = all_embeddings[0]
|
426 |
else:
|
|
|
713 |
raise
|
714 |
|
715 |
def verify_faiss_index(self):
|
716 |
+
"""Verify that FAISS index matches the response pool, if index exists."""
|
717 |
+
if not hasattr(self, 'index') or self.index is None:
|
718 |
+
logger.info("FAISS index not initialized. Skipping verification.")
|
719 |
+
return
|
720 |
+
|
721 |
indexed_size = self.index.ntotal
|
722 |
pool_size = len(self.response_pool)
|
723 |
logger.info(f"FAISS index size: {indexed_size}")
|
|
|
803 |
|
804 |
def retrieve_responses_faiss(self, query: str, top_k: int = 5) -> List[Tuple[str, float]]:
|
805 |
"""Retrieve top-k responses using FAISS."""
|
806 |
+
if not hasattr(self, 'index') or self.index is None:
|
807 |
+
logger.warning("FAISS index not initialized. Cannot retrieve responses.")
|
808 |
+
return []
|
809 |
+
|
810 |
# Encode the query
|
811 |
q_emb = self.encode_query(query) # Shape: [1, embedding_dim]
|
812 |
q_emb_np = q_emb.numpy().astype('float32') # Ensure type match
|
|
|
868 |
logger.info(f"Models and tokenizer loaded from {load_dir}.")
|
869 |
return chatbot
|
870 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
871 |
def train_streaming(
|
872 |
self,
|
873 |
tfrecord_file_path: str,
|
|
|
883 |
) -> None:
|
884 |
"""Training using a pre-prepared TFRecord dataset."""
|
885 |
logger.info("Starting training with pre-prepared TFRecord dataset...")
|
886 |
+
|
887 |
+
def parse_tfrecord_fn(example_proto, max_length, neg_samples):
|
888 |
+
"""
|
889 |
+
Parses a single TFRecord example.
|
890 |
+
|
891 |
+
Args:
|
892 |
+
example_proto: A serialized TFRecord example.
|
893 |
+
max_length: The maximum sequence length for tokenization.
|
894 |
+
neg_samples: The number of hard negatives per query.
|
895 |
+
|
896 |
+
Returns:
|
897 |
+
A tuple of (query_ids, positive_ids, negative_ids).
|
898 |
+
"""
|
899 |
+
feature_description = {
|
900 |
+
'query_ids': tf.io.FixedLenFeature([max_length], tf.int64),
|
901 |
+
'positive_ids': tf.io.FixedLenFeature([max_length], tf.int64),
|
902 |
+
'negative_ids': tf.io.FixedLenFeature([neg_samples * max_length], tf.int64),
|
903 |
+
}
|
904 |
+
parsed_features = tf.io.parse_single_example(example_proto, feature_description)
|
905 |
+
|
906 |
+
query_ids = tf.cast(parsed_features['query_ids'], tf.int32)
|
907 |
+
positive_ids = tf.cast(parsed_features['positive_ids'], tf.int32)
|
908 |
+
negative_ids = tf.cast(parsed_features['negative_ids'], tf.int32)
|
909 |
+
negative_ids = tf.reshape(negative_ids, [neg_samples, max_length])
|
910 |
+
|
911 |
+
return query_ids, positive_ids, negative_ids
|
912 |
+
|
913 |
+
# Calculate total steps by counting the number of records in the TFRecord
|
914 |
raw_dataset = tf.data.TFRecordDataset(tfrecord_file_path)
|
915 |
total_pairs = sum(1 for _ in raw_dataset)
|
916 |
logger.info(f"Total pairs in TFRecord: {total_pairs}")
|
|
|
956 |
logger.info(f"TensorBoard logs will be saved in {log_dir}")
|
957 |
|
958 |
# Define the parsing function with the appropriate max_length and neg_samples
|
959 |
+
parse_fn = lambda x: parse_tfrecord_fn(x, self.config.max_context_token_limit, self.config.neg_samples)
|
960 |
|
961 |
# Create the full dataset
|
962 |
dataset = tf.data.TFRecordDataset(tfrecord_file_path)
|
963 |
dataset = dataset.map(parse_fn, num_parallel_calls=tf.data.AUTOTUNE)
|
964 |
+
dataset = dataset.shuffle(buffer_size=10000) # Adjust buffer size as needed
|
965 |
dataset = dataset.batch(batch_size, drop_remainder=True)
|
966 |
dataset = dataset.prefetch(tf.data.AUTOTUNE)
|
967 |
|
run_data_preparer.py
CHANGED
@@ -30,7 +30,7 @@ def main():
|
|
30 |
TF_RECORD_DIR = 'training_data'
|
31 |
FAISS_INDEX_PRODUCTION_PATH = os.path.join(FAISS_INDICES_DIR, 'faiss_index_production.index')
|
32 |
FAISS_INDEX_TEST_PATH = os.path.join(FAISS_INDICES_DIR, 'faiss_index_test.index')
|
33 |
-
ENVIRONMENT = '
|
34 |
if ENVIRONMENT == 'test':
|
35 |
FAISS_INDEX_PATH = FAISS_INDEX_TEST_PATH
|
36 |
else:
|
|
|
30 |
TF_RECORD_DIR = 'training_data'
|
31 |
FAISS_INDEX_PRODUCTION_PATH = os.path.join(FAISS_INDICES_DIR, 'faiss_index_production.index')
|
32 |
FAISS_INDEX_TEST_PATH = os.path.join(FAISS_INDICES_DIR, 'faiss_index_test.index')
|
33 |
+
ENVIRONMENT = 'production' # or 'test'
|
34 |
if ENVIRONMENT == 'test':
|
35 |
FAISS_INDEX_PATH = FAISS_INDEX_TEST_PATH
|
36 |
else:
|
run_model_train.py
CHANGED
@@ -31,36 +31,57 @@ def run_interactive_chat(chatbot, quality_checker):
|
|
31 |
for resp, score in candidates[1:4]:
|
32 |
print(f"Score: {score:.4f} - {resp}")
|
33 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
34 |
def main():
|
|
|
|
|
|
|
|
|
35 |
# Initialize environment
|
36 |
tf.keras.backend.clear_session()
|
37 |
env = EnvironmentSetup()
|
38 |
env.initialize()
|
39 |
|
40 |
-
|
41 |
-
EPOCHS =
|
42 |
TF_RECORD_FILE_PATH = 'training_data/training_data.tfrecord'
|
43 |
|
44 |
# Optimize batch size for Colab
|
45 |
batch_size = env.optimize_batch_size(base_batch_size=16)
|
46 |
|
|
|
47 |
# Initialize configuration
|
48 |
config = ChatbotConfig(
|
49 |
embedding_dim=768, # DistilBERT
|
50 |
max_context_token_limit=512,
|
51 |
freeze_embeddings=False,
|
52 |
-
neg_samples=3,
|
53 |
)
|
54 |
|
55 |
-
#
|
56 |
-
# dialogues = RetrievalChatbot.load_training_data(data_path=TRAINING_DATA_PATH, debug_samples=DEBUG_SAMPLES)
|
57 |
-
# print(dialogues)
|
58 |
-
|
59 |
-
# Initialize chatbot and verify FAISS index
|
60 |
#with env.strategy.scope():
|
61 |
-
chatbot = RetrievalChatbot(config)
|
62 |
chatbot.build_models()
|
63 |
-
|
|
|
|
|
64 |
|
65 |
chatbot.train_streaming(
|
66 |
tfrecord_file_path=TF_RECORD_FILE_PATH,
|
|
|
31 |
for resp, score in candidates[1:4]:
|
32 |
print(f"Score: {score:.4f} - {resp}")
|
33 |
|
34 |
+
def inspect_tfrecord(tfrecord_file_path, num_examples=3):
|
35 |
+
def parse_example(example_proto):
|
36 |
+
feature_description = {
|
37 |
+
'query_ids': tf.io.FixedLenFeature([512], tf.int64), # Adjust max_length if different
|
38 |
+
'positive_ids': tf.io.FixedLenFeature([512], tf.int64),
|
39 |
+
'negative_ids': tf.io.FixedLenFeature([3 * 512], tf.int64), # Adjust neg_samples if different
|
40 |
+
}
|
41 |
+
return tf.io.parse_single_example(example_proto, feature_description)
|
42 |
+
|
43 |
+
dataset = tf.data.TFRecordDataset(tfrecord_file_path)
|
44 |
+
dataset = dataset.map(parse_example)
|
45 |
+
|
46 |
+
for i, example in enumerate(dataset.take(num_examples)):
|
47 |
+
print(f"Example {i+1}:")
|
48 |
+
print(f"Query IDs: {example['query_ids'].numpy()}")
|
49 |
+
print(f"Positive IDs: {example['positive_ids'].numpy()}")
|
50 |
+
print(f"Negative IDs: {example['negative_ids'].numpy()}")
|
51 |
+
print("-" * 50)
|
52 |
+
|
53 |
def main():
|
54 |
+
|
55 |
+
# Quick test to inspect TFRecord
|
56 |
+
#inspect_tfrecord('training_data/training_data.tfrecord', num_examples=3)
|
57 |
+
|
58 |
# Initialize environment
|
59 |
tf.keras.backend.clear_session()
|
60 |
env = EnvironmentSetup()
|
61 |
env.initialize()
|
62 |
|
63 |
+
# Training configuration
|
64 |
+
EPOCHS = 20
|
65 |
TF_RECORD_FILE_PATH = 'training_data/training_data.tfrecord'
|
66 |
|
67 |
# Optimize batch size for Colab
|
68 |
batch_size = env.optimize_batch_size(base_batch_size=16)
|
69 |
|
70 |
+
|
71 |
# Initialize configuration
|
72 |
config = ChatbotConfig(
|
73 |
embedding_dim=768, # DistilBERT
|
74 |
max_context_token_limit=512,
|
75 |
freeze_embeddings=False,
|
|
|
76 |
)
|
77 |
|
78 |
+
# Initialize chatbot
|
|
|
|
|
|
|
|
|
79 |
#with env.strategy.scope():
|
80 |
+
chatbot = RetrievalChatbot(config, mode='training')
|
81 |
chatbot.build_models()
|
82 |
+
|
83 |
+
if chatbot.mode == 'preparation':
|
84 |
+
chatbot.verify_faiss_index()
|
85 |
|
86 |
chatbot.train_streaming(
|
87 |
tfrecord_file_path=TF_RECORD_FILE_PATH,
|
tf_data_pipeline.py
CHANGED
@@ -4,6 +4,7 @@ import numpy as np
|
|
4 |
import faiss
|
5 |
import tensorflow as tf
|
6 |
import h5py
|
|
|
7 |
from tqdm import tqdm
|
8 |
import json
|
9 |
from pathlib import Path
|
@@ -146,7 +147,7 @@ class TFDataPipeline:
|
|
146 |
def collect_responses(self, dialogues: List[dict]) -> List[str]:
|
147 |
"""Extract unique assistant responses from dialogues."""
|
148 |
response_set = set()
|
149 |
-
for dialogue in dialogues:
|
150 |
turns = dialogue.get('turns', [])
|
151 |
for turn in turns:
|
152 |
speaker = turn.get('speaker')
|
@@ -180,20 +181,17 @@ class TFDataPipeline:
|
|
180 |
|
181 |
def _compute_and_index_response_embeddings(self):
|
182 |
"""
|
183 |
-
Computes embeddings for the response pool and adds them to the FAISS index.
|
184 |
"""
|
185 |
logger.info("Computing embeddings for the response pool...")
|
186 |
|
187 |
-
# Log the contents and types of response_pool
|
188 |
-
for idx, response in enumerate(self.response_pool[:5], 1): # Log first 5 responses
|
189 |
-
logger.debug(f"Response {idx}: {response} (Type: {type(response)})")
|
190 |
-
|
191 |
# Ensure all responses are strings
|
192 |
if not all(isinstance(response, str) for response in self.response_pool):
|
193 |
logger.error("All elements in response_pool must be strings.")
|
194 |
raise ValueError("Invalid data type in response_pool.")
|
195 |
|
196 |
-
#
|
|
|
197 |
encoded_responses = self.tokenizer(
|
198 |
self.response_pool,
|
199 |
padding=True,
|
@@ -203,28 +201,87 @@ class TFDataPipeline:
|
|
203 |
)
|
204 |
response_ids = encoded_responses['input_ids']
|
205 |
|
206 |
-
# Compute embeddings in batches
|
207 |
batch_size = getattr(self, 'embedding_batch_size', 64) # Default to 64 if not set
|
|
|
|
|
208 |
embeddings = []
|
209 |
-
|
210 |
-
|
211 |
-
|
212 |
-
|
213 |
-
|
214 |
-
|
215 |
-
|
|
|
|
|
|
|
216 |
|
217 |
if embeddings:
|
218 |
embeddings = np.vstack(embeddings).astype(np.float32)
|
219 |
-
# Add embeddings to FAISS index
|
220 |
logger.info(f"Adding {len(embeddings)} response embeddings to FAISS index...")
|
221 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
222 |
logger.info("Response embeddings added to FAISS index.")
|
223 |
else:
|
224 |
logger.warning("No embeddings to add to FAISS index.")
|
225 |
|
226 |
# **Sanity Check:** Verify the number of embeddings in FAISS index
|
227 |
logger.info(f"Total embeddings in FAISS index after addition: {self.index.ntotal}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
228 |
|
229 |
def _find_hard_negatives_batch(self, queries: List[str], positives: List[str]) -> List[List[str]]:
|
230 |
"""Find hard negatives for a batch of queries with error handling and retries."""
|
@@ -349,53 +406,169 @@ class TFDataPipeline:
|
|
349 |
|
350 |
return query_ids, positive_ids, negative_ids
|
351 |
|
352 |
-
|
353 |
-
|
354 |
-
|
355 |
-
|
356 |
-
|
357 |
-
|
358 |
-
|
359 |
-
|
360 |
-
|
361 |
-
|
362 |
-
|
363 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
364 |
for dialogue in batch_dialogues:
|
365 |
pairs = self._extract_pairs_from_dialogue(dialogue)
|
366 |
-
queries = []
|
367 |
-
positives = []
|
368 |
-
|
369 |
for query, positive in pairs:
|
370 |
-
|
371 |
-
|
372 |
-
|
373 |
-
|
374 |
-
|
375 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
376 |
|
377 |
-
|
378 |
-
|
379 |
|
380 |
-
|
381 |
-
|
382 |
-
|
383 |
-
|
384 |
-
|
385 |
-
|
386 |
|
387 |
-
|
388 |
-
|
389 |
-
|
390 |
-
|
391 |
-
|
392 |
-
|
393 |
-
|
394 |
-
|
395 |
-
|
396 |
|
397 |
-
|
398 |
-
|
399 |
|
400 |
def _tokenize_negatives_tf(self, negatives):
|
401 |
"""Tokenizes negatives using tf.py_function."""
|
|
|
4 |
import faiss
|
5 |
import tensorflow as tf
|
6 |
import h5py
|
7 |
+
import math
|
8 |
from tqdm import tqdm
|
9 |
import json
|
10 |
from pathlib import Path
|
|
|
147 |
def collect_responses(self, dialogues: List[dict]) -> List[str]:
|
148 |
"""Extract unique assistant responses from dialogues."""
|
149 |
response_set = set()
|
150 |
+
for dialogue in tqdm(dialogues, desc="Processing Dialogues", unit="dialogue"):
|
151 |
turns = dialogue.get('turns', [])
|
152 |
for turn in turns:
|
153 |
speaker = turn.get('speaker')
|
|
|
181 |
|
182 |
def _compute_and_index_response_embeddings(self):
|
183 |
"""
|
184 |
+
Computes embeddings for the response pool and adds them to the FAISS index with progress bars.
|
185 |
"""
|
186 |
logger.info("Computing embeddings for the response pool...")
|
187 |
|
|
|
|
|
|
|
|
|
188 |
# Ensure all responses are strings
|
189 |
if not all(isinstance(response, str) for response in self.response_pool):
|
190 |
logger.error("All elements in response_pool must be strings.")
|
191 |
raise ValueError("Invalid data type in response_pool.")
|
192 |
|
193 |
+
# Tokenization
|
194 |
+
logger.info("Tokenizing responses...")
|
195 |
encoded_responses = self.tokenizer(
|
196 |
self.response_pool,
|
197 |
padding=True,
|
|
|
201 |
)
|
202 |
response_ids = encoded_responses['input_ids']
|
203 |
|
204 |
+
# Compute embeddings in batches with progress bar
|
205 |
batch_size = getattr(self, 'embedding_batch_size', 64) # Default to 64 if not set
|
206 |
+
total_responses = len(response_ids)
|
207 |
+
logger.info(f"Computing embeddings in batches of {batch_size}...")
|
208 |
embeddings = []
|
209 |
+
|
210 |
+
with tqdm(total=total_responses, desc="Computing Embeddings", unit="response") as pbar:
|
211 |
+
for i in range(0, total_responses, batch_size):
|
212 |
+
batch_ids = response_ids[i:i + batch_size]
|
213 |
+
# Compute embeddings
|
214 |
+
batch_embeddings = self.encoder(batch_ids, training=False).numpy()
|
215 |
+
# Normalize embeddings for cosine similarity
|
216 |
+
faiss.normalize_L2(batch_embeddings)
|
217 |
+
embeddings.append(batch_embeddings)
|
218 |
+
pbar.update(len(batch_ids))
|
219 |
|
220 |
if embeddings:
|
221 |
embeddings = np.vstack(embeddings).astype(np.float32)
|
222 |
+
# Add embeddings to FAISS index with progress bar
|
223 |
logger.info(f"Adding {len(embeddings)} response embeddings to FAISS index...")
|
224 |
+
|
225 |
+
# Determine number of batches for indexing
|
226 |
+
index_batch_size = getattr(self, 'index_batch_size', 1000) # Adjust as needed
|
227 |
+
total_embeddings = len(embeddings)
|
228 |
+
num_index_batches = math.ceil(total_embeddings / index_batch_size)
|
229 |
+
|
230 |
+
with tqdm(total=total_embeddings, desc="Indexing Embeddings", unit="embedding") as pbar_index:
|
231 |
+
for i in range(0, total_embeddings, index_batch_size):
|
232 |
+
batch = embeddings[i:i + index_batch_size]
|
233 |
+
self.index.add(batch)
|
234 |
+
pbar_index.update(len(batch))
|
235 |
+
|
236 |
logger.info("Response embeddings added to FAISS index.")
|
237 |
else:
|
238 |
logger.warning("No embeddings to add to FAISS index.")
|
239 |
|
240 |
# **Sanity Check:** Verify the number of embeddings in FAISS index
|
241 |
logger.info(f"Total embeddings in FAISS index after addition: {self.index.ntotal}")
|
242 |
+
# def _compute_and_index_response_embeddings(self):
|
243 |
+
# """
|
244 |
+
# Computes embeddings for the response pool and adds them to the FAISS index.
|
245 |
+
# """
|
246 |
+
# logger.info("Computing embeddings for the response pool...")
|
247 |
+
|
248 |
+
# # Ensure all responses are strings
|
249 |
+
# if not all(isinstance(response, str) for response in self.response_pool):
|
250 |
+
# logger.error("All elements in response_pool must be strings.")
|
251 |
+
# raise ValueError("Invalid data type in response_pool.")
|
252 |
+
|
253 |
+
# # Proceed with tokenization
|
254 |
+
# encoded_responses = self.tokenizer(
|
255 |
+
# self.response_pool,
|
256 |
+
# padding=True,
|
257 |
+
# truncation=True,
|
258 |
+
# max_length=self.max_length,
|
259 |
+
# return_tensors='tf'
|
260 |
+
# )
|
261 |
+
# response_ids = encoded_responses['input_ids']
|
262 |
+
|
263 |
+
# # Compute embeddings in batches
|
264 |
+
# batch_size = getattr(self, 'embedding_batch_size', 64) # Default to 64 if not set
|
265 |
+
# embeddings = []
|
266 |
+
# for i in range(0, len(response_ids), batch_size):
|
267 |
+
# batch_ids = response_ids[i:i+batch_size]
|
268 |
+
# # Compute embeddings
|
269 |
+
# batch_embeddings = self.encoder(batch_ids, training=False).numpy()
|
270 |
+
# # Normalize embeddings if using inner product or cosine similarity
|
271 |
+
# faiss.normalize_L2(batch_embeddings)
|
272 |
+
# embeddings.append(batch_embeddings)
|
273 |
+
|
274 |
+
# if embeddings:
|
275 |
+
# embeddings = np.vstack(embeddings).astype(np.float32)
|
276 |
+
# # Add embeddings to FAISS index
|
277 |
+
# logger.info(f"Adding {len(embeddings)} response embeddings to FAISS index...")
|
278 |
+
# self.index.add(embeddings)
|
279 |
+
# logger.info("Response embeddings added to FAISS index.")
|
280 |
+
# else:
|
281 |
+
# logger.warning("No embeddings to add to FAISS index.")
|
282 |
+
|
283 |
+
# # **Sanity Check:** Verify the number of embeddings in FAISS index
|
284 |
+
# logger.info(f"Total embeddings in FAISS index after addition: {self.index.ntotal}")
|
285 |
|
286 |
def _find_hard_negatives_batch(self, queries: List[str], positives: List[str]) -> List[List[str]]:
|
287 |
"""Find hard negatives for a batch of queries with error handling and retries."""
|
|
|
406 |
|
407 |
return query_ids, positive_ids, negative_ids
|
408 |
|
409 |
+
# Testing updated batch tokenization
|
410 |
+
def prepare_and_save_data(self, dialogues: List[dict], tf_record_path: str, batch_size: int = 32):
|
411 |
+
"""
|
412 |
+
Processes dialogues in batches and saves to a TFRecord file using optimized batch tokenization and encoding.
|
413 |
+
|
414 |
+
Args:
|
415 |
+
dialogues (List[dict]): List of dialogue dictionaries.
|
416 |
+
tf_record_path (str): Path to save the TFRecord file.
|
417 |
+
batch_size (int): Number of dialogues to process per batch.
|
418 |
+
"""
|
419 |
+
logger.info(f"Preparing and saving data to {tf_record_path}...")
|
420 |
+
|
421 |
+
num_dialogues = len(dialogues)
|
422 |
+
num_batches = math.ceil(num_dialogues / batch_size)
|
423 |
+
|
424 |
+
with tf.io.TFRecordWriter(tf_record_path) as writer:
|
425 |
+
# Initialize progress bar
|
426 |
+
with tqdm(total=num_batches, desc="Preparing Data Batches", unit="batch") as pbar:
|
427 |
+
for i in range(num_batches):
|
428 |
+
start_idx = i * batch_size
|
429 |
+
end_idx = min(start_idx + batch_size, num_dialogues)
|
430 |
+
batch_dialogues = dialogues[start_idx:end_idx]
|
431 |
+
|
432 |
+
# Extract all query-positive pairs in the batch
|
433 |
+
queries = []
|
434 |
+
positives = []
|
435 |
for dialogue in batch_dialogues:
|
436 |
pairs = self._extract_pairs_from_dialogue(dialogue)
|
|
|
|
|
|
|
437 |
for query, positive in pairs:
|
438 |
+
if len(query) <= self.max_length and len(positive) <= self.max_length:
|
439 |
+
queries.append(query)
|
440 |
+
positives.append(positive)
|
441 |
+
|
442 |
+
if not queries:
|
443 |
+
pbar.update(1)
|
444 |
+
continue # Skip if no valid queries
|
445 |
+
|
446 |
+
# Compute and cache query embeddings
|
447 |
+
try:
|
448 |
+
self._compute_embeddings(queries)
|
449 |
+
except Exception as e:
|
450 |
+
logger.error(f"Error computing embeddings: {e}")
|
451 |
+
pbar.update(1)
|
452 |
+
continue # Skip to the next batch
|
453 |
+
|
454 |
+
# Find hard negatives for the batch
|
455 |
+
try:
|
456 |
+
hard_negatives = self._find_hard_negatives_batch(queries, positives)
|
457 |
+
except Exception as e:
|
458 |
+
logger.error(f"Error finding hard negatives: {e}")
|
459 |
+
pbar.update(1)
|
460 |
+
continue # Skip to the next batch
|
461 |
+
|
462 |
+
# Tokenize and encode all queries, positives, and negatives in the batch
|
463 |
+
try:
|
464 |
+
encoded_queries = self.tokenizer.batch_encode_plus(
|
465 |
+
queries,
|
466 |
+
max_length=self.config.max_context_token_limit,
|
467 |
+
truncation=True,
|
468 |
+
padding='max_length',
|
469 |
+
return_tensors='tf'
|
470 |
+
)
|
471 |
+
encoded_positives = self.tokenizer.batch_encode_plus(
|
472 |
+
positives,
|
473 |
+
max_length=self.config.max_context_token_limit,
|
474 |
+
truncation=True,
|
475 |
+
padding='max_length',
|
476 |
+
return_tensors='tf'
|
477 |
+
)
|
478 |
+
except Exception as e:
|
479 |
+
logger.error(f"Error during tokenization: {e}")
|
480 |
+
pbar.update(1)
|
481 |
+
continue # Skip to the next batch
|
482 |
+
|
483 |
+
# Flatten hard_negatives while maintaining alignment
|
484 |
+
# Assuming hard_negatives is a list of lists, where each sublist corresponds to a query
|
485 |
+
try:
|
486 |
+
flattened_negatives = [neg for sublist in hard_negatives for neg in sublist]
|
487 |
+
encoded_negatives = self.tokenizer.batch_encode_plus(
|
488 |
+
flattened_negatives,
|
489 |
+
max_length=self.config.max_context_token_limit,
|
490 |
+
truncation=True,
|
491 |
+
padding='max_length',
|
492 |
+
return_tensors='tf'
|
493 |
+
)
|
494 |
+
|
495 |
+
# Reshape encoded_negatives['input_ids'] to [num_queries, num_negatives, max_length]
|
496 |
+
num_negatives = self.config.neg_samples
|
497 |
+
reshaped_negatives = encoded_negatives['input_ids'].numpy().reshape(-1, num_negatives, self.config.max_context_token_limit)
|
498 |
+
except Exception as e:
|
499 |
+
logger.error(f"Error during negatives tokenization: {e}")
|
500 |
+
pbar.update(1)
|
501 |
+
continue # Skip to the next batch
|
502 |
+
|
503 |
+
# Serialize each example and write to TFRecord
|
504 |
+
for j in range(len(queries)):
|
505 |
+
try:
|
506 |
+
q_id = encoded_queries['input_ids'][j].numpy()
|
507 |
+
p_id = encoded_positives['input_ids'][j].numpy()
|
508 |
+
n_id = reshaped_negatives[j]
|
509 |
+
|
510 |
+
feature = {
|
511 |
+
'query_ids': tf.train.Feature(int64_list=tf.train.Int64List(value=q_id)),
|
512 |
+
'positive_ids': tf.train.Feature(int64_list=tf.train.Int64List(value=p_id)),
|
513 |
+
'negative_ids': tf.train.Feature(int64_list=tf.train.Int64List(value=n_id.flatten())),
|
514 |
+
}
|
515 |
+
example = tf.train.Example(features=tf.train.Features(feature=feature))
|
516 |
+
writer.write(example.SerializeToString())
|
517 |
+
except Exception as e:
|
518 |
+
logger.error(f"Error serializing example {j} in batch {i}: {e}")
|
519 |
+
continue # Skip to the next example
|
520 |
+
|
521 |
+
# Update progress bar
|
522 |
+
pbar.update(1)
|
523 |
+
|
524 |
+
logger.info(f"Data preparation complete. TFRecord saved.")
|
525 |
+
# def prepare_and_save_data(self, dialogues: List[dict], tfrecord_file_path: str, batch_size: int = 32):
|
526 |
+
# """Processes dialogues in batches and saves to a TFRecord file."""
|
527 |
+
# with tf.io.TFRecordWriter(tfrecord_file_path) as writer:
|
528 |
+
# total_dialogues = len(dialogues)
|
529 |
+
# logger.debug(f"Total dialogues to process: {total_dialogues}")
|
530 |
+
|
531 |
+
# with tqdm(total=total_dialogues, desc="Processing Dialogues", unit="dialogue") as pbar:
|
532 |
+
# for i in range(0, total_dialogues, batch_size):
|
533 |
+
# batch_dialogues = dialogues[i:i+batch_size]
|
534 |
+
# # Process each batch_dialogues
|
535 |
+
# # Extract pairs, find negatives, tokenize, and serialize
|
536 |
+
# # Example:
|
537 |
+
# for dialogue in batch_dialogues:
|
538 |
+
# pairs = self._extract_pairs_from_dialogue(dialogue)
|
539 |
+
# queries = []
|
540 |
+
# positives = []
|
541 |
+
|
542 |
+
# for query, positive in pairs:
|
543 |
+
# queries.append(query)
|
544 |
+
# positives.append(positive)
|
545 |
+
|
546 |
+
# if queries:
|
547 |
+
# # **Compute and cache query embeddings before searching**
|
548 |
+
# self._compute_embeddings(queries)
|
549 |
|
550 |
+
# # Find hard negatives
|
551 |
+
# hard_negatives = self._find_hard_negatives_batch(queries, positives)
|
552 |
|
553 |
+
# # for idx, negatives in enumerate(hard_negatives[:5]): # Log first 5 examples
|
554 |
+
# # logger.debug(f"Query: {queries[idx]}")
|
555 |
+
# # logger.debug(f"Positive: {positives[idx]}")
|
556 |
+
# # logger.debug(f"Hard Negatives: {negatives}")
|
557 |
+
# # Tokenize and encode
|
558 |
+
# query_ids, positive_ids, negative_ids = self._tokenize_and_encode(queries, positives, hard_negatives)
|
559 |
|
560 |
+
# # Serialize each example and write to TFRecord
|
561 |
+
# for q_id, p_id, n_id in zip(query_ids, positive_ids, negative_ids):
|
562 |
+
# feature = {
|
563 |
+
# 'query_ids': tf.train.Feature(int64_list=tf.train.Int64List(value=q_id)),
|
564 |
+
# 'positive_ids': tf.train.Feature(int64_list=tf.train.Int64List(value=p_id)),
|
565 |
+
# 'negative_ids': tf.train.Feature(int64_list=tf.train.Int64List(value=n_id.flatten())),
|
566 |
+
# }
|
567 |
+
# example = tf.train.Example(features=tf.train.Features(feature=feature))
|
568 |
+
# writer.write(example.SerializeToString())
|
569 |
|
570 |
+
# pbar.update(len(batch_dialogues))
|
571 |
+
# logger.info(f"Data preparation complete. TFRecord saved at {tfrecord_file_path}")
|
572 |
|
573 |
def _tokenize_negatives_tf(self, negatives):
|
574 |
"""Tokenizes negatives using tf.py_function."""
|