JoeArmani
commited on
Commit
·
74af405
1
Parent(s):
2183656
data processing pipeline
Browse files- chatbot_model.py +133 -637
- requirements.txt +1 -0
- run_data_preparer.py +182 -0
- tf_data_pipeline.py +734 -0
chatbot_model.py
CHANGED
|
@@ -2,7 +2,7 @@ import time
|
|
| 2 |
from transformers import TFAutoModel, AutoTokenizer
|
| 3 |
import tensorflow as tf
|
| 4 |
import numpy as np
|
| 5 |
-
from typing import
|
| 6 |
import math
|
| 7 |
from dataclasses import dataclass
|
| 8 |
import json
|
|
@@ -10,6 +10,7 @@ from pathlib import Path
|
|
| 10 |
import datetime
|
| 11 |
import faiss
|
| 12 |
import gc
|
|
|
|
| 13 |
from response_quality_checker import ResponseQualityChecker
|
| 14 |
from cross_encoder_reranker import CrossEncoderReranker
|
| 15 |
from conversation_summarizer import DeviceAwareModel, Summarizer
|
|
@@ -24,14 +25,12 @@ logger = config_logger(__name__)
|
|
| 24 |
@dataclass
|
| 25 |
class ChatbotConfig:
|
| 26 |
"""Configuration for the RetrievalChatbot."""
|
| 27 |
-
vocab_size: int = 30526 # DistilBERT vocab size + special tokens
|
| 28 |
max_context_token_limit: int = 512
|
| 29 |
embedding_dim: int = 768
|
| 30 |
encoder_units: int = 256
|
| 31 |
num_attention_heads: int = 8
|
| 32 |
dropout_rate: float = 0.2
|
| 33 |
l2_reg_weight: float = 0.001
|
| 34 |
-
margin: float = 0.3
|
| 35 |
learning_rate: float = 0.001
|
| 36 |
min_text_length: int = 3
|
| 37 |
max_context_turns: int = 5
|
|
@@ -39,16 +38,19 @@ class ChatbotConfig:
|
|
| 39 |
pretrained_model: str = 'distilbert-base-uncased'
|
| 40 |
dtype: str = 'float32'
|
| 41 |
freeze_embeddings: bool = False
|
| 42 |
-
embedding_batch_size: int =
|
| 43 |
-
|
|
|
|
|
|
|
|
|
|
| 44 |
|
| 45 |
-
def to_dict(self) ->
|
| 46 |
"""Convert config to dictionary."""
|
| 47 |
-
return {k: str(v) if isinstance(v, Path) else v
|
| 48 |
for k, v in self.__dict__.items()}
|
| 49 |
|
| 50 |
@classmethod
|
| 51 |
-
def from_dict(cls, config_dict:
|
| 52 |
"""Create config from dictionary."""
|
| 53 |
return cls(**{k: v for k, v in config_dict.items()
|
| 54 |
if k in cls.__dataclass_fields__})
|
|
@@ -59,24 +61,17 @@ class EncoderModel(tf.keras.Model):
|
|
| 59 |
self,
|
| 60 |
config: ChatbotConfig,
|
| 61 |
name: str = "encoder",
|
| 62 |
-
shared_weights: bool = False,
|
| 63 |
**kwargs
|
| 64 |
):
|
| 65 |
super().__init__(name=name, **kwargs)
|
| 66 |
self.config = config
|
| 67 |
-
self.shared_weights = shared_weights
|
| 68 |
|
| 69 |
# Load pretrained model
|
| 70 |
self.pretrained = TFAutoModel.from_pretrained(config.pretrained_model)
|
| 71 |
|
| 72 |
-
# Freeze
|
| 73 |
-
self.
|
| 74 |
-
|
| 75 |
-
if i < 1: # freeze first layer
|
| 76 |
-
layer_module.trainable = False
|
| 77 |
-
else:
|
| 78 |
-
layer_module.trainable = True
|
| 79 |
-
|
| 80 |
# Pooling layer (Global Average Pooling)
|
| 81 |
self.pooler = tf.keras.layers.GlobalAveragePooling1D()
|
| 82 |
|
|
@@ -90,9 +85,27 @@ class EncoderModel(tf.keras.Model):
|
|
| 90 |
# Dropout and normalization
|
| 91 |
self.dropout = tf.keras.layers.Dropout(config.dropout_rate)
|
| 92 |
self.normalize = tf.keras.layers.Lambda(
|
| 93 |
-
lambda x: tf.nn.l2_normalize(x, axis=1)
|
|
|
|
| 94 |
)
|
| 95 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 96 |
def call(self, inputs: tf.Tensor, training: bool = False) -> tf.Tensor:
|
| 97 |
"""Forward pass."""
|
| 98 |
# Get pretrained embeddings
|
|
@@ -112,46 +125,33 @@ class EncoderModel(tf.keras.Model):
|
|
| 112 |
config = super().get_config()
|
| 113 |
config.update({
|
| 114 |
"config": self.config.to_dict(),
|
| 115 |
-
"shared_weights": self.shared_weights,
|
| 116 |
"name": self.name
|
| 117 |
})
|
| 118 |
return config
|
| 119 |
|
| 120 |
class RetrievalChatbot(DeviceAwareModel):
|
| 121 |
"""Retrieval-based chatbot using pretrained embeddings and FAISS for similarity search."""
|
| 122 |
-
def __init__(
|
| 123 |
-
|
| 124 |
-
|
| 125 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 126 |
self.config = config
|
| 127 |
self.strategy = strategy
|
| 128 |
-
self.
|
| 129 |
-
|
| 130 |
-
if reranker is None:
|
| 131 |
-
logger.info("Creating default CrossEncoderReranker...")
|
| 132 |
-
reranker = CrossEncoderReranker(model_name="cross-encoder/ms-marco-MiniLM-L-12-v2")
|
| 133 |
-
self.reranker = reranker
|
| 134 |
-
|
| 135 |
-
if summarizer is None:
|
| 136 |
-
logger.info("Creating default Summarizer...")
|
| 137 |
-
summarizer = Summarizer(device=self.device)
|
| 138 |
-
self.summarizer = summarizer
|
| 139 |
-
|
| 140 |
-
# Special tokens
|
| 141 |
-
self.special_tokens = {
|
| 142 |
-
"user": "<USER>",
|
| 143 |
-
"assistant": "<ASSISTANT>",
|
| 144 |
-
"context": "<CONTEXT>",
|
| 145 |
-
"sep": "<SEP>"
|
| 146 |
-
}
|
| 147 |
-
|
| 148 |
-
# Initialize tokenizer and add special tokens
|
| 149 |
-
self.tokenizer = AutoTokenizer.from_pretrained(config.pretrained_model)
|
| 150 |
-
self.tokenizer.add_special_tokens(
|
| 151 |
-
{'additional_special_tokens': list(self.special_tokens.values())}
|
| 152 |
-
)
|
| 153 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 154 |
self.memory_monitor = GPUMemoryMonitor()
|
|
|
|
|
|
|
| 155 |
self.min_batch_size = 8
|
| 156 |
self.max_batch_size = 128
|
| 157 |
self.current_batch_size = 32
|
|
@@ -166,9 +166,62 @@ class RetrievalChatbot(DeviceAwareModel):
|
|
| 166 |
"train_metrics": {},
|
| 167 |
"val_metrics": {}
|
| 168 |
}
|
| 169 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 170 |
def build_models(self):
|
| 171 |
-
"""Initialize the shared encoder."""
|
| 172 |
logger.info("Building encoder model...")
|
| 173 |
tf.keras.backend.clear_session()
|
| 174 |
|
|
@@ -176,6 +229,7 @@ class RetrievalChatbot(DeviceAwareModel):
|
|
| 176 |
self.encoder = EncoderModel(
|
| 177 |
self.config,
|
| 178 |
name="shared_encoder",
|
|
|
|
| 179 |
)
|
| 180 |
|
| 181 |
# Resize token embeddings after adding special tokens
|
|
@@ -183,31 +237,14 @@ class RetrievalChatbot(DeviceAwareModel):
|
|
| 183 |
self.encoder.pretrained.resize_token_embeddings(new_vocab_size)
|
| 184 |
logger.info(f"Token embeddings resized to: {new_vocab_size}")
|
| 185 |
|
| 186 |
-
# Initialize FAISS index
|
| 187 |
self._initialize_faiss()
|
| 188 |
-
# Compute embeddings after FAISS is initialized and moved
|
| 189 |
-
self._compute_and_index_embeddings()
|
| 190 |
|
| 191 |
-
#
|
| 192 |
-
|
| 193 |
-
# First try: from config
|
| 194 |
-
embedding_dim = self.encoder.pretrained.config.dim
|
| 195 |
-
logger.info("Got embedding dim from config")
|
| 196 |
-
except AttributeError:
|
| 197 |
-
try:
|
| 198 |
-
# Second try: from word embeddings
|
| 199 |
-
embedding_dim = self.encoder.pretrained.distilbert.embeddings.word_embeddings.embedding_dim
|
| 200 |
-
logger.info("Got embedding dim from word embeddings")
|
| 201 |
-
except AttributeError:
|
| 202 |
-
try:
|
| 203 |
-
# Third try: from embeddings module
|
| 204 |
-
embedding_dim = self.encoder.pretrained.distilbert.embeddings.embedding_dim
|
| 205 |
-
logger.info("Got embedding dim from embeddings module")
|
| 206 |
-
except AttributeError:
|
| 207 |
-
# Fallback to config value
|
| 208 |
-
embedding_dim = self.config.embedding_dim
|
| 209 |
-
logger.info("Using config embedding dim")
|
| 210 |
|
|
|
|
|
|
|
| 211 |
vocab_size = len(self.tokenizer)
|
| 212 |
|
| 213 |
logger.info(f"Encoder Embedding Dimension: {embedding_dim}")
|
|
@@ -217,29 +254,6 @@ class RetrievalChatbot(DeviceAwareModel):
|
|
| 217 |
else:
|
| 218 |
logger.error("Vocabulary size is less than embedding dimension.")
|
| 219 |
raise ValueError("Vocabulary size is less than embedding dimension.")
|
| 220 |
-
|
| 221 |
-
def _collect_responses(self, dialogues: List[dict]) -> Tuple[List[str], List[str]]:
|
| 222 |
-
"""Collect all unique responses from dialogues."""
|
| 223 |
-
logger.info("Collecting responses from dialogues...")
|
| 224 |
-
|
| 225 |
-
responses = []
|
| 226 |
-
try:
|
| 227 |
-
progress_bar = tqdm(dialogues, desc="Collecting assistant responses")
|
| 228 |
-
except ImportError:
|
| 229 |
-
progress_bar = dialogues
|
| 230 |
-
logger.info("Progress bar disabled - continuing without visual progress")
|
| 231 |
-
|
| 232 |
-
for dialogue in progress_bar:
|
| 233 |
-
turns = dialogue.get('turns', [])
|
| 234 |
-
for turn in turns:
|
| 235 |
-
if turn.get('speaker') == 'assistant' and 'text' in turn:
|
| 236 |
-
responses.append(turn['text'].strip())
|
| 237 |
-
|
| 238 |
-
# Remove duplicates
|
| 239 |
-
unique_responses = list(set(responses))
|
| 240 |
-
logger.info(f"Found {len(unique_responses)} unique responses.")
|
| 241 |
-
|
| 242 |
-
return responses, unique_responses
|
| 243 |
|
| 244 |
def _adjust_batch_size(self) -> None:
|
| 245 |
"""Dynamically adjust batch size based on GPU memory usage."""
|
|
@@ -288,6 +302,7 @@ class RetrievalChatbot(DeviceAwareModel):
|
|
| 288 |
logger.warning(f"Using CPU due to GPU initialization error: {e}")
|
| 289 |
|
| 290 |
# TODO: figure out buf with faiss-gpu
|
|
|
|
| 291 |
try:
|
| 292 |
# Create appropriate index based on dataset size
|
| 293 |
if len(self.unique_responses) < 1000:
|
|
@@ -860,33 +875,33 @@ class RetrievalChatbot(DeviceAwareModel):
|
|
| 860 |
logger.info(f"Models and tokenizer loaded from {load_dir}.")
|
| 861 |
return chatbot
|
| 862 |
|
| 863 |
-
@staticmethod
|
| 864 |
-
def load_training_data(data_path: Union[str, Path], debug_samples: Optional[int] = None) -> List[dict]:
|
| 865 |
-
|
| 866 |
-
|
| 867 |
|
| 868 |
-
|
| 869 |
-
|
| 870 |
-
|
| 871 |
|
| 872 |
-
|
| 873 |
-
|
| 874 |
-
|
| 875 |
-
|
| 876 |
-
|
| 877 |
-
|
| 878 |
-
|
| 879 |
-
|
| 880 |
|
| 881 |
-
|
| 882 |
-
|
| 883 |
|
| 884 |
-
|
| 885 |
-
|
| 886 |
-
|
| 887 |
|
| 888 |
-
|
| 889 |
-
|
| 890 |
|
| 891 |
def train_streaming(
|
| 892 |
self,
|
|
@@ -1336,522 +1351,3 @@ class RetrievalChatbot(DeviceAwareModel):
|
|
| 1336 |
|
| 1337 |
conversation_parts.append(f"{self.special_tokens['user']} {query}")
|
| 1338 |
return "\n".join(conversation_parts)
|
| 1339 |
-
|
| 1340 |
-
class TFDataPipeline:
|
| 1341 |
-
def __init__(
|
| 1342 |
-
self,
|
| 1343 |
-
embedding_batch_size,
|
| 1344 |
-
tokenizer,
|
| 1345 |
-
encoder,
|
| 1346 |
-
index,
|
| 1347 |
-
response_pool,
|
| 1348 |
-
max_length: int,
|
| 1349 |
-
neg_samples: int,
|
| 1350 |
-
):
|
| 1351 |
-
self.embedding_batch_size = embedding_batch_size
|
| 1352 |
-
self.tokenizer = tokenizer
|
| 1353 |
-
self.encoder = encoder
|
| 1354 |
-
self.index = index # CPU version of the index
|
| 1355 |
-
self.response_pool = response_pool
|
| 1356 |
-
self.max_length = max_length
|
| 1357 |
-
self.neg_samples = neg_samples
|
| 1358 |
-
self.embedding_batch_size = 16 if len(response_pool) < 100 else 64
|
| 1359 |
-
self.search_batch_size = 8 if len(response_pool) < 100 else 32
|
| 1360 |
-
self.max_batch_size = 32 if len(response_pool) < 100 else 256
|
| 1361 |
-
self.memory_monitor = GPUMemoryMonitor()
|
| 1362 |
-
self.max_retries = 3
|
| 1363 |
-
|
| 1364 |
-
# In-memory cache for embeddings
|
| 1365 |
-
self.query_embeddings_cache = {}
|
| 1366 |
-
|
| 1367 |
-
def _extract_pairs_from_dialogue(self, dialogue: dict) -> List[Tuple[str, str]]:
|
| 1368 |
-
"""Extract query-response pairs from a dialogue."""
|
| 1369 |
-
pairs = []
|
| 1370 |
-
turns = dialogue.get('turns', [])
|
| 1371 |
-
|
| 1372 |
-
for i in range(len(turns) - 1):
|
| 1373 |
-
current_turn = turns[i]
|
| 1374 |
-
next_turn = turns[i+1]
|
| 1375 |
-
|
| 1376 |
-
if (current_turn.get('speaker') == 'user' and
|
| 1377 |
-
next_turn.get('speaker') == 'assistant' and
|
| 1378 |
-
'text' in current_turn and
|
| 1379 |
-
'text' in next_turn):
|
| 1380 |
-
|
| 1381 |
-
query = current_turn['text'].strip()
|
| 1382 |
-
positive = next_turn['text'].strip()
|
| 1383 |
-
pairs.append((query, positive))
|
| 1384 |
-
|
| 1385 |
-
return pairs
|
| 1386 |
-
|
| 1387 |
-
def estimate_total_pairs(self, dialogues: List[dict]) -> int:
|
| 1388 |
-
"""Estimate total number of training pairs including hard negatives."""
|
| 1389 |
-
base_pairs = sum(
|
| 1390 |
-
len([
|
| 1391 |
-
1 for i in range(len(d.get('turns', [])) - 1)
|
| 1392 |
-
if (d['turns'][i].get('speaker') == 'user' and
|
| 1393 |
-
d['turns'][i+1].get('speaker') == 'assistant')
|
| 1394 |
-
])
|
| 1395 |
-
for d in dialogues
|
| 1396 |
-
)
|
| 1397 |
-
# Account for hard negatives
|
| 1398 |
-
return base_pairs * (1 + self.neg_samples)
|
| 1399 |
-
|
| 1400 |
-
def _find_hard_negatives_batch(self, queries: List[str], positives: List[str]) -> List[List[str]]:
|
| 1401 |
-
"""Find hard negatives for a batch of queries with error handling and retries."""
|
| 1402 |
-
retry_count = 0
|
| 1403 |
-
total_responses = len(self.response_pool)
|
| 1404 |
-
|
| 1405 |
-
while retry_count < self.max_retries:
|
| 1406 |
-
try:
|
| 1407 |
-
query_embeddings = np.vstack([
|
| 1408 |
-
self.query_embeddings_cache[q] for q in queries
|
| 1409 |
-
]).astype(np.float32)
|
| 1410 |
-
|
| 1411 |
-
query_embeddings = np.ascontiguousarray(query_embeddings)
|
| 1412 |
-
faiss.normalize_L2(query_embeddings)
|
| 1413 |
-
|
| 1414 |
-
k = 1 # TODO: try higher k for better results
|
| 1415 |
-
#logger.debug(f"Searching with k={k} among {total_responses} responses")
|
| 1416 |
-
|
| 1417 |
-
distances, indices = self.index.search(query_embeddings, k)
|
| 1418 |
-
|
| 1419 |
-
all_negatives = []
|
| 1420 |
-
for query_indices, query, positive in zip(indices, queries, positives):
|
| 1421 |
-
negatives = []
|
| 1422 |
-
positive_strip = positive.strip()
|
| 1423 |
-
seen = {positive_strip}
|
| 1424 |
-
|
| 1425 |
-
for idx in query_indices:
|
| 1426 |
-
if idx >= 0 and idx < total_responses:
|
| 1427 |
-
candidate = self.response_pool[idx].strip()
|
| 1428 |
-
if candidate and candidate not in seen:
|
| 1429 |
-
seen.add(candidate)
|
| 1430 |
-
negatives.append(candidate)
|
| 1431 |
-
if len(negatives) >= self.neg_samples:
|
| 1432 |
-
break
|
| 1433 |
-
|
| 1434 |
-
# Pad with a special empty negative if necessary
|
| 1435 |
-
while len(negatives) < self.neg_samples:
|
| 1436 |
-
negatives.append("<EMPTY_NEGATIVE>") # Use a special token
|
| 1437 |
-
|
| 1438 |
-
all_negatives.append(negatives)
|
| 1439 |
-
|
| 1440 |
-
return all_negatives
|
| 1441 |
-
|
| 1442 |
-
except Exception as e:
|
| 1443 |
-
retry_count += 1
|
| 1444 |
-
logger.warning(f"Hard negative search attempt {retry_count} failed: {e}")
|
| 1445 |
-
if retry_count == self.max_retries:
|
| 1446 |
-
logger.error("Max retries reached for hard negative search")
|
| 1447 |
-
return [["<EMPTY_NEGATIVE>"] * self.neg_samples for _ in queries] # Return empty negatives for all queries
|
| 1448 |
-
gc.collect()
|
| 1449 |
-
if tf.config.list_physical_devices('GPU'):
|
| 1450 |
-
tf.keras.backend.clear_session()
|
| 1451 |
-
|
| 1452 |
-
def _tokenize_negatives_tf(self, negatives):
|
| 1453 |
-
"""Tokenizes negatives using tf.py_function."""
|
| 1454 |
-
# Handle the case where negatives is an empty tensor
|
| 1455 |
-
if tf.size(negatives) == 0:
|
| 1456 |
-
return tf.zeros([0, self.neg_samples, self.max_length], dtype=tf.int32)
|
| 1457 |
-
|
| 1458 |
-
# Convert EagerTensor to a list of strings
|
| 1459 |
-
negatives_list = []
|
| 1460 |
-
for neg_list in negatives.numpy():
|
| 1461 |
-
decoded_negs = [neg.decode("utf-8") for neg in neg_list if neg] # Filter out empty strings
|
| 1462 |
-
negatives_list.append(decoded_negs)
|
| 1463 |
-
|
| 1464 |
-
# Flatten the list of lists
|
| 1465 |
-
flattened_negatives = [neg for sublist in negatives_list for neg in sublist]
|
| 1466 |
-
|
| 1467 |
-
# Tokenize the flattened negatives
|
| 1468 |
-
if flattened_negatives:
|
| 1469 |
-
n_tokens = self.tokenizer(
|
| 1470 |
-
flattened_negatives,
|
| 1471 |
-
padding='max_length',
|
| 1472 |
-
truncation=True,
|
| 1473 |
-
max_length=self.max_length,
|
| 1474 |
-
return_tensors='tf'
|
| 1475 |
-
)
|
| 1476 |
-
# Reshape the tokens
|
| 1477 |
-
n_tokens_reshaped = tf.reshape(n_tokens['input_ids'], [-1, self.neg_samples, self.max_length])
|
| 1478 |
-
return n_tokens_reshaped
|
| 1479 |
-
else:
|
| 1480 |
-
return tf.zeros([0, self.neg_samples, self.max_length], dtype=tf.int32)
|
| 1481 |
-
|
| 1482 |
-
def _compute_embeddings(self, queries: List[str]) -> None:
|
| 1483 |
-
"""Computes and caches embeddings for new queries."""
|
| 1484 |
-
new_queries = [q for q in queries if q not in self.query_embeddings_cache]
|
| 1485 |
-
if not new_queries:
|
| 1486 |
-
return # All queries already cached
|
| 1487 |
-
|
| 1488 |
-
new_embeddings = []
|
| 1489 |
-
for i in range(0, len(new_queries), self.embedding_batch_size):
|
| 1490 |
-
batch_queries = new_queries[i:i + self.embedding_batch_size]
|
| 1491 |
-
|
| 1492 |
-
encoded = self.tokenizer(
|
| 1493 |
-
batch_queries,
|
| 1494 |
-
padding=True,
|
| 1495 |
-
truncation=True,
|
| 1496 |
-
max_length=self.max_length,
|
| 1497 |
-
return_tensors='tf'
|
| 1498 |
-
)
|
| 1499 |
-
|
| 1500 |
-
# Compute embeddings on CPU
|
| 1501 |
-
with tf.device('/CPU:0'):
|
| 1502 |
-
batch_embeddings = self.encoder(encoded['input_ids'], training=False).numpy()
|
| 1503 |
-
|
| 1504 |
-
new_embeddings.extend(batch_embeddings)
|
| 1505 |
-
|
| 1506 |
-
# Update cache with new embeddings
|
| 1507 |
-
for query, emb in zip(new_queries, new_embeddings):
|
| 1508 |
-
self.query_embeddings_cache[query] = emb
|
| 1509 |
-
|
| 1510 |
-
def data_generator(self, dialogues: List[dict]) -> Generator[Tuple[str, str, List[str]], None, None]:
|
| 1511 |
-
"""
|
| 1512 |
-
Generates training examples: (query, positive, hard_negatives).
|
| 1513 |
-
Wrapped the outer loop with tqdm for progress tracking.
|
| 1514 |
-
"""
|
| 1515 |
-
total_dialogues = len(dialogues)
|
| 1516 |
-
logger.debug(f"Total dialogues to process: {total_dialogues}")
|
| 1517 |
-
|
| 1518 |
-
# Initialize tqdm progress bar
|
| 1519 |
-
with tqdm(total=total_dialogues, desc="Processing Dialogues", unit="dialogue") as pbar:
|
| 1520 |
-
for dialogue in dialogues:
|
| 1521 |
-
pairs = self._extract_pairs_from_dialogue(dialogue)
|
| 1522 |
-
for query, positive in pairs:
|
| 1523 |
-
# Ensure embeddings are computed, find hard negatives, etc.
|
| 1524 |
-
self._compute_embeddings([query])
|
| 1525 |
-
hard_negatives = self._find_hard_negatives_batch([query], [positive])[0]
|
| 1526 |
-
yield (query, positive, hard_negatives)
|
| 1527 |
-
pbar.update(1)
|
| 1528 |
-
|
| 1529 |
-
def _prepare_batch(self, queries: tf.Tensor, positives: tf.Tensor, negatives: tf.Tensor) -> Tuple[tf.Tensor, tf.Tensor]:
|
| 1530 |
-
"""Prepares a batch of data for training."""
|
| 1531 |
-
|
| 1532 |
-
# Convert EagerTensors to lists of strings
|
| 1533 |
-
queries_list = [query.decode("utf-8") for query in queries.numpy()]
|
| 1534 |
-
positives_list = [pos.decode("utf-8") for pos in positives.numpy()]
|
| 1535 |
-
|
| 1536 |
-
# Tokenize queries and positives
|
| 1537 |
-
q_tokens = self.tokenizer(queries_list, padding='max_length', truncation=True, max_length=self.max_length, return_tensors='tf')
|
| 1538 |
-
p_tokens = self.tokenizer(positives_list, padding='max_length', truncation=True, max_length=self.max_length, return_tensors='tf')
|
| 1539 |
-
|
| 1540 |
-
# Decode negatives and ensure they are lists of strings
|
| 1541 |
-
negatives_list = []
|
| 1542 |
-
for neg_list in negatives.numpy():
|
| 1543 |
-
decoded_negs = [neg.decode("utf-8") for neg in neg_list if neg] # Filter out empty strings
|
| 1544 |
-
negatives_list.append(decoded_negs)
|
| 1545 |
-
|
| 1546 |
-
# Flatten negatives for tokenization if there are any valid negatives
|
| 1547 |
-
flattened_negatives = [neg for sublist in negatives_list for neg in sublist if neg]
|
| 1548 |
-
|
| 1549 |
-
# Tokenize negatives if there are any
|
| 1550 |
-
n_tokens_reshaped = None
|
| 1551 |
-
if flattened_negatives:
|
| 1552 |
-
n_tokens = self.tokenizer(flattened_negatives, padding='max_length', truncation=True, max_length=self.max_length, return_tensors='tf')
|
| 1553 |
-
|
| 1554 |
-
# Reshape n_tokens to match the expected shape based on the number of negatives per query
|
| 1555 |
-
# This part may need adjustment if the number of negatives varies per query
|
| 1556 |
-
n_tokens_reshaped = tf.reshape(n_tokens['input_ids'], [len(queries_list), -1, self.max_length])
|
| 1557 |
-
else:
|
| 1558 |
-
# Create a placeholder tensor for the case where there are no negatives
|
| 1559 |
-
n_tokens_reshaped = tf.zeros([len(queries_list), 0, self.max_length], dtype=tf.int32)
|
| 1560 |
-
|
| 1561 |
-
# Ensure n_tokens_reshaped has a consistent shape even when there are no negatives
|
| 1562 |
-
# Adjust shape to [batch_size, num_neg_samples, max_length]
|
| 1563 |
-
if n_tokens_reshaped.shape[1] != self.neg_samples:
|
| 1564 |
-
# Pad or truncate the second dimension to match neg_samples
|
| 1565 |
-
padding = tf.zeros([len(queries_list), tf.maximum(0, self.neg_samples - n_tokens_reshaped.shape[1]), self.max_length], dtype=tf.int32)
|
| 1566 |
-
n_tokens_reshaped = tf.concat([n_tokens_reshaped, padding], axis=1)
|
| 1567 |
-
n_tokens_reshaped = n_tokens_reshaped[:, :self.neg_samples, :]
|
| 1568 |
-
|
| 1569 |
-
# Concatenate the positive and negative examples along the 'neg_samples' dimension
|
| 1570 |
-
combined_p_n_tokens = tf.concat([tf.expand_dims(p_tokens['input_ids'], axis=1), n_tokens_reshaped], axis=1)
|
| 1571 |
-
|
| 1572 |
-
return q_tokens['input_ids'], combined_p_n_tokens
|
| 1573 |
-
|
| 1574 |
-
def get_tf_dataset(self, dialogues: List[dict], batch_size: int) -> tf.data.Dataset:
|
| 1575 |
-
"""
|
| 1576 |
-
Creates a tf.data.Dataset for streaming training that yields
|
| 1577 |
-
(input_ids_query, input_ids_positive, input_ids_negatives).
|
| 1578 |
-
"""
|
| 1579 |
-
# 1) Start with a generator dataset
|
| 1580 |
-
dataset = tf.data.Dataset.from_generator(
|
| 1581 |
-
lambda: self.data_generator(dialogues),
|
| 1582 |
-
output_signature=(
|
| 1583 |
-
tf.TensorSpec(shape=(), dtype=tf.string), # Query (single string)
|
| 1584 |
-
tf.TensorSpec(shape=(), dtype=tf.string), # Positive (single string)
|
| 1585 |
-
tf.TensorSpec(shape=(None,), dtype=tf.string) # Hard Negatives (list of strings)
|
| 1586 |
-
)
|
| 1587 |
-
)
|
| 1588 |
-
|
| 1589 |
-
# 2) Batch the raw strings
|
| 1590 |
-
dataset = dataset.batch(batch_size)
|
| 1591 |
-
|
| 1592 |
-
# 3) Now map them through a tokenize step (via py_function)
|
| 1593 |
-
dataset = dataset.map(
|
| 1594 |
-
lambda q, p, n: self._tokenize_triple(q, p, n),
|
| 1595 |
-
num_parallel_calls=1 #tf.data.AUTOTUNE
|
| 1596 |
-
)
|
| 1597 |
-
|
| 1598 |
-
dataset = dataset.prefetch(tf.data.AUTOTUNE)
|
| 1599 |
-
return dataset
|
| 1600 |
-
|
| 1601 |
-
def _tokenize_triple(
|
| 1602 |
-
self,
|
| 1603 |
-
q: tf.Tensor,
|
| 1604 |
-
p: tf.Tensor,
|
| 1605 |
-
n: tf.Tensor
|
| 1606 |
-
) -> Tuple[tf.Tensor, tf.Tensor, tf.Tensor]:
|
| 1607 |
-
"""
|
| 1608 |
-
Wraps a Python function via tf.py_function to convert tf.Tensors of strings
|
| 1609 |
-
-> Python lists of strings -> HF tokenizer -> Tensors of IDs.
|
| 1610 |
-
|
| 1611 |
-
q is shape [batch_size], p is shape [batch_size],
|
| 1612 |
-
n is shape [batch_size, neg_samples] (i.e., each row is a list of negatives).
|
| 1613 |
-
"""
|
| 1614 |
-
# Use tf.py_function with limited parallelism
|
| 1615 |
-
q_ids, p_ids, n_ids = tf.py_function(
|
| 1616 |
-
func=self._tokenize_triple_py,
|
| 1617 |
-
inp=[q, p, n, tf.constant(self.max_length), tf.constant(self.neg_samples)],
|
| 1618 |
-
Tout=[tf.int32, tf.int32, tf.int32]
|
| 1619 |
-
)
|
| 1620 |
-
|
| 1621 |
-
# Manually set shape information
|
| 1622 |
-
q_ids.set_shape([None, self.max_length]) # [batch_size, max_length]
|
| 1623 |
-
p_ids.set_shape([None, self.max_length]) # [batch_size, max_length]
|
| 1624 |
-
n_ids.set_shape([None, self.neg_samples, self.max_length]) # [batch_size, neg_samples, max_length]
|
| 1625 |
-
|
| 1626 |
-
return q_ids, p_ids, n_ids
|
| 1627 |
-
# def _tokenize_triple(
|
| 1628 |
-
# self,
|
| 1629 |
-
# q: tf.Tensor,
|
| 1630 |
-
# p: tf.Tensor,
|
| 1631 |
-
# n: tf.Tensor
|
| 1632 |
-
# ) -> Tuple[tf.Tensor, tf.Tensor, tf.Tensor]:
|
| 1633 |
-
# """
|
| 1634 |
-
# Wraps a Python function via tf.py_function to convert tf.Tensors of strings
|
| 1635 |
-
# -> Python lists of strings -> HF tokenizer -> Tensors of IDs.
|
| 1636 |
-
|
| 1637 |
-
# q is shape [batch_size], p is shape [batch_size],
|
| 1638 |
-
# n is shape [batch_size, None] (i.e. each row is a variable number of negatives).
|
| 1639 |
-
# """
|
| 1640 |
-
# # Use tf.py_function
|
| 1641 |
-
# # We pass in self.max_length as well, so we can do it in one shot.
|
| 1642 |
-
# q_ids, p_ids, n_ids = tf.py_function(
|
| 1643 |
-
# func=self._tokenize_triple_py,
|
| 1644 |
-
# inp=[q, p, n, tf.constant(self.max_length), tf.constant(self.neg_samples)],
|
| 1645 |
-
# Tout=[tf.int32, tf.int32, tf.int32]
|
| 1646 |
-
# )
|
| 1647 |
-
|
| 1648 |
-
# # We must manually set shape information so that TF data pipeline knows the dimensions
|
| 1649 |
-
# q_ids.set_shape([None, self.max_length]) # [batch_size, max_length]
|
| 1650 |
-
# p_ids.set_shape([None, self.max_length]) # [batch_size, max_length]
|
| 1651 |
-
# n_ids.set_shape([None, self.neg_samples, self.max_length])
|
| 1652 |
-
# # The negative dimension is set to `self.neg_samples` for consistency.
|
| 1653 |
-
|
| 1654 |
-
# return q_ids, p_ids, n_ids
|
| 1655 |
-
|
| 1656 |
-
def _tokenize_triple_py(
|
| 1657 |
-
self,
|
| 1658 |
-
q: tf.Tensor,
|
| 1659 |
-
p: tf.Tensor,
|
| 1660 |
-
n: tf.Tensor,
|
| 1661 |
-
max_len: tf.Tensor,
|
| 1662 |
-
neg_samples: tf.Tensor
|
| 1663 |
-
) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
|
| 1664 |
-
"""
|
| 1665 |
-
Python function that:
|
| 1666 |
-
- Decodes each tf.string Tensor to a Python list of strings
|
| 1667 |
-
- Calls the HF tokenizer
|
| 1668 |
-
- Reshapes negatives
|
| 1669 |
-
- Returns np.array of int32s for (q_ids, p_ids, n_ids).
|
| 1670 |
-
|
| 1671 |
-
q: shape [batch_size], p: shape [batch_size]
|
| 1672 |
-
n: shape [batch_size, neg_samples]
|
| 1673 |
-
max_len: scalar int
|
| 1674 |
-
neg_samples: scalar int
|
| 1675 |
-
"""
|
| 1676 |
-
max_len = int(max_len.numpy()) # Convert to Python int
|
| 1677 |
-
neg_samples = int(neg_samples.numpy())
|
| 1678 |
-
|
| 1679 |
-
# 1) Convert Tensors -> Python lists of strings
|
| 1680 |
-
q_list = [q_i.decode("utf-8") for q_i in q.numpy()] # shape [batch_size]
|
| 1681 |
-
p_list = [p_i.decode("utf-8") for p_i in p.numpy()] # shape [batch_size]
|
| 1682 |
-
|
| 1683 |
-
# shape [batch_size, neg_samples], decode each row
|
| 1684 |
-
n_list = []
|
| 1685 |
-
for row in n.numpy():
|
| 1686 |
-
# row is shape [neg_samples], each is a tf.string
|
| 1687 |
-
decoded = [neg.decode("utf-8") for neg in row]
|
| 1688 |
-
n_list.append(decoded)
|
| 1689 |
-
|
| 1690 |
-
# 2) Tokenize queries & positives
|
| 1691 |
-
q_enc = self.tokenizer(
|
| 1692 |
-
q_list,
|
| 1693 |
-
padding="max_length",
|
| 1694 |
-
truncation=True,
|
| 1695 |
-
max_length=max_len,
|
| 1696 |
-
return_tensors="np"
|
| 1697 |
-
)
|
| 1698 |
-
p_enc = self.tokenizer(
|
| 1699 |
-
p_list,
|
| 1700 |
-
padding="max_length",
|
| 1701 |
-
truncation=True,
|
| 1702 |
-
max_length=max_len,
|
| 1703 |
-
return_tensors="np"
|
| 1704 |
-
)
|
| 1705 |
-
|
| 1706 |
-
# 3) Tokenize negatives
|
| 1707 |
-
# Flatten [batch_size, neg_samples] -> single list
|
| 1708 |
-
flattened_negatives = [neg for row in n_list for neg in row]
|
| 1709 |
-
if len(flattened_negatives) == 0:
|
| 1710 |
-
# No negatives at all: return a zero array
|
| 1711 |
-
n_ids = np.zeros((len(q_list), neg_samples, max_len), dtype=np.int32)
|
| 1712 |
-
else:
|
| 1713 |
-
n_enc = self.tokenizer(
|
| 1714 |
-
flattened_negatives,
|
| 1715 |
-
padding="max_length",
|
| 1716 |
-
truncation=True,
|
| 1717 |
-
max_length=max_len,
|
| 1718 |
-
return_tensors="np"
|
| 1719 |
-
)
|
| 1720 |
-
# shape [batch_size * neg_samples, max_len]
|
| 1721 |
-
n_input_ids = n_enc["input_ids"]
|
| 1722 |
-
|
| 1723 |
-
# We want to reshape to [batch_size, neg_samples, max_len]
|
| 1724 |
-
# Handle cases where there might be fewer negatives
|
| 1725 |
-
batch_size = len(q_list)
|
| 1726 |
-
n_ids_list = []
|
| 1727 |
-
for i in range(batch_size):
|
| 1728 |
-
start_idx = i * neg_samples
|
| 1729 |
-
end_idx = start_idx + neg_samples
|
| 1730 |
-
row_negs = n_input_ids[start_idx:end_idx]
|
| 1731 |
-
|
| 1732 |
-
# If fewer negatives, pad with zeros
|
| 1733 |
-
if row_negs.shape[0] < neg_samples:
|
| 1734 |
-
deficit = neg_samples - row_negs.shape[0]
|
| 1735 |
-
pad_arr = np.zeros((deficit, max_len), dtype=np.int32)
|
| 1736 |
-
row_negs = np.concatenate([row_negs, pad_arr], axis=0)
|
| 1737 |
-
|
| 1738 |
-
n_ids_list.append(row_negs)
|
| 1739 |
-
|
| 1740 |
-
# stack them -> shape [batch_size, neg_samples, max_len]
|
| 1741 |
-
n_ids = np.stack(n_ids_list, axis=0)
|
| 1742 |
-
|
| 1743 |
-
# 4) Return as np.int32 arrays
|
| 1744 |
-
q_ids = q_enc["input_ids"].astype(np.int32) # shape [batch_size, max_len]
|
| 1745 |
-
p_ids = p_enc["input_ids"].astype(np.int32) # shape [batch_size, max_len]
|
| 1746 |
-
n_ids = n_ids.astype(np.int32) # shape [batch_size, neg_samples, max_len]
|
| 1747 |
-
|
| 1748 |
-
return q_ids, p_ids, n_ids
|
| 1749 |
-
# def _tokenize_triple_py(
|
| 1750 |
-
# self,
|
| 1751 |
-
# q: tf.Tensor,
|
| 1752 |
-
# p: tf.Tensor,
|
| 1753 |
-
# n: tf.Tensor,
|
| 1754 |
-
# max_len: tf.Tensor,
|
| 1755 |
-
# neg_samples: tf.Tensor
|
| 1756 |
-
# ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
|
| 1757 |
-
# """
|
| 1758 |
-
# Python function that:
|
| 1759 |
-
# - Decodes each tf.string Tensor to a Python list of strings
|
| 1760 |
-
# - Calls the HF tokenizer
|
| 1761 |
-
# - Reshapes negatives
|
| 1762 |
-
# - Returns np.array of int32s for (q_ids, p_ids, n_ids).
|
| 1763 |
-
|
| 1764 |
-
# q: shape [batch_size], p: shape [batch_size]
|
| 1765 |
-
# n: shape [batch_size, None]
|
| 1766 |
-
# max_len: scalar int
|
| 1767 |
-
# neg_samples: scalar int
|
| 1768 |
-
# """
|
| 1769 |
-
# max_len = int(max_len.numpy()) # convert to python int
|
| 1770 |
-
# neg_samples = int(neg_samples.numpy())
|
| 1771 |
-
|
| 1772 |
-
# # 1) Convert Tensors -> Python lists of strings
|
| 1773 |
-
# q_list = [q_i.decode("utf-8") for q_i in q.numpy()] # shape [batch_size]
|
| 1774 |
-
# p_list = [p_i.decode("utf-8") for p_i in p.numpy()] # shape [batch_size]
|
| 1775 |
-
|
| 1776 |
-
# # shape [batch_size, variable_negatives], decode each row
|
| 1777 |
-
# n_list = []
|
| 1778 |
-
# for row in n.numpy():
|
| 1779 |
-
# # row is shape [N], each is a tf.string
|
| 1780 |
-
# decoded = [neg.decode("utf-8") for neg in row]
|
| 1781 |
-
# n_list.append(decoded)
|
| 1782 |
-
|
| 1783 |
-
# # 2) Tokenize queries & positives
|
| 1784 |
-
# q_enc = self.tokenizer(
|
| 1785 |
-
# q_list,
|
| 1786 |
-
# padding="max_length",
|
| 1787 |
-
# truncation=True,
|
| 1788 |
-
# max_length=max_len,
|
| 1789 |
-
# return_tensors="np" # you can do return_tensors="tf", but "np" is often simpler here
|
| 1790 |
-
# )
|
| 1791 |
-
# p_enc = self.tokenizer(
|
| 1792 |
-
# p_list,
|
| 1793 |
-
# padding="max_length",
|
| 1794 |
-
# truncation=True,
|
| 1795 |
-
# max_length=max_len,
|
| 1796 |
-
# return_tensors="np"
|
| 1797 |
-
# )
|
| 1798 |
-
|
| 1799 |
-
# # 3) Tokenize negatives
|
| 1800 |
-
# # Flatten [batch_size, variable_negatives] -> single list
|
| 1801 |
-
# flattened_negatives = [neg for row in n_list for neg in row]
|
| 1802 |
-
# if len(flattened_negatives) == 0:
|
| 1803 |
-
# # No negatives at all: return a zero array
|
| 1804 |
-
# n_ids = np.zeros((len(q_list), neg_samples, max_len), dtype=np.int32)
|
| 1805 |
-
# else:
|
| 1806 |
-
# n_enc = self.tokenizer(
|
| 1807 |
-
# flattened_negatives,
|
| 1808 |
-
# padding="max_length",
|
| 1809 |
-
# truncation=True,
|
| 1810 |
-
# max_length=max_len,
|
| 1811 |
-
# return_tensors="np"
|
| 1812 |
-
# )
|
| 1813 |
-
# # shape [batch_size * total_negatives, max_len]
|
| 1814 |
-
# n_input_ids = n_enc["input_ids"]
|
| 1815 |
-
|
| 1816 |
-
# # We want to reshape to [batch_size, neg_samples, max_len].
|
| 1817 |
-
# # If each row truly has exactly `neg_samples` (or fewer), we can do:
|
| 1818 |
-
# # n_input_ids = n_input_ids.reshape(len(q_list), neg_samples, max_len)
|
| 1819 |
-
# # But if the rows have variable # of negatives, we must clamp or pad.
|
| 1820 |
-
# # For simplicity, let's just "take first neg_samples" per row
|
| 1821 |
-
# # and pad if fewer.
|
| 1822 |
-
|
| 1823 |
-
# # We'll do it row by row:
|
| 1824 |
-
# batch_size = len(q_list)
|
| 1825 |
-
# row_offsets = 0
|
| 1826 |
-
# n_ids_list = []
|
| 1827 |
-
# for row_idx in range(batch_size):
|
| 1828 |
-
# row_negs = n_list[row_idx]
|
| 1829 |
-
# row_count = len(row_negs)
|
| 1830 |
-
|
| 1831 |
-
# # slice from the flattened array
|
| 1832 |
-
# row_slice = n_input_ids[row_offsets:row_offsets + row_count]
|
| 1833 |
-
# row_offsets += row_count
|
| 1834 |
-
|
| 1835 |
-
# # Now pick out up to neg_samples
|
| 1836 |
-
# row_slice = row_slice[:neg_samples]
|
| 1837 |
-
|
| 1838 |
-
# # If fewer than neg_samples, pad
|
| 1839 |
-
# if row_slice.shape[0] < neg_samples:
|
| 1840 |
-
# deficit = neg_samples - row_slice.shape[0]
|
| 1841 |
-
# pad_arr = np.zeros((deficit, max_len), dtype=np.int32)
|
| 1842 |
-
# row_slice = np.concatenate([row_slice, pad_arr], axis=0)
|
| 1843 |
-
|
| 1844 |
-
# # row_slice is now shape [neg_samples, max_len]
|
| 1845 |
-
# n_ids_list.append(row_slice)
|
| 1846 |
-
|
| 1847 |
-
# # stack them -> shape [batch_size, neg_samples, max_len]
|
| 1848 |
-
# n_ids = np.stack(n_ids_list, axis=0)
|
| 1849 |
-
|
| 1850 |
-
# # 4) Return as np.int32 arrays (tokenizer should already return int32,
|
| 1851 |
-
# # but we can cast to be sure)
|
| 1852 |
-
# q_ids = q_enc["input_ids"].astype(np.int32) # shape [batch_size, max_len]
|
| 1853 |
-
# p_ids = p_enc["input_ids"].astype(np.int32) # shape [batch_size, max_len]
|
| 1854 |
-
# n_ids = n_ids.astype(np.int32) # shape [batch_size, neg_samples, max_len]
|
| 1855 |
-
|
| 1856 |
-
# return q_ids, p_ids, n_ids
|
| 1857 |
-
|
|
|
|
| 2 |
from transformers import TFAutoModel, AutoTokenizer
|
| 3 |
import tensorflow as tf
|
| 4 |
import numpy as np
|
| 5 |
+
from typing import List, Tuple, Dict, Optional, Union, Any
|
| 6 |
import math
|
| 7 |
from dataclasses import dataclass
|
| 8 |
import json
|
|
|
|
| 10 |
import datetime
|
| 11 |
import faiss
|
| 12 |
import gc
|
| 13 |
+
from tf_data_pipeline import TFDataPipeline
|
| 14 |
from response_quality_checker import ResponseQualityChecker
|
| 15 |
from cross_encoder_reranker import CrossEncoderReranker
|
| 16 |
from conversation_summarizer import DeviceAwareModel, Summarizer
|
|
|
|
| 25 |
@dataclass
|
| 26 |
class ChatbotConfig:
|
| 27 |
"""Configuration for the RetrievalChatbot."""
|
|
|
|
| 28 |
max_context_token_limit: int = 512
|
| 29 |
embedding_dim: int = 768
|
| 30 |
encoder_units: int = 256
|
| 31 |
num_attention_heads: int = 8
|
| 32 |
dropout_rate: float = 0.2
|
| 33 |
l2_reg_weight: float = 0.001
|
|
|
|
| 34 |
learning_rate: float = 0.001
|
| 35 |
min_text_length: int = 3
|
| 36 |
max_context_turns: int = 5
|
|
|
|
| 38 |
pretrained_model: str = 'distilbert-base-uncased'
|
| 39 |
dtype: str = 'float32'
|
| 40 |
freeze_embeddings: bool = False
|
| 41 |
+
embedding_batch_size: int = 64
|
| 42 |
+
search_batch_size: int = 64
|
| 43 |
+
max_batch_size: int = 64
|
| 44 |
+
neg_samples: int = 3
|
| 45 |
+
max_retries: int = 3
|
| 46 |
|
| 47 |
+
def to_dict(self) -> Dict:
|
| 48 |
"""Convert config to dictionary."""
|
| 49 |
+
return {k: (str(v) if isinstance(v, Path) else v)
|
| 50 |
for k, v in self.__dict__.items()}
|
| 51 |
|
| 52 |
@classmethod
|
| 53 |
+
def from_dict(cls, config_dict: Dict) -> 'ChatbotConfig':
|
| 54 |
"""Create config from dictionary."""
|
| 55 |
return cls(**{k: v for k, v in config_dict.items()
|
| 56 |
if k in cls.__dataclass_fields__})
|
|
|
|
| 61 |
self,
|
| 62 |
config: ChatbotConfig,
|
| 63 |
name: str = "encoder",
|
|
|
|
| 64 |
**kwargs
|
| 65 |
):
|
| 66 |
super().__init__(name=name, **kwargs)
|
| 67 |
self.config = config
|
|
|
|
| 68 |
|
| 69 |
# Load pretrained model
|
| 70 |
self.pretrained = TFAutoModel.from_pretrained(config.pretrained_model)
|
| 71 |
|
| 72 |
+
# Freeze layers based on config
|
| 73 |
+
self._freeze_layers()
|
| 74 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 75 |
# Pooling layer (Global Average Pooling)
|
| 76 |
self.pooler = tf.keras.layers.GlobalAveragePooling1D()
|
| 77 |
|
|
|
|
| 85 |
# Dropout and normalization
|
| 86 |
self.dropout = tf.keras.layers.Dropout(config.dropout_rate)
|
| 87 |
self.normalize = tf.keras.layers.Lambda(
|
| 88 |
+
lambda x: tf.nn.l2_normalize(x, axis=1),
|
| 89 |
+
name="l2_normalize"
|
| 90 |
)
|
| 91 |
|
| 92 |
+
def _freeze_layers(self):
|
| 93 |
+
"""Freeze layers of the pretrained model based on configuration."""
|
| 94 |
+
if self.config.freeze_embeddings:
|
| 95 |
+
self.pretrained.trainable = False
|
| 96 |
+
logger.info("All pretrained layers frozen.")
|
| 97 |
+
else:
|
| 98 |
+
# Freeze only the first 'n' transformer layers
|
| 99 |
+
for i, layer in enumerate(self.pretrained.layers):
|
| 100 |
+
if isinstance(layer, tf.keras.layers.Layer):
|
| 101 |
+
if hasattr(layer, 'trainable'):
|
| 102 |
+
# Freeze the first transformer block
|
| 103 |
+
if i < 1:
|
| 104 |
+
layer.trainable = False
|
| 105 |
+
logger.info(f"Layer {i} frozen.")
|
| 106 |
+
else:
|
| 107 |
+
layer.trainable = True
|
| 108 |
+
|
| 109 |
def call(self, inputs: tf.Tensor, training: bool = False) -> tf.Tensor:
|
| 110 |
"""Forward pass."""
|
| 111 |
# Get pretrained embeddings
|
|
|
|
| 125 |
config = super().get_config()
|
| 126 |
config.update({
|
| 127 |
"config": self.config.to_dict(),
|
|
|
|
| 128 |
"name": self.name
|
| 129 |
})
|
| 130 |
return config
|
| 131 |
|
| 132 |
class RetrievalChatbot(DeviceAwareModel):
|
| 133 |
"""Retrieval-based chatbot using pretrained embeddings and FAISS for similarity search."""
|
| 134 |
+
def __init__(
|
| 135 |
+
self,
|
| 136 |
+
config: ChatbotConfig,
|
| 137 |
+
dialogues: List[dict] = [],
|
| 138 |
+
device: str = None,
|
| 139 |
+
strategy=None,
|
| 140 |
+
reranker: Optional[CrossEncoderReranker] = None,
|
| 141 |
+
summarizer: Optional[Summarizer] = None
|
| 142 |
+
):
|
| 143 |
+
super().__init__()
|
| 144 |
self.config = config
|
| 145 |
self.strategy = strategy
|
| 146 |
+
self.device = device or self._setup_default_device()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 147 |
|
| 148 |
+
# Initialize reranker, summarizer, tokenizer, and memory monitor
|
| 149 |
+
self.reranker = reranker or self._initialize_reranker()
|
| 150 |
+
self.summarizer = summarizer or self._initialize_summarizer()
|
| 151 |
+
self.tokenizer = self._initialize_tokenizer()
|
| 152 |
self.memory_monitor = GPUMemoryMonitor()
|
| 153 |
+
|
| 154 |
+
# Initialize models
|
| 155 |
self.min_batch_size = 8
|
| 156 |
self.max_batch_size = 128
|
| 157 |
self.current_batch_size = 32
|
|
|
|
| 166 |
"train_metrics": {},
|
| 167 |
"val_metrics": {}
|
| 168 |
}
|
| 169 |
+
|
| 170 |
+
def _setup_default_device(self) -> str:
|
| 171 |
+
"""Set up default device if none is provided."""
|
| 172 |
+
if tf.config.list_physical_devices('GPU'):
|
| 173 |
+
return 'GPU'
|
| 174 |
+
else:
|
| 175 |
+
return 'CPU'
|
| 176 |
+
|
| 177 |
+
def _initialize_reranker(self) -> CrossEncoderReranker:
|
| 178 |
+
"""Initialize the CrossEncoderReranker."""
|
| 179 |
+
logger.info("Initializing default CrossEncoderReranker...")
|
| 180 |
+
return CrossEncoderReranker(model_name="cross-encoder/ms-marco-MiniLM-L-12-v2")
|
| 181 |
+
|
| 182 |
+
def _initialize_summarizer(self) -> Summarizer:
|
| 183 |
+
"""Initialize the Summarizer."""
|
| 184 |
+
logger.info("Initializing default Summarizer...")
|
| 185 |
+
return Summarizer(device=self.device)
|
| 186 |
+
|
| 187 |
+
def _initialize_tokenizer(self) -> AutoTokenizer:
|
| 188 |
+
"""Initialize the tokenizer and add special tokens."""
|
| 189 |
+
logger.info("Initializing tokenizer and adding special tokens...")
|
| 190 |
+
tokenizer = AutoTokenizer.from_pretrained(self.config.pretrained_model)
|
| 191 |
+
special_tokens = {
|
| 192 |
+
"user": "<USER>",
|
| 193 |
+
"assistant": "<ASSISTANT>",
|
| 194 |
+
"context": "<CONTEXT>",
|
| 195 |
+
"sep": "<SEP>"
|
| 196 |
+
}
|
| 197 |
+
tokenizer.add_special_tokens(
|
| 198 |
+
{'additional_special_tokens': list(special_tokens.values())}
|
| 199 |
+
)
|
| 200 |
+
return tokenizer
|
| 201 |
+
|
| 202 |
+
def _collect_responses(self, dialogues: List[dict]) -> Tuple[List[str], List[str]]:
|
| 203 |
+
"""
|
| 204 |
+
Collect unique responses from dialogues.
|
| 205 |
+
Returns:
|
| 206 |
+
response_pool: List of all possible responses.
|
| 207 |
+
unique_responses: List of unique responses.
|
| 208 |
+
"""
|
| 209 |
+
logger.info("Collecting unique responses from dialogues...")
|
| 210 |
+
responses = set()
|
| 211 |
+
for dialogue in dialogues:
|
| 212 |
+
turns = dialogue.get('turns', [])
|
| 213 |
+
for turn in turns:
|
| 214 |
+
if turn.get('speaker') == 'assistant' and 'text' in turn:
|
| 215 |
+
response = turn['text'].strip()
|
| 216 |
+
if len(response) >= self.config.min_text_length:
|
| 217 |
+
responses.add(response)
|
| 218 |
+
response_pool = list(responses)
|
| 219 |
+
unique_responses = list(responses) # Assuming uniqueness
|
| 220 |
+
logger.info(f"Collected {len(response_pool)} unique responses.")
|
| 221 |
+
return response_pool, unique_responses
|
| 222 |
+
|
| 223 |
def build_models(self):
|
| 224 |
+
"""Initialize the shared encoder and FAISS index."""
|
| 225 |
logger.info("Building encoder model...")
|
| 226 |
tf.keras.backend.clear_session()
|
| 227 |
|
|
|
|
| 229 |
self.encoder = EncoderModel(
|
| 230 |
self.config,
|
| 231 |
name="shared_encoder",
|
| 232 |
+
shared_weights=True # If weight sharing is intended
|
| 233 |
)
|
| 234 |
|
| 235 |
# Resize token embeddings after adding special tokens
|
|
|
|
| 237 |
self.encoder.pretrained.resize_token_embeddings(new_vocab_size)
|
| 238 |
logger.info(f"Token embeddings resized to: {new_vocab_size}")
|
| 239 |
|
| 240 |
+
# Initialize FAISS index
|
| 241 |
self._initialize_faiss()
|
|
|
|
|
|
|
| 242 |
|
| 243 |
+
# Compute and index embeddings
|
| 244 |
+
self._compute_and_index_embeddings()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 245 |
|
| 246 |
+
# Retrieve embedding dimension from encoder
|
| 247 |
+
embedding_dim = self.config.embedding_dim
|
| 248 |
vocab_size = len(self.tokenizer)
|
| 249 |
|
| 250 |
logger.info(f"Encoder Embedding Dimension: {embedding_dim}")
|
|
|
|
| 254 |
else:
|
| 255 |
logger.error("Vocabulary size is less than embedding dimension.")
|
| 256 |
raise ValueError("Vocabulary size is less than embedding dimension.")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 257 |
|
| 258 |
def _adjust_batch_size(self) -> None:
|
| 259 |
"""Dynamically adjust batch size based on GPU memory usage."""
|
|
|
|
| 302 |
logger.warning(f"Using CPU due to GPU initialization error: {e}")
|
| 303 |
|
| 304 |
# TODO: figure out buf with faiss-gpu
|
| 305 |
+
# TODO: consider IndexIVFFlat in the future (speed).
|
| 306 |
try:
|
| 307 |
# Create appropriate index based on dataset size
|
| 308 |
if len(self.unique_responses) < 1000:
|
|
|
|
| 875 |
logger.info(f"Models and tokenizer loaded from {load_dir}.")
|
| 876 |
return chatbot
|
| 877 |
|
| 878 |
+
# @staticmethod
|
| 879 |
+
# def load_training_data(data_path: Union[str, Path], debug_samples: Optional[int] = None) -> List[dict]:
|
| 880 |
+
# """
|
| 881 |
+
# Load training data from a JSON file.
|
| 882 |
|
| 883 |
+
# Args:
|
| 884 |
+
# data_path (Union[str, Path]): Path to the JSON file containing dialogues.
|
| 885 |
+
# debug_samples (Optional[int]): Number of samples to load for debugging.
|
| 886 |
|
| 887 |
+
# Returns:
|
| 888 |
+
# List[dict]: List of dialogue dictionaries.
|
| 889 |
+
# """
|
| 890 |
+
# logger.info(f"Loading training data from {data_path}...")
|
| 891 |
+
# data_path = Path(data_path)
|
| 892 |
+
# if not data_path.exists():
|
| 893 |
+
# logger.error(f"Data file {data_path} does not exist.")
|
| 894 |
+
# return []
|
| 895 |
|
| 896 |
+
# with open(data_path, 'r', encoding='utf-8') as f:
|
| 897 |
+
# dialogues = json.load(f)
|
| 898 |
|
| 899 |
+
# if debug_samples is not None:
|
| 900 |
+
# dialogues = dialogues[:debug_samples]
|
| 901 |
+
# logger.info(f"Debug mode: Limited to {debug_samples} dialogues")
|
| 902 |
|
| 903 |
+
# logger.info(f"Loaded {len(dialogues)} dialogues.")
|
| 904 |
+
# return dialogues
|
| 905 |
|
| 906 |
def train_streaming(
|
| 907 |
self,
|
|
|
|
| 1351 |
|
| 1352 |
conversation_parts.append(f"{self.special_tokens['user']} {query}")
|
| 1353 |
return "\n".join(conversation_parts)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
requirements.txt
CHANGED
|
@@ -1,4 +1,5 @@
|
|
| 1 |
faiss-cpu>=1.7.0 # Required for Facebook AI Similarity Search
|
|
|
|
| 2 |
ipython>=8.0.0 # For interactive Python
|
| 3 |
loguru>=0.7.0 # Enhanced logging (optional but recommended)
|
| 4 |
matplotlib>=3.5.0 # For validation plotting
|
|
|
|
| 1 |
faiss-cpu>=1.7.0 # Required for Facebook AI Similarity Search
|
| 2 |
+
h5py>=3.1.0 # For saving and loading models
|
| 3 |
ipython>=8.0.0 # For interactive Python
|
| 4 |
loguru>=0.7.0 # Enhanced logging (optional but recommended)
|
| 5 |
matplotlib>=3.5.0 # For validation plotting
|
run_data_preparer.py
ADDED
|
@@ -0,0 +1,182 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import sys
|
| 3 |
+
import faiss
|
| 4 |
+
import pickle
|
| 5 |
+
from transformers import AutoTokenizer
|
| 6 |
+
from tqdm.auto import tqdm
|
| 7 |
+
from chatbot_model import ChatbotConfig, EncoderModel
|
| 8 |
+
from environment_setup import EnvironmentSetup
|
| 9 |
+
from tf_data_pipeline import TFDataPipeline
|
| 10 |
+
from logger_config import config_logger
|
| 11 |
+
|
| 12 |
+
logger = config_logger(__name__)
|
| 13 |
+
|
| 14 |
+
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
| 15 |
+
|
| 16 |
+
def cleanup_test_indices(faiss_dir, test_prefix='test_'):
|
| 17 |
+
test_files = [f for f in os.listdir(faiss_dir) if f.startswith(test_prefix)]
|
| 18 |
+
for file in test_files:
|
| 19 |
+
file_path = os.path.join(faiss_dir, file)
|
| 20 |
+
os.remove(file_path)
|
| 21 |
+
logger.info(f"Removed test FAISS index file: {file_path}")
|
| 22 |
+
|
| 23 |
+
def main():
|
| 24 |
+
# Constants
|
| 25 |
+
MODELS_DIR = 'models'
|
| 26 |
+
PROCESSED_DATA_DIR = 'processed_outputs'
|
| 27 |
+
CACHE_DIR = 'cache'
|
| 28 |
+
TOKENIZER_DIR = os.path.join(MODELS_DIR, 'tokenizer')
|
| 29 |
+
FAISS_INDICES_DIR = os.path.join(MODELS_DIR, 'faiss_indices')
|
| 30 |
+
TF_RECORD_DIR = 'training_data'
|
| 31 |
+
FAISS_INDEX_PRODUCTION_PATH = os.path.join(FAISS_INDICES_DIR, 'faiss_index_production.index')
|
| 32 |
+
FAISS_INDEX_TEST_PATH = os.path.join(FAISS_INDICES_DIR, 'faiss_index_test.index')
|
| 33 |
+
ENVIRONMENT = 'test' # or 'production'
|
| 34 |
+
if ENVIRONMENT == 'test':
|
| 35 |
+
FAISS_INDEX_PATH = FAISS_INDEX_TEST_PATH
|
| 36 |
+
else:
|
| 37 |
+
FAISS_INDEX_PATH = FAISS_INDEX_PRODUCTION_PATH
|
| 38 |
+
JSON_TRAINING_DATA_PATH = os.path.join(PROCESSED_DATA_DIR, 'augmented_dialogues.json')
|
| 39 |
+
CACHE_FILE = os.path.join(CACHE_DIR, 'query_embeddings_cache.pkl')
|
| 40 |
+
TF_RECORD_PATH = os.path.join(TF_RECORD_DIR, 'training_data.tfrecord')
|
| 41 |
+
DEBUG_SAMPLES = None
|
| 42 |
+
|
| 43 |
+
# Ensure output directories exist
|
| 44 |
+
os.makedirs(MODELS_DIR, exist_ok=True)
|
| 45 |
+
os.makedirs(PROCESSED_DATA_DIR, exist_ok=True)
|
| 46 |
+
os.makedirs(CACHE_DIR, exist_ok=True)
|
| 47 |
+
os.makedirs(TOKENIZER_DIR, exist_ok=True)
|
| 48 |
+
os.makedirs(FAISS_INDICES_DIR, exist_ok=True)
|
| 49 |
+
os.makedirs(TF_RECORD_DIR, exist_ok=True)
|
| 50 |
+
|
| 51 |
+
# Initialize configuration
|
| 52 |
+
config = ChatbotConfig()
|
| 53 |
+
logger.info(f"Chatbot Configuration: {config}")
|
| 54 |
+
|
| 55 |
+
# Initialize tokenizer
|
| 56 |
+
try:
|
| 57 |
+
tokenizer = AutoTokenizer.from_pretrained(config.pretrained_model)
|
| 58 |
+
logger.info(f"Tokenizer '{config.pretrained_model}' loaded successfully.")
|
| 59 |
+
except Exception as e:
|
| 60 |
+
logger.error(f"Failed to load tokenizer: {e}")
|
| 61 |
+
sys.exit(1)
|
| 62 |
+
|
| 63 |
+
# Add special tokens
|
| 64 |
+
try:
|
| 65 |
+
tokenizer.add_special_tokens({'additional_special_tokens': ['<EMPTY_NEGATIVE>']})
|
| 66 |
+
logger.info("Added special tokens to tokenizer.")
|
| 67 |
+
except Exception as e:
|
| 68 |
+
logger.error(f"Failed to add special tokens: {e}")
|
| 69 |
+
sys.exit(1)
|
| 70 |
+
|
| 71 |
+
# Initialize encoder model
|
| 72 |
+
try:
|
| 73 |
+
encoder = EncoderModel(config=config)
|
| 74 |
+
logger.info("EncoderModel initialized successfully.")
|
| 75 |
+
except Exception as e:
|
| 76 |
+
logger.error(f"Failed to initialize EncoderModel: {e}")
|
| 77 |
+
sys.exit(1)
|
| 78 |
+
|
| 79 |
+
# Resize token embeddings in encoder to match tokenizer
|
| 80 |
+
try:
|
| 81 |
+
encoder.pretrained.resize_token_embeddings(len(tokenizer))
|
| 82 |
+
logger.info(f"Token embeddings resized to: {len(tokenizer)}")
|
| 83 |
+
except Exception as e:
|
| 84 |
+
logger.error(f"Failed to resize token embeddings: {e}")
|
| 85 |
+
sys.exit(1)
|
| 86 |
+
|
| 87 |
+
# Load JSON dialogues
|
| 88 |
+
try:
|
| 89 |
+
dialogues = TFDataPipeline.load_json_training_data(JSON_TRAINING_DATA_PATH, DEBUG_SAMPLES)
|
| 90 |
+
logger.info(f"Loaded {len(dialogues)} dialogues from {JSON_TRAINING_DATA_PATH}.")
|
| 91 |
+
except Exception as e:
|
| 92 |
+
logger.error(f"Failed to load dialogues: {e}")
|
| 93 |
+
sys.exit(1)
|
| 94 |
+
|
| 95 |
+
# Load or initialize query_embeddings_cache
|
| 96 |
+
try:
|
| 97 |
+
if os.path.exists(CACHE_FILE):
|
| 98 |
+
with open(CACHE_FILE, 'rb') as f:
|
| 99 |
+
query_embeddings_cache = pickle.load(f)
|
| 100 |
+
logger.info(f"Loaded {len(query_embeddings_cache)} query embeddings from {CACHE_FILE}.")
|
| 101 |
+
else:
|
| 102 |
+
query_embeddings_cache = {}
|
| 103 |
+
logger.info("Initialized empty query embeddings cache.")
|
| 104 |
+
except Exception as e:
|
| 105 |
+
logger.error(f"Failed to load or initialize query embeddings cache: {e}")
|
| 106 |
+
sys.exit(1)
|
| 107 |
+
|
| 108 |
+
# Initialize TFDataPipeline
|
| 109 |
+
try:
|
| 110 |
+
data_pipeline = TFDataPipeline(
|
| 111 |
+
config=config,
|
| 112 |
+
tokenizer=tokenizer,
|
| 113 |
+
encoder=encoder,
|
| 114 |
+
index_file_path=FAISS_INDEX_PATH,
|
| 115 |
+
response_pool=[],
|
| 116 |
+
max_length=config.max_context_token_limit,
|
| 117 |
+
neg_samples=config.neg_samples,
|
| 118 |
+
query_embeddings_cache=query_embeddings_cache,
|
| 119 |
+
max_retries=config.max_retries
|
| 120 |
+
)
|
| 121 |
+
logger.info("TFDataPipeline initialized successfully.")
|
| 122 |
+
except Exception as e:
|
| 123 |
+
logger.error(f"Failed to initialize TFDataPipeline: {e}")
|
| 124 |
+
sys.exit(1)
|
| 125 |
+
|
| 126 |
+
# Collect unique assistant responses from dialogues
|
| 127 |
+
try:
|
| 128 |
+
response_pool = data_pipeline.collect_responses(dialogues)
|
| 129 |
+
data_pipeline.response_pool = response_pool
|
| 130 |
+
logger.info(f"Collected {len(response_pool)} unique assistant responses from dialogues.")
|
| 131 |
+
except Exception as e:
|
| 132 |
+
logger.error(f"Failed to collect responses: {e}")
|
| 133 |
+
sys.exit(1)
|
| 134 |
+
|
| 135 |
+
# Compute and add response embeddings to FAISS index
|
| 136 |
+
try:
|
| 137 |
+
logger.info("Computing and adding response embeddings to FAISS index...")
|
| 138 |
+
data_pipeline._compute_and_index_response_embeddings()
|
| 139 |
+
logger.info("Response embeddings computed and added to FAISS index.")
|
| 140 |
+
except Exception as e:
|
| 141 |
+
logger.error(f"Failed to compute or add response embeddings: {e}")
|
| 142 |
+
sys.exit(1)
|
| 143 |
+
|
| 144 |
+
# Save FAISS index
|
| 145 |
+
try:
|
| 146 |
+
logger.info(f"Saving FAISS index to {FAISS_INDEX_PATH}...")
|
| 147 |
+
faiss.write_index(data_pipeline.index, FAISS_INDEX_PATH)
|
| 148 |
+
logger.info("FAISS index saved successfully.")
|
| 149 |
+
except Exception as e:
|
| 150 |
+
logger.error(f"Failed to save FAISS index: {e}")
|
| 151 |
+
sys.exit(1)
|
| 152 |
+
|
| 153 |
+
# Prepare and save training data as TFRecords
|
| 154 |
+
try:
|
| 155 |
+
logger.info("Starting data preparation and saving as TFRecord...")
|
| 156 |
+
data_pipeline.prepare_and_save_data(dialogues, TF_RECORD_PATH)
|
| 157 |
+
logger.info(f"Data saved as TFRecord at {TF_RECORD_PATH}.")
|
| 158 |
+
except Exception as e:
|
| 159 |
+
logger.error(f"Failed during data preparation and saving: {e}")
|
| 160 |
+
sys.exit(1)
|
| 161 |
+
|
| 162 |
+
# Save query embeddings cache
|
| 163 |
+
try:
|
| 164 |
+
with open(CACHE_FILE, 'wb') as f:
|
| 165 |
+
pickle.dump(data_pipeline.query_embeddings_cache, f)
|
| 166 |
+
logger.info(f"Saved {len(data_pipeline.query_embeddings_cache)} query embeddings to {CACHE_FILE}.")
|
| 167 |
+
except Exception as e:
|
| 168 |
+
logger.error(f"Failed to save query embeddings cache: {e}")
|
| 169 |
+
sys.exit(1)
|
| 170 |
+
|
| 171 |
+
# Save Tokenizer (including special tokens)
|
| 172 |
+
try:
|
| 173 |
+
tokenizer.save_pretrained(TOKENIZER_DIR)
|
| 174 |
+
logger.info(f"Tokenizer saved to {TOKENIZER_DIR}.")
|
| 175 |
+
except Exception as e:
|
| 176 |
+
logger.error(f"Failed to save tokenizer: {e}")
|
| 177 |
+
sys.exit(1)
|
| 178 |
+
|
| 179 |
+
logger.info("Data preparation pipeline completed successfully.")
|
| 180 |
+
|
| 181 |
+
if __name__ == "__main__":
|
| 182 |
+
main()
|
tf_data_pipeline.py
ADDED
|
@@ -0,0 +1,734 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import gc
|
| 3 |
+
import numpy as np
|
| 4 |
+
import faiss
|
| 5 |
+
import tensorflow as tf
|
| 6 |
+
import h5py
|
| 7 |
+
from tqdm import tqdm
|
| 8 |
+
import json
|
| 9 |
+
from pathlib import Path
|
| 10 |
+
from typing import Union, Optional, List, Tuple, Generator
|
| 11 |
+
from transformers import AutoTokenizer
|
| 12 |
+
from typing import List, Tuple, Generator
|
| 13 |
+
from gpu_monitor import GPUMemoryMonitor
|
| 14 |
+
|
| 15 |
+
from logger_config import config_logger
|
| 16 |
+
logger = config_logger(__name__)
|
| 17 |
+
|
| 18 |
+
class TFDataPipeline:
|
| 19 |
+
def __init__(
|
| 20 |
+
self,
|
| 21 |
+
config,
|
| 22 |
+
tokenizer,
|
| 23 |
+
encoder,
|
| 24 |
+
index_file_path: str,
|
| 25 |
+
response_pool: List[str],
|
| 26 |
+
max_length: int,
|
| 27 |
+
query_embeddings_cache: dict,
|
| 28 |
+
neg_samples: int = 3,
|
| 29 |
+
index_type: str = 'IndexFlatIP',
|
| 30 |
+
nlist: int = 100,
|
| 31 |
+
max_retries: int = 3
|
| 32 |
+
):
|
| 33 |
+
#self.embedding_batch_size = embedding_batch_size
|
| 34 |
+
self.config = config
|
| 35 |
+
self.tokenizer = tokenizer
|
| 36 |
+
self.encoder = encoder
|
| 37 |
+
self.index_file_path = index_file_path
|
| 38 |
+
self.response_pool = response_pool
|
| 39 |
+
self.max_length = max_length
|
| 40 |
+
self.neg_samples = neg_samples
|
| 41 |
+
self.query_embeddings_cache = query_embeddings_cache # In-memory cache for embeddings
|
| 42 |
+
self.index_type = index_type
|
| 43 |
+
self.nlist = nlist
|
| 44 |
+
self.embedding_batch_size = 16 if len(response_pool) < 100 else 64
|
| 45 |
+
self.search_batch_size = 16 if len(response_pool) < 100 else 64
|
| 46 |
+
self.max_batch_size = 16 if len(response_pool) < 100 else 64
|
| 47 |
+
self.memory_monitor = GPUMemoryMonitor()
|
| 48 |
+
self.max_retries = max_retries
|
| 49 |
+
|
| 50 |
+
if os.path.exists(index_file_path):
|
| 51 |
+
logger.info(f"Loading existing FAISS index from {index_file_path}...")
|
| 52 |
+
self.index = faiss.read_index(index_file_path)
|
| 53 |
+
self.validate_faiss_index()
|
| 54 |
+
logger.info("FAISS index loaded and validated successfully.")
|
| 55 |
+
else:
|
| 56 |
+
# Initialize FAISS index
|
| 57 |
+
dimension = self.encoder.config.embedding_dim
|
| 58 |
+
self.index = faiss.IndexFlatIP(dimension)
|
| 59 |
+
logger.info(f"Initialized FAISS IndexFlatIP with dimension {dimension}.")
|
| 60 |
+
|
| 61 |
+
if not self.index.is_trained:
|
| 62 |
+
# Train the index if it's not trained. # TODO: Replace 'dimension' with embedding size
|
| 63 |
+
dimension = self.query_embeddings_cache[next(iter(self.query_embeddings_cache))].shape[0]
|
| 64 |
+
self.index.train(np.array(list(self.query_embeddings_cache.values())).astype(np.float32))
|
| 65 |
+
self.index.add(np.array(list(self.query_embeddings_cache.values())).astype(np.float32))
|
| 66 |
+
|
| 67 |
+
def validate_faiss_index(self):
|
| 68 |
+
"""Validates that the FAISS index has the correct dimensionality."""
|
| 69 |
+
expected_dim = self.encoder.config.embedding_dim
|
| 70 |
+
if self.index.d != expected_dim:
|
| 71 |
+
logger.error(f"FAISS index dimension {self.index.d} does not match encoder embedding dimension {expected_dim}.")
|
| 72 |
+
raise ValueError("FAISS index dimensionality mismatch.")
|
| 73 |
+
logger.info("FAISS index dimension validated successfully.")
|
| 74 |
+
|
| 75 |
+
def save_embeddings_cache_hdf5(self, cache_file_path: str):
|
| 76 |
+
"""Save the embeddings cache to an HDF5 file."""
|
| 77 |
+
with h5py.File(cache_file_path, 'w') as hf:
|
| 78 |
+
for query, emb in self.query_embeddings_cache.items():
|
| 79 |
+
hf.create_dataset(query, data=emb)
|
| 80 |
+
logger.info(f"Embeddings cache saved to {cache_file_path}.")
|
| 81 |
+
|
| 82 |
+
def load_embeddings_cache_hdf5(self, cache_file_path: str):
|
| 83 |
+
"""Load the embeddings cache from an HDF5 file."""
|
| 84 |
+
with h5py.File(cache_file_path, 'r') as hf:
|
| 85 |
+
for query in hf.keys():
|
| 86 |
+
self.query_embeddings_cache[query] = hf[query][:]
|
| 87 |
+
logger.info(f"Embeddings cache loaded from {cache_file_path}.")
|
| 88 |
+
|
| 89 |
+
def save_faiss_index(self, index_file_path: str):
|
| 90 |
+
faiss.write_index(self.index, index_file_path)
|
| 91 |
+
logger.info(f"FAISS index saved to {index_file_path}")
|
| 92 |
+
|
| 93 |
+
def load_faiss_index(self, index_file_path: str):
|
| 94 |
+
self.index = faiss.read_index(index_file_path)
|
| 95 |
+
logger.info(f"FAISS index loaded from {index_file_path}")
|
| 96 |
+
|
| 97 |
+
def save_tokenizer(self, tokenizer_dir: str):
|
| 98 |
+
self.tokenizer.save_pretrained(tokenizer_dir)
|
| 99 |
+
logger.info(f"Tokenizer saved to {tokenizer_dir}")
|
| 100 |
+
|
| 101 |
+
def load_tokenizer(self, tokenizer_dir: str):
|
| 102 |
+
self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_dir)
|
| 103 |
+
logger.info(f"Tokenizer loaded from {tokenizer_dir}")
|
| 104 |
+
|
| 105 |
+
def estimate_total_pairs(self, dialogues: List[dict]) -> int:
|
| 106 |
+
"""Estimate total number of training pairs including hard negatives."""
|
| 107 |
+
base_pairs = sum(
|
| 108 |
+
len([
|
| 109 |
+
1 for i in range(len(d.get('turns', [])) - 1)
|
| 110 |
+
if (d['turns'][i].get('speaker') == 'user' and
|
| 111 |
+
d['turns'][i+1].get('speaker') == 'assistant')
|
| 112 |
+
])
|
| 113 |
+
for d in dialogues
|
| 114 |
+
)
|
| 115 |
+
# Account for hard negatives
|
| 116 |
+
return base_pairs * (1 + self.neg_samples)
|
| 117 |
+
|
| 118 |
+
@staticmethod
|
| 119 |
+
def load_json_training_data(data_path: Union[str, Path], debug_samples: Optional[int] = None) -> List[dict]:
|
| 120 |
+
"""
|
| 121 |
+
Load training data from a JSON file.
|
| 122 |
+
|
| 123 |
+
Args:
|
| 124 |
+
data_path (Union[str, Path]): Path to the JSON file containing dialogues.
|
| 125 |
+
debug_samples (Optional[int]): Number of samples to load for debugging.
|
| 126 |
+
|
| 127 |
+
Returns:
|
| 128 |
+
List[dict]: List of dialogue dictionaries.
|
| 129 |
+
"""
|
| 130 |
+
logger.info(f"Loading training data from {data_path}...")
|
| 131 |
+
data_path = Path(data_path)
|
| 132 |
+
if not data_path.exists():
|
| 133 |
+
logger.error(f"Data file {data_path} does not exist.")
|
| 134 |
+
return []
|
| 135 |
+
|
| 136 |
+
with open(data_path, 'r', encoding='utf-8') as f:
|
| 137 |
+
dialogues = json.load(f)
|
| 138 |
+
|
| 139 |
+
if debug_samples is not None:
|
| 140 |
+
dialogues = dialogues[:debug_samples]
|
| 141 |
+
logger.info(f"Debug mode: Limited to {debug_samples} dialogues")
|
| 142 |
+
|
| 143 |
+
logger.info(f"Loaded {len(dialogues)} dialogues.")
|
| 144 |
+
return dialogues
|
| 145 |
+
|
| 146 |
+
def collect_responses(self, dialogues: List[dict]) -> List[str]:
|
| 147 |
+
"""Extract unique assistant responses from dialogues."""
|
| 148 |
+
response_set = set()
|
| 149 |
+
for dialogue in dialogues:
|
| 150 |
+
turns = dialogue.get('turns', [])
|
| 151 |
+
for turn in turns:
|
| 152 |
+
speaker = turn.get('speaker')
|
| 153 |
+
text = turn.get('text', '').strip()
|
| 154 |
+
if speaker == 'assistant' and text:
|
| 155 |
+
# Ensure we don't exclude valid shorter responses
|
| 156 |
+
if len(text) <= self.max_length:
|
| 157 |
+
response_set.add(text)
|
| 158 |
+
logger.info(f"Collected {len(response_set)} unique assistant responses from dialogues.")
|
| 159 |
+
return list(response_set)
|
| 160 |
+
|
| 161 |
+
def _extract_pairs_from_dialogue(self, dialogue: dict) -> List[Tuple[str, str]]:
|
| 162 |
+
"""Extract query-response pairs from a dialogue."""
|
| 163 |
+
pairs = []
|
| 164 |
+
turns = dialogue.get('turns', [])
|
| 165 |
+
|
| 166 |
+
for i in range(len(turns) - 1):
|
| 167 |
+
current_turn = turns[i]
|
| 168 |
+
next_turn = turns[i+1]
|
| 169 |
+
|
| 170 |
+
if (current_turn.get('speaker') == 'user' and
|
| 171 |
+
next_turn.get('speaker') == 'assistant' and
|
| 172 |
+
'text' in current_turn and
|
| 173 |
+
'text' in next_turn):
|
| 174 |
+
|
| 175 |
+
query = current_turn['text'].strip()
|
| 176 |
+
positive = next_turn['text'].strip()
|
| 177 |
+
pairs.append((query, positive))
|
| 178 |
+
|
| 179 |
+
return pairs
|
| 180 |
+
|
| 181 |
+
def _compute_and_index_response_embeddings(self):
|
| 182 |
+
"""
|
| 183 |
+
Computes embeddings for the response pool and adds them to the FAISS index.
|
| 184 |
+
"""
|
| 185 |
+
logger.info("Computing embeddings for the response pool...")
|
| 186 |
+
|
| 187 |
+
# Log the contents and types of response_pool
|
| 188 |
+
for idx, response in enumerate(self.response_pool[:5], 1): # Log first 5 responses
|
| 189 |
+
logger.debug(f"Response {idx}: {response} (Type: {type(response)})")
|
| 190 |
+
|
| 191 |
+
# Ensure all responses are strings
|
| 192 |
+
if not all(isinstance(response, str) for response in self.response_pool):
|
| 193 |
+
logger.error("All elements in response_pool must be strings.")
|
| 194 |
+
raise ValueError("Invalid data type in response_pool.")
|
| 195 |
+
|
| 196 |
+
# Proceed with tokenization
|
| 197 |
+
encoded_responses = self.tokenizer(
|
| 198 |
+
self.response_pool,
|
| 199 |
+
padding=True,
|
| 200 |
+
truncation=True,
|
| 201 |
+
max_length=self.max_length,
|
| 202 |
+
return_tensors='tf'
|
| 203 |
+
)
|
| 204 |
+
response_ids = encoded_responses['input_ids']
|
| 205 |
+
|
| 206 |
+
# Compute embeddings in batches
|
| 207 |
+
batch_size = getattr(self, 'embedding_batch_size', 64) # Default to 64 if not set
|
| 208 |
+
embeddings = []
|
| 209 |
+
for i in range(0, len(response_ids), batch_size):
|
| 210 |
+
batch_ids = response_ids[i:i+batch_size]
|
| 211 |
+
# Compute embeddings
|
| 212 |
+
batch_embeddings = self.encoder(batch_ids, training=False).numpy()
|
| 213 |
+
# Normalize embeddings if using inner product or cosine similarity
|
| 214 |
+
faiss.normalize_L2(batch_embeddings)
|
| 215 |
+
embeddings.append(batch_embeddings)
|
| 216 |
+
|
| 217 |
+
if embeddings:
|
| 218 |
+
embeddings = np.vstack(embeddings).astype(np.float32)
|
| 219 |
+
# Add embeddings to FAISS index
|
| 220 |
+
logger.info(f"Adding {len(embeddings)} response embeddings to FAISS index...")
|
| 221 |
+
self.index.add(embeddings)
|
| 222 |
+
logger.info("Response embeddings added to FAISS index.")
|
| 223 |
+
else:
|
| 224 |
+
logger.warning("No embeddings to add to FAISS index.")
|
| 225 |
+
|
| 226 |
+
# **Sanity Check:** Verify the number of embeddings in FAISS index
|
| 227 |
+
logger.info(f"Total embeddings in FAISS index after addition: {self.index.ntotal}")
|
| 228 |
+
|
| 229 |
+
def _find_hard_negatives_batch(self, queries: List[str], positives: List[str]) -> List[List[str]]:
|
| 230 |
+
"""Find hard negatives for a batch of queries with error handling and retries."""
|
| 231 |
+
retry_count = 0
|
| 232 |
+
total_responses = len(self.response_pool)
|
| 233 |
+
|
| 234 |
+
# Set k to be neg_samples + additional candidates to improve negative selection
|
| 235 |
+
k = self.neg_samples + 0
|
| 236 |
+
|
| 237 |
+
while retry_count < self.max_retries:
|
| 238 |
+
try:
|
| 239 |
+
# Compute embeddings in sub-batches to manage memory
|
| 240 |
+
batch_size = 128 # Example sub-batch size; adjust as needed
|
| 241 |
+
query_embeddings = []
|
| 242 |
+
for i in range(0, len(queries), batch_size):
|
| 243 |
+
sub_queries = queries[i:i + batch_size]
|
| 244 |
+
sub_embeddings = np.vstack([
|
| 245 |
+
self.query_embeddings_cache[q] for q in sub_queries
|
| 246 |
+
]).astype(np.float32)
|
| 247 |
+
faiss.normalize_L2(sub_embeddings)
|
| 248 |
+
query_embeddings.append(sub_embeddings)
|
| 249 |
+
query_embeddings = np.vstack(query_embeddings)
|
| 250 |
+
|
| 251 |
+
# Ensure contiguous memory layout
|
| 252 |
+
query_embeddings = np.ascontiguousarray(query_embeddings)
|
| 253 |
+
|
| 254 |
+
# Perform FAISS search on CPU
|
| 255 |
+
distances, indices = self.index.search(query_embeddings, k)
|
| 256 |
+
|
| 257 |
+
all_negatives = []
|
| 258 |
+
for query_indices, query, positive in zip(indices, queries, positives):
|
| 259 |
+
negatives = []
|
| 260 |
+
positive_strip = positive.strip()
|
| 261 |
+
seen = {positive_strip}
|
| 262 |
+
|
| 263 |
+
for idx in query_indices:
|
| 264 |
+
if idx >= 0 and idx < total_responses:
|
| 265 |
+
candidate = self.response_pool[idx].strip()
|
| 266 |
+
if candidate and candidate not in seen:
|
| 267 |
+
seen.add(candidate)
|
| 268 |
+
negatives.append(candidate)
|
| 269 |
+
if len(negatives) >= self.neg_samples:
|
| 270 |
+
break
|
| 271 |
+
|
| 272 |
+
# If not enough negatives are found, pad with a special token
|
| 273 |
+
while len(negatives) < self.neg_samples:
|
| 274 |
+
negatives.append("<EMPTY_NEGATIVE>") # Use a special token
|
| 275 |
+
|
| 276 |
+
all_negatives.append(negatives)
|
| 277 |
+
|
| 278 |
+
return all_negatives
|
| 279 |
+
|
| 280 |
+
except KeyError as ke:
|
| 281 |
+
retry_count += 1
|
| 282 |
+
logger.warning(f"Hard negative search attempt {retry_count} failed due to missing embeddings: {ke}")
|
| 283 |
+
if retry_count == self.max_retries:
|
| 284 |
+
logger.error("Max retries reached for hard negative search due to missing embeddings.")
|
| 285 |
+
return [["<EMPTY_NEGATIVE>"] * self.neg_samples for _ in queries]
|
| 286 |
+
# Perform memory cleanup
|
| 287 |
+
gc.collect()
|
| 288 |
+
if tf.config.list_physical_devices('GPU'):
|
| 289 |
+
tf.keras.backend.clear_session()
|
| 290 |
+
except Exception as e:
|
| 291 |
+
retry_count += 1
|
| 292 |
+
logger.warning(f"Hard negative search attempt {retry_count} failed: {e}")
|
| 293 |
+
if retry_count == self.max_retries:
|
| 294 |
+
logger.error("Max retries reached for hard negative search.")
|
| 295 |
+
return [["<EMPTY_NEGATIVE>"] * self.neg_samples for _ in queries]
|
| 296 |
+
# Perform memory cleanup
|
| 297 |
+
gc.collect()
|
| 298 |
+
if tf.config.list_physical_devices('GPU'):
|
| 299 |
+
tf.keras.backend.clear_session()
|
| 300 |
+
|
| 301 |
+
def _tokenize_and_encode(self, queries: List[str], positives: List[str], negatives: List[List[str]]) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
|
| 302 |
+
"""
|
| 303 |
+
Tokenize and encode the queries, positives, and negatives.
|
| 304 |
+
Returns:
|
| 305 |
+
query_ids: [batch_size, max_length]
|
| 306 |
+
positive_ids: [batch_size, max_length]
|
| 307 |
+
negative_ids: [batch_size, neg_samples, max_length]
|
| 308 |
+
"""
|
| 309 |
+
# Tokenize queries
|
| 310 |
+
q_enc = self.tokenizer(
|
| 311 |
+
queries,
|
| 312 |
+
padding="max_length",
|
| 313 |
+
truncation=True,
|
| 314 |
+
max_length=self.max_length,
|
| 315 |
+
return_tensors="np"
|
| 316 |
+
)
|
| 317 |
+
# Tokenize positives
|
| 318 |
+
p_enc = self.tokenizer(
|
| 319 |
+
positives,
|
| 320 |
+
padding="max_length",
|
| 321 |
+
truncation=True,
|
| 322 |
+
max_length=self.max_length,
|
| 323 |
+
return_tensors="np"
|
| 324 |
+
)
|
| 325 |
+
# Tokenize negatives
|
| 326 |
+
# Flatten negatives
|
| 327 |
+
flattened_negatives = [neg for sublist in negatives for neg in sublist]
|
| 328 |
+
if len(flattened_negatives) == 0:
|
| 329 |
+
# No negatives at all: return a zero array
|
| 330 |
+
n_ids = np.zeros((len(queries), self.neg_samples, self.max_length), dtype=np.int32)
|
| 331 |
+
else:
|
| 332 |
+
n_enc = self.tokenizer(
|
| 333 |
+
flattened_negatives,
|
| 334 |
+
padding="max_length",
|
| 335 |
+
truncation=True,
|
| 336 |
+
max_length=self.max_length,
|
| 337 |
+
return_tensors="np"
|
| 338 |
+
)
|
| 339 |
+
n_input_ids = n_enc["input_ids"]
|
| 340 |
+
|
| 341 |
+
# Reshape to [batch_size, neg_samples, max_length]
|
| 342 |
+
batch_size = len(queries)
|
| 343 |
+
n_ids = n_input_ids.reshape(batch_size, self.neg_samples, self.max_length)
|
| 344 |
+
|
| 345 |
+
# Convert to int32
|
| 346 |
+
query_ids = q_enc["input_ids"].astype(np.int32)
|
| 347 |
+
positive_ids = p_enc["input_ids"].astype(np.int32)
|
| 348 |
+
negative_ids = n_ids.astype(np.int32)
|
| 349 |
+
|
| 350 |
+
return query_ids, positive_ids, negative_ids
|
| 351 |
+
|
| 352 |
+
def prepare_and_save_data(self, dialogues: List[dict], tfrecord_file_path: str, batch_size: int = 32):
|
| 353 |
+
"""Processes dialogues in batches and saves to a TFRecord file."""
|
| 354 |
+
with tf.io.TFRecordWriter(tfrecord_file_path) as writer:
|
| 355 |
+
total_dialogues = len(dialogues)
|
| 356 |
+
logger.debug(f"Total dialogues to process: {total_dialogues}")
|
| 357 |
+
|
| 358 |
+
with tqdm(total=total_dialogues, desc="Processing Dialogues", unit="dialogue") as pbar:
|
| 359 |
+
for i in range(0, total_dialogues, batch_size):
|
| 360 |
+
batch_dialogues = dialogues[i:i+batch_size]
|
| 361 |
+
# Process each batch_dialogues
|
| 362 |
+
# Extract pairs, find negatives, tokenize, and serialize
|
| 363 |
+
# Example:
|
| 364 |
+
for dialogue in batch_dialogues:
|
| 365 |
+
pairs = self._extract_pairs_from_dialogue(dialogue)
|
| 366 |
+
queries = []
|
| 367 |
+
positives = []
|
| 368 |
+
|
| 369 |
+
for query, positive in pairs:
|
| 370 |
+
queries.append(query)
|
| 371 |
+
positives.append(positive)
|
| 372 |
+
|
| 373 |
+
if queries:
|
| 374 |
+
# **Compute and cache query embeddings before searching**
|
| 375 |
+
self._compute_embeddings(queries)
|
| 376 |
+
|
| 377 |
+
# Find hard negatives
|
| 378 |
+
hard_negatives = self._find_hard_negatives_batch(queries, positives)
|
| 379 |
+
|
| 380 |
+
for idx, negatives in enumerate(hard_negatives[:5]): # Log first 5 examples
|
| 381 |
+
logger.debug(f"Query: {queries[idx]}")
|
| 382 |
+
logger.debug(f"Positive: {positives[idx]}")
|
| 383 |
+
logger.debug(f"Hard Negatives: {negatives}")
|
| 384 |
+
# Tokenize and encode
|
| 385 |
+
query_ids, positive_ids, negative_ids = self._tokenize_and_encode(queries, positives, hard_negatives)
|
| 386 |
+
|
| 387 |
+
# Serialize each example and write to TFRecord
|
| 388 |
+
for q_id, p_id, n_id in zip(query_ids, positive_ids, negative_ids):
|
| 389 |
+
feature = {
|
| 390 |
+
'query_ids': tf.train.Feature(int64_list=tf.train.Int64List(value=q_id)),
|
| 391 |
+
'positive_ids': tf.train.Feature(int64_list=tf.train.Int64List(value=p_id)),
|
| 392 |
+
'negative_ids': tf.train.Feature(int64_list=tf.train.Int64List(value=n_id.flatten())),
|
| 393 |
+
}
|
| 394 |
+
example = tf.train.Example(features=tf.train.Features(feature=feature))
|
| 395 |
+
writer.write(example.SerializeToString())
|
| 396 |
+
|
| 397 |
+
pbar.update(len(batch_dialogues))
|
| 398 |
+
logger.info(f"Data preparation complete. TFRecord saved at {tfrecord_file_path}")
|
| 399 |
+
|
| 400 |
+
def _tokenize_negatives_tf(self, negatives):
|
| 401 |
+
"""Tokenizes negatives using tf.py_function."""
|
| 402 |
+
# Handle the case where negatives is an empty tensor
|
| 403 |
+
if tf.size(negatives) == 0:
|
| 404 |
+
return tf.zeros([0, self.neg_samples, self.max_length], dtype=tf.int32)
|
| 405 |
+
|
| 406 |
+
# Convert EagerTensor to a list of strings
|
| 407 |
+
negatives_list = []
|
| 408 |
+
for neg_list in negatives.numpy():
|
| 409 |
+
decoded_negs = [neg.decode("utf-8") for neg in neg_list if neg] # Filter out empty strings
|
| 410 |
+
negatives_list.append(decoded_negs)
|
| 411 |
+
|
| 412 |
+
# Flatten the list of lists
|
| 413 |
+
flattened_negatives = [neg for sublist in negatives_list for neg in sublist]
|
| 414 |
+
|
| 415 |
+
# Tokenize the flattened negatives
|
| 416 |
+
if flattened_negatives:
|
| 417 |
+
n_tokens = self.tokenizer(
|
| 418 |
+
flattened_negatives,
|
| 419 |
+
padding='max_length',
|
| 420 |
+
truncation=True,
|
| 421 |
+
max_length=self.max_length,
|
| 422 |
+
return_tensors='tf'
|
| 423 |
+
)
|
| 424 |
+
# Reshape the tokens
|
| 425 |
+
n_tokens_reshaped = tf.reshape(n_tokens['input_ids'], [-1, self.neg_samples, self.max_length])
|
| 426 |
+
return n_tokens_reshaped
|
| 427 |
+
else:
|
| 428 |
+
return tf.zeros([0, self.neg_samples, self.max_length], dtype=tf.int32)
|
| 429 |
+
|
| 430 |
+
def _compute_embeddings(self, queries: List[str]) -> None:
|
| 431 |
+
new_queries = [q for q in queries if q not in self.query_embeddings_cache]
|
| 432 |
+
if not new_queries:
|
| 433 |
+
return # All queries already cached
|
| 434 |
+
|
| 435 |
+
# Compute embeddings for new queries
|
| 436 |
+
new_embeddings = []
|
| 437 |
+
for i in range(0, len(new_queries), self.embedding_batch_size):
|
| 438 |
+
batch_queries = new_queries[i:i + self.embedding_batch_size]
|
| 439 |
+
encoded = self.tokenizer(
|
| 440 |
+
batch_queries,
|
| 441 |
+
padding=True,
|
| 442 |
+
truncation=True,
|
| 443 |
+
max_length=self.max_length,
|
| 444 |
+
return_tensors='tf'
|
| 445 |
+
)
|
| 446 |
+
batch_embeddings = self.encoder(encoded['input_ids'], training=False).numpy()
|
| 447 |
+
faiss.normalize_L2(batch_embeddings)
|
| 448 |
+
new_embeddings.extend(batch_embeddings)
|
| 449 |
+
|
| 450 |
+
# Update the cache
|
| 451 |
+
for query, emb in zip(new_queries, new_embeddings):
|
| 452 |
+
self.query_embeddings_cache[query] = emb
|
| 453 |
+
|
| 454 |
+
def data_generator(self, dialogues: List[dict]) -> Generator[Tuple[str, str, List[str]], None, None]:
|
| 455 |
+
"""
|
| 456 |
+
Generates training examples: (query, positive, hard_negatives).
|
| 457 |
+
Wrapped the outer loop with tqdm for progress tracking.
|
| 458 |
+
"""
|
| 459 |
+
total_dialogues = len(dialogues)
|
| 460 |
+
logger.debug(f"Total dialogues to process: {total_dialogues}")
|
| 461 |
+
|
| 462 |
+
# Initialize tqdm progress bar
|
| 463 |
+
with tqdm(total=total_dialogues, desc="Processing Dialogues", unit="dialogue") as pbar:
|
| 464 |
+
for dialogue in dialogues:
|
| 465 |
+
pairs = self._extract_pairs_from_dialogue(dialogue)
|
| 466 |
+
for query, positive in pairs:
|
| 467 |
+
# Ensure embeddings are computed, find hard negatives, etc.
|
| 468 |
+
self._compute_embeddings([query])
|
| 469 |
+
hard_negatives = self._find_hard_negatives_batch([query], [positive])[0]
|
| 470 |
+
yield (query, positive, hard_negatives)
|
| 471 |
+
pbar.update(1)
|
| 472 |
+
|
| 473 |
+
def _prepare_batch(self, queries: tf.Tensor, positives: tf.Tensor, negatives: tf.Tensor) -> Tuple[tf.Tensor, tf.Tensor]:
|
| 474 |
+
"""Prepares a batch of data for training."""
|
| 475 |
+
|
| 476 |
+
# Convert EagerTensors to lists of strings
|
| 477 |
+
queries_list = [query.decode("utf-8") for query in queries.numpy()]
|
| 478 |
+
positives_list = [pos.decode("utf-8") for pos in positives.numpy()]
|
| 479 |
+
|
| 480 |
+
# Tokenize queries and positives
|
| 481 |
+
q_tokens = self.tokenizer(queries_list, padding='max_length', truncation=True, max_length=self.max_length, return_tensors='tf')
|
| 482 |
+
p_tokens = self.tokenizer(positives_list, padding='max_length', truncation=True, max_length=self.max_length, return_tensors='tf')
|
| 483 |
+
|
| 484 |
+
# Decode negatives and ensure they are lists of strings
|
| 485 |
+
negatives_list = []
|
| 486 |
+
for neg_list in negatives.numpy():
|
| 487 |
+
decoded_negs = [neg.decode("utf-8") for neg in neg_list if neg] # Filter out empty strings
|
| 488 |
+
negatives_list.append(decoded_negs)
|
| 489 |
+
|
| 490 |
+
# Flatten negatives for tokenization if there are any valid negatives
|
| 491 |
+
flattened_negatives = [neg for sublist in negatives_list for neg in sublist if neg]
|
| 492 |
+
|
| 493 |
+
# Tokenize negatives if there are any
|
| 494 |
+
n_tokens_reshaped = None
|
| 495 |
+
if flattened_negatives:
|
| 496 |
+
n_tokens = self.tokenizer(flattened_negatives, padding='max_length', truncation=True, max_length=self.max_length, return_tensors='tf')
|
| 497 |
+
|
| 498 |
+
# Reshape n_tokens to match the expected shape based on the number of negatives per query
|
| 499 |
+
# This part may need adjustment if the number of negatives varies per query
|
| 500 |
+
n_tokens_reshaped = tf.reshape(n_tokens['input_ids'], [len(queries_list), -1, self.max_length])
|
| 501 |
+
else:
|
| 502 |
+
# Create a placeholder tensor for the case where there are no negatives
|
| 503 |
+
n_tokens_reshaped = tf.zeros([len(queries_list), 0, self.max_length], dtype=tf.int32)
|
| 504 |
+
|
| 505 |
+
# Ensure n_tokens_reshaped has a consistent shape even when there are no negatives
|
| 506 |
+
# Adjust shape to [batch_size, num_neg_samples, max_length]
|
| 507 |
+
if n_tokens_reshaped.shape[1] != self.neg_samples:
|
| 508 |
+
# Pad or truncate the second dimension to match neg_samples
|
| 509 |
+
padding = tf.zeros([len(queries_list), tf.maximum(0, self.neg_samples - n_tokens_reshaped.shape[1]), self.max_length], dtype=tf.int32)
|
| 510 |
+
n_tokens_reshaped = tf.concat([n_tokens_reshaped, padding], axis=1)
|
| 511 |
+
n_tokens_reshaped = n_tokens_reshaped[:, :self.neg_samples, :]
|
| 512 |
+
|
| 513 |
+
# Concatenate the positive and negative examples along the 'neg_samples' dimension
|
| 514 |
+
combined_p_n_tokens = tf.concat([tf.expand_dims(p_tokens['input_ids'], axis=1), n_tokens_reshaped], axis=1)
|
| 515 |
+
|
| 516 |
+
return q_tokens['input_ids'], combined_p_n_tokens
|
| 517 |
+
|
| 518 |
+
def get_tf_dataset(self, dialogues: List[dict], batch_size: int) -> tf.data.Dataset:
|
| 519 |
+
"""
|
| 520 |
+
Creates a tf.data.Dataset for streaming training that yields
|
| 521 |
+
(input_ids_query, input_ids_positive, input_ids_negatives).
|
| 522 |
+
"""
|
| 523 |
+
# 1) Start with a generator dataset
|
| 524 |
+
dataset = tf.data.Dataset.from_generator(
|
| 525 |
+
lambda: self.data_generator(dialogues),
|
| 526 |
+
output_signature=(
|
| 527 |
+
tf.TensorSpec(shape=(), dtype=tf.string), # Query (single string)
|
| 528 |
+
tf.TensorSpec(shape=(), dtype=tf.string), # Positive (single string)
|
| 529 |
+
tf.TensorSpec(shape=(self.neg_samples,), dtype=tf.string) # Hard Negatives (list of strings)
|
| 530 |
+
)
|
| 531 |
+
)
|
| 532 |
+
|
| 533 |
+
# 2) Batch the raw strings
|
| 534 |
+
dataset = dataset.batch(batch_size, drop_remainder=True)
|
| 535 |
+
|
| 536 |
+
# 3) Map them through a tokenize step using `tf.py_function`
|
| 537 |
+
dataset = dataset.map(
|
| 538 |
+
lambda q, p, n: self._tokenize_triple(q, p, n),
|
| 539 |
+
num_parallel_calls=1 #tf.data.AUTOTUNE
|
| 540 |
+
)
|
| 541 |
+
|
| 542 |
+
dataset = dataset.prefetch(tf.data.AUTOTUNE)
|
| 543 |
+
return dataset
|
| 544 |
+
# def get_tf_dataset(self, dialogues: List[dict], batch_size: int) -> tf.data.Dataset:
|
| 545 |
+
# """
|
| 546 |
+
# Creates a tf.data.Dataset for streaming training that yields
|
| 547 |
+
# (input_ids_query, input_ids_positive, input_ids_negatives).
|
| 548 |
+
# """
|
| 549 |
+
# # 1) Start with a generator dataset
|
| 550 |
+
# dataset = tf.data.Dataset.from_generator(
|
| 551 |
+
# lambda: self.data_generator(dialogues),
|
| 552 |
+
# output_signature=(
|
| 553 |
+
# tf.TensorSpec(shape=(), dtype=tf.string), # Query (single string)
|
| 554 |
+
# tf.TensorSpec(shape=(), dtype=tf.string), # Positive (single string)
|
| 555 |
+
# tf.TensorSpec(shape=(None,), dtype=tf.string) # Hard Negatives (list of strings)
|
| 556 |
+
# )
|
| 557 |
+
# )
|
| 558 |
+
|
| 559 |
+
# # 2) Batch the raw strings
|
| 560 |
+
# dataset = dataset.batch(batch_size)
|
| 561 |
+
|
| 562 |
+
# # 3) Now map them through a tokenize step (via py_function)
|
| 563 |
+
# dataset = dataset.map(
|
| 564 |
+
# lambda q, p, n: self._tokenize_triple(q, p, n),
|
| 565 |
+
# num_parallel_calls=1 #tf.data.AUTOTUNE
|
| 566 |
+
# )
|
| 567 |
+
|
| 568 |
+
# dataset = dataset.prefetch(tf.data.AUTOTUNE)
|
| 569 |
+
# return dataset
|
| 570 |
+
|
| 571 |
+
def _tokenize_triple(
|
| 572 |
+
self,
|
| 573 |
+
q: tf.Tensor,
|
| 574 |
+
p: tf.Tensor,
|
| 575 |
+
n: tf.Tensor
|
| 576 |
+
) -> Tuple[tf.Tensor, tf.Tensor, tf.Tensor]:
|
| 577 |
+
"""
|
| 578 |
+
Wraps a Python function via tf.py_function to convert tf.Tensors of strings
|
| 579 |
+
-> Python lists of strings -> HF tokenizer -> Tensors of IDs.
|
| 580 |
+
|
| 581 |
+
q is shape [batch_size], p is shape [batch_size],
|
| 582 |
+
n is shape [batch_size, neg_samples] (i.e., each row is a list of negatives).
|
| 583 |
+
"""
|
| 584 |
+
# Use tf.py_function with limited parallelism
|
| 585 |
+
q_ids, p_ids, n_ids = tf.py_function(
|
| 586 |
+
func=self._tokenize_triple_py,
|
| 587 |
+
inp=[q, p, n, tf.constant(self.max_length), tf.constant(self.neg_samples)],
|
| 588 |
+
Tout=[tf.int32, tf.int32, tf.int32]
|
| 589 |
+
)
|
| 590 |
+
|
| 591 |
+
# Manually set shape information
|
| 592 |
+
q_ids.set_shape([None, self.max_length]) # [batch_size, max_length]
|
| 593 |
+
p_ids.set_shape([None, self.max_length]) # [batch_size, max_length]
|
| 594 |
+
n_ids.set_shape([None, self.neg_samples, self.max_length]) # [batch_size, neg_samples, max_length]
|
| 595 |
+
|
| 596 |
+
return q_ids, p_ids, n_ids
|
| 597 |
+
|
| 598 |
+
def _tokenize_triple_py(
|
| 599 |
+
self,
|
| 600 |
+
q: tf.Tensor,
|
| 601 |
+
p: tf.Tensor,
|
| 602 |
+
n: tf.Tensor,
|
| 603 |
+
max_len: tf.Tensor,
|
| 604 |
+
neg_samples: tf.Tensor
|
| 605 |
+
) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
|
| 606 |
+
"""
|
| 607 |
+
Python function that:
|
| 608 |
+
- Decodes each tf.string Tensor to a Python list of strings
|
| 609 |
+
- Calls the HF tokenizer
|
| 610 |
+
- Reshapes negatives
|
| 611 |
+
- Returns np.array of int32s for (q_ids, p_ids, n_ids).
|
| 612 |
+
|
| 613 |
+
q: shape [batch_size], p: shape [batch_size]
|
| 614 |
+
n: shape [batch_size, neg_samples]
|
| 615 |
+
max_len: scalar int
|
| 616 |
+
neg_samples: scalar int
|
| 617 |
+
"""
|
| 618 |
+
max_len = int(max_len.numpy()) # Convert to Python int
|
| 619 |
+
neg_samples = int(neg_samples.numpy())
|
| 620 |
+
|
| 621 |
+
# 1) Convert Tensors -> Python lists of strings
|
| 622 |
+
q_list = [q_i.decode("utf-8") for q_i in q.numpy()] # shape [batch_size]
|
| 623 |
+
p_list = [p_i.decode("utf-8") for p_i in p.numpy()] # shape [batch_size]
|
| 624 |
+
|
| 625 |
+
# shape [batch_size, neg_samples], decode each row
|
| 626 |
+
n_list = []
|
| 627 |
+
for row in n.numpy():
|
| 628 |
+
# row is shape [neg_samples], each is a tf.string
|
| 629 |
+
decoded = [neg.decode("utf-8") for neg in row]
|
| 630 |
+
n_list.append(decoded)
|
| 631 |
+
|
| 632 |
+
# 2) Tokenize queries & positives
|
| 633 |
+
q_enc = self.tokenizer(
|
| 634 |
+
q_list,
|
| 635 |
+
padding="max_length",
|
| 636 |
+
truncation=True,
|
| 637 |
+
max_length=max_len,
|
| 638 |
+
return_tensors="np"
|
| 639 |
+
)
|
| 640 |
+
p_enc = self.tokenizer(
|
| 641 |
+
p_list,
|
| 642 |
+
padding="max_length",
|
| 643 |
+
truncation=True,
|
| 644 |
+
max_length=max_len,
|
| 645 |
+
return_tensors="np"
|
| 646 |
+
)
|
| 647 |
+
|
| 648 |
+
# 3) Tokenize negatives
|
| 649 |
+
# Flatten [batch_size, neg_samples] -> single list
|
| 650 |
+
flattened_negatives = [neg for row in n_list for neg in row]
|
| 651 |
+
if len(flattened_negatives) == 0:
|
| 652 |
+
# No negatives at all: return a zero array
|
| 653 |
+
n_ids = np.zeros((len(q_list), neg_samples, max_len), dtype=np.int32)
|
| 654 |
+
else:
|
| 655 |
+
n_enc = self.tokenizer(
|
| 656 |
+
flattened_negatives,
|
| 657 |
+
padding="max_length",
|
| 658 |
+
truncation=True,
|
| 659 |
+
max_length=max_len,
|
| 660 |
+
return_tensors="np"
|
| 661 |
+
)
|
| 662 |
+
# shape [batch_size * neg_samples, max_len]
|
| 663 |
+
n_input_ids = n_enc["input_ids"]
|
| 664 |
+
|
| 665 |
+
# We want to reshape to [batch_size, neg_samples, max_len]
|
| 666 |
+
# Handle cases where there might be fewer negatives
|
| 667 |
+
batch_size = len(q_list)
|
| 668 |
+
n_ids_list = []
|
| 669 |
+
for i in range(batch_size):
|
| 670 |
+
start_idx = i * neg_samples
|
| 671 |
+
end_idx = start_idx + neg_samples
|
| 672 |
+
row_negs = n_input_ids[start_idx:end_idx]
|
| 673 |
+
|
| 674 |
+
# If fewer negatives, pad with zeros
|
| 675 |
+
if row_negs.shape[0] < neg_samples:
|
| 676 |
+
deficit = neg_samples - row_negs.shape[0]
|
| 677 |
+
pad_arr = np.zeros((deficit, max_len), dtype=np.int32)
|
| 678 |
+
row_negs = np.concatenate([row_negs, pad_arr], axis=0)
|
| 679 |
+
|
| 680 |
+
n_ids_list.append(row_negs)
|
| 681 |
+
|
| 682 |
+
# stack them -> shape [batch_size, neg_samples, max_len]
|
| 683 |
+
n_ids = np.stack(n_ids_list, axis=0)
|
| 684 |
+
|
| 685 |
+
# 4) Return as np.int32 arrays
|
| 686 |
+
q_ids = q_enc["input_ids"].astype(np.int32) # shape [batch_size, max_len]
|
| 687 |
+
p_ids = p_enc["input_ids"].astype(np.int32) # shape [batch_size, max_len]
|
| 688 |
+
n_ids = n_ids.astype(np.int32) # shape [batch_size, neg_samples, max_len]
|
| 689 |
+
|
| 690 |
+
return q_ids, p_ids, n_ids
|
| 691 |
+
|
| 692 |
+
|
| 693 |
+
|
| 694 |
+
# def _find_hard_negatives_batch(self, queries: List[str], positives: List[str]) -> List[List[str]]:
|
| 695 |
+
# """Find hard negatives for a batch of queries with error handling and retries."""
|
| 696 |
+
# retry_count = 0
|
| 697 |
+
# total_responses = len(self.response_pool)
|
| 698 |
+
|
| 699 |
+
# while retry_count < self.max_retries:
|
| 700 |
+
# try:
|
| 701 |
+
# query_embeddings = np.vstack([
|
| 702 |
+
# self.query_embeddings_cache[q] for q in queries
|
| 703 |
+
# ]).astype(np.float32)
|
| 704 |
+
|
| 705 |
+
# query_embeddings = np.ascontiguousarray(query_embeddings)
|
| 706 |
+
# faiss.normalize_L2(query_embeddings)
|
| 707 |
+
|
| 708 |
+
# k = 1 # TODO: try higher k for better results
|
| 709 |
+
# #logger.debug(f"Searching with k={k} among {total_responses} responses")
|
| 710 |
+
|
| 711 |
+
# distances, indices = self.index.search(query_embeddings, k)
|
| 712 |
+
|
| 713 |
+
# all_negatives = []
|
| 714 |
+
# for query_indices, query, positive in zip(indices, queries, positives):
|
| 715 |
+
# negatives = []
|
| 716 |
+
# positive_strip = positive.strip()
|
| 717 |
+
# seen = {positive_strip}
|
| 718 |
+
|
| 719 |
+
# for idx in query_indices:
|
| 720 |
+
# if idx >= 0 and idx < total_responses:
|
| 721 |
+
# candidate = self.response_pool[idx].strip()
|
| 722 |
+
# if candidate and candidate not in seen:
|
| 723 |
+
# seen.add(candidate)
|
| 724 |
+
# negatives.append(candidate)
|
| 725 |
+
# if len(negatives) >= self.neg_samples:
|
| 726 |
+
# break
|
| 727 |
+
|
| 728 |
+
# # Pad with a special empty negative if necessary
|
| 729 |
+
# while len(negatives) < self.neg_samples:
|
| 730 |
+
# negatives.append("<EMPTY_NEGATIVE>") # Use a special token
|
| 731 |
+
|
| 732 |
+
# all_negatives.append(negatives)
|
| 733 |
+
|
| 734 |
+
# return all_negatives
|