JoeArmani
commited on
Commit
·
2183656
1
Parent(s):
775baf9
upgrade to tf-dataset
Browse files- chatbot_model.py +648 -726
- deduplicate_augmented_dialogues.py +74 -0
- run_model_train.py +3 -3
chatbot_model.py
CHANGED
@@ -2,8 +2,6 @@ import time
|
|
2 |
from transformers import TFAutoModel, AutoTokenizer
|
3 |
import tensorflow as tf
|
4 |
import numpy as np
|
5 |
-
import threading
|
6 |
-
from queue import Queue, Empty
|
7 |
from typing import Generator, List, Tuple, Dict, Optional, Union, Any
|
8 |
import math
|
9 |
from dataclasses import dataclass
|
@@ -12,7 +10,6 @@ from pathlib import Path
|
|
12 |
import datetime
|
13 |
import faiss
|
14 |
import gc
|
15 |
-
import random
|
16 |
from response_quality_checker import ResponseQualityChecker
|
17 |
from cross_encoder_reranker import CrossEncoderReranker
|
18 |
from conversation_summarizer import DeviceAwareModel, Summarizer
|
@@ -29,7 +26,7 @@ class ChatbotConfig:
|
|
29 |
"""Configuration for the RetrievalChatbot."""
|
30 |
vocab_size: int = 30526 # DistilBERT vocab size + special tokens
|
31 |
max_context_token_limit: int = 512
|
32 |
-
embedding_dim: int =
|
33 |
encoder_units: int = 256
|
34 |
num_attention_heads: int = 8
|
35 |
dropout_rate: float = 0.2
|
@@ -42,6 +39,7 @@ class ChatbotConfig:
|
|
42 |
pretrained_model: str = 'distilbert-base-uncased'
|
43 |
dtype: str = 'float32'
|
44 |
freeze_embeddings: bool = False
|
|
|
45 |
# Additional configurations can be added here
|
46 |
|
47 |
def to_dict(self) -> dict:
|
@@ -103,9 +101,9 @@ class EncoderModel(tf.keras.Model):
|
|
103 |
|
104 |
# Apply pooling, projection, dropout, and normalization
|
105 |
x = self.pooler(x) # Shape: [batch_size, 768]
|
106 |
-
x = self.projection(x) # Shape: [batch_size,
|
107 |
x = self.dropout(x, training=training) # Apply dropout
|
108 |
-
x = self.normalize(x) # Shape: [batch_size,
|
109 |
|
110 |
return x
|
111 |
|
@@ -139,17 +137,6 @@ class RetrievalChatbot(DeviceAwareModel):
|
|
139 |
summarizer = Summarizer(device=self.device)
|
140 |
self.summarizer = summarizer
|
141 |
|
142 |
-
# # Configure XLA optimization if on GPU/TPU
|
143 |
-
# if self.device in ["GPU", "TPU"]:
|
144 |
-
# tf.config.optimizer.set_jit(True)
|
145 |
-
# logger.info(f"XLA compilation enabled for {self.device}")
|
146 |
-
|
147 |
-
# # Configure mixed precision for GPU/TPU
|
148 |
-
# if self.device != "CPU":
|
149 |
-
# policy = tf.keras.mixed_precision.Policy('mixed_float16')
|
150 |
-
# tf.keras.mixed_precision.set_global_policy(policy)
|
151 |
-
# logger.info("Mixed precision training enabled (float16)")
|
152 |
-
|
153 |
# Special tokens
|
154 |
self.special_tokens = {
|
155 |
"user": "<USER>",
|
@@ -354,6 +341,9 @@ class RetrievalChatbot(DeviceAwareModel):
|
|
354 |
"""
|
355 |
all_embeddings = []
|
356 |
self.current_batch_size = batch_size
|
|
|
|
|
|
|
357 |
|
358 |
# Memory stats
|
359 |
# if self.memory_monitor.has_gpu:
|
@@ -541,9 +531,9 @@ class RetrievalChatbot(DeviceAwareModel):
|
|
541 |
logger.info("Starting vector addition process...")
|
542 |
|
543 |
# Even smaller batches
|
544 |
-
initial_batch_size =
|
545 |
-
min_batch_size =
|
546 |
-
max_batch_size =
|
547 |
|
548 |
total_added = 0
|
549 |
retry_count = 0
|
@@ -572,7 +562,6 @@ class RetrievalChatbot(DeviceAwareModel):
|
|
572 |
# Update progress
|
573 |
batch_size = len(batch)
|
574 |
total_added += batch_size
|
575 |
-
#logger.info(f"Added batch of {batch_size} vectors ({total_added}/{len(response_embeddings)} total)")
|
576 |
|
577 |
# Memory cleanup every few batches
|
578 |
if total_added % (initial_batch_size * 5) == 0:
|
@@ -618,7 +607,7 @@ class RetrievalChatbot(DeviceAwareModel):
|
|
618 |
cpu_index = self.index
|
619 |
|
620 |
# Add remaining vectors on CPU with very small batches
|
621 |
-
batch_size =
|
622 |
total_added = already_added
|
623 |
|
624 |
for i in range(0, len(remaining_embeddings), batch_size):
|
@@ -911,36 +900,33 @@ class RetrievalChatbot(DeviceAwareModel):
|
|
911 |
warmup_steps_ratio: float = 0.1,
|
912 |
early_stopping_patience: int = 3,
|
913 |
min_delta: float = 1e-4,
|
914 |
-
buffer_size: int = 10,
|
915 |
neg_samples: int = 1
|
916 |
) -> None:
|
917 |
-
"""
|
918 |
-
|
919 |
-
giving priority to training until we meet `steps_per_epoch`, then
|
920 |
-
sending leftover batches to validation.
|
921 |
-
"""
|
922 |
-
logger.info("Starting streaming training pipeline...")
|
923 |
|
924 |
-
# Initialize
|
925 |
-
dataset_preparer =
|
|
|
926 |
tokenizer=self.tokenizer,
|
927 |
encoder=self.encoder,
|
928 |
-
index=self.index,
|
929 |
response_pool=self.response_pool,
|
930 |
max_length=self.config.max_context_token_limit,
|
931 |
-
batch_size=batch_size,
|
932 |
neg_samples=neg_samples
|
933 |
)
|
934 |
|
935 |
# Calculate total steps for learning rate schedule
|
936 |
total_pairs = dataset_preparer.estimate_total_pairs(dialogues)
|
937 |
-
train_size = total_pairs * (1 - validation_split)
|
|
|
938 |
steps_per_epoch = int(math.ceil(train_size / batch_size))
|
939 |
-
val_steps = int(math.ceil(
|
940 |
total_steps = steps_per_epoch * epochs
|
941 |
|
942 |
logger.info(f"Total pairs: {total_pairs}")
|
943 |
logger.info(f"Training pairs: {train_size}")
|
|
|
944 |
logger.info(f"Steps per epoch: {steps_per_epoch}")
|
945 |
logger.info(f"Validation steps: {val_steps}")
|
946 |
logger.info(f"Total steps: {total_steps}")
|
@@ -971,276 +957,245 @@ class RetrievalChatbot(DeviceAwareModel):
|
|
971 |
val_log_dir = str(log_dir / f"val_{current_time}")
|
972 |
train_summary_writer = tf.summary.create_file_writer(train_log_dir)
|
973 |
val_summary_writer = tf.summary.create_file_writer(val_log_dir)
|
974 |
-
|
975 |
logger.info(f"TensorBoard logs will be saved in {log_dir}")
|
976 |
|
|
|
|
|
|
|
|
|
977 |
# Training loop
|
978 |
best_val_loss = float("inf")
|
979 |
epochs_no_improve = 0
|
980 |
|
981 |
-
|
982 |
-
|
983 |
-
|
984 |
-
|
985 |
-
epoch_pbar = range(1, epochs + 1)
|
986 |
-
is_tqdm_epoch = False
|
987 |
-
logger.info("Epoch progress bar disabled - continuing without visual progress")
|
988 |
-
|
989 |
-
for epoch in epoch_pbar:
|
990 |
-
# Shared queues for streaming pipeline
|
991 |
-
train_queue = Queue(maxsize=buffer_size)
|
992 |
-
val_queue = Queue(maxsize=buffer_size)
|
993 |
-
stop_flag = threading.Event()
|
994 |
-
|
995 |
-
def data_pipeline_worker():
|
996 |
-
"""Thread function that processes dialogues and sends batches to train or val."""
|
997 |
-
try:
|
998 |
-
train_batches_needed = steps_per_epoch # 9 in your logs
|
999 |
-
val_batches_needed = val_steps # 3 in your logs
|
1000 |
-
train_batches_sent = 0
|
1001 |
-
val_batches_sent = 0
|
1002 |
|
1003 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1004 |
|
1005 |
-
|
1006 |
-
|
1007 |
-
|
|
|
|
|
1008 |
|
1009 |
-
|
1010 |
-
|
1011 |
-
|
1012 |
-
|
1013 |
-
|
1014 |
-
|
1015 |
-
|
1016 |
-
|
1017 |
-
|
1018 |
-
|
1019 |
-
train_queue.put(batch)
|
1020 |
-
train_batches_sent += 1
|
1021 |
-
elif val_batches_sent < val_batches_needed:
|
1022 |
-
val_queue.put(batch)
|
1023 |
-
val_batches_sent += 1
|
1024 |
-
else:
|
1025 |
-
# We have enough batches for both train & val
|
1026 |
-
break
|
1027 |
-
|
1028 |
-
# If we still haven't met our target steps, REPEAT the data
|
1029 |
-
if train_batches_sent < train_batches_needed or val_batches_sent < val_batches_needed:
|
1030 |
-
logger.info("Data exhausted, repeating since we still need more batches...")
|
1031 |
-
# Optionally shuffle again
|
1032 |
-
random.shuffle(dataset_preparer.processed_pairs)
|
1033 |
-
else:
|
1034 |
-
# We have enough
|
1035 |
-
break
|
1036 |
|
1037 |
-
|
1038 |
-
|
1039 |
-
f"{val_batches_sent}/{val_batches_needed} val batches"
|
1040 |
-
)
|
1041 |
|
1042 |
-
|
1043 |
-
|
1044 |
-
raise e
|
1045 |
-
finally:
|
1046 |
-
train_queue.put(None)
|
1047 |
-
val_queue.put(None)
|
1048 |
|
1049 |
-
#
|
1050 |
-
|
1051 |
-
|
1052 |
|
1053 |
try:
|
1054 |
-
|
1055 |
-
|
1056 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1057 |
|
1058 |
-
try:
|
1059 |
-
train_pbar = tqdm(total=steps_per_epoch, desc=f"Training Epoch {epoch}")
|
1060 |
-
is_tqdm_train = True
|
1061 |
-
except ImportError:
|
1062 |
-
train_pbar = None
|
1063 |
-
is_tqdm_train = False
|
1064 |
-
logger.info("Training progress bar disabled")
|
1065 |
-
|
1066 |
-
while batches_processed < steps_per_epoch:
|
1067 |
-
try:
|
1068 |
-
batch = train_queue.get(timeout=1200) # 20 minutes timeout
|
1069 |
-
if batch is None:
|
1070 |
-
logger.warning(f"Received end signal after only {batches_processed}/{steps_per_epoch} batches")
|
1071 |
-
break
|
1072 |
-
|
1073 |
-
q_batch, p_batch = batch[0], batch[1]
|
1074 |
-
attention_mask = batch[2] if len(batch) > 2 else None
|
1075 |
-
|
1076 |
-
loss = self.train_step(q_batch, p_batch, attention_mask)
|
1077 |
-
epoch_loss_avg(loss)
|
1078 |
-
batches_processed += 1
|
1079 |
-
|
1080 |
-
# Log to TensorBoard
|
1081 |
-
with train_summary_writer.as_default():
|
1082 |
-
tf.summary.scalar("loss", loss, step=epoch)
|
1083 |
-
|
1084 |
-
# Update progress bar
|
1085 |
-
if use_lr_schedule:
|
1086 |
-
current_lr = float(lr_schedule(self.optimizer.iterations))
|
1087 |
-
else:
|
1088 |
-
current_lr = float(self.optimizer.learning_rate.numpy())
|
1089 |
|
1090 |
-
|
1091 |
-
|
1092 |
-
train_pbar.set_postfix({
|
1093 |
-
"loss": f"{loss.numpy():.4f}",
|
1094 |
-
"lr": f"{current_lr:.2e}",
|
1095 |
-
"batches": f"{batches_processed}/{steps_per_epoch}"
|
1096 |
-
})
|
1097 |
|
1098 |
-
|
1099 |
-
|
1100 |
-
break
|
1101 |
|
1102 |
-
|
1103 |
-
|
|
|
|
|
1104 |
|
1105 |
-
|
1106 |
-
|
1107 |
-
|
|
|
|
|
1108 |
|
1109 |
-
|
1110 |
-
|
1111 |
-
is_tqdm_val = True
|
1112 |
-
except ImportError:
|
1113 |
-
val_pbar = None
|
1114 |
-
is_tqdm_val = False
|
1115 |
-
logger.info("Validation progress bar disabled")
|
1116 |
-
|
1117 |
-
while val_batches_processed < val_steps:
|
1118 |
-
try:
|
1119 |
-
batch = val_queue.get(timeout=30)
|
1120 |
-
if batch is None:
|
1121 |
-
logger.warning(
|
1122 |
-
f"Received end signal after {val_batches_processed}/{val_steps} validation batches"
|
1123 |
-
)
|
1124 |
-
break
|
1125 |
-
|
1126 |
-
q_batch, p_batch = batch[0], batch[1]
|
1127 |
-
attention_mask = batch[2] if len(batch) > 2 else None
|
1128 |
-
|
1129 |
-
val_loss = self.validation_step(q_batch, p_batch, attention_mask)
|
1130 |
-
val_loss_avg(val_loss)
|
1131 |
-
val_batches_processed += 1
|
1132 |
-
|
1133 |
-
if is_tqdm_val:
|
1134 |
-
val_pbar.update(1)
|
1135 |
-
val_pbar.set_postfix({
|
1136 |
-
"val_loss": f"{val_loss.numpy():.4f}",
|
1137 |
-
"batches": f"{val_batches_processed}/{val_steps}"
|
1138 |
-
})
|
1139 |
-
|
1140 |
-
except Empty:
|
1141 |
-
logger.warning(
|
1142 |
-
f"Validation queue timeout after {val_batches_processed}/{val_steps} batches"
|
1143 |
-
)
|
1144 |
-
break
|
1145 |
-
|
1146 |
-
if is_tqdm_val and val_pbar:
|
1147 |
-
val_pbar.close()
|
1148 |
-
|
1149 |
-
# End of epoch: compute final epoch stats
|
1150 |
-
train_loss = epoch_loss_avg.result().numpy()
|
1151 |
-
val_loss = val_loss_avg.result().numpy()
|
1152 |
-
logger.info(f"Epoch {epoch} Complete: Train Loss={train_loss:.4f}, Val Loss={val_loss:.4f}")
|
1153 |
-
|
1154 |
-
# Log epoch metrics
|
1155 |
-
with val_summary_writer.as_default():
|
1156 |
-
tf.summary.scalar("val_loss", val_loss, step=epoch)
|
1157 |
-
|
1158 |
-
# Save checkpoint
|
1159 |
-
manager.save()
|
1160 |
-
|
1161 |
-
# Store metrics in history
|
1162 |
-
self.history['train_loss'].append(train_loss)
|
1163 |
-
self.history['val_loss'].append(val_loss)
|
1164 |
|
1165 |
-
|
1166 |
-
|
1167 |
-
|
1168 |
-
current_lr = float(self.optimizer.learning_rate.numpy())
|
1169 |
|
1170 |
-
|
|
|
|
|
|
|
1171 |
|
1172 |
-
|
1173 |
-
if val_loss < best_val_loss - min_delta:
|
1174 |
-
best_val_loss = val_loss
|
1175 |
-
epochs_no_improve = 0
|
1176 |
-
logger.info(f"Validation loss improved to {val_loss:.4f}. Reset patience.")
|
1177 |
-
else:
|
1178 |
-
epochs_no_improve += 1
|
1179 |
-
logger.info(f"No improvement this epoch. Patience: {epochs_no_improve}/{early_stopping_patience}")
|
1180 |
-
if epochs_no_improve >= early_stopping_patience:
|
1181 |
-
logger.info("Early stopping triggered.")
|
1182 |
-
break
|
1183 |
|
1184 |
-
|
1185 |
-
|
1186 |
-
|
1187 |
-
|
1188 |
-
|
1189 |
-
|
1190 |
-
|
1191 |
-
|
|
|
|
|
|
|
1192 |
|
1193 |
logger.info("Streaming training completed!")
|
1194 |
|
1195 |
|
1196 |
@tf.function
|
1197 |
-
def train_step(
|
1198 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1199 |
with tf.GradientTape() as tape:
|
1200 |
-
|
1201 |
-
|
1202 |
-
|
1203 |
-
|
1204 |
-
|
1205 |
-
|
1206 |
-
|
1207 |
-
|
1208 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1209 |
loss = tf.nn.sparse_softmax_cross_entropy_with_logits(
|
1210 |
-
labels=labels,
|
|
|
1211 |
)
|
1212 |
-
|
1213 |
-
|
|
|
|
|
1214 |
if attention_mask is not None:
|
1215 |
loss = loss * attention_mask
|
1216 |
-
# normalize by the sum of attention_mask
|
1217 |
loss = tf.reduce_sum(loss) / tf.reduce_sum(attention_mask)
|
1218 |
-
else:
|
1219 |
-
loss = tf.reduce_mean(loss)
|
1220 |
|
|
|
1221 |
gradients = tape.gradient(loss, self.encoder.trainable_variables)
|
1222 |
self.optimizer.apply_gradients(zip(gradients, self.encoder.trainable_variables))
|
1223 |
return loss
|
1224 |
|
1225 |
@tf.function
|
1226 |
-
def validation_step(
|
1227 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1228 |
q_enc = self.encoder(q_batch, training=False)
|
1229 |
p_enc = self.encoder(p_batch, training=False)
|
1230 |
|
1231 |
-
|
1232 |
-
|
1233 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1234 |
|
1235 |
loss = tf.nn.sparse_softmax_cross_entropy_with_logits(
|
1236 |
-
labels=labels,
|
|
|
1237 |
)
|
1238 |
-
|
|
|
1239 |
if attention_mask is not None:
|
1240 |
loss = loss * attention_mask
|
1241 |
loss = tf.reduce_sum(loss) / tf.reduce_sum(attention_mask)
|
1242 |
-
else:
|
1243 |
-
loss = tf.reduce_mean(loss)
|
1244 |
|
1245 |
return loss
|
1246 |
|
@@ -1382,235 +1337,33 @@ class RetrievalChatbot(DeviceAwareModel):
|
|
1382 |
conversation_parts.append(f"{self.special_tokens['user']} {query}")
|
1383 |
return "\n".join(conversation_parts)
|
1384 |
|
1385 |
-
class
|
1386 |
-
"""Helper class to manage the streaming data preparation pipeline with optimized caching and GPU usage."""
|
1387 |
def __init__(
|
1388 |
-
self,
|
1389 |
-
|
1390 |
-
|
1391 |
-
|
1392 |
-
|
1393 |
-
|
1394 |
-
|
1395 |
-
neg_samples: int
|
1396 |
):
|
|
|
1397 |
self.tokenizer = tokenizer
|
1398 |
self.encoder = encoder
|
1399 |
-
self.index = index
|
1400 |
self.response_pool = response_pool
|
1401 |
self.max_length = max_length
|
1402 |
-
self.base_batch_size = batch_size
|
1403 |
self.neg_samples = neg_samples
|
|
|
|
|
|
|
1404 |
self.memory_monitor = GPUMemoryMonitor()
|
1405 |
-
|
1406 |
-
# Caching structures
|
1407 |
-
self.hard_negatives_cache = {}
|
1408 |
-
self.processed_pairs = []
|
1409 |
-
self.query_embeddings_cache = {}
|
1410 |
-
|
1411 |
-
# Error tracking
|
1412 |
-
self.error_count = 0
|
1413 |
self.max_retries = 3
|
1414 |
-
|
1415 |
-
# Batch processing settings
|
1416 |
-
self.current_batch_size = batch_size
|
1417 |
-
self.batch_increase_factor = 1.25
|
1418 |
-
|
1419 |
-
# TODO: use GPU/strategy
|
1420 |
-
if len(response_pool) < 100:
|
1421 |
-
self.embedding_batch_size = 16
|
1422 |
-
self.search_batch_size = 16
|
1423 |
-
self.max_batch_size = 32
|
1424 |
-
self.min_batch_size = 4
|
1425 |
-
else:
|
1426 |
-
self.embedding_batch_size = 64
|
1427 |
-
self.search_batch_size = 64
|
1428 |
-
self.min_batch_size = 8
|
1429 |
-
self.max_batch_size = 64
|
1430 |
-
|
1431 |
-
def save_cache(self, cache_dir: Path) -> None:
|
1432 |
-
"""Save all cached data for future runs."""
|
1433 |
-
cache_dir = Path(cache_dir)
|
1434 |
-
cache_dir.mkdir(parents=True, exist_ok=True)
|
1435 |
-
|
1436 |
-
logger.info(f"Saving cache to {cache_dir}")
|
1437 |
-
|
1438 |
-
# Save embeddings cache
|
1439 |
-
embeddings_path = cache_dir / "query_embeddings.npy"
|
1440 |
-
np.save(
|
1441 |
-
embeddings_path,
|
1442 |
-
{k: v.numpy() if hasattr(v, 'numpy') else v
|
1443 |
-
for k, v in self.query_embeddings_cache.items()}
|
1444 |
-
)
|
1445 |
-
|
1446 |
-
# Save hard negatives and processed pairs
|
1447 |
-
with open(cache_dir / "hard_negatives.json", 'w') as f:
|
1448 |
-
json.dump(self.hard_negatives_cache, f)
|
1449 |
-
|
1450 |
-
with open(cache_dir / "processed_pairs.json", 'w') as f:
|
1451 |
-
json.dump(self.processed_pairs, f)
|
1452 |
-
|
1453 |
-
logger.info("Cache saved successfully")
|
1454 |
-
|
1455 |
-
def load_cache(self, cache_dir: Path) -> bool:
|
1456 |
-
"""Load cached data if available."""
|
1457 |
-
cache_dir = Path(cache_dir)
|
1458 |
-
required_files = [
|
1459 |
-
"query_embeddings.npy",
|
1460 |
-
"hard_negatives.json",
|
1461 |
-
"processed_pairs.json"
|
1462 |
-
]
|
1463 |
-
|
1464 |
-
if not all((cache_dir / f).exists() for f in required_files):
|
1465 |
-
logger.info("Cache files not found")
|
1466 |
-
return False
|
1467 |
-
|
1468 |
-
try:
|
1469 |
-
logger.info("Loading cache...")
|
1470 |
-
|
1471 |
-
# Load embeddings
|
1472 |
-
self.query_embeddings_cache = np.load(
|
1473 |
-
cache_dir / "query_embeddings.npy",
|
1474 |
-
allow_pickle=True
|
1475 |
-
).item()
|
1476 |
-
|
1477 |
-
# Load other caches
|
1478 |
-
with open(cache_dir / "hard_negatives.json", 'r') as f:
|
1479 |
-
self.hard_negatives_cache = json.load(f)
|
1480 |
-
|
1481 |
-
with open(cache_dir / "processed_pairs.json", 'r') as f:
|
1482 |
-
self.processed_pairs = json.load(f)
|
1483 |
-
|
1484 |
-
logger.info(f"Cache loaded successfully: {len(self.processed_pairs)} pairs")
|
1485 |
-
return True
|
1486 |
-
|
1487 |
-
except Exception as e:
|
1488 |
-
logger.error(f"Error loading cache: {e}")
|
1489 |
-
return False
|
1490 |
|
1491 |
-
|
1492 |
-
|
1493 |
-
if self.memory_monitor:
|
1494 |
-
if self.memory_monitor.should_reduce_batch_size():
|
1495 |
-
new_size = max(self.min_batch_size, self.current_batch_size // 2)
|
1496 |
-
if new_size != self.current_batch_size:
|
1497 |
-
if new_size < self.min_batch_size:
|
1498 |
-
logger.info(f"Reducing batch size to {new_size} due to high memory usage")
|
1499 |
-
self.current_batch_size = new_size
|
1500 |
-
gc.collect()
|
1501 |
-
if tf.config.list_physical_devices('GPU'):
|
1502 |
-
tf.keras.backend.clear_session()
|
1503 |
-
|
1504 |
-
elif self.memory_monitor.can_increase_batch_size():
|
1505 |
-
new_size = min(self.max_batch_size, int(self.current_batch_size * self.batch_increase_factor)) # More gradual increase
|
1506 |
-
if new_size != self.current_batch_size:
|
1507 |
-
if new_size > self.max_batch_size:
|
1508 |
-
logger.info(f"Increasing batch size to {new_size}")
|
1509 |
-
self.current_batch_size = new_size
|
1510 |
-
|
1511 |
-
def _add_progress_metrics(self, pbar, **metrics) -> None:
|
1512 |
-
"""Add memory and batch size metrics to progress bars."""
|
1513 |
-
if self.memory_monitor:
|
1514 |
-
gpu_usage = self.memory_monitor.get_memory_usage()
|
1515 |
-
metrics['gpu_mem'] = f"{gpu_usage:.1%}"
|
1516 |
-
metrics['batch_size'] = self.current_batch_size
|
1517 |
-
pbar.set_postfix(**metrics)
|
1518 |
-
|
1519 |
-
def preprocess_dialogues(self, dialogues: List[dict]) -> None:
|
1520 |
-
"""Preprocess all dialogues with error recovery and caching."""
|
1521 |
-
retry_count = 0
|
1522 |
-
|
1523 |
-
while retry_count < self.max_retries:
|
1524 |
-
try:
|
1525 |
-
self._preprocess_dialogues_internal(dialogues)
|
1526 |
-
break
|
1527 |
-
except Exception as e:
|
1528 |
-
retry_count += 1
|
1529 |
-
logger.warning(f"Preprocessing attempt {retry_count} failed: {e}")
|
1530 |
-
if retry_count == self.max_retries:
|
1531 |
-
logger.error("Max retries reached. Falling back to CPU processing")
|
1532 |
-
self._fallback_to_cpu_processing(dialogues)
|
1533 |
-
|
1534 |
-
def _preprocess_dialogues_internal(self, dialogues: List[dict]) -> None:
|
1535 |
-
"""Internal preprocessing implementation with progress tracking."""
|
1536 |
-
logger.info("Starting dialogue preprocessing...")
|
1537 |
-
|
1538 |
-
# Collect unique queries and pairs
|
1539 |
-
unique_queries = set()
|
1540 |
-
query_positive_pairs = []
|
1541 |
-
|
1542 |
-
with tqdm(total=len(dialogues), desc="Collecting dialogue pairs") as pbar:
|
1543 |
-
for dialogue in dialogues:
|
1544 |
-
pairs = self._extract_pairs_from_dialogue(dialogue)
|
1545 |
-
for query, positive in pairs:
|
1546 |
-
unique_queries.add(query)
|
1547 |
-
query_positive_pairs.append((query, positive))
|
1548 |
-
pbar.update(1)
|
1549 |
-
self._add_progress_metrics(pbar, pairs=len(query_positive_pairs))
|
1550 |
-
|
1551 |
-
# Precompute embeddings
|
1552 |
-
logger.info("Precomputing query embeddings...")
|
1553 |
-
self.precompute_query_embeddings(list(unique_queries))
|
1554 |
-
|
1555 |
-
# Find hard negatives
|
1556 |
-
logger.info("Finding hard negatives for all pairs...")
|
1557 |
-
self._find_hard_negatives_for_pairs(query_positive_pairs)
|
1558 |
|
1559 |
-
def precompute_query_embeddings(self, queries: List[str]) -> None:
|
1560 |
-
"""Precompute embeddings for all unique queries in batches."""
|
1561 |
-
unique_queries = list(set(queries))
|
1562 |
-
|
1563 |
-
with tqdm(total=len(unique_queries), desc="Precomputing query embeddings") as pbar:
|
1564 |
-
for i in range(0, len(unique_queries), self.embedding_batch_size):
|
1565 |
-
# Adjust batch size based on memory
|
1566 |
-
self._adjust_batch_size()
|
1567 |
-
batch_size = min(self.embedding_batch_size, len(unique_queries) - i)
|
1568 |
-
|
1569 |
-
# Get batch of queries
|
1570 |
-
batch_queries = unique_queries[i:i + batch_size]
|
1571 |
-
|
1572 |
-
try:
|
1573 |
-
# Tokenize batch
|
1574 |
-
encoded = self.tokenizer(
|
1575 |
-
batch_queries,
|
1576 |
-
padding=True,
|
1577 |
-
truncation=True,
|
1578 |
-
max_length=self.max_length,
|
1579 |
-
return_tensors='tf'
|
1580 |
-
)
|
1581 |
-
|
1582 |
-
# Get embeddings
|
1583 |
-
embeddings = self.encoder(encoded['input_ids'], training=False)
|
1584 |
-
embeddings_np = embeddings.numpy().astype('float32')
|
1585 |
-
|
1586 |
-
# Normalize for similarity search
|
1587 |
-
faiss.normalize_L2(embeddings_np)
|
1588 |
-
|
1589 |
-
# Cache embeddings
|
1590 |
-
for query, emb in zip(batch_queries, embeddings_np):
|
1591 |
-
self.query_embeddings_cache[query] = emb
|
1592 |
-
|
1593 |
-
pbar.update(len(batch_queries))
|
1594 |
-
self._add_progress_metrics(
|
1595 |
-
pbar,
|
1596 |
-
cached=len(self.query_embeddings_cache),
|
1597 |
-
batch_size=batch_size
|
1598 |
-
)
|
1599 |
-
|
1600 |
-
except Exception as e:
|
1601 |
-
logger.warning(f"Error processing batch: {e}")
|
1602 |
-
# Reduce batch size and retry
|
1603 |
-
self.embedding_batch_size = max(self.min_batch_size, self.embedding_batch_size // 2)
|
1604 |
-
continue
|
1605 |
-
|
1606 |
-
# Memory cleanup after successful batch
|
1607 |
-
if i % (self.embedding_batch_size * 10) == 0:
|
1608 |
-
gc.collect()
|
1609 |
-
if tf.config.list_physical_devices('GPU'):
|
1610 |
-
tf.keras.backend.clear_session()
|
1611 |
-
|
1612 |
-
logger.info(f"Cached embeddings for {len(self.query_embeddings_cache)} unique queries")
|
1613 |
-
|
1614 |
def _extract_pairs_from_dialogue(self, dialogue: dict) -> List[Tuple[str, str]]:
|
1615 |
"""Extract query-response pairs from a dialogue."""
|
1616 |
pairs = []
|
@@ -1631,305 +1384,474 @@ class StreamingDataPipeline:
|
|
1631 |
|
1632 |
return pairs
|
1633 |
|
1634 |
-
def
|
1635 |
-
"""
|
1636 |
-
|
1637 |
-
|
1638 |
-
|
1639 |
-
|
1640 |
-
|
1641 |
-
|
1642 |
-
|
1643 |
-
|
1644 |
-
|
1645 |
-
|
1646 |
-
|
1647 |
-
Process dialogues using cached data with dynamic batch sizing.
|
1648 |
-
Yields (q_tokens['input_ids'], p_tokens['input_ids'], attention_mask) tuples.
|
1649 |
-
"""
|
1650 |
-
# Preprocess if not already done
|
1651 |
-
if not self.processed_pairs:
|
1652 |
-
self.preprocess_dialogues(dialogues)
|
1653 |
-
|
1654 |
-
# Generate batches from cached data
|
1655 |
-
current_queries = []
|
1656 |
-
current_positives = []
|
1657 |
-
|
1658 |
-
# Counters for logging
|
1659 |
-
total_examples_yielded = 0
|
1660 |
-
total_batches_yielded = 0
|
1661 |
-
|
1662 |
-
with tqdm(total=len(self.processed_pairs), desc="Generating training batches", leave=False) as pbar:
|
1663 |
-
for i, (query, positive) in enumerate(self.processed_pairs):
|
1664 |
-
# Periodically adjust batch size
|
1665 |
-
if i % 10 == 0: # Check more frequently (e.g., every 10 pairs)
|
1666 |
-
self._adjust_batch_size()
|
1667 |
-
|
1668 |
-
# Add original pair
|
1669 |
-
current_queries.append(query)
|
1670 |
-
current_positives.append(positive)
|
1671 |
-
|
1672 |
-
# Add cached hard negatives for each query
|
1673 |
-
hard_negatives = self.hard_negatives_cache.get((query, positive), [])
|
1674 |
-
for neg_text in hard_negatives:
|
1675 |
-
current_queries.append(query)
|
1676 |
-
current_positives.append(neg_text)
|
1677 |
-
|
1678 |
-
# If we have enough examples to form a full batch, yield it
|
1679 |
-
while len(current_queries) >= self.current_batch_size:
|
1680 |
-
batch_queries = current_queries[:self.current_batch_size]
|
1681 |
-
batch_positives = current_positives[:self.current_batch_size]
|
1682 |
-
|
1683 |
-
# Update counters and logs
|
1684 |
-
batch_size_to_yield = len(batch_queries)
|
1685 |
-
total_examples_yielded += batch_size_to_yield
|
1686 |
-
total_batches_yielded += 1
|
1687 |
-
|
1688 |
-
yield self._prepare_batch(batch_queries, batch_positives, pad_to_batch_size=False)
|
1689 |
-
|
1690 |
-
# Remove used entries
|
1691 |
-
current_queries = current_queries[self.current_batch_size:]
|
1692 |
-
current_positives = current_positives[self.current_batch_size:]
|
1693 |
-
|
1694 |
-
# Update progress bar
|
1695 |
-
pbar.update(1)
|
1696 |
-
self._add_progress_metrics(
|
1697 |
-
pbar,
|
1698 |
-
pairs_processed=pbar.n,
|
1699 |
-
pending_pairs=len(current_queries)
|
1700 |
-
)
|
1701 |
-
|
1702 |
-
# After the loop, if anything is left, yield a final partial batch
|
1703 |
-
if current_queries:
|
1704 |
-
leftover_size = len(current_queries)
|
1705 |
-
total_examples_yielded += leftover_size
|
1706 |
-
total_batches_yielded += 1
|
1707 |
-
|
1708 |
-
yield self._prepare_batch(
|
1709 |
-
current_queries,
|
1710 |
-
current_positives,
|
1711 |
-
pad_to_batch_size=True
|
1712 |
-
)
|
1713 |
-
|
1714 |
-
def _find_hard_negatives_for_pairs(self, query_positive_pairs: List[Tuple[str, str]]) -> None:
|
1715 |
-
"""Process pairs in batches to find hard negatives with GPU acceleration."""
|
1716 |
-
total_pairs = len(query_positive_pairs)
|
1717 |
-
|
1718 |
-
# Use smaller batch size for small datasets
|
1719 |
-
if len(self.response_pool) < 1000:
|
1720 |
-
batch_size = min(8, self.search_batch_size)
|
1721 |
-
else:
|
1722 |
-
batch_size = self.search_batch_size
|
1723 |
-
|
1724 |
-
try:
|
1725 |
-
pbar = tqdm(total=total_pairs, desc="Finding hard negatives")
|
1726 |
-
is_tqdm = True
|
1727 |
-
except ImportError:
|
1728 |
-
pbar = None
|
1729 |
-
is_tqdm = False
|
1730 |
-
logger.info("Progress bar disabled - continuing without visual progress")
|
1731 |
-
|
1732 |
-
for i in range(0, total_pairs, batch_size):
|
1733 |
-
self._adjust_batch_size()
|
1734 |
-
|
1735 |
-
batch_pairs = query_positive_pairs[i:i + batch_size]
|
1736 |
-
batch_queries, batch_positives = zip(*batch_pairs)
|
1737 |
-
|
1738 |
-
batch_negatives = self._find_hard_negatives_batch(
|
1739 |
-
list(batch_queries),
|
1740 |
-
list(batch_positives)
|
1741 |
-
)
|
1742 |
-
|
1743 |
-
for query, positive, negatives in zip(batch_queries, batch_positives, batch_negatives):
|
1744 |
-
self.hard_negatives_cache[(query, positive)] = negatives
|
1745 |
-
self.processed_pairs.append((query, positive))
|
1746 |
-
|
1747 |
-
if is_tqdm:
|
1748 |
-
pbar.update(len(batch_pairs))
|
1749 |
-
self._add_progress_metrics(
|
1750 |
-
pbar,
|
1751 |
-
cached=len(self.processed_pairs),
|
1752 |
-
progress=f"{i+len(batch_pairs)}/{total_pairs}"
|
1753 |
-
)
|
1754 |
-
|
1755 |
-
if is_tqdm:
|
1756 |
-
pbar.close()
|
1757 |
-
|
1758 |
def _find_hard_negatives_batch(self, queries: List[str], positives: List[str]) -> List[List[str]]:
|
1759 |
"""Find hard negatives for a batch of queries with error handling and retries."""
|
1760 |
retry_count = 0
|
1761 |
total_responses = len(self.response_pool)
|
1762 |
-
|
1763 |
-
# For very small datasets (testing), just use random sampling
|
1764 |
-
if total_responses < 100:
|
1765 |
-
all_negatives = []
|
1766 |
-
for positive in positives:
|
1767 |
-
available = [r for r in self.response_pool if r.strip() != positive.strip()]
|
1768 |
-
if available:
|
1769 |
-
negatives = list(np.random.choice(
|
1770 |
-
available,
|
1771 |
-
size=min(self.neg_samples, len(available)),
|
1772 |
-
replace=False
|
1773 |
-
))
|
1774 |
-
else:
|
1775 |
-
negatives = []
|
1776 |
-
# Pad with empty strings if needed
|
1777 |
-
while len(negatives) < self.neg_samples:
|
1778 |
-
negatives.append("")
|
1779 |
-
all_negatives.append(negatives)
|
1780 |
-
return all_negatives
|
1781 |
-
|
1782 |
while retry_count < self.max_retries:
|
1783 |
try:
|
1784 |
-
# Get cached embeddings and ensure they're the right type
|
1785 |
query_embeddings = np.vstack([
|
1786 |
self.query_embeddings_cache[q] for q in queries
|
1787 |
]).astype(np.float32)
|
1788 |
|
1789 |
-
|
1790 |
-
query_embeddings = np.ascontiguousarray(query_embeddings)
|
1791 |
-
|
1792 |
-
# Normalize embeddings
|
1793 |
faiss.normalize_L2(query_embeddings)
|
1794 |
-
|
1795 |
-
k = 1
|
1796 |
#logger.debug(f"Searching with k={k} among {total_responses} responses")
|
1797 |
-
|
1798 |
-
assert query_embeddings.dtype == np.float32, f"Embeddings are not float32: {query_embeddings.dtype}" # Assertion here
|
1799 |
|
1800 |
-
|
1801 |
-
|
1802 |
-
except RuntimeError as e:
|
1803 |
-
logger.error(f"FAISS search failed: {e}")
|
1804 |
-
return self._fallback_random_negatives(queries, positives)
|
1805 |
-
|
1806 |
-
# Process results
|
1807 |
all_negatives = []
|
1808 |
-
for
|
1809 |
negatives = []
|
1810 |
positive_strip = positive.strip()
|
1811 |
-
|
1812 |
-
# Filter valid indices and deduplicate
|
1813 |
seen = {positive_strip}
|
|
|
1814 |
for idx in query_indices:
|
1815 |
if idx >= 0 and idx < total_responses:
|
1816 |
candidate = self.response_pool[idx].strip()
|
1817 |
-
if candidate and candidate not in seen:
|
1818 |
seen.add(candidate)
|
1819 |
negatives.append(candidate)
|
1820 |
if len(negatives) >= self.neg_samples:
|
1821 |
break
|
1822 |
-
|
1823 |
-
#
|
1824 |
-
if len(negatives) < self.neg_samples:
|
1825 |
-
available = [r for r in self.response_pool if r.strip() not in seen and r.strip()]
|
1826 |
-
if available:
|
1827 |
-
additional = np.random.choice(
|
1828 |
-
available,
|
1829 |
-
size=min(self.neg_samples - len(negatives), len(available)),
|
1830 |
-
replace=False
|
1831 |
-
)
|
1832 |
-
negatives.extend(additional)
|
1833 |
-
|
1834 |
-
# Still pad with empty strings if needed
|
1835 |
while len(negatives) < self.neg_samples:
|
1836 |
-
negatives.append("")
|
1837 |
-
|
1838 |
all_negatives.append(negatives)
|
1839 |
-
|
1840 |
return all_negatives
|
1841 |
-
|
1842 |
except Exception as e:
|
1843 |
retry_count += 1
|
1844 |
logger.warning(f"Hard negative search attempt {retry_count} failed: {e}")
|
1845 |
if retry_count == self.max_retries:
|
1846 |
logger.error("Max retries reached for hard negative search")
|
1847 |
-
return [[] for _ in queries] # Return empty
|
1848 |
gc.collect()
|
1849 |
if tf.config.list_physical_devices('GPU'):
|
1850 |
tf.keras.backend.clear_session()
|
1851 |
-
|
1852 |
-
def _fallback_random_negatives(self, queries: List[str], positives: List[str]) -> List[List[str]]:
|
1853 |
-
"""Fallback to random sampling when similarity search fails."""
|
1854 |
-
logger.warning("Falling back to random negative sampling")
|
1855 |
-
all_negatives = []
|
1856 |
-
for positive in positives:
|
1857 |
-
available = [r for r in self.response_pool if r.strip() != positive.strip()]
|
1858 |
-
negatives = list(np.random.choice(
|
1859 |
-
available,
|
1860 |
-
size=min(self.neg_samples, len(available)),
|
1861 |
-
replace=False
|
1862 |
-
)) if available else []
|
1863 |
-
while len(negatives) < self.neg_samples:
|
1864 |
-
negatives.append("")
|
1865 |
-
all_negatives.append(negatives)
|
1866 |
-
return all_negatives
|
1867 |
-
|
1868 |
-
def _prepare_batch(
|
1869 |
-
self,
|
1870 |
-
queries: List[str],
|
1871 |
-
positives: List[str],
|
1872 |
-
pad_to_batch_size: bool = False
|
1873 |
-
) -> Tuple[tf.Tensor, tf.Tensor, Optional[tf.Tensor]]:
|
1874 |
-
"""Prepare a batch with dynamic padding and memory optimization."""
|
1875 |
-
actual_size = len(queries)
|
1876 |
-
|
1877 |
-
# Handle padding if requested and needed
|
1878 |
-
if pad_to_batch_size and actual_size < self.current_batch_size:
|
1879 |
-
padding_needed = self.current_batch_size - actual_size
|
1880 |
-
queries.extend([queries[0]] * padding_needed)
|
1881 |
-
positives.extend([positives[0]] * padding_needed)
|
1882 |
-
# Create attention mask for padded examples
|
1883 |
-
attention_mask = tf.concat([
|
1884 |
-
tf.ones((actual_size,), dtype=tf.float32),
|
1885 |
-
tf.zeros((padding_needed,), dtype=tf.float32)
|
1886 |
-
], axis=0)
|
1887 |
-
else:
|
1888 |
-
attention_mask = None
|
1889 |
|
1890 |
-
|
1891 |
-
|
1892 |
-
|
1893 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1894 |
padding='max_length',
|
1895 |
truncation=True,
|
1896 |
max_length=self.max_length,
|
1897 |
return_tensors='tf'
|
1898 |
)
|
1899 |
-
|
1900 |
-
|
1901 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1902 |
truncation=True,
|
1903 |
max_length=self.max_length,
|
1904 |
return_tensors='tf'
|
1905 |
)
|
1906 |
|
1907 |
-
|
1908 |
-
|
1909 |
-
|
1910 |
-
logger.error(f"Error preparing batch: {e}")
|
1911 |
-
# Emergency memory cleanup
|
1912 |
-
gc.collect()
|
1913 |
-
if tf.config.list_physical_devices('GPU'):
|
1914 |
-
tf.keras.backend.clear_session()
|
1915 |
-
raise
|
1916 |
|
1917 |
-
|
1918 |
-
|
1919 |
-
|
1920 |
-
|
1921 |
-
|
1922 |
-
|
1923 |
-
|
1924 |
-
|
1925 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1926 |
)
|
1927 |
-
|
1928 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1929 |
|
1930 |
-
|
1931 |
-
|
1932 |
-
self.query_embeddings_cache.clear()
|
1933 |
-
gc.collect()
|
1934 |
-
if tf.config.list_physical_devices('GPU'):
|
1935 |
-
tf.keras.backend.clear_session()
|
|
|
2 |
from transformers import TFAutoModel, AutoTokenizer
|
3 |
import tensorflow as tf
|
4 |
import numpy as np
|
|
|
|
|
5 |
from typing import Generator, List, Tuple, Dict, Optional, Union, Any
|
6 |
import math
|
7 |
from dataclasses import dataclass
|
|
|
10 |
import datetime
|
11 |
import faiss
|
12 |
import gc
|
|
|
13 |
from response_quality_checker import ResponseQualityChecker
|
14 |
from cross_encoder_reranker import CrossEncoderReranker
|
15 |
from conversation_summarizer import DeviceAwareModel, Summarizer
|
|
|
26 |
"""Configuration for the RetrievalChatbot."""
|
27 |
vocab_size: int = 30526 # DistilBERT vocab size + special tokens
|
28 |
max_context_token_limit: int = 512
|
29 |
+
embedding_dim: int = 768
|
30 |
encoder_units: int = 256
|
31 |
num_attention_heads: int = 8
|
32 |
dropout_rate: float = 0.2
|
|
|
39 |
pretrained_model: str = 'distilbert-base-uncased'
|
40 |
dtype: str = 'float32'
|
41 |
freeze_embeddings: bool = False
|
42 |
+
embedding_batch_size: int = 128
|
43 |
# Additional configurations can be added here
|
44 |
|
45 |
def to_dict(self) -> dict:
|
|
|
101 |
|
102 |
# Apply pooling, projection, dropout, and normalization
|
103 |
x = self.pooler(x) # Shape: [batch_size, 768]
|
104 |
+
x = self.projection(x) # Shape: [batch_size, 768]
|
105 |
x = self.dropout(x, training=training) # Apply dropout
|
106 |
+
x = self.normalize(x) # Shape: [batch_size, 768]
|
107 |
|
108 |
return x
|
109 |
|
|
|
137 |
summarizer = Summarizer(device=self.device)
|
138 |
self.summarizer = summarizer
|
139 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
140 |
# Special tokens
|
141 |
self.special_tokens = {
|
142 |
"user": "<USER>",
|
|
|
341 |
"""
|
342 |
all_embeddings = []
|
343 |
self.current_batch_size = batch_size
|
344 |
+
|
345 |
+
if self.memory_monitor.has_gpu:
|
346 |
+
batch_size = 128
|
347 |
|
348 |
# Memory stats
|
349 |
# if self.memory_monitor.has_gpu:
|
|
|
531 |
logger.info("Starting vector addition process...")
|
532 |
|
533 |
# Even smaller batches
|
534 |
+
initial_batch_size = 128
|
535 |
+
min_batch_size = 32
|
536 |
+
max_batch_size = 1024
|
537 |
|
538 |
total_added = 0
|
539 |
retry_count = 0
|
|
|
562 |
# Update progress
|
563 |
batch_size = len(batch)
|
564 |
total_added += batch_size
|
|
|
565 |
|
566 |
# Memory cleanup every few batches
|
567 |
if total_added % (initial_batch_size * 5) == 0:
|
|
|
607 |
cpu_index = self.index
|
608 |
|
609 |
# Add remaining vectors on CPU with very small batches
|
610 |
+
batch_size = 128
|
611 |
total_added = already_added
|
612 |
|
613 |
for i in range(0, len(remaining_embeddings), batch_size):
|
|
|
900 |
warmup_steps_ratio: float = 0.1,
|
901 |
early_stopping_patience: int = 3,
|
902 |
min_delta: float = 1e-4,
|
|
|
903 |
neg_samples: int = 1
|
904 |
) -> None:
|
905 |
+
"""Streaming training with tf.data pipeline."""
|
906 |
+
logger.info("Starting streaming training pipeline with tf.data...")
|
|
|
|
|
|
|
|
|
907 |
|
908 |
+
# Initialize TFDataPipeline (replaces StreamingDataPipeline)
|
909 |
+
dataset_preparer = TFDataPipeline(
|
910 |
+
embedding_batch_size=self.config.embedding_batch_size,
|
911 |
tokenizer=self.tokenizer,
|
912 |
encoder=self.encoder,
|
913 |
+
index=self.index, # Pass CPU version of FAISS index
|
914 |
response_pool=self.response_pool,
|
915 |
max_length=self.config.max_context_token_limit,
|
|
|
916 |
neg_samples=neg_samples
|
917 |
)
|
918 |
|
919 |
# Calculate total steps for learning rate schedule
|
920 |
total_pairs = dataset_preparer.estimate_total_pairs(dialogues)
|
921 |
+
train_size = int(total_pairs * (1 - validation_split))
|
922 |
+
val_size = int(total_pairs * validation_split)
|
923 |
steps_per_epoch = int(math.ceil(train_size / batch_size))
|
924 |
+
val_steps = int(math.ceil(val_size / batch_size))
|
925 |
total_steps = steps_per_epoch * epochs
|
926 |
|
927 |
logger.info(f"Total pairs: {total_pairs}")
|
928 |
logger.info(f"Training pairs: {train_size}")
|
929 |
+
logger.info(f"Validation pairs: {val_size}")
|
930 |
logger.info(f"Steps per epoch: {steps_per_epoch}")
|
931 |
logger.info(f"Validation steps: {val_steps}")
|
932 |
logger.info(f"Total steps: {total_steps}")
|
|
|
957 |
val_log_dir = str(log_dir / f"val_{current_time}")
|
958 |
train_summary_writer = tf.summary.create_file_writer(train_log_dir)
|
959 |
val_summary_writer = tf.summary.create_file_writer(val_log_dir)
|
|
|
960 |
logger.info(f"TensorBoard logs will be saved in {log_dir}")
|
961 |
|
962 |
+
# Create training and validation datasets
|
963 |
+
train_dataset = dataset_preparer.get_tf_dataset(dialogues, batch_size).take(train_size)
|
964 |
+
val_dataset = dataset_preparer.get_tf_dataset(dialogues, batch_size).skip(train_size).take(val_size)
|
965 |
+
|
966 |
# Training loop
|
967 |
best_val_loss = float("inf")
|
968 |
epochs_no_improve = 0
|
969 |
|
970 |
+
for epoch in range(1, epochs + 1):
|
971 |
+
# --- Training Phase ---
|
972 |
+
epoch_loss_avg = tf.keras.metrics.Mean()
|
973 |
+
batches_processed = 0
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
974 |
|
975 |
+
try:
|
976 |
+
train_pbar = tqdm(total=steps_per_epoch, desc=f"Training Epoch {epoch}", unit="batch")
|
977 |
+
is_tqdm_train = True
|
978 |
+
except ImportError:
|
979 |
+
train_pbar = None
|
980 |
+
is_tqdm_train = False
|
981 |
+
logger.info("Training progress bar disabled")
|
982 |
+
|
983 |
+
for q_batch, p_batch, n_batch in train_dataset:
|
984 |
+
#p_batch = p_n_batch[:, 0, :] # Extract positive from (positive, negative) pair
|
985 |
+
loss = self.train_step(q_batch, p_batch, n_batch)
|
986 |
+
epoch_loss_avg(loss)
|
987 |
+
batches_processed += 1
|
988 |
+
|
989 |
+
# Log to TensorBoard
|
990 |
+
with train_summary_writer.as_default():
|
991 |
+
tf.summary.scalar("loss", loss, step=(epoch - 1) * steps_per_epoch + batches_processed)
|
992 |
|
993 |
+
# Update progress bar
|
994 |
+
if use_lr_schedule:
|
995 |
+
current_lr = float(lr_schedule(self.optimizer.iterations))
|
996 |
+
else:
|
997 |
+
current_lr = float(self.optimizer.learning_rate.numpy())
|
998 |
|
999 |
+
if is_tqdm_train:
|
1000 |
+
train_pbar.update(1)
|
1001 |
+
train_pbar.set_postfix({
|
1002 |
+
"loss": f"{loss.numpy():.4f}",
|
1003 |
+
"lr": f"{current_lr:.2e}",
|
1004 |
+
"batches": f"{batches_processed}/{steps_per_epoch}"
|
1005 |
+
})
|
1006 |
+
|
1007 |
+
# Memory cleanup
|
1008 |
+
gc.collect()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1009 |
|
1010 |
+
if batches_processed >= steps_per_epoch:
|
1011 |
+
break
|
|
|
|
|
1012 |
|
1013 |
+
if is_tqdm_train and train_pbar:
|
1014 |
+
train_pbar.close()
|
|
|
|
|
|
|
|
|
1015 |
|
1016 |
+
# --- Validation Phase ---
|
1017 |
+
val_loss_avg = tf.keras.metrics.Mean()
|
1018 |
+
val_batches_processed = 0
|
1019 |
|
1020 |
try:
|
1021 |
+
val_pbar = tqdm(total=val_steps, desc="Validation", unit="batch")
|
1022 |
+
is_tqdm_val = True
|
1023 |
+
except ImportError:
|
1024 |
+
val_pbar = None
|
1025 |
+
is_tqdm_val = False
|
1026 |
+
logger.info("Validation progress bar disabled")
|
1027 |
+
|
1028 |
+
for q_batch, p_batch, n_batch in val_dataset:
|
1029 |
+
#p_batch = p_n_batch[:, 0, :] # Extract positive from (positive, negative) pair
|
1030 |
+
val_loss = self.validation_step(q_batch, p_batch, n_batch)
|
1031 |
+
val_loss_avg(val_loss)
|
1032 |
+
val_batches_processed += 1
|
1033 |
+
|
1034 |
+
if is_tqdm_val:
|
1035 |
+
val_pbar.update(1)
|
1036 |
+
val_pbar.set_postfix({
|
1037 |
+
"val_loss": f"{val_loss.numpy():.4f}",
|
1038 |
+
"batches": f"{val_batches_processed}/{val_steps}"
|
1039 |
+
})
|
1040 |
+
|
1041 |
+
# Memory cleanup
|
1042 |
+
gc.collect()
|
1043 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1044 |
|
1045 |
+
if val_batches_processed >= val_steps:
|
1046 |
+
break
|
|
|
|
|
|
|
|
|
|
|
1047 |
|
1048 |
+
if is_tqdm_val and val_pbar:
|
1049 |
+
val_pbar.close()
|
|
|
1050 |
|
1051 |
+
# End of epoch: compute final epoch stats, log, and save checkpoint
|
1052 |
+
train_loss = epoch_loss_avg.result().numpy()
|
1053 |
+
val_loss = val_loss_avg.result().numpy()
|
1054 |
+
logger.info(f"Epoch {epoch} Complete: Train Loss={train_loss:.4f}, Val Loss={val_loss:.4f}")
|
1055 |
|
1056 |
+
# Log epoch metrics
|
1057 |
+
with train_summary_writer.as_default():
|
1058 |
+
tf.summary.scalar("epoch_loss", train_loss, step=epoch)
|
1059 |
+
with val_summary_writer.as_default():
|
1060 |
+
tf.summary.scalar("val_loss", val_loss, step=epoch)
|
1061 |
|
1062 |
+
# Save checkpoint
|
1063 |
+
manager.save()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1064 |
|
1065 |
+
# Store metrics in history
|
1066 |
+
self.history['train_loss'].append(train_loss)
|
1067 |
+
self.history['val_loss'].append(val_loss)
|
|
|
1068 |
|
1069 |
+
if use_lr_schedule:
|
1070 |
+
current_lr = float(lr_schedule(self.optimizer.iterations))
|
1071 |
+
else:
|
1072 |
+
current_lr = float(self.optimizer.learning_rate.numpy())
|
1073 |
|
1074 |
+
self.history.setdefault('learning_rate', []).append(current_lr)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1075 |
|
1076 |
+
# Early stopping logic
|
1077 |
+
if val_loss < best_val_loss - min_delta:
|
1078 |
+
best_val_loss = val_loss
|
1079 |
+
epochs_no_improve = 0
|
1080 |
+
logger.info(f"Validation loss improved to {val_loss:.4f}. Reset patience.")
|
1081 |
+
else:
|
1082 |
+
epochs_no_improve += 1
|
1083 |
+
logger.info(f"No improvement this epoch. Patience: {epochs_no_improve}/{early_stopping_patience}")
|
1084 |
+
if epochs_no_improve >= early_stopping_patience:
|
1085 |
+
logger.info("Early stopping triggered.")
|
1086 |
+
break
|
1087 |
|
1088 |
logger.info("Streaming training completed!")
|
1089 |
|
1090 |
|
1091 |
@tf.function
|
1092 |
+
def train_step(
|
1093 |
+
self,
|
1094 |
+
q_batch: tf.Tensor,
|
1095 |
+
p_batch: tf.Tensor,
|
1096 |
+
n_batch: tf.Tensor,
|
1097 |
+
attention_mask: Optional[tf.Tensor] = None
|
1098 |
+
) -> tf.Tensor:
|
1099 |
+
"""
|
1100 |
+
Single training step that uses queries, positives, and negatives in a
|
1101 |
+
contrastive/InfoNCE style. The label is always 0 (the positive) vs.
|
1102 |
+
the negative alternatives.
|
1103 |
+
"""
|
1104 |
with tf.GradientTape() as tape:
|
1105 |
+
# Encode queries
|
1106 |
+
q_enc = self.encoder(q_batch, training=True) # [batch_size, embed_dim]
|
1107 |
+
|
1108 |
+
# Encode positives
|
1109 |
+
p_enc = self.encoder(p_batch, training=True) # [batch_size, embed_dim]
|
1110 |
+
|
1111 |
+
# Encode negatives
|
1112 |
+
# n_batch: [batch_size, neg_samples, max_length]
|
1113 |
+
shape = tf.shape(n_batch)
|
1114 |
+
bs = shape[0]
|
1115 |
+
neg_samples = shape[1]
|
1116 |
+
|
1117 |
+
# Flatten negatives to feed them in one pass:
|
1118 |
+
# => [batch_size * neg_samples, max_length]
|
1119 |
+
n_batch_flat = tf.reshape(n_batch, [bs * neg_samples, shape[2]])
|
1120 |
+
n_enc_flat = self.encoder(n_batch_flat, training=True) # [bs*neg_samples, embed_dim]
|
1121 |
+
|
1122 |
+
# Reshape back => [batch_size, neg_samples, embed_dim]
|
1123 |
+
n_enc = tf.reshape(n_enc_flat, [bs, neg_samples, -1])
|
1124 |
+
|
1125 |
+
# Combine the positive embedding and negative embeddings along dim=1
|
1126 |
+
# => shape [batch_size, 1 + neg_samples, embed_dim]
|
1127 |
+
# The first column is the positive; subsequent columns are negatives
|
1128 |
+
combined_p_n = tf.concat(
|
1129 |
+
[tf.expand_dims(p_enc, axis=1), n_enc],
|
1130 |
+
axis=1
|
1131 |
+
) # [bs, (1+neg_samples), embed_dim]
|
1132 |
+
|
1133 |
+
# Now compute scores: dot product of q_enc with each column in combined_p_n
|
1134 |
+
# We'll use `tf.einsum` to handle the batch dimension properly
|
1135 |
+
# dot_products => shape [batch_size, (1+neg_samples)]
|
1136 |
+
dot_products = tf.einsum('bd,bkd->bk', q_enc, combined_p_n)
|
1137 |
+
|
1138 |
+
# The label for each row is 0 (the first column is the correct/positive)
|
1139 |
+
labels = tf.zeros([bs], dtype=tf.int32)
|
1140 |
+
|
1141 |
+
# Cross-entropy over the [batch_size, 1+neg_samples] scores
|
1142 |
loss = tf.nn.sparse_softmax_cross_entropy_with_logits(
|
1143 |
+
labels=labels,
|
1144 |
+
logits=dot_products
|
1145 |
)
|
1146 |
+
loss = tf.reduce_mean(loss)
|
1147 |
+
|
1148 |
+
# If there's an attention_mask you want to apply (less common in this scenario),
|
1149 |
+
# you could do something like:
|
1150 |
if attention_mask is not None:
|
1151 |
loss = loss * attention_mask
|
|
|
1152 |
loss = tf.reduce_sum(loss) / tf.reduce_sum(attention_mask)
|
|
|
|
|
1153 |
|
1154 |
+
# Apply gradients
|
1155 |
gradients = tape.gradient(loss, self.encoder.trainable_variables)
|
1156 |
self.optimizer.apply_gradients(zip(gradients, self.encoder.trainable_variables))
|
1157 |
return loss
|
1158 |
|
1159 |
@tf.function
|
1160 |
+
def validation_step(
|
1161 |
+
self,
|
1162 |
+
q_batch: tf.Tensor,
|
1163 |
+
p_batch: tf.Tensor,
|
1164 |
+
n_batch: tf.Tensor,
|
1165 |
+
attention_mask: Optional[tf.Tensor] = None
|
1166 |
+
) -> tf.Tensor:
|
1167 |
+
"""
|
1168 |
+
Single validation step with queries, positives, and negatives.
|
1169 |
+
Uses the same loss calculation as train_step, but `training=False`.
|
1170 |
+
"""
|
1171 |
q_enc = self.encoder(q_batch, training=False)
|
1172 |
p_enc = self.encoder(p_batch, training=False)
|
1173 |
|
1174 |
+
shape = tf.shape(n_batch)
|
1175 |
+
bs = shape[0]
|
1176 |
+
neg_samples = shape[1]
|
1177 |
+
|
1178 |
+
n_batch_flat = tf.reshape(n_batch, [bs * neg_samples, shape[2]])
|
1179 |
+
n_enc_flat = self.encoder(n_batch_flat, training=False)
|
1180 |
+
n_enc = tf.reshape(n_enc_flat, [bs, neg_samples, -1])
|
1181 |
+
|
1182 |
+
combined_p_n = tf.concat(
|
1183 |
+
[tf.expand_dims(p_enc, axis=1), n_enc],
|
1184 |
+
axis=1
|
1185 |
+
)
|
1186 |
+
|
1187 |
+
dot_products = tf.einsum('bd,bkd->bk', q_enc, combined_p_n)
|
1188 |
+
labels = tf.zeros([bs], dtype=tf.int32)
|
1189 |
|
1190 |
loss = tf.nn.sparse_softmax_cross_entropy_with_logits(
|
1191 |
+
labels=labels,
|
1192 |
+
logits=dot_products
|
1193 |
)
|
1194 |
+
loss = tf.reduce_mean(loss)
|
1195 |
+
|
1196 |
if attention_mask is not None:
|
1197 |
loss = loss * attention_mask
|
1198 |
loss = tf.reduce_sum(loss) / tf.reduce_sum(attention_mask)
|
|
|
|
|
1199 |
|
1200 |
return loss
|
1201 |
|
|
|
1337 |
conversation_parts.append(f"{self.special_tokens['user']} {query}")
|
1338 |
return "\n".join(conversation_parts)
|
1339 |
|
1340 |
+
class TFDataPipeline:
|
|
|
1341 |
def __init__(
|
1342 |
+
self,
|
1343 |
+
embedding_batch_size,
|
1344 |
+
tokenizer,
|
1345 |
+
encoder,
|
1346 |
+
index,
|
1347 |
+
response_pool,
|
1348 |
+
max_length: int,
|
1349 |
+
neg_samples: int,
|
1350 |
):
|
1351 |
+
self.embedding_batch_size = embedding_batch_size
|
1352 |
self.tokenizer = tokenizer
|
1353 |
self.encoder = encoder
|
1354 |
+
self.index = index # CPU version of the index
|
1355 |
self.response_pool = response_pool
|
1356 |
self.max_length = max_length
|
|
|
1357 |
self.neg_samples = neg_samples
|
1358 |
+
self.embedding_batch_size = 16 if len(response_pool) < 100 else 64
|
1359 |
+
self.search_batch_size = 8 if len(response_pool) < 100 else 32
|
1360 |
+
self.max_batch_size = 32 if len(response_pool) < 100 else 256
|
1361 |
self.memory_monitor = GPUMemoryMonitor()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1362 |
self.max_retries = 3
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1363 |
|
1364 |
+
# In-memory cache for embeddings
|
1365 |
+
self.query_embeddings_cache = {}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1366 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1367 |
def _extract_pairs_from_dialogue(self, dialogue: dict) -> List[Tuple[str, str]]:
|
1368 |
"""Extract query-response pairs from a dialogue."""
|
1369 |
pairs = []
|
|
|
1384 |
|
1385 |
return pairs
|
1386 |
|
1387 |
+
def estimate_total_pairs(self, dialogues: List[dict]) -> int:
|
1388 |
+
"""Estimate total number of training pairs including hard negatives."""
|
1389 |
+
base_pairs = sum(
|
1390 |
+
len([
|
1391 |
+
1 for i in range(len(d.get('turns', [])) - 1)
|
1392 |
+
if (d['turns'][i].get('speaker') == 'user' and
|
1393 |
+
d['turns'][i+1].get('speaker') == 'assistant')
|
1394 |
+
])
|
1395 |
+
for d in dialogues
|
1396 |
+
)
|
1397 |
+
# Account for hard negatives
|
1398 |
+
return base_pairs * (1 + self.neg_samples)
|
1399 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1400 |
def _find_hard_negatives_batch(self, queries: List[str], positives: List[str]) -> List[List[str]]:
|
1401 |
"""Find hard negatives for a batch of queries with error handling and retries."""
|
1402 |
retry_count = 0
|
1403 |
total_responses = len(self.response_pool)
|
1404 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1405 |
while retry_count < self.max_retries:
|
1406 |
try:
|
|
|
1407 |
query_embeddings = np.vstack([
|
1408 |
self.query_embeddings_cache[q] for q in queries
|
1409 |
]).astype(np.float32)
|
1410 |
|
1411 |
+
query_embeddings = np.ascontiguousarray(query_embeddings)
|
|
|
|
|
|
|
1412 |
faiss.normalize_L2(query_embeddings)
|
1413 |
+
|
1414 |
+
k = 1 # TODO: try higher k for better results
|
1415 |
#logger.debug(f"Searching with k={k} among {total_responses} responses")
|
|
|
|
|
1416 |
|
1417 |
+
distances, indices = self.index.search(query_embeddings, k)
|
1418 |
+
|
|
|
|
|
|
|
|
|
|
|
1419 |
all_negatives = []
|
1420 |
+
for query_indices, query, positive in zip(indices, queries, positives):
|
1421 |
negatives = []
|
1422 |
positive_strip = positive.strip()
|
|
|
|
|
1423 |
seen = {positive_strip}
|
1424 |
+
|
1425 |
for idx in query_indices:
|
1426 |
if idx >= 0 and idx < total_responses:
|
1427 |
candidate = self.response_pool[idx].strip()
|
1428 |
+
if candidate and candidate not in seen:
|
1429 |
seen.add(candidate)
|
1430 |
negatives.append(candidate)
|
1431 |
if len(negatives) >= self.neg_samples:
|
1432 |
break
|
1433 |
+
|
1434 |
+
# Pad with a special empty negative if necessary
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1435 |
while len(negatives) < self.neg_samples:
|
1436 |
+
negatives.append("<EMPTY_NEGATIVE>") # Use a special token
|
1437 |
+
|
1438 |
all_negatives.append(negatives)
|
1439 |
+
|
1440 |
return all_negatives
|
1441 |
+
|
1442 |
except Exception as e:
|
1443 |
retry_count += 1
|
1444 |
logger.warning(f"Hard negative search attempt {retry_count} failed: {e}")
|
1445 |
if retry_count == self.max_retries:
|
1446 |
logger.error("Max retries reached for hard negative search")
|
1447 |
+
return [["<EMPTY_NEGATIVE>"] * self.neg_samples for _ in queries] # Return empty negatives for all queries
|
1448 |
gc.collect()
|
1449 |
if tf.config.list_physical_devices('GPU'):
|
1450 |
tf.keras.backend.clear_session()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1451 |
|
1452 |
+
def _tokenize_negatives_tf(self, negatives):
|
1453 |
+
"""Tokenizes negatives using tf.py_function."""
|
1454 |
+
# Handle the case where negatives is an empty tensor
|
1455 |
+
if tf.size(negatives) == 0:
|
1456 |
+
return tf.zeros([0, self.neg_samples, self.max_length], dtype=tf.int32)
|
1457 |
+
|
1458 |
+
# Convert EagerTensor to a list of strings
|
1459 |
+
negatives_list = []
|
1460 |
+
for neg_list in negatives.numpy():
|
1461 |
+
decoded_negs = [neg.decode("utf-8") for neg in neg_list if neg] # Filter out empty strings
|
1462 |
+
negatives_list.append(decoded_negs)
|
1463 |
+
|
1464 |
+
# Flatten the list of lists
|
1465 |
+
flattened_negatives = [neg for sublist in negatives_list for neg in sublist]
|
1466 |
+
|
1467 |
+
# Tokenize the flattened negatives
|
1468 |
+
if flattened_negatives:
|
1469 |
+
n_tokens = self.tokenizer(
|
1470 |
+
flattened_negatives,
|
1471 |
padding='max_length',
|
1472 |
truncation=True,
|
1473 |
max_length=self.max_length,
|
1474 |
return_tensors='tf'
|
1475 |
)
|
1476 |
+
# Reshape the tokens
|
1477 |
+
n_tokens_reshaped = tf.reshape(n_tokens['input_ids'], [-1, self.neg_samples, self.max_length])
|
1478 |
+
return n_tokens_reshaped
|
1479 |
+
else:
|
1480 |
+
return tf.zeros([0, self.neg_samples, self.max_length], dtype=tf.int32)
|
1481 |
+
|
1482 |
+
def _compute_embeddings(self, queries: List[str]) -> None:
|
1483 |
+
"""Computes and caches embeddings for new queries."""
|
1484 |
+
new_queries = [q for q in queries if q not in self.query_embeddings_cache]
|
1485 |
+
if not new_queries:
|
1486 |
+
return # All queries already cached
|
1487 |
+
|
1488 |
+
new_embeddings = []
|
1489 |
+
for i in range(0, len(new_queries), self.embedding_batch_size):
|
1490 |
+
batch_queries = new_queries[i:i + self.embedding_batch_size]
|
1491 |
+
|
1492 |
+
encoded = self.tokenizer(
|
1493 |
+
batch_queries,
|
1494 |
+
padding=True,
|
1495 |
truncation=True,
|
1496 |
max_length=self.max_length,
|
1497 |
return_tensors='tf'
|
1498 |
)
|
1499 |
|
1500 |
+
# Compute embeddings on CPU
|
1501 |
+
with tf.device('/CPU:0'):
|
1502 |
+
batch_embeddings = self.encoder(encoded['input_ids'], training=False).numpy()
|
|
|
|
|
|
|
|
|
|
|
|
|
1503 |
|
1504 |
+
new_embeddings.extend(batch_embeddings)
|
1505 |
+
|
1506 |
+
# Update cache with new embeddings
|
1507 |
+
for query, emb in zip(new_queries, new_embeddings):
|
1508 |
+
self.query_embeddings_cache[query] = emb
|
1509 |
+
|
1510 |
+
def data_generator(self, dialogues: List[dict]) -> Generator[Tuple[str, str, List[str]], None, None]:
|
1511 |
+
"""
|
1512 |
+
Generates training examples: (query, positive, hard_negatives).
|
1513 |
+
Wrapped the outer loop with tqdm for progress tracking.
|
1514 |
+
"""
|
1515 |
+
total_dialogues = len(dialogues)
|
1516 |
+
logger.debug(f"Total dialogues to process: {total_dialogues}")
|
1517 |
+
|
1518 |
+
# Initialize tqdm progress bar
|
1519 |
+
with tqdm(total=total_dialogues, desc="Processing Dialogues", unit="dialogue") as pbar:
|
1520 |
+
for dialogue in dialogues:
|
1521 |
+
pairs = self._extract_pairs_from_dialogue(dialogue)
|
1522 |
+
for query, positive in pairs:
|
1523 |
+
# Ensure embeddings are computed, find hard negatives, etc.
|
1524 |
+
self._compute_embeddings([query])
|
1525 |
+
hard_negatives = self._find_hard_negatives_batch([query], [positive])[0]
|
1526 |
+
yield (query, positive, hard_negatives)
|
1527 |
+
pbar.update(1)
|
1528 |
+
|
1529 |
+
def _prepare_batch(self, queries: tf.Tensor, positives: tf.Tensor, negatives: tf.Tensor) -> Tuple[tf.Tensor, tf.Tensor]:
|
1530 |
+
"""Prepares a batch of data for training."""
|
1531 |
+
|
1532 |
+
# Convert EagerTensors to lists of strings
|
1533 |
+
queries_list = [query.decode("utf-8") for query in queries.numpy()]
|
1534 |
+
positives_list = [pos.decode("utf-8") for pos in positives.numpy()]
|
1535 |
+
|
1536 |
+
# Tokenize queries and positives
|
1537 |
+
q_tokens = self.tokenizer(queries_list, padding='max_length', truncation=True, max_length=self.max_length, return_tensors='tf')
|
1538 |
+
p_tokens = self.tokenizer(positives_list, padding='max_length', truncation=True, max_length=self.max_length, return_tensors='tf')
|
1539 |
+
|
1540 |
+
# Decode negatives and ensure they are lists of strings
|
1541 |
+
negatives_list = []
|
1542 |
+
for neg_list in negatives.numpy():
|
1543 |
+
decoded_negs = [neg.decode("utf-8") for neg in neg_list if neg] # Filter out empty strings
|
1544 |
+
negatives_list.append(decoded_negs)
|
1545 |
+
|
1546 |
+
# Flatten negatives for tokenization if there are any valid negatives
|
1547 |
+
flattened_negatives = [neg for sublist in negatives_list for neg in sublist if neg]
|
1548 |
+
|
1549 |
+
# Tokenize negatives if there are any
|
1550 |
+
n_tokens_reshaped = None
|
1551 |
+
if flattened_negatives:
|
1552 |
+
n_tokens = self.tokenizer(flattened_negatives, padding='max_length', truncation=True, max_length=self.max_length, return_tensors='tf')
|
1553 |
+
|
1554 |
+
# Reshape n_tokens to match the expected shape based on the number of negatives per query
|
1555 |
+
# This part may need adjustment if the number of negatives varies per query
|
1556 |
+
n_tokens_reshaped = tf.reshape(n_tokens['input_ids'], [len(queries_list), -1, self.max_length])
|
1557 |
+
else:
|
1558 |
+
# Create a placeholder tensor for the case where there are no negatives
|
1559 |
+
n_tokens_reshaped = tf.zeros([len(queries_list), 0, self.max_length], dtype=tf.int32)
|
1560 |
+
|
1561 |
+
# Ensure n_tokens_reshaped has a consistent shape even when there are no negatives
|
1562 |
+
# Adjust shape to [batch_size, num_neg_samples, max_length]
|
1563 |
+
if n_tokens_reshaped.shape[1] != self.neg_samples:
|
1564 |
+
# Pad or truncate the second dimension to match neg_samples
|
1565 |
+
padding = tf.zeros([len(queries_list), tf.maximum(0, self.neg_samples - n_tokens_reshaped.shape[1]), self.max_length], dtype=tf.int32)
|
1566 |
+
n_tokens_reshaped = tf.concat([n_tokens_reshaped, padding], axis=1)
|
1567 |
+
n_tokens_reshaped = n_tokens_reshaped[:, :self.neg_samples, :]
|
1568 |
+
|
1569 |
+
# Concatenate the positive and negative examples along the 'neg_samples' dimension
|
1570 |
+
combined_p_n_tokens = tf.concat([tf.expand_dims(p_tokens['input_ids'], axis=1), n_tokens_reshaped], axis=1)
|
1571 |
+
|
1572 |
+
return q_tokens['input_ids'], combined_p_n_tokens
|
1573 |
+
|
1574 |
+
def get_tf_dataset(self, dialogues: List[dict], batch_size: int) -> tf.data.Dataset:
|
1575 |
+
"""
|
1576 |
+
Creates a tf.data.Dataset for streaming training that yields
|
1577 |
+
(input_ids_query, input_ids_positive, input_ids_negatives).
|
1578 |
+
"""
|
1579 |
+
# 1) Start with a generator dataset
|
1580 |
+
dataset = tf.data.Dataset.from_generator(
|
1581 |
+
lambda: self.data_generator(dialogues),
|
1582 |
+
output_signature=(
|
1583 |
+
tf.TensorSpec(shape=(), dtype=tf.string), # Query (single string)
|
1584 |
+
tf.TensorSpec(shape=(), dtype=tf.string), # Positive (single string)
|
1585 |
+
tf.TensorSpec(shape=(None,), dtype=tf.string) # Hard Negatives (list of strings)
|
1586 |
+
)
|
1587 |
)
|
1588 |
+
|
1589 |
+
# 2) Batch the raw strings
|
1590 |
+
dataset = dataset.batch(batch_size)
|
1591 |
+
|
1592 |
+
# 3) Now map them through a tokenize step (via py_function)
|
1593 |
+
dataset = dataset.map(
|
1594 |
+
lambda q, p, n: self._tokenize_triple(q, p, n),
|
1595 |
+
num_parallel_calls=1 #tf.data.AUTOTUNE
|
1596 |
+
)
|
1597 |
+
|
1598 |
+
dataset = dataset.prefetch(tf.data.AUTOTUNE)
|
1599 |
+
return dataset
|
1600 |
+
|
1601 |
+
def _tokenize_triple(
|
1602 |
+
self,
|
1603 |
+
q: tf.Tensor,
|
1604 |
+
p: tf.Tensor,
|
1605 |
+
n: tf.Tensor
|
1606 |
+
) -> Tuple[tf.Tensor, tf.Tensor, tf.Tensor]:
|
1607 |
+
"""
|
1608 |
+
Wraps a Python function via tf.py_function to convert tf.Tensors of strings
|
1609 |
+
-> Python lists of strings -> HF tokenizer -> Tensors of IDs.
|
1610 |
+
|
1611 |
+
q is shape [batch_size], p is shape [batch_size],
|
1612 |
+
n is shape [batch_size, neg_samples] (i.e., each row is a list of negatives).
|
1613 |
+
"""
|
1614 |
+
# Use tf.py_function with limited parallelism
|
1615 |
+
q_ids, p_ids, n_ids = tf.py_function(
|
1616 |
+
func=self._tokenize_triple_py,
|
1617 |
+
inp=[q, p, n, tf.constant(self.max_length), tf.constant(self.neg_samples)],
|
1618 |
+
Tout=[tf.int32, tf.int32, tf.int32]
|
1619 |
+
)
|
1620 |
+
|
1621 |
+
# Manually set shape information
|
1622 |
+
q_ids.set_shape([None, self.max_length]) # [batch_size, max_length]
|
1623 |
+
p_ids.set_shape([None, self.max_length]) # [batch_size, max_length]
|
1624 |
+
n_ids.set_shape([None, self.neg_samples, self.max_length]) # [batch_size, neg_samples, max_length]
|
1625 |
+
|
1626 |
+
return q_ids, p_ids, n_ids
|
1627 |
+
# def _tokenize_triple(
|
1628 |
+
# self,
|
1629 |
+
# q: tf.Tensor,
|
1630 |
+
# p: tf.Tensor,
|
1631 |
+
# n: tf.Tensor
|
1632 |
+
# ) -> Tuple[tf.Tensor, tf.Tensor, tf.Tensor]:
|
1633 |
+
# """
|
1634 |
+
# Wraps a Python function via tf.py_function to convert tf.Tensors of strings
|
1635 |
+
# -> Python lists of strings -> HF tokenizer -> Tensors of IDs.
|
1636 |
+
|
1637 |
+
# q is shape [batch_size], p is shape [batch_size],
|
1638 |
+
# n is shape [batch_size, None] (i.e. each row is a variable number of negatives).
|
1639 |
+
# """
|
1640 |
+
# # Use tf.py_function
|
1641 |
+
# # We pass in self.max_length as well, so we can do it in one shot.
|
1642 |
+
# q_ids, p_ids, n_ids = tf.py_function(
|
1643 |
+
# func=self._tokenize_triple_py,
|
1644 |
+
# inp=[q, p, n, tf.constant(self.max_length), tf.constant(self.neg_samples)],
|
1645 |
+
# Tout=[tf.int32, tf.int32, tf.int32]
|
1646 |
+
# )
|
1647 |
+
|
1648 |
+
# # We must manually set shape information so that TF data pipeline knows the dimensions
|
1649 |
+
# q_ids.set_shape([None, self.max_length]) # [batch_size, max_length]
|
1650 |
+
# p_ids.set_shape([None, self.max_length]) # [batch_size, max_length]
|
1651 |
+
# n_ids.set_shape([None, self.neg_samples, self.max_length])
|
1652 |
+
# # The negative dimension is set to `self.neg_samples` for consistency.
|
1653 |
+
|
1654 |
+
# return q_ids, p_ids, n_ids
|
1655 |
+
|
1656 |
+
def _tokenize_triple_py(
|
1657 |
+
self,
|
1658 |
+
q: tf.Tensor,
|
1659 |
+
p: tf.Tensor,
|
1660 |
+
n: tf.Tensor,
|
1661 |
+
max_len: tf.Tensor,
|
1662 |
+
neg_samples: tf.Tensor
|
1663 |
+
) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
|
1664 |
+
"""
|
1665 |
+
Python function that:
|
1666 |
+
- Decodes each tf.string Tensor to a Python list of strings
|
1667 |
+
- Calls the HF tokenizer
|
1668 |
+
- Reshapes negatives
|
1669 |
+
- Returns np.array of int32s for (q_ids, p_ids, n_ids).
|
1670 |
+
|
1671 |
+
q: shape [batch_size], p: shape [batch_size]
|
1672 |
+
n: shape [batch_size, neg_samples]
|
1673 |
+
max_len: scalar int
|
1674 |
+
neg_samples: scalar int
|
1675 |
+
"""
|
1676 |
+
max_len = int(max_len.numpy()) # Convert to Python int
|
1677 |
+
neg_samples = int(neg_samples.numpy())
|
1678 |
+
|
1679 |
+
# 1) Convert Tensors -> Python lists of strings
|
1680 |
+
q_list = [q_i.decode("utf-8") for q_i in q.numpy()] # shape [batch_size]
|
1681 |
+
p_list = [p_i.decode("utf-8") for p_i in p.numpy()] # shape [batch_size]
|
1682 |
+
|
1683 |
+
# shape [batch_size, neg_samples], decode each row
|
1684 |
+
n_list = []
|
1685 |
+
for row in n.numpy():
|
1686 |
+
# row is shape [neg_samples], each is a tf.string
|
1687 |
+
decoded = [neg.decode("utf-8") for neg in row]
|
1688 |
+
n_list.append(decoded)
|
1689 |
+
|
1690 |
+
# 2) Tokenize queries & positives
|
1691 |
+
q_enc = self.tokenizer(
|
1692 |
+
q_list,
|
1693 |
+
padding="max_length",
|
1694 |
+
truncation=True,
|
1695 |
+
max_length=max_len,
|
1696 |
+
return_tensors="np"
|
1697 |
+
)
|
1698 |
+
p_enc = self.tokenizer(
|
1699 |
+
p_list,
|
1700 |
+
padding="max_length",
|
1701 |
+
truncation=True,
|
1702 |
+
max_length=max_len,
|
1703 |
+
return_tensors="np"
|
1704 |
+
)
|
1705 |
+
|
1706 |
+
# 3) Tokenize negatives
|
1707 |
+
# Flatten [batch_size, neg_samples] -> single list
|
1708 |
+
flattened_negatives = [neg for row in n_list for neg in row]
|
1709 |
+
if len(flattened_negatives) == 0:
|
1710 |
+
# No negatives at all: return a zero array
|
1711 |
+
n_ids = np.zeros((len(q_list), neg_samples, max_len), dtype=np.int32)
|
1712 |
+
else:
|
1713 |
+
n_enc = self.tokenizer(
|
1714 |
+
flattened_negatives,
|
1715 |
+
padding="max_length",
|
1716 |
+
truncation=True,
|
1717 |
+
max_length=max_len,
|
1718 |
+
return_tensors="np"
|
1719 |
+
)
|
1720 |
+
# shape [batch_size * neg_samples, max_len]
|
1721 |
+
n_input_ids = n_enc["input_ids"]
|
1722 |
+
|
1723 |
+
# We want to reshape to [batch_size, neg_samples, max_len]
|
1724 |
+
# Handle cases where there might be fewer negatives
|
1725 |
+
batch_size = len(q_list)
|
1726 |
+
n_ids_list = []
|
1727 |
+
for i in range(batch_size):
|
1728 |
+
start_idx = i * neg_samples
|
1729 |
+
end_idx = start_idx + neg_samples
|
1730 |
+
row_negs = n_input_ids[start_idx:end_idx]
|
1731 |
+
|
1732 |
+
# If fewer negatives, pad with zeros
|
1733 |
+
if row_negs.shape[0] < neg_samples:
|
1734 |
+
deficit = neg_samples - row_negs.shape[0]
|
1735 |
+
pad_arr = np.zeros((deficit, max_len), dtype=np.int32)
|
1736 |
+
row_negs = np.concatenate([row_negs, pad_arr], axis=0)
|
1737 |
+
|
1738 |
+
n_ids_list.append(row_negs)
|
1739 |
+
|
1740 |
+
# stack them -> shape [batch_size, neg_samples, max_len]
|
1741 |
+
n_ids = np.stack(n_ids_list, axis=0)
|
1742 |
+
|
1743 |
+
# 4) Return as np.int32 arrays
|
1744 |
+
q_ids = q_enc["input_ids"].astype(np.int32) # shape [batch_size, max_len]
|
1745 |
+
p_ids = p_enc["input_ids"].astype(np.int32) # shape [batch_size, max_len]
|
1746 |
+
n_ids = n_ids.astype(np.int32) # shape [batch_size, neg_samples, max_len]
|
1747 |
+
|
1748 |
+
return q_ids, p_ids, n_ids
|
1749 |
+
# def _tokenize_triple_py(
|
1750 |
+
# self,
|
1751 |
+
# q: tf.Tensor,
|
1752 |
+
# p: tf.Tensor,
|
1753 |
+
# n: tf.Tensor,
|
1754 |
+
# max_len: tf.Tensor,
|
1755 |
+
# neg_samples: tf.Tensor
|
1756 |
+
# ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
|
1757 |
+
# """
|
1758 |
+
# Python function that:
|
1759 |
+
# - Decodes each tf.string Tensor to a Python list of strings
|
1760 |
+
# - Calls the HF tokenizer
|
1761 |
+
# - Reshapes negatives
|
1762 |
+
# - Returns np.array of int32s for (q_ids, p_ids, n_ids).
|
1763 |
+
|
1764 |
+
# q: shape [batch_size], p: shape [batch_size]
|
1765 |
+
# n: shape [batch_size, None]
|
1766 |
+
# max_len: scalar int
|
1767 |
+
# neg_samples: scalar int
|
1768 |
+
# """
|
1769 |
+
# max_len = int(max_len.numpy()) # convert to python int
|
1770 |
+
# neg_samples = int(neg_samples.numpy())
|
1771 |
+
|
1772 |
+
# # 1) Convert Tensors -> Python lists of strings
|
1773 |
+
# q_list = [q_i.decode("utf-8") for q_i in q.numpy()] # shape [batch_size]
|
1774 |
+
# p_list = [p_i.decode("utf-8") for p_i in p.numpy()] # shape [batch_size]
|
1775 |
+
|
1776 |
+
# # shape [batch_size, variable_negatives], decode each row
|
1777 |
+
# n_list = []
|
1778 |
+
# for row in n.numpy():
|
1779 |
+
# # row is shape [N], each is a tf.string
|
1780 |
+
# decoded = [neg.decode("utf-8") for neg in row]
|
1781 |
+
# n_list.append(decoded)
|
1782 |
+
|
1783 |
+
# # 2) Tokenize queries & positives
|
1784 |
+
# q_enc = self.tokenizer(
|
1785 |
+
# q_list,
|
1786 |
+
# padding="max_length",
|
1787 |
+
# truncation=True,
|
1788 |
+
# max_length=max_len,
|
1789 |
+
# return_tensors="np" # you can do return_tensors="tf", but "np" is often simpler here
|
1790 |
+
# )
|
1791 |
+
# p_enc = self.tokenizer(
|
1792 |
+
# p_list,
|
1793 |
+
# padding="max_length",
|
1794 |
+
# truncation=True,
|
1795 |
+
# max_length=max_len,
|
1796 |
+
# return_tensors="np"
|
1797 |
+
# )
|
1798 |
+
|
1799 |
+
# # 3) Tokenize negatives
|
1800 |
+
# # Flatten [batch_size, variable_negatives] -> single list
|
1801 |
+
# flattened_negatives = [neg for row in n_list for neg in row]
|
1802 |
+
# if len(flattened_negatives) == 0:
|
1803 |
+
# # No negatives at all: return a zero array
|
1804 |
+
# n_ids = np.zeros((len(q_list), neg_samples, max_len), dtype=np.int32)
|
1805 |
+
# else:
|
1806 |
+
# n_enc = self.tokenizer(
|
1807 |
+
# flattened_negatives,
|
1808 |
+
# padding="max_length",
|
1809 |
+
# truncation=True,
|
1810 |
+
# max_length=max_len,
|
1811 |
+
# return_tensors="np"
|
1812 |
+
# )
|
1813 |
+
# # shape [batch_size * total_negatives, max_len]
|
1814 |
+
# n_input_ids = n_enc["input_ids"]
|
1815 |
+
|
1816 |
+
# # We want to reshape to [batch_size, neg_samples, max_len].
|
1817 |
+
# # If each row truly has exactly `neg_samples` (or fewer), we can do:
|
1818 |
+
# # n_input_ids = n_input_ids.reshape(len(q_list), neg_samples, max_len)
|
1819 |
+
# # But if the rows have variable # of negatives, we must clamp or pad.
|
1820 |
+
# # For simplicity, let's just "take first neg_samples" per row
|
1821 |
+
# # and pad if fewer.
|
1822 |
+
|
1823 |
+
# # We'll do it row by row:
|
1824 |
+
# batch_size = len(q_list)
|
1825 |
+
# row_offsets = 0
|
1826 |
+
# n_ids_list = []
|
1827 |
+
# for row_idx in range(batch_size):
|
1828 |
+
# row_negs = n_list[row_idx]
|
1829 |
+
# row_count = len(row_negs)
|
1830 |
+
|
1831 |
+
# # slice from the flattened array
|
1832 |
+
# row_slice = n_input_ids[row_offsets:row_offsets + row_count]
|
1833 |
+
# row_offsets += row_count
|
1834 |
+
|
1835 |
+
# # Now pick out up to neg_samples
|
1836 |
+
# row_slice = row_slice[:neg_samples]
|
1837 |
+
|
1838 |
+
# # If fewer than neg_samples, pad
|
1839 |
+
# if row_slice.shape[0] < neg_samples:
|
1840 |
+
# deficit = neg_samples - row_slice.shape[0]
|
1841 |
+
# pad_arr = np.zeros((deficit, max_len), dtype=np.int32)
|
1842 |
+
# row_slice = np.concatenate([row_slice, pad_arr], axis=0)
|
1843 |
+
|
1844 |
+
# # row_slice is now shape [neg_samples, max_len]
|
1845 |
+
# n_ids_list.append(row_slice)
|
1846 |
+
|
1847 |
+
# # stack them -> shape [batch_size, neg_samples, max_len]
|
1848 |
+
# n_ids = np.stack(n_ids_list, axis=0)
|
1849 |
+
|
1850 |
+
# # 4) Return as np.int32 arrays (tokenizer should already return int32,
|
1851 |
+
# # but we can cast to be sure)
|
1852 |
+
# q_ids = q_enc["input_ids"].astype(np.int32) # shape [batch_size, max_len]
|
1853 |
+
# p_ids = p_enc["input_ids"].astype(np.int32) # shape [batch_size, max_len]
|
1854 |
+
# n_ids = n_ids.astype(np.int32) # shape [batch_size, neg_samples, max_len]
|
1855 |
|
1856 |
+
# return q_ids, p_ids, n_ids
|
1857 |
+
|
|
|
|
|
|
|
|
deduplicate_augmented_dialogues.py
ADDED
@@ -0,0 +1,74 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
from pathlib import Path
|
3 |
+
import logging
|
4 |
+
from typing import List, Dict
|
5 |
+
from collections import defaultdict
|
6 |
+
|
7 |
+
logging.basicConfig(level=logging.INFO)
|
8 |
+
logger = logging.getLogger(__name__)
|
9 |
+
|
10 |
+
def load_json_file(file_path: str) -> List[Dict]:
|
11 |
+
"""Load and parse a JSON file."""
|
12 |
+
try:
|
13 |
+
with open(file_path, 'r', encoding='utf-8') as f:
|
14 |
+
return json.load(f)
|
15 |
+
except json.JSONDecodeError as e:
|
16 |
+
logger.error(f"Error parsing JSON from {file_path}: {e}")
|
17 |
+
return []
|
18 |
+
except Exception as e:
|
19 |
+
logger.error(f"Error reading file {file_path}: {e}")
|
20 |
+
return []
|
21 |
+
|
22 |
+
def combine_json_files(input_directory: str, output_file: str):
|
23 |
+
"""
|
24 |
+
Combine multiple JSON files while removing duplicates based on dialogue_id.
|
25 |
+
|
26 |
+
Args:
|
27 |
+
input_directory: Directory containing JSON files to process
|
28 |
+
output_file: Path to save the combined output
|
29 |
+
"""
|
30 |
+
# Track unique dialogues and their source files
|
31 |
+
dialogue_map = {}
|
32 |
+
duplicate_count = 0
|
33 |
+
|
34 |
+
# Process all JSON files in the directory
|
35 |
+
input_path = Path(input_directory)
|
36 |
+
for json_file in input_path.glob('*.json'):
|
37 |
+
logger.info(f"Processing {json_file}")
|
38 |
+
|
39 |
+
data = load_json_file(str(json_file))
|
40 |
+
|
41 |
+
# Process each dialogue in the file
|
42 |
+
for dialogue in data:
|
43 |
+
dialogue_id = dialogue.get('dialogue_id')
|
44 |
+
|
45 |
+
if not dialogue_id:
|
46 |
+
logger.warning(f"Found dialogue without ID in {json_file}")
|
47 |
+
continue
|
48 |
+
|
49 |
+
# Keep the first occurrence
|
50 |
+
if dialogue_id in dialogue_map:
|
51 |
+
duplicate_count += 1
|
52 |
+
logger.debug(f"Duplicate dialogue_id found: {dialogue_id}")
|
53 |
+
|
54 |
+
else:
|
55 |
+
dialogue_map[dialogue_id] = dialogue
|
56 |
+
|
57 |
+
# Convert the map of unique dialogues back to a list
|
58 |
+
unique_dialogues = list(dialogue_map.values())
|
59 |
+
|
60 |
+
# Save combined dialogues to a new file
|
61 |
+
try:
|
62 |
+
with open(output_file, 'w', encoding='utf-8') as f:
|
63 |
+
json.dump(unique_dialogues, f, indent=4)
|
64 |
+
logger.info(f"Successfully combined files. Found {duplicate_count} duplicates.")
|
65 |
+
logger.info(f"Total unique dialogues: {len(unique_dialogues)}")
|
66 |
+
except Exception as e:
|
67 |
+
logger.error(f"Error writing output file: {e}")
|
68 |
+
|
69 |
+
# Usage example
|
70 |
+
if __name__ == "__main__":
|
71 |
+
combine_json_files(
|
72 |
+
input_directory="/Users/joe/Desktop/Grad School/CSC525/CSC525_mod8_option2_joseph_armani/processed_outputs",
|
73 |
+
output_file="augmented_dialogues.json"
|
74 |
+
)
|
run_model_train.py
CHANGED
@@ -5,7 +5,6 @@ from response_quality_checker import ResponseQualityChecker
|
|
5 |
from chatbot_validator import ChatbotValidator
|
6 |
from training_plotter import TrainingPlotter
|
7 |
|
8 |
-
|
9 |
# Configure logging
|
10 |
from logger_config import config_logger
|
11 |
logger = config_logger(__name__)
|
@@ -38,7 +37,7 @@ def main():
|
|
38 |
env = EnvironmentSetup()
|
39 |
env.initialize()
|
40 |
|
41 |
-
DEBUG_SAMPLES =
|
42 |
EPOCHS = 5 if DEBUG_SAMPLES else 20
|
43 |
TRAINING_DATA_PATH = 'processed_outputs/batch_group_0010.json'
|
44 |
|
@@ -47,13 +46,14 @@ def main():
|
|
47 |
|
48 |
# Initialize configuration
|
49 |
config = ChatbotConfig(
|
50 |
-
embedding_dim=
|
51 |
max_context_token_limit=512,
|
52 |
freeze_embeddings=False,
|
53 |
)
|
54 |
|
55 |
# Load training data
|
56 |
dialogues = RetrievalChatbot.load_training_data(data_path=TRAINING_DATA_PATH, debug_samples=DEBUG_SAMPLES)
|
|
|
57 |
|
58 |
# Initialize chatbot and verify FAISS index
|
59 |
#with env.strategy.scope():
|
|
|
5 |
from chatbot_validator import ChatbotValidator
|
6 |
from training_plotter import TrainingPlotter
|
7 |
|
|
|
8 |
# Configure logging
|
9 |
from logger_config import config_logger
|
10 |
logger = config_logger(__name__)
|
|
|
37 |
env = EnvironmentSetup()
|
38 |
env.initialize()
|
39 |
|
40 |
+
DEBUG_SAMPLES = 5
|
41 |
EPOCHS = 5 if DEBUG_SAMPLES else 20
|
42 |
TRAINING_DATA_PATH = 'processed_outputs/batch_group_0010.json'
|
43 |
|
|
|
46 |
|
47 |
# Initialize configuration
|
48 |
config = ChatbotConfig(
|
49 |
+
embedding_dim=768, # DistilBERT
|
50 |
max_context_token_limit=512,
|
51 |
freeze_embeddings=False,
|
52 |
)
|
53 |
|
54 |
# Load training data
|
55 |
dialogues = RetrievalChatbot.load_training_data(data_path=TRAINING_DATA_PATH, debug_samples=DEBUG_SAMPLES)
|
56 |
+
print(dialogues)
|
57 |
|
58 |
# Initialize chatbot and verify FAISS index
|
59 |
#with env.strategy.scope():
|