JoeArmani
commited on
Commit
·
74af405
1
Parent(s):
2183656
data processing pipeline
Browse files- chatbot_model.py +133 -637
- requirements.txt +1 -0
- run_data_preparer.py +182 -0
- tf_data_pipeline.py +734 -0
chatbot_model.py
CHANGED
@@ -2,7 +2,7 @@ import time
|
|
2 |
from transformers import TFAutoModel, AutoTokenizer
|
3 |
import tensorflow as tf
|
4 |
import numpy as np
|
5 |
-
from typing import
|
6 |
import math
|
7 |
from dataclasses import dataclass
|
8 |
import json
|
@@ -10,6 +10,7 @@ from pathlib import Path
|
|
10 |
import datetime
|
11 |
import faiss
|
12 |
import gc
|
|
|
13 |
from response_quality_checker import ResponseQualityChecker
|
14 |
from cross_encoder_reranker import CrossEncoderReranker
|
15 |
from conversation_summarizer import DeviceAwareModel, Summarizer
|
@@ -24,14 +25,12 @@ logger = config_logger(__name__)
|
|
24 |
@dataclass
|
25 |
class ChatbotConfig:
|
26 |
"""Configuration for the RetrievalChatbot."""
|
27 |
-
vocab_size: int = 30526 # DistilBERT vocab size + special tokens
|
28 |
max_context_token_limit: int = 512
|
29 |
embedding_dim: int = 768
|
30 |
encoder_units: int = 256
|
31 |
num_attention_heads: int = 8
|
32 |
dropout_rate: float = 0.2
|
33 |
l2_reg_weight: float = 0.001
|
34 |
-
margin: float = 0.3
|
35 |
learning_rate: float = 0.001
|
36 |
min_text_length: int = 3
|
37 |
max_context_turns: int = 5
|
@@ -39,16 +38,19 @@ class ChatbotConfig:
|
|
39 |
pretrained_model: str = 'distilbert-base-uncased'
|
40 |
dtype: str = 'float32'
|
41 |
freeze_embeddings: bool = False
|
42 |
-
embedding_batch_size: int =
|
43 |
-
|
|
|
|
|
|
|
44 |
|
45 |
-
def to_dict(self) ->
|
46 |
"""Convert config to dictionary."""
|
47 |
-
return {k: str(v) if isinstance(v, Path) else v
|
48 |
for k, v in self.__dict__.items()}
|
49 |
|
50 |
@classmethod
|
51 |
-
def from_dict(cls, config_dict:
|
52 |
"""Create config from dictionary."""
|
53 |
return cls(**{k: v for k, v in config_dict.items()
|
54 |
if k in cls.__dataclass_fields__})
|
@@ -59,24 +61,17 @@ class EncoderModel(tf.keras.Model):
|
|
59 |
self,
|
60 |
config: ChatbotConfig,
|
61 |
name: str = "encoder",
|
62 |
-
shared_weights: bool = False,
|
63 |
**kwargs
|
64 |
):
|
65 |
super().__init__(name=name, **kwargs)
|
66 |
self.config = config
|
67 |
-
self.shared_weights = shared_weights
|
68 |
|
69 |
# Load pretrained model
|
70 |
self.pretrained = TFAutoModel.from_pretrained(config.pretrained_model)
|
71 |
|
72 |
-
# Freeze
|
73 |
-
self.
|
74 |
-
|
75 |
-
if i < 1: # freeze first layer
|
76 |
-
layer_module.trainable = False
|
77 |
-
else:
|
78 |
-
layer_module.trainable = True
|
79 |
-
|
80 |
# Pooling layer (Global Average Pooling)
|
81 |
self.pooler = tf.keras.layers.GlobalAveragePooling1D()
|
82 |
|
@@ -90,9 +85,27 @@ class EncoderModel(tf.keras.Model):
|
|
90 |
# Dropout and normalization
|
91 |
self.dropout = tf.keras.layers.Dropout(config.dropout_rate)
|
92 |
self.normalize = tf.keras.layers.Lambda(
|
93 |
-
lambda x: tf.nn.l2_normalize(x, axis=1)
|
|
|
94 |
)
|
95 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
96 |
def call(self, inputs: tf.Tensor, training: bool = False) -> tf.Tensor:
|
97 |
"""Forward pass."""
|
98 |
# Get pretrained embeddings
|
@@ -112,46 +125,33 @@ class EncoderModel(tf.keras.Model):
|
|
112 |
config = super().get_config()
|
113 |
config.update({
|
114 |
"config": self.config.to_dict(),
|
115 |
-
"shared_weights": self.shared_weights,
|
116 |
"name": self.name
|
117 |
})
|
118 |
return config
|
119 |
|
120 |
class RetrievalChatbot(DeviceAwareModel):
|
121 |
"""Retrieval-based chatbot using pretrained embeddings and FAISS for similarity search."""
|
122 |
-
def __init__(
|
123 |
-
|
124 |
-
|
125 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
126 |
self.config = config
|
127 |
self.strategy = strategy
|
128 |
-
self.
|
129 |
-
|
130 |
-
if reranker is None:
|
131 |
-
logger.info("Creating default CrossEncoderReranker...")
|
132 |
-
reranker = CrossEncoderReranker(model_name="cross-encoder/ms-marco-MiniLM-L-12-v2")
|
133 |
-
self.reranker = reranker
|
134 |
-
|
135 |
-
if summarizer is None:
|
136 |
-
logger.info("Creating default Summarizer...")
|
137 |
-
summarizer = Summarizer(device=self.device)
|
138 |
-
self.summarizer = summarizer
|
139 |
-
|
140 |
-
# Special tokens
|
141 |
-
self.special_tokens = {
|
142 |
-
"user": "<USER>",
|
143 |
-
"assistant": "<ASSISTANT>",
|
144 |
-
"context": "<CONTEXT>",
|
145 |
-
"sep": "<SEP>"
|
146 |
-
}
|
147 |
-
|
148 |
-
# Initialize tokenizer and add special tokens
|
149 |
-
self.tokenizer = AutoTokenizer.from_pretrained(config.pretrained_model)
|
150 |
-
self.tokenizer.add_special_tokens(
|
151 |
-
{'additional_special_tokens': list(self.special_tokens.values())}
|
152 |
-
)
|
153 |
|
|
|
|
|
|
|
|
|
154 |
self.memory_monitor = GPUMemoryMonitor()
|
|
|
|
|
155 |
self.min_batch_size = 8
|
156 |
self.max_batch_size = 128
|
157 |
self.current_batch_size = 32
|
@@ -166,9 +166,62 @@ class RetrievalChatbot(DeviceAwareModel):
|
|
166 |
"train_metrics": {},
|
167 |
"val_metrics": {}
|
168 |
}
|
169 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
170 |
def build_models(self):
|
171 |
-
"""Initialize the shared encoder."""
|
172 |
logger.info("Building encoder model...")
|
173 |
tf.keras.backend.clear_session()
|
174 |
|
@@ -176,6 +229,7 @@ class RetrievalChatbot(DeviceAwareModel):
|
|
176 |
self.encoder = EncoderModel(
|
177 |
self.config,
|
178 |
name="shared_encoder",
|
|
|
179 |
)
|
180 |
|
181 |
# Resize token embeddings after adding special tokens
|
@@ -183,31 +237,14 @@ class RetrievalChatbot(DeviceAwareModel):
|
|
183 |
self.encoder.pretrained.resize_token_embeddings(new_vocab_size)
|
184 |
logger.info(f"Token embeddings resized to: {new_vocab_size}")
|
185 |
|
186 |
-
# Initialize FAISS index
|
187 |
self._initialize_faiss()
|
188 |
-
# Compute embeddings after FAISS is initialized and moved
|
189 |
-
self._compute_and_index_embeddings()
|
190 |
|
191 |
-
#
|
192 |
-
|
193 |
-
# First try: from config
|
194 |
-
embedding_dim = self.encoder.pretrained.config.dim
|
195 |
-
logger.info("Got embedding dim from config")
|
196 |
-
except AttributeError:
|
197 |
-
try:
|
198 |
-
# Second try: from word embeddings
|
199 |
-
embedding_dim = self.encoder.pretrained.distilbert.embeddings.word_embeddings.embedding_dim
|
200 |
-
logger.info("Got embedding dim from word embeddings")
|
201 |
-
except AttributeError:
|
202 |
-
try:
|
203 |
-
# Third try: from embeddings module
|
204 |
-
embedding_dim = self.encoder.pretrained.distilbert.embeddings.embedding_dim
|
205 |
-
logger.info("Got embedding dim from embeddings module")
|
206 |
-
except AttributeError:
|
207 |
-
# Fallback to config value
|
208 |
-
embedding_dim = self.config.embedding_dim
|
209 |
-
logger.info("Using config embedding dim")
|
210 |
|
|
|
|
|
211 |
vocab_size = len(self.tokenizer)
|
212 |
|
213 |
logger.info(f"Encoder Embedding Dimension: {embedding_dim}")
|
@@ -217,29 +254,6 @@ class RetrievalChatbot(DeviceAwareModel):
|
|
217 |
else:
|
218 |
logger.error("Vocabulary size is less than embedding dimension.")
|
219 |
raise ValueError("Vocabulary size is less than embedding dimension.")
|
220 |
-
|
221 |
-
def _collect_responses(self, dialogues: List[dict]) -> Tuple[List[str], List[str]]:
|
222 |
-
"""Collect all unique responses from dialogues."""
|
223 |
-
logger.info("Collecting responses from dialogues...")
|
224 |
-
|
225 |
-
responses = []
|
226 |
-
try:
|
227 |
-
progress_bar = tqdm(dialogues, desc="Collecting assistant responses")
|
228 |
-
except ImportError:
|
229 |
-
progress_bar = dialogues
|
230 |
-
logger.info("Progress bar disabled - continuing without visual progress")
|
231 |
-
|
232 |
-
for dialogue in progress_bar:
|
233 |
-
turns = dialogue.get('turns', [])
|
234 |
-
for turn in turns:
|
235 |
-
if turn.get('speaker') == 'assistant' and 'text' in turn:
|
236 |
-
responses.append(turn['text'].strip())
|
237 |
-
|
238 |
-
# Remove duplicates
|
239 |
-
unique_responses = list(set(responses))
|
240 |
-
logger.info(f"Found {len(unique_responses)} unique responses.")
|
241 |
-
|
242 |
-
return responses, unique_responses
|
243 |
|
244 |
def _adjust_batch_size(self) -> None:
|
245 |
"""Dynamically adjust batch size based on GPU memory usage."""
|
@@ -288,6 +302,7 @@ class RetrievalChatbot(DeviceAwareModel):
|
|
288 |
logger.warning(f"Using CPU due to GPU initialization error: {e}")
|
289 |
|
290 |
# TODO: figure out buf with faiss-gpu
|
|
|
291 |
try:
|
292 |
# Create appropriate index based on dataset size
|
293 |
if len(self.unique_responses) < 1000:
|
@@ -860,33 +875,33 @@ class RetrievalChatbot(DeviceAwareModel):
|
|
860 |
logger.info(f"Models and tokenizer loaded from {load_dir}.")
|
861 |
return chatbot
|
862 |
|
863 |
-
@staticmethod
|
864 |
-
def load_training_data(data_path: Union[str, Path], debug_samples: Optional[int] = None) -> List[dict]:
|
865 |
-
|
866 |
-
|
867 |
|
868 |
-
|
869 |
-
|
870 |
-
|
871 |
|
872 |
-
|
873 |
-
|
874 |
-
|
875 |
-
|
876 |
-
|
877 |
-
|
878 |
-
|
879 |
-
|
880 |
|
881 |
-
|
882 |
-
|
883 |
|
884 |
-
|
885 |
-
|
886 |
-
|
887 |
|
888 |
-
|
889 |
-
|
890 |
|
891 |
def train_streaming(
|
892 |
self,
|
@@ -1336,522 +1351,3 @@ class RetrievalChatbot(DeviceAwareModel):
|
|
1336 |
|
1337 |
conversation_parts.append(f"{self.special_tokens['user']} {query}")
|
1338 |
return "\n".join(conversation_parts)
|
1339 |
-
|
1340 |
-
class TFDataPipeline:
|
1341 |
-
def __init__(
|
1342 |
-
self,
|
1343 |
-
embedding_batch_size,
|
1344 |
-
tokenizer,
|
1345 |
-
encoder,
|
1346 |
-
index,
|
1347 |
-
response_pool,
|
1348 |
-
max_length: int,
|
1349 |
-
neg_samples: int,
|
1350 |
-
):
|
1351 |
-
self.embedding_batch_size = embedding_batch_size
|
1352 |
-
self.tokenizer = tokenizer
|
1353 |
-
self.encoder = encoder
|
1354 |
-
self.index = index # CPU version of the index
|
1355 |
-
self.response_pool = response_pool
|
1356 |
-
self.max_length = max_length
|
1357 |
-
self.neg_samples = neg_samples
|
1358 |
-
self.embedding_batch_size = 16 if len(response_pool) < 100 else 64
|
1359 |
-
self.search_batch_size = 8 if len(response_pool) < 100 else 32
|
1360 |
-
self.max_batch_size = 32 if len(response_pool) < 100 else 256
|
1361 |
-
self.memory_monitor = GPUMemoryMonitor()
|
1362 |
-
self.max_retries = 3
|
1363 |
-
|
1364 |
-
# In-memory cache for embeddings
|
1365 |
-
self.query_embeddings_cache = {}
|
1366 |
-
|
1367 |
-
def _extract_pairs_from_dialogue(self, dialogue: dict) -> List[Tuple[str, str]]:
|
1368 |
-
"""Extract query-response pairs from a dialogue."""
|
1369 |
-
pairs = []
|
1370 |
-
turns = dialogue.get('turns', [])
|
1371 |
-
|
1372 |
-
for i in range(len(turns) - 1):
|
1373 |
-
current_turn = turns[i]
|
1374 |
-
next_turn = turns[i+1]
|
1375 |
-
|
1376 |
-
if (current_turn.get('speaker') == 'user' and
|
1377 |
-
next_turn.get('speaker') == 'assistant' and
|
1378 |
-
'text' in current_turn and
|
1379 |
-
'text' in next_turn):
|
1380 |
-
|
1381 |
-
query = current_turn['text'].strip()
|
1382 |
-
positive = next_turn['text'].strip()
|
1383 |
-
pairs.append((query, positive))
|
1384 |
-
|
1385 |
-
return pairs
|
1386 |
-
|
1387 |
-
def estimate_total_pairs(self, dialogues: List[dict]) -> int:
|
1388 |
-
"""Estimate total number of training pairs including hard negatives."""
|
1389 |
-
base_pairs = sum(
|
1390 |
-
len([
|
1391 |
-
1 for i in range(len(d.get('turns', [])) - 1)
|
1392 |
-
if (d['turns'][i].get('speaker') == 'user' and
|
1393 |
-
d['turns'][i+1].get('speaker') == 'assistant')
|
1394 |
-
])
|
1395 |
-
for d in dialogues
|
1396 |
-
)
|
1397 |
-
# Account for hard negatives
|
1398 |
-
return base_pairs * (1 + self.neg_samples)
|
1399 |
-
|
1400 |
-
def _find_hard_negatives_batch(self, queries: List[str], positives: List[str]) -> List[List[str]]:
|
1401 |
-
"""Find hard negatives for a batch of queries with error handling and retries."""
|
1402 |
-
retry_count = 0
|
1403 |
-
total_responses = len(self.response_pool)
|
1404 |
-
|
1405 |
-
while retry_count < self.max_retries:
|
1406 |
-
try:
|
1407 |
-
query_embeddings = np.vstack([
|
1408 |
-
self.query_embeddings_cache[q] for q in queries
|
1409 |
-
]).astype(np.float32)
|
1410 |
-
|
1411 |
-
query_embeddings = np.ascontiguousarray(query_embeddings)
|
1412 |
-
faiss.normalize_L2(query_embeddings)
|
1413 |
-
|
1414 |
-
k = 1 # TODO: try higher k for better results
|
1415 |
-
#logger.debug(f"Searching with k={k} among {total_responses} responses")
|
1416 |
-
|
1417 |
-
distances, indices = self.index.search(query_embeddings, k)
|
1418 |
-
|
1419 |
-
all_negatives = []
|
1420 |
-
for query_indices, query, positive in zip(indices, queries, positives):
|
1421 |
-
negatives = []
|
1422 |
-
positive_strip = positive.strip()
|
1423 |
-
seen = {positive_strip}
|
1424 |
-
|
1425 |
-
for idx in query_indices:
|
1426 |
-
if idx >= 0 and idx < total_responses:
|
1427 |
-
candidate = self.response_pool[idx].strip()
|
1428 |
-
if candidate and candidate not in seen:
|
1429 |
-
seen.add(candidate)
|
1430 |
-
negatives.append(candidate)
|
1431 |
-
if len(negatives) >= self.neg_samples:
|
1432 |
-
break
|
1433 |
-
|
1434 |
-
# Pad with a special empty negative if necessary
|
1435 |
-
while len(negatives) < self.neg_samples:
|
1436 |
-
negatives.append("<EMPTY_NEGATIVE>") # Use a special token
|
1437 |
-
|
1438 |
-
all_negatives.append(negatives)
|
1439 |
-
|
1440 |
-
return all_negatives
|
1441 |
-
|
1442 |
-
except Exception as e:
|
1443 |
-
retry_count += 1
|
1444 |
-
logger.warning(f"Hard negative search attempt {retry_count} failed: {e}")
|
1445 |
-
if retry_count == self.max_retries:
|
1446 |
-
logger.error("Max retries reached for hard negative search")
|
1447 |
-
return [["<EMPTY_NEGATIVE>"] * self.neg_samples for _ in queries] # Return empty negatives for all queries
|
1448 |
-
gc.collect()
|
1449 |
-
if tf.config.list_physical_devices('GPU'):
|
1450 |
-
tf.keras.backend.clear_session()
|
1451 |
-
|
1452 |
-
def _tokenize_negatives_tf(self, negatives):
|
1453 |
-
"""Tokenizes negatives using tf.py_function."""
|
1454 |
-
# Handle the case where negatives is an empty tensor
|
1455 |
-
if tf.size(negatives) == 0:
|
1456 |
-
return tf.zeros([0, self.neg_samples, self.max_length], dtype=tf.int32)
|
1457 |
-
|
1458 |
-
# Convert EagerTensor to a list of strings
|
1459 |
-
negatives_list = []
|
1460 |
-
for neg_list in negatives.numpy():
|
1461 |
-
decoded_negs = [neg.decode("utf-8") for neg in neg_list if neg] # Filter out empty strings
|
1462 |
-
negatives_list.append(decoded_negs)
|
1463 |
-
|
1464 |
-
# Flatten the list of lists
|
1465 |
-
flattened_negatives = [neg for sublist in negatives_list for neg in sublist]
|
1466 |
-
|
1467 |
-
# Tokenize the flattened negatives
|
1468 |
-
if flattened_negatives:
|
1469 |
-
n_tokens = self.tokenizer(
|
1470 |
-
flattened_negatives,
|
1471 |
-
padding='max_length',
|
1472 |
-
truncation=True,
|
1473 |
-
max_length=self.max_length,
|
1474 |
-
return_tensors='tf'
|
1475 |
-
)
|
1476 |
-
# Reshape the tokens
|
1477 |
-
n_tokens_reshaped = tf.reshape(n_tokens['input_ids'], [-1, self.neg_samples, self.max_length])
|
1478 |
-
return n_tokens_reshaped
|
1479 |
-
else:
|
1480 |
-
return tf.zeros([0, self.neg_samples, self.max_length], dtype=tf.int32)
|
1481 |
-
|
1482 |
-
def _compute_embeddings(self, queries: List[str]) -> None:
|
1483 |
-
"""Computes and caches embeddings for new queries."""
|
1484 |
-
new_queries = [q for q in queries if q not in self.query_embeddings_cache]
|
1485 |
-
if not new_queries:
|
1486 |
-
return # All queries already cached
|
1487 |
-
|
1488 |
-
new_embeddings = []
|
1489 |
-
for i in range(0, len(new_queries), self.embedding_batch_size):
|
1490 |
-
batch_queries = new_queries[i:i + self.embedding_batch_size]
|
1491 |
-
|
1492 |
-
encoded = self.tokenizer(
|
1493 |
-
batch_queries,
|
1494 |
-
padding=True,
|
1495 |
-
truncation=True,
|
1496 |
-
max_length=self.max_length,
|
1497 |
-
return_tensors='tf'
|
1498 |
-
)
|
1499 |
-
|
1500 |
-
# Compute embeddings on CPU
|
1501 |
-
with tf.device('/CPU:0'):
|
1502 |
-
batch_embeddings = self.encoder(encoded['input_ids'], training=False).numpy()
|
1503 |
-
|
1504 |
-
new_embeddings.extend(batch_embeddings)
|
1505 |
-
|
1506 |
-
# Update cache with new embeddings
|
1507 |
-
for query, emb in zip(new_queries, new_embeddings):
|
1508 |
-
self.query_embeddings_cache[query] = emb
|
1509 |
-
|
1510 |
-
def data_generator(self, dialogues: List[dict]) -> Generator[Tuple[str, str, List[str]], None, None]:
|
1511 |
-
"""
|
1512 |
-
Generates training examples: (query, positive, hard_negatives).
|
1513 |
-
Wrapped the outer loop with tqdm for progress tracking.
|
1514 |
-
"""
|
1515 |
-
total_dialogues = len(dialogues)
|
1516 |
-
logger.debug(f"Total dialogues to process: {total_dialogues}")
|
1517 |
-
|
1518 |
-
# Initialize tqdm progress bar
|
1519 |
-
with tqdm(total=total_dialogues, desc="Processing Dialogues", unit="dialogue") as pbar:
|
1520 |
-
for dialogue in dialogues:
|
1521 |
-
pairs = self._extract_pairs_from_dialogue(dialogue)
|
1522 |
-
for query, positive in pairs:
|
1523 |
-
# Ensure embeddings are computed, find hard negatives, etc.
|
1524 |
-
self._compute_embeddings([query])
|
1525 |
-
hard_negatives = self._find_hard_negatives_batch([query], [positive])[0]
|
1526 |
-
yield (query, positive, hard_negatives)
|
1527 |
-
pbar.update(1)
|
1528 |
-
|
1529 |
-
def _prepare_batch(self, queries: tf.Tensor, positives: tf.Tensor, negatives: tf.Tensor) -> Tuple[tf.Tensor, tf.Tensor]:
|
1530 |
-
"""Prepares a batch of data for training."""
|
1531 |
-
|
1532 |
-
# Convert EagerTensors to lists of strings
|
1533 |
-
queries_list = [query.decode("utf-8") for query in queries.numpy()]
|
1534 |
-
positives_list = [pos.decode("utf-8") for pos in positives.numpy()]
|
1535 |
-
|
1536 |
-
# Tokenize queries and positives
|
1537 |
-
q_tokens = self.tokenizer(queries_list, padding='max_length', truncation=True, max_length=self.max_length, return_tensors='tf')
|
1538 |
-
p_tokens = self.tokenizer(positives_list, padding='max_length', truncation=True, max_length=self.max_length, return_tensors='tf')
|
1539 |
-
|
1540 |
-
# Decode negatives and ensure they are lists of strings
|
1541 |
-
negatives_list = []
|
1542 |
-
for neg_list in negatives.numpy():
|
1543 |
-
decoded_negs = [neg.decode("utf-8") for neg in neg_list if neg] # Filter out empty strings
|
1544 |
-
negatives_list.append(decoded_negs)
|
1545 |
-
|
1546 |
-
# Flatten negatives for tokenization if there are any valid negatives
|
1547 |
-
flattened_negatives = [neg for sublist in negatives_list for neg in sublist if neg]
|
1548 |
-
|
1549 |
-
# Tokenize negatives if there are any
|
1550 |
-
n_tokens_reshaped = None
|
1551 |
-
if flattened_negatives:
|
1552 |
-
n_tokens = self.tokenizer(flattened_negatives, padding='max_length', truncation=True, max_length=self.max_length, return_tensors='tf')
|
1553 |
-
|
1554 |
-
# Reshape n_tokens to match the expected shape based on the number of negatives per query
|
1555 |
-
# This part may need adjustment if the number of negatives varies per query
|
1556 |
-
n_tokens_reshaped = tf.reshape(n_tokens['input_ids'], [len(queries_list), -1, self.max_length])
|
1557 |
-
else:
|
1558 |
-
# Create a placeholder tensor for the case where there are no negatives
|
1559 |
-
n_tokens_reshaped = tf.zeros([len(queries_list), 0, self.max_length], dtype=tf.int32)
|
1560 |
-
|
1561 |
-
# Ensure n_tokens_reshaped has a consistent shape even when there are no negatives
|
1562 |
-
# Adjust shape to [batch_size, num_neg_samples, max_length]
|
1563 |
-
if n_tokens_reshaped.shape[1] != self.neg_samples:
|
1564 |
-
# Pad or truncate the second dimension to match neg_samples
|
1565 |
-
padding = tf.zeros([len(queries_list), tf.maximum(0, self.neg_samples - n_tokens_reshaped.shape[1]), self.max_length], dtype=tf.int32)
|
1566 |
-
n_tokens_reshaped = tf.concat([n_tokens_reshaped, padding], axis=1)
|
1567 |
-
n_tokens_reshaped = n_tokens_reshaped[:, :self.neg_samples, :]
|
1568 |
-
|
1569 |
-
# Concatenate the positive and negative examples along the 'neg_samples' dimension
|
1570 |
-
combined_p_n_tokens = tf.concat([tf.expand_dims(p_tokens['input_ids'], axis=1), n_tokens_reshaped], axis=1)
|
1571 |
-
|
1572 |
-
return q_tokens['input_ids'], combined_p_n_tokens
|
1573 |
-
|
1574 |
-
def get_tf_dataset(self, dialogues: List[dict], batch_size: int) -> tf.data.Dataset:
|
1575 |
-
"""
|
1576 |
-
Creates a tf.data.Dataset for streaming training that yields
|
1577 |
-
(input_ids_query, input_ids_positive, input_ids_negatives).
|
1578 |
-
"""
|
1579 |
-
# 1) Start with a generator dataset
|
1580 |
-
dataset = tf.data.Dataset.from_generator(
|
1581 |
-
lambda: self.data_generator(dialogues),
|
1582 |
-
output_signature=(
|
1583 |
-
tf.TensorSpec(shape=(), dtype=tf.string), # Query (single string)
|
1584 |
-
tf.TensorSpec(shape=(), dtype=tf.string), # Positive (single string)
|
1585 |
-
tf.TensorSpec(shape=(None,), dtype=tf.string) # Hard Negatives (list of strings)
|
1586 |
-
)
|
1587 |
-
)
|
1588 |
-
|
1589 |
-
# 2) Batch the raw strings
|
1590 |
-
dataset = dataset.batch(batch_size)
|
1591 |
-
|
1592 |
-
# 3) Now map them through a tokenize step (via py_function)
|
1593 |
-
dataset = dataset.map(
|
1594 |
-
lambda q, p, n: self._tokenize_triple(q, p, n),
|
1595 |
-
num_parallel_calls=1 #tf.data.AUTOTUNE
|
1596 |
-
)
|
1597 |
-
|
1598 |
-
dataset = dataset.prefetch(tf.data.AUTOTUNE)
|
1599 |
-
return dataset
|
1600 |
-
|
1601 |
-
def _tokenize_triple(
|
1602 |
-
self,
|
1603 |
-
q: tf.Tensor,
|
1604 |
-
p: tf.Tensor,
|
1605 |
-
n: tf.Tensor
|
1606 |
-
) -> Tuple[tf.Tensor, tf.Tensor, tf.Tensor]:
|
1607 |
-
"""
|
1608 |
-
Wraps a Python function via tf.py_function to convert tf.Tensors of strings
|
1609 |
-
-> Python lists of strings -> HF tokenizer -> Tensors of IDs.
|
1610 |
-
|
1611 |
-
q is shape [batch_size], p is shape [batch_size],
|
1612 |
-
n is shape [batch_size, neg_samples] (i.e., each row is a list of negatives).
|
1613 |
-
"""
|
1614 |
-
# Use tf.py_function with limited parallelism
|
1615 |
-
q_ids, p_ids, n_ids = tf.py_function(
|
1616 |
-
func=self._tokenize_triple_py,
|
1617 |
-
inp=[q, p, n, tf.constant(self.max_length), tf.constant(self.neg_samples)],
|
1618 |
-
Tout=[tf.int32, tf.int32, tf.int32]
|
1619 |
-
)
|
1620 |
-
|
1621 |
-
# Manually set shape information
|
1622 |
-
q_ids.set_shape([None, self.max_length]) # [batch_size, max_length]
|
1623 |
-
p_ids.set_shape([None, self.max_length]) # [batch_size, max_length]
|
1624 |
-
n_ids.set_shape([None, self.neg_samples, self.max_length]) # [batch_size, neg_samples, max_length]
|
1625 |
-
|
1626 |
-
return q_ids, p_ids, n_ids
|
1627 |
-
# def _tokenize_triple(
|
1628 |
-
# self,
|
1629 |
-
# q: tf.Tensor,
|
1630 |
-
# p: tf.Tensor,
|
1631 |
-
# n: tf.Tensor
|
1632 |
-
# ) -> Tuple[tf.Tensor, tf.Tensor, tf.Tensor]:
|
1633 |
-
# """
|
1634 |
-
# Wraps a Python function via tf.py_function to convert tf.Tensors of strings
|
1635 |
-
# -> Python lists of strings -> HF tokenizer -> Tensors of IDs.
|
1636 |
-
|
1637 |
-
# q is shape [batch_size], p is shape [batch_size],
|
1638 |
-
# n is shape [batch_size, None] (i.e. each row is a variable number of negatives).
|
1639 |
-
# """
|
1640 |
-
# # Use tf.py_function
|
1641 |
-
# # We pass in self.max_length as well, so we can do it in one shot.
|
1642 |
-
# q_ids, p_ids, n_ids = tf.py_function(
|
1643 |
-
# func=self._tokenize_triple_py,
|
1644 |
-
# inp=[q, p, n, tf.constant(self.max_length), tf.constant(self.neg_samples)],
|
1645 |
-
# Tout=[tf.int32, tf.int32, tf.int32]
|
1646 |
-
# )
|
1647 |
-
|
1648 |
-
# # We must manually set shape information so that TF data pipeline knows the dimensions
|
1649 |
-
# q_ids.set_shape([None, self.max_length]) # [batch_size, max_length]
|
1650 |
-
# p_ids.set_shape([None, self.max_length]) # [batch_size, max_length]
|
1651 |
-
# n_ids.set_shape([None, self.neg_samples, self.max_length])
|
1652 |
-
# # The negative dimension is set to `self.neg_samples` for consistency.
|
1653 |
-
|
1654 |
-
# return q_ids, p_ids, n_ids
|
1655 |
-
|
1656 |
-
def _tokenize_triple_py(
|
1657 |
-
self,
|
1658 |
-
q: tf.Tensor,
|
1659 |
-
p: tf.Tensor,
|
1660 |
-
n: tf.Tensor,
|
1661 |
-
max_len: tf.Tensor,
|
1662 |
-
neg_samples: tf.Tensor
|
1663 |
-
) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
|
1664 |
-
"""
|
1665 |
-
Python function that:
|
1666 |
-
- Decodes each tf.string Tensor to a Python list of strings
|
1667 |
-
- Calls the HF tokenizer
|
1668 |
-
- Reshapes negatives
|
1669 |
-
- Returns np.array of int32s for (q_ids, p_ids, n_ids).
|
1670 |
-
|
1671 |
-
q: shape [batch_size], p: shape [batch_size]
|
1672 |
-
n: shape [batch_size, neg_samples]
|
1673 |
-
max_len: scalar int
|
1674 |
-
neg_samples: scalar int
|
1675 |
-
"""
|
1676 |
-
max_len = int(max_len.numpy()) # Convert to Python int
|
1677 |
-
neg_samples = int(neg_samples.numpy())
|
1678 |
-
|
1679 |
-
# 1) Convert Tensors -> Python lists of strings
|
1680 |
-
q_list = [q_i.decode("utf-8") for q_i in q.numpy()] # shape [batch_size]
|
1681 |
-
p_list = [p_i.decode("utf-8") for p_i in p.numpy()] # shape [batch_size]
|
1682 |
-
|
1683 |
-
# shape [batch_size, neg_samples], decode each row
|
1684 |
-
n_list = []
|
1685 |
-
for row in n.numpy():
|
1686 |
-
# row is shape [neg_samples], each is a tf.string
|
1687 |
-
decoded = [neg.decode("utf-8") for neg in row]
|
1688 |
-
n_list.append(decoded)
|
1689 |
-
|
1690 |
-
# 2) Tokenize queries & positives
|
1691 |
-
q_enc = self.tokenizer(
|
1692 |
-
q_list,
|
1693 |
-
padding="max_length",
|
1694 |
-
truncation=True,
|
1695 |
-
max_length=max_len,
|
1696 |
-
return_tensors="np"
|
1697 |
-
)
|
1698 |
-
p_enc = self.tokenizer(
|
1699 |
-
p_list,
|
1700 |
-
padding="max_length",
|
1701 |
-
truncation=True,
|
1702 |
-
max_length=max_len,
|
1703 |
-
return_tensors="np"
|
1704 |
-
)
|
1705 |
-
|
1706 |
-
# 3) Tokenize negatives
|
1707 |
-
# Flatten [batch_size, neg_samples] -> single list
|
1708 |
-
flattened_negatives = [neg for row in n_list for neg in row]
|
1709 |
-
if len(flattened_negatives) == 0:
|
1710 |
-
# No negatives at all: return a zero array
|
1711 |
-
n_ids = np.zeros((len(q_list), neg_samples, max_len), dtype=np.int32)
|
1712 |
-
else:
|
1713 |
-
n_enc = self.tokenizer(
|
1714 |
-
flattened_negatives,
|
1715 |
-
padding="max_length",
|
1716 |
-
truncation=True,
|
1717 |
-
max_length=max_len,
|
1718 |
-
return_tensors="np"
|
1719 |
-
)
|
1720 |
-
# shape [batch_size * neg_samples, max_len]
|
1721 |
-
n_input_ids = n_enc["input_ids"]
|
1722 |
-
|
1723 |
-
# We want to reshape to [batch_size, neg_samples, max_len]
|
1724 |
-
# Handle cases where there might be fewer negatives
|
1725 |
-
batch_size = len(q_list)
|
1726 |
-
n_ids_list = []
|
1727 |
-
for i in range(batch_size):
|
1728 |
-
start_idx = i * neg_samples
|
1729 |
-
end_idx = start_idx + neg_samples
|
1730 |
-
row_negs = n_input_ids[start_idx:end_idx]
|
1731 |
-
|
1732 |
-
# If fewer negatives, pad with zeros
|
1733 |
-
if row_negs.shape[0] < neg_samples:
|
1734 |
-
deficit = neg_samples - row_negs.shape[0]
|
1735 |
-
pad_arr = np.zeros((deficit, max_len), dtype=np.int32)
|
1736 |
-
row_negs = np.concatenate([row_negs, pad_arr], axis=0)
|
1737 |
-
|
1738 |
-
n_ids_list.append(row_negs)
|
1739 |
-
|
1740 |
-
# stack them -> shape [batch_size, neg_samples, max_len]
|
1741 |
-
n_ids = np.stack(n_ids_list, axis=0)
|
1742 |
-
|
1743 |
-
# 4) Return as np.int32 arrays
|
1744 |
-
q_ids = q_enc["input_ids"].astype(np.int32) # shape [batch_size, max_len]
|
1745 |
-
p_ids = p_enc["input_ids"].astype(np.int32) # shape [batch_size, max_len]
|
1746 |
-
n_ids = n_ids.astype(np.int32) # shape [batch_size, neg_samples, max_len]
|
1747 |
-
|
1748 |
-
return q_ids, p_ids, n_ids
|
1749 |
-
# def _tokenize_triple_py(
|
1750 |
-
# self,
|
1751 |
-
# q: tf.Tensor,
|
1752 |
-
# p: tf.Tensor,
|
1753 |
-
# n: tf.Tensor,
|
1754 |
-
# max_len: tf.Tensor,
|
1755 |
-
# neg_samples: tf.Tensor
|
1756 |
-
# ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
|
1757 |
-
# """
|
1758 |
-
# Python function that:
|
1759 |
-
# - Decodes each tf.string Tensor to a Python list of strings
|
1760 |
-
# - Calls the HF tokenizer
|
1761 |
-
# - Reshapes negatives
|
1762 |
-
# - Returns np.array of int32s for (q_ids, p_ids, n_ids).
|
1763 |
-
|
1764 |
-
# q: shape [batch_size], p: shape [batch_size]
|
1765 |
-
# n: shape [batch_size, None]
|
1766 |
-
# max_len: scalar int
|
1767 |
-
# neg_samples: scalar int
|
1768 |
-
# """
|
1769 |
-
# max_len = int(max_len.numpy()) # convert to python int
|
1770 |
-
# neg_samples = int(neg_samples.numpy())
|
1771 |
-
|
1772 |
-
# # 1) Convert Tensors -> Python lists of strings
|
1773 |
-
# q_list = [q_i.decode("utf-8") for q_i in q.numpy()] # shape [batch_size]
|
1774 |
-
# p_list = [p_i.decode("utf-8") for p_i in p.numpy()] # shape [batch_size]
|
1775 |
-
|
1776 |
-
# # shape [batch_size, variable_negatives], decode each row
|
1777 |
-
# n_list = []
|
1778 |
-
# for row in n.numpy():
|
1779 |
-
# # row is shape [N], each is a tf.string
|
1780 |
-
# decoded = [neg.decode("utf-8") for neg in row]
|
1781 |
-
# n_list.append(decoded)
|
1782 |
-
|
1783 |
-
# # 2) Tokenize queries & positives
|
1784 |
-
# q_enc = self.tokenizer(
|
1785 |
-
# q_list,
|
1786 |
-
# padding="max_length",
|
1787 |
-
# truncation=True,
|
1788 |
-
# max_length=max_len,
|
1789 |
-
# return_tensors="np" # you can do return_tensors="tf", but "np" is often simpler here
|
1790 |
-
# )
|
1791 |
-
# p_enc = self.tokenizer(
|
1792 |
-
# p_list,
|
1793 |
-
# padding="max_length",
|
1794 |
-
# truncation=True,
|
1795 |
-
# max_length=max_len,
|
1796 |
-
# return_tensors="np"
|
1797 |
-
# )
|
1798 |
-
|
1799 |
-
# # 3) Tokenize negatives
|
1800 |
-
# # Flatten [batch_size, variable_negatives] -> single list
|
1801 |
-
# flattened_negatives = [neg for row in n_list for neg in row]
|
1802 |
-
# if len(flattened_negatives) == 0:
|
1803 |
-
# # No negatives at all: return a zero array
|
1804 |
-
# n_ids = np.zeros((len(q_list), neg_samples, max_len), dtype=np.int32)
|
1805 |
-
# else:
|
1806 |
-
# n_enc = self.tokenizer(
|
1807 |
-
# flattened_negatives,
|
1808 |
-
# padding="max_length",
|
1809 |
-
# truncation=True,
|
1810 |
-
# max_length=max_len,
|
1811 |
-
# return_tensors="np"
|
1812 |
-
# )
|
1813 |
-
# # shape [batch_size * total_negatives, max_len]
|
1814 |
-
# n_input_ids = n_enc["input_ids"]
|
1815 |
-
|
1816 |
-
# # We want to reshape to [batch_size, neg_samples, max_len].
|
1817 |
-
# # If each row truly has exactly `neg_samples` (or fewer), we can do:
|
1818 |
-
# # n_input_ids = n_input_ids.reshape(len(q_list), neg_samples, max_len)
|
1819 |
-
# # But if the rows have variable # of negatives, we must clamp or pad.
|
1820 |
-
# # For simplicity, let's just "take first neg_samples" per row
|
1821 |
-
# # and pad if fewer.
|
1822 |
-
|
1823 |
-
# # We'll do it row by row:
|
1824 |
-
# batch_size = len(q_list)
|
1825 |
-
# row_offsets = 0
|
1826 |
-
# n_ids_list = []
|
1827 |
-
# for row_idx in range(batch_size):
|
1828 |
-
# row_negs = n_list[row_idx]
|
1829 |
-
# row_count = len(row_negs)
|
1830 |
-
|
1831 |
-
# # slice from the flattened array
|
1832 |
-
# row_slice = n_input_ids[row_offsets:row_offsets + row_count]
|
1833 |
-
# row_offsets += row_count
|
1834 |
-
|
1835 |
-
# # Now pick out up to neg_samples
|
1836 |
-
# row_slice = row_slice[:neg_samples]
|
1837 |
-
|
1838 |
-
# # If fewer than neg_samples, pad
|
1839 |
-
# if row_slice.shape[0] < neg_samples:
|
1840 |
-
# deficit = neg_samples - row_slice.shape[0]
|
1841 |
-
# pad_arr = np.zeros((deficit, max_len), dtype=np.int32)
|
1842 |
-
# row_slice = np.concatenate([row_slice, pad_arr], axis=0)
|
1843 |
-
|
1844 |
-
# # row_slice is now shape [neg_samples, max_len]
|
1845 |
-
# n_ids_list.append(row_slice)
|
1846 |
-
|
1847 |
-
# # stack them -> shape [batch_size, neg_samples, max_len]
|
1848 |
-
# n_ids = np.stack(n_ids_list, axis=0)
|
1849 |
-
|
1850 |
-
# # 4) Return as np.int32 arrays (tokenizer should already return int32,
|
1851 |
-
# # but we can cast to be sure)
|
1852 |
-
# q_ids = q_enc["input_ids"].astype(np.int32) # shape [batch_size, max_len]
|
1853 |
-
# p_ids = p_enc["input_ids"].astype(np.int32) # shape [batch_size, max_len]
|
1854 |
-
# n_ids = n_ids.astype(np.int32) # shape [batch_size, neg_samples, max_len]
|
1855 |
-
|
1856 |
-
# return q_ids, p_ids, n_ids
|
1857 |
-
|
|
|
2 |
from transformers import TFAutoModel, AutoTokenizer
|
3 |
import tensorflow as tf
|
4 |
import numpy as np
|
5 |
+
from typing import List, Tuple, Dict, Optional, Union, Any
|
6 |
import math
|
7 |
from dataclasses import dataclass
|
8 |
import json
|
|
|
10 |
import datetime
|
11 |
import faiss
|
12 |
import gc
|
13 |
+
from tf_data_pipeline import TFDataPipeline
|
14 |
from response_quality_checker import ResponseQualityChecker
|
15 |
from cross_encoder_reranker import CrossEncoderReranker
|
16 |
from conversation_summarizer import DeviceAwareModel, Summarizer
|
|
|
25 |
@dataclass
|
26 |
class ChatbotConfig:
|
27 |
"""Configuration for the RetrievalChatbot."""
|
|
|
28 |
max_context_token_limit: int = 512
|
29 |
embedding_dim: int = 768
|
30 |
encoder_units: int = 256
|
31 |
num_attention_heads: int = 8
|
32 |
dropout_rate: float = 0.2
|
33 |
l2_reg_weight: float = 0.001
|
|
|
34 |
learning_rate: float = 0.001
|
35 |
min_text_length: int = 3
|
36 |
max_context_turns: int = 5
|
|
|
38 |
pretrained_model: str = 'distilbert-base-uncased'
|
39 |
dtype: str = 'float32'
|
40 |
freeze_embeddings: bool = False
|
41 |
+
embedding_batch_size: int = 64
|
42 |
+
search_batch_size: int = 64
|
43 |
+
max_batch_size: int = 64
|
44 |
+
neg_samples: int = 3
|
45 |
+
max_retries: int = 3
|
46 |
|
47 |
+
def to_dict(self) -> Dict:
|
48 |
"""Convert config to dictionary."""
|
49 |
+
return {k: (str(v) if isinstance(v, Path) else v)
|
50 |
for k, v in self.__dict__.items()}
|
51 |
|
52 |
@classmethod
|
53 |
+
def from_dict(cls, config_dict: Dict) -> 'ChatbotConfig':
|
54 |
"""Create config from dictionary."""
|
55 |
return cls(**{k: v for k, v in config_dict.items()
|
56 |
if k in cls.__dataclass_fields__})
|
|
|
61 |
self,
|
62 |
config: ChatbotConfig,
|
63 |
name: str = "encoder",
|
|
|
64 |
**kwargs
|
65 |
):
|
66 |
super().__init__(name=name, **kwargs)
|
67 |
self.config = config
|
|
|
68 |
|
69 |
# Load pretrained model
|
70 |
self.pretrained = TFAutoModel.from_pretrained(config.pretrained_model)
|
71 |
|
72 |
+
# Freeze layers based on config
|
73 |
+
self._freeze_layers()
|
74 |
+
|
|
|
|
|
|
|
|
|
|
|
75 |
# Pooling layer (Global Average Pooling)
|
76 |
self.pooler = tf.keras.layers.GlobalAveragePooling1D()
|
77 |
|
|
|
85 |
# Dropout and normalization
|
86 |
self.dropout = tf.keras.layers.Dropout(config.dropout_rate)
|
87 |
self.normalize = tf.keras.layers.Lambda(
|
88 |
+
lambda x: tf.nn.l2_normalize(x, axis=1),
|
89 |
+
name="l2_normalize"
|
90 |
)
|
91 |
|
92 |
+
def _freeze_layers(self):
|
93 |
+
"""Freeze layers of the pretrained model based on configuration."""
|
94 |
+
if self.config.freeze_embeddings:
|
95 |
+
self.pretrained.trainable = False
|
96 |
+
logger.info("All pretrained layers frozen.")
|
97 |
+
else:
|
98 |
+
# Freeze only the first 'n' transformer layers
|
99 |
+
for i, layer in enumerate(self.pretrained.layers):
|
100 |
+
if isinstance(layer, tf.keras.layers.Layer):
|
101 |
+
if hasattr(layer, 'trainable'):
|
102 |
+
# Freeze the first transformer block
|
103 |
+
if i < 1:
|
104 |
+
layer.trainable = False
|
105 |
+
logger.info(f"Layer {i} frozen.")
|
106 |
+
else:
|
107 |
+
layer.trainable = True
|
108 |
+
|
109 |
def call(self, inputs: tf.Tensor, training: bool = False) -> tf.Tensor:
|
110 |
"""Forward pass."""
|
111 |
# Get pretrained embeddings
|
|
|
125 |
config = super().get_config()
|
126 |
config.update({
|
127 |
"config": self.config.to_dict(),
|
|
|
128 |
"name": self.name
|
129 |
})
|
130 |
return config
|
131 |
|
132 |
class RetrievalChatbot(DeviceAwareModel):
|
133 |
"""Retrieval-based chatbot using pretrained embeddings and FAISS for similarity search."""
|
134 |
+
def __init__(
|
135 |
+
self,
|
136 |
+
config: ChatbotConfig,
|
137 |
+
dialogues: List[dict] = [],
|
138 |
+
device: str = None,
|
139 |
+
strategy=None,
|
140 |
+
reranker: Optional[CrossEncoderReranker] = None,
|
141 |
+
summarizer: Optional[Summarizer] = None
|
142 |
+
):
|
143 |
+
super().__init__()
|
144 |
self.config = config
|
145 |
self.strategy = strategy
|
146 |
+
self.device = device or self._setup_default_device()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
147 |
|
148 |
+
# Initialize reranker, summarizer, tokenizer, and memory monitor
|
149 |
+
self.reranker = reranker or self._initialize_reranker()
|
150 |
+
self.summarizer = summarizer or self._initialize_summarizer()
|
151 |
+
self.tokenizer = self._initialize_tokenizer()
|
152 |
self.memory_monitor = GPUMemoryMonitor()
|
153 |
+
|
154 |
+
# Initialize models
|
155 |
self.min_batch_size = 8
|
156 |
self.max_batch_size = 128
|
157 |
self.current_batch_size = 32
|
|
|
166 |
"train_metrics": {},
|
167 |
"val_metrics": {}
|
168 |
}
|
169 |
+
|
170 |
+
def _setup_default_device(self) -> str:
|
171 |
+
"""Set up default device if none is provided."""
|
172 |
+
if tf.config.list_physical_devices('GPU'):
|
173 |
+
return 'GPU'
|
174 |
+
else:
|
175 |
+
return 'CPU'
|
176 |
+
|
177 |
+
def _initialize_reranker(self) -> CrossEncoderReranker:
|
178 |
+
"""Initialize the CrossEncoderReranker."""
|
179 |
+
logger.info("Initializing default CrossEncoderReranker...")
|
180 |
+
return CrossEncoderReranker(model_name="cross-encoder/ms-marco-MiniLM-L-12-v2")
|
181 |
+
|
182 |
+
def _initialize_summarizer(self) -> Summarizer:
|
183 |
+
"""Initialize the Summarizer."""
|
184 |
+
logger.info("Initializing default Summarizer...")
|
185 |
+
return Summarizer(device=self.device)
|
186 |
+
|
187 |
+
def _initialize_tokenizer(self) -> AutoTokenizer:
|
188 |
+
"""Initialize the tokenizer and add special tokens."""
|
189 |
+
logger.info("Initializing tokenizer and adding special tokens...")
|
190 |
+
tokenizer = AutoTokenizer.from_pretrained(self.config.pretrained_model)
|
191 |
+
special_tokens = {
|
192 |
+
"user": "<USER>",
|
193 |
+
"assistant": "<ASSISTANT>",
|
194 |
+
"context": "<CONTEXT>",
|
195 |
+
"sep": "<SEP>"
|
196 |
+
}
|
197 |
+
tokenizer.add_special_tokens(
|
198 |
+
{'additional_special_tokens': list(special_tokens.values())}
|
199 |
+
)
|
200 |
+
return tokenizer
|
201 |
+
|
202 |
+
def _collect_responses(self, dialogues: List[dict]) -> Tuple[List[str], List[str]]:
|
203 |
+
"""
|
204 |
+
Collect unique responses from dialogues.
|
205 |
+
Returns:
|
206 |
+
response_pool: List of all possible responses.
|
207 |
+
unique_responses: List of unique responses.
|
208 |
+
"""
|
209 |
+
logger.info("Collecting unique responses from dialogues...")
|
210 |
+
responses = set()
|
211 |
+
for dialogue in dialogues:
|
212 |
+
turns = dialogue.get('turns', [])
|
213 |
+
for turn in turns:
|
214 |
+
if turn.get('speaker') == 'assistant' and 'text' in turn:
|
215 |
+
response = turn['text'].strip()
|
216 |
+
if len(response) >= self.config.min_text_length:
|
217 |
+
responses.add(response)
|
218 |
+
response_pool = list(responses)
|
219 |
+
unique_responses = list(responses) # Assuming uniqueness
|
220 |
+
logger.info(f"Collected {len(response_pool)} unique responses.")
|
221 |
+
return response_pool, unique_responses
|
222 |
+
|
223 |
def build_models(self):
|
224 |
+
"""Initialize the shared encoder and FAISS index."""
|
225 |
logger.info("Building encoder model...")
|
226 |
tf.keras.backend.clear_session()
|
227 |
|
|
|
229 |
self.encoder = EncoderModel(
|
230 |
self.config,
|
231 |
name="shared_encoder",
|
232 |
+
shared_weights=True # If weight sharing is intended
|
233 |
)
|
234 |
|
235 |
# Resize token embeddings after adding special tokens
|
|
|
237 |
self.encoder.pretrained.resize_token_embeddings(new_vocab_size)
|
238 |
logger.info(f"Token embeddings resized to: {new_vocab_size}")
|
239 |
|
240 |
+
# Initialize FAISS index
|
241 |
self._initialize_faiss()
|
|
|
|
|
242 |
|
243 |
+
# Compute and index embeddings
|
244 |
+
self._compute_and_index_embeddings()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
245 |
|
246 |
+
# Retrieve embedding dimension from encoder
|
247 |
+
embedding_dim = self.config.embedding_dim
|
248 |
vocab_size = len(self.tokenizer)
|
249 |
|
250 |
logger.info(f"Encoder Embedding Dimension: {embedding_dim}")
|
|
|
254 |
else:
|
255 |
logger.error("Vocabulary size is less than embedding dimension.")
|
256 |
raise ValueError("Vocabulary size is less than embedding dimension.")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
257 |
|
258 |
def _adjust_batch_size(self) -> None:
|
259 |
"""Dynamically adjust batch size based on GPU memory usage."""
|
|
|
302 |
logger.warning(f"Using CPU due to GPU initialization error: {e}")
|
303 |
|
304 |
# TODO: figure out buf with faiss-gpu
|
305 |
+
# TODO: consider IndexIVFFlat in the future (speed).
|
306 |
try:
|
307 |
# Create appropriate index based on dataset size
|
308 |
if len(self.unique_responses) < 1000:
|
|
|
875 |
logger.info(f"Models and tokenizer loaded from {load_dir}.")
|
876 |
return chatbot
|
877 |
|
878 |
+
# @staticmethod
|
879 |
+
# def load_training_data(data_path: Union[str, Path], debug_samples: Optional[int] = None) -> List[dict]:
|
880 |
+
# """
|
881 |
+
# Load training data from a JSON file.
|
882 |
|
883 |
+
# Args:
|
884 |
+
# data_path (Union[str, Path]): Path to the JSON file containing dialogues.
|
885 |
+
# debug_samples (Optional[int]): Number of samples to load for debugging.
|
886 |
|
887 |
+
# Returns:
|
888 |
+
# List[dict]: List of dialogue dictionaries.
|
889 |
+
# """
|
890 |
+
# logger.info(f"Loading training data from {data_path}...")
|
891 |
+
# data_path = Path(data_path)
|
892 |
+
# if not data_path.exists():
|
893 |
+
# logger.error(f"Data file {data_path} does not exist.")
|
894 |
+
# return []
|
895 |
|
896 |
+
# with open(data_path, 'r', encoding='utf-8') as f:
|
897 |
+
# dialogues = json.load(f)
|
898 |
|
899 |
+
# if debug_samples is not None:
|
900 |
+
# dialogues = dialogues[:debug_samples]
|
901 |
+
# logger.info(f"Debug mode: Limited to {debug_samples} dialogues")
|
902 |
|
903 |
+
# logger.info(f"Loaded {len(dialogues)} dialogues.")
|
904 |
+
# return dialogues
|
905 |
|
906 |
def train_streaming(
|
907 |
self,
|
|
|
1351 |
|
1352 |
conversation_parts.append(f"{self.special_tokens['user']} {query}")
|
1353 |
return "\n".join(conversation_parts)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
requirements.txt
CHANGED
@@ -1,4 +1,5 @@
|
|
1 |
faiss-cpu>=1.7.0 # Required for Facebook AI Similarity Search
|
|
|
2 |
ipython>=8.0.0 # For interactive Python
|
3 |
loguru>=0.7.0 # Enhanced logging (optional but recommended)
|
4 |
matplotlib>=3.5.0 # For validation plotting
|
|
|
1 |
faiss-cpu>=1.7.0 # Required for Facebook AI Similarity Search
|
2 |
+
h5py>=3.1.0 # For saving and loading models
|
3 |
ipython>=8.0.0 # For interactive Python
|
4 |
loguru>=0.7.0 # Enhanced logging (optional but recommended)
|
5 |
matplotlib>=3.5.0 # For validation plotting
|
run_data_preparer.py
ADDED
@@ -0,0 +1,182 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import sys
|
3 |
+
import faiss
|
4 |
+
import pickle
|
5 |
+
from transformers import AutoTokenizer
|
6 |
+
from tqdm.auto import tqdm
|
7 |
+
from chatbot_model import ChatbotConfig, EncoderModel
|
8 |
+
from environment_setup import EnvironmentSetup
|
9 |
+
from tf_data_pipeline import TFDataPipeline
|
10 |
+
from logger_config import config_logger
|
11 |
+
|
12 |
+
logger = config_logger(__name__)
|
13 |
+
|
14 |
+
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
15 |
+
|
16 |
+
def cleanup_test_indices(faiss_dir, test_prefix='test_'):
|
17 |
+
test_files = [f for f in os.listdir(faiss_dir) if f.startswith(test_prefix)]
|
18 |
+
for file in test_files:
|
19 |
+
file_path = os.path.join(faiss_dir, file)
|
20 |
+
os.remove(file_path)
|
21 |
+
logger.info(f"Removed test FAISS index file: {file_path}")
|
22 |
+
|
23 |
+
def main():
|
24 |
+
# Constants
|
25 |
+
MODELS_DIR = 'models'
|
26 |
+
PROCESSED_DATA_DIR = 'processed_outputs'
|
27 |
+
CACHE_DIR = 'cache'
|
28 |
+
TOKENIZER_DIR = os.path.join(MODELS_DIR, 'tokenizer')
|
29 |
+
FAISS_INDICES_DIR = os.path.join(MODELS_DIR, 'faiss_indices')
|
30 |
+
TF_RECORD_DIR = 'training_data'
|
31 |
+
FAISS_INDEX_PRODUCTION_PATH = os.path.join(FAISS_INDICES_DIR, 'faiss_index_production.index')
|
32 |
+
FAISS_INDEX_TEST_PATH = os.path.join(FAISS_INDICES_DIR, 'faiss_index_test.index')
|
33 |
+
ENVIRONMENT = 'test' # or 'production'
|
34 |
+
if ENVIRONMENT == 'test':
|
35 |
+
FAISS_INDEX_PATH = FAISS_INDEX_TEST_PATH
|
36 |
+
else:
|
37 |
+
FAISS_INDEX_PATH = FAISS_INDEX_PRODUCTION_PATH
|
38 |
+
JSON_TRAINING_DATA_PATH = os.path.join(PROCESSED_DATA_DIR, 'augmented_dialogues.json')
|
39 |
+
CACHE_FILE = os.path.join(CACHE_DIR, 'query_embeddings_cache.pkl')
|
40 |
+
TF_RECORD_PATH = os.path.join(TF_RECORD_DIR, 'training_data.tfrecord')
|
41 |
+
DEBUG_SAMPLES = None
|
42 |
+
|
43 |
+
# Ensure output directories exist
|
44 |
+
os.makedirs(MODELS_DIR, exist_ok=True)
|
45 |
+
os.makedirs(PROCESSED_DATA_DIR, exist_ok=True)
|
46 |
+
os.makedirs(CACHE_DIR, exist_ok=True)
|
47 |
+
os.makedirs(TOKENIZER_DIR, exist_ok=True)
|
48 |
+
os.makedirs(FAISS_INDICES_DIR, exist_ok=True)
|
49 |
+
os.makedirs(TF_RECORD_DIR, exist_ok=True)
|
50 |
+
|
51 |
+
# Initialize configuration
|
52 |
+
config = ChatbotConfig()
|
53 |
+
logger.info(f"Chatbot Configuration: {config}")
|
54 |
+
|
55 |
+
# Initialize tokenizer
|
56 |
+
try:
|
57 |
+
tokenizer = AutoTokenizer.from_pretrained(config.pretrained_model)
|
58 |
+
logger.info(f"Tokenizer '{config.pretrained_model}' loaded successfully.")
|
59 |
+
except Exception as e:
|
60 |
+
logger.error(f"Failed to load tokenizer: {e}")
|
61 |
+
sys.exit(1)
|
62 |
+
|
63 |
+
# Add special tokens
|
64 |
+
try:
|
65 |
+
tokenizer.add_special_tokens({'additional_special_tokens': ['<EMPTY_NEGATIVE>']})
|
66 |
+
logger.info("Added special tokens to tokenizer.")
|
67 |
+
except Exception as e:
|
68 |
+
logger.error(f"Failed to add special tokens: {e}")
|
69 |
+
sys.exit(1)
|
70 |
+
|
71 |
+
# Initialize encoder model
|
72 |
+
try:
|
73 |
+
encoder = EncoderModel(config=config)
|
74 |
+
logger.info("EncoderModel initialized successfully.")
|
75 |
+
except Exception as e:
|
76 |
+
logger.error(f"Failed to initialize EncoderModel: {e}")
|
77 |
+
sys.exit(1)
|
78 |
+
|
79 |
+
# Resize token embeddings in encoder to match tokenizer
|
80 |
+
try:
|
81 |
+
encoder.pretrained.resize_token_embeddings(len(tokenizer))
|
82 |
+
logger.info(f"Token embeddings resized to: {len(tokenizer)}")
|
83 |
+
except Exception as e:
|
84 |
+
logger.error(f"Failed to resize token embeddings: {e}")
|
85 |
+
sys.exit(1)
|
86 |
+
|
87 |
+
# Load JSON dialogues
|
88 |
+
try:
|
89 |
+
dialogues = TFDataPipeline.load_json_training_data(JSON_TRAINING_DATA_PATH, DEBUG_SAMPLES)
|
90 |
+
logger.info(f"Loaded {len(dialogues)} dialogues from {JSON_TRAINING_DATA_PATH}.")
|
91 |
+
except Exception as e:
|
92 |
+
logger.error(f"Failed to load dialogues: {e}")
|
93 |
+
sys.exit(1)
|
94 |
+
|
95 |
+
# Load or initialize query_embeddings_cache
|
96 |
+
try:
|
97 |
+
if os.path.exists(CACHE_FILE):
|
98 |
+
with open(CACHE_FILE, 'rb') as f:
|
99 |
+
query_embeddings_cache = pickle.load(f)
|
100 |
+
logger.info(f"Loaded {len(query_embeddings_cache)} query embeddings from {CACHE_FILE}.")
|
101 |
+
else:
|
102 |
+
query_embeddings_cache = {}
|
103 |
+
logger.info("Initialized empty query embeddings cache.")
|
104 |
+
except Exception as e:
|
105 |
+
logger.error(f"Failed to load or initialize query embeddings cache: {e}")
|
106 |
+
sys.exit(1)
|
107 |
+
|
108 |
+
# Initialize TFDataPipeline
|
109 |
+
try:
|
110 |
+
data_pipeline = TFDataPipeline(
|
111 |
+
config=config,
|
112 |
+
tokenizer=tokenizer,
|
113 |
+
encoder=encoder,
|
114 |
+
index_file_path=FAISS_INDEX_PATH,
|
115 |
+
response_pool=[],
|
116 |
+
max_length=config.max_context_token_limit,
|
117 |
+
neg_samples=config.neg_samples,
|
118 |
+
query_embeddings_cache=query_embeddings_cache,
|
119 |
+
max_retries=config.max_retries
|
120 |
+
)
|
121 |
+
logger.info("TFDataPipeline initialized successfully.")
|
122 |
+
except Exception as e:
|
123 |
+
logger.error(f"Failed to initialize TFDataPipeline: {e}")
|
124 |
+
sys.exit(1)
|
125 |
+
|
126 |
+
# Collect unique assistant responses from dialogues
|
127 |
+
try:
|
128 |
+
response_pool = data_pipeline.collect_responses(dialogues)
|
129 |
+
data_pipeline.response_pool = response_pool
|
130 |
+
logger.info(f"Collected {len(response_pool)} unique assistant responses from dialogues.")
|
131 |
+
except Exception as e:
|
132 |
+
logger.error(f"Failed to collect responses: {e}")
|
133 |
+
sys.exit(1)
|
134 |
+
|
135 |
+
# Compute and add response embeddings to FAISS index
|
136 |
+
try:
|
137 |
+
logger.info("Computing and adding response embeddings to FAISS index...")
|
138 |
+
data_pipeline._compute_and_index_response_embeddings()
|
139 |
+
logger.info("Response embeddings computed and added to FAISS index.")
|
140 |
+
except Exception as e:
|
141 |
+
logger.error(f"Failed to compute or add response embeddings: {e}")
|
142 |
+
sys.exit(1)
|
143 |
+
|
144 |
+
# Save FAISS index
|
145 |
+
try:
|
146 |
+
logger.info(f"Saving FAISS index to {FAISS_INDEX_PATH}...")
|
147 |
+
faiss.write_index(data_pipeline.index, FAISS_INDEX_PATH)
|
148 |
+
logger.info("FAISS index saved successfully.")
|
149 |
+
except Exception as e:
|
150 |
+
logger.error(f"Failed to save FAISS index: {e}")
|
151 |
+
sys.exit(1)
|
152 |
+
|
153 |
+
# Prepare and save training data as TFRecords
|
154 |
+
try:
|
155 |
+
logger.info("Starting data preparation and saving as TFRecord...")
|
156 |
+
data_pipeline.prepare_and_save_data(dialogues, TF_RECORD_PATH)
|
157 |
+
logger.info(f"Data saved as TFRecord at {TF_RECORD_PATH}.")
|
158 |
+
except Exception as e:
|
159 |
+
logger.error(f"Failed during data preparation and saving: {e}")
|
160 |
+
sys.exit(1)
|
161 |
+
|
162 |
+
# Save query embeddings cache
|
163 |
+
try:
|
164 |
+
with open(CACHE_FILE, 'wb') as f:
|
165 |
+
pickle.dump(data_pipeline.query_embeddings_cache, f)
|
166 |
+
logger.info(f"Saved {len(data_pipeline.query_embeddings_cache)} query embeddings to {CACHE_FILE}.")
|
167 |
+
except Exception as e:
|
168 |
+
logger.error(f"Failed to save query embeddings cache: {e}")
|
169 |
+
sys.exit(1)
|
170 |
+
|
171 |
+
# Save Tokenizer (including special tokens)
|
172 |
+
try:
|
173 |
+
tokenizer.save_pretrained(TOKENIZER_DIR)
|
174 |
+
logger.info(f"Tokenizer saved to {TOKENIZER_DIR}.")
|
175 |
+
except Exception as e:
|
176 |
+
logger.error(f"Failed to save tokenizer: {e}")
|
177 |
+
sys.exit(1)
|
178 |
+
|
179 |
+
logger.info("Data preparation pipeline completed successfully.")
|
180 |
+
|
181 |
+
if __name__ == "__main__":
|
182 |
+
main()
|
tf_data_pipeline.py
ADDED
@@ -0,0 +1,734 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import gc
|
3 |
+
import numpy as np
|
4 |
+
import faiss
|
5 |
+
import tensorflow as tf
|
6 |
+
import h5py
|
7 |
+
from tqdm import tqdm
|
8 |
+
import json
|
9 |
+
from pathlib import Path
|
10 |
+
from typing import Union, Optional, List, Tuple, Generator
|
11 |
+
from transformers import AutoTokenizer
|
12 |
+
from typing import List, Tuple, Generator
|
13 |
+
from gpu_monitor import GPUMemoryMonitor
|
14 |
+
|
15 |
+
from logger_config import config_logger
|
16 |
+
logger = config_logger(__name__)
|
17 |
+
|
18 |
+
class TFDataPipeline:
|
19 |
+
def __init__(
|
20 |
+
self,
|
21 |
+
config,
|
22 |
+
tokenizer,
|
23 |
+
encoder,
|
24 |
+
index_file_path: str,
|
25 |
+
response_pool: List[str],
|
26 |
+
max_length: int,
|
27 |
+
query_embeddings_cache: dict,
|
28 |
+
neg_samples: int = 3,
|
29 |
+
index_type: str = 'IndexFlatIP',
|
30 |
+
nlist: int = 100,
|
31 |
+
max_retries: int = 3
|
32 |
+
):
|
33 |
+
#self.embedding_batch_size = embedding_batch_size
|
34 |
+
self.config = config
|
35 |
+
self.tokenizer = tokenizer
|
36 |
+
self.encoder = encoder
|
37 |
+
self.index_file_path = index_file_path
|
38 |
+
self.response_pool = response_pool
|
39 |
+
self.max_length = max_length
|
40 |
+
self.neg_samples = neg_samples
|
41 |
+
self.query_embeddings_cache = query_embeddings_cache # In-memory cache for embeddings
|
42 |
+
self.index_type = index_type
|
43 |
+
self.nlist = nlist
|
44 |
+
self.embedding_batch_size = 16 if len(response_pool) < 100 else 64
|
45 |
+
self.search_batch_size = 16 if len(response_pool) < 100 else 64
|
46 |
+
self.max_batch_size = 16 if len(response_pool) < 100 else 64
|
47 |
+
self.memory_monitor = GPUMemoryMonitor()
|
48 |
+
self.max_retries = max_retries
|
49 |
+
|
50 |
+
if os.path.exists(index_file_path):
|
51 |
+
logger.info(f"Loading existing FAISS index from {index_file_path}...")
|
52 |
+
self.index = faiss.read_index(index_file_path)
|
53 |
+
self.validate_faiss_index()
|
54 |
+
logger.info("FAISS index loaded and validated successfully.")
|
55 |
+
else:
|
56 |
+
# Initialize FAISS index
|
57 |
+
dimension = self.encoder.config.embedding_dim
|
58 |
+
self.index = faiss.IndexFlatIP(dimension)
|
59 |
+
logger.info(f"Initialized FAISS IndexFlatIP with dimension {dimension}.")
|
60 |
+
|
61 |
+
if not self.index.is_trained:
|
62 |
+
# Train the index if it's not trained. # TODO: Replace 'dimension' with embedding size
|
63 |
+
dimension = self.query_embeddings_cache[next(iter(self.query_embeddings_cache))].shape[0]
|
64 |
+
self.index.train(np.array(list(self.query_embeddings_cache.values())).astype(np.float32))
|
65 |
+
self.index.add(np.array(list(self.query_embeddings_cache.values())).astype(np.float32))
|
66 |
+
|
67 |
+
def validate_faiss_index(self):
|
68 |
+
"""Validates that the FAISS index has the correct dimensionality."""
|
69 |
+
expected_dim = self.encoder.config.embedding_dim
|
70 |
+
if self.index.d != expected_dim:
|
71 |
+
logger.error(f"FAISS index dimension {self.index.d} does not match encoder embedding dimension {expected_dim}.")
|
72 |
+
raise ValueError("FAISS index dimensionality mismatch.")
|
73 |
+
logger.info("FAISS index dimension validated successfully.")
|
74 |
+
|
75 |
+
def save_embeddings_cache_hdf5(self, cache_file_path: str):
|
76 |
+
"""Save the embeddings cache to an HDF5 file."""
|
77 |
+
with h5py.File(cache_file_path, 'w') as hf:
|
78 |
+
for query, emb in self.query_embeddings_cache.items():
|
79 |
+
hf.create_dataset(query, data=emb)
|
80 |
+
logger.info(f"Embeddings cache saved to {cache_file_path}.")
|
81 |
+
|
82 |
+
def load_embeddings_cache_hdf5(self, cache_file_path: str):
|
83 |
+
"""Load the embeddings cache from an HDF5 file."""
|
84 |
+
with h5py.File(cache_file_path, 'r') as hf:
|
85 |
+
for query in hf.keys():
|
86 |
+
self.query_embeddings_cache[query] = hf[query][:]
|
87 |
+
logger.info(f"Embeddings cache loaded from {cache_file_path}.")
|
88 |
+
|
89 |
+
def save_faiss_index(self, index_file_path: str):
|
90 |
+
faiss.write_index(self.index, index_file_path)
|
91 |
+
logger.info(f"FAISS index saved to {index_file_path}")
|
92 |
+
|
93 |
+
def load_faiss_index(self, index_file_path: str):
|
94 |
+
self.index = faiss.read_index(index_file_path)
|
95 |
+
logger.info(f"FAISS index loaded from {index_file_path}")
|
96 |
+
|
97 |
+
def save_tokenizer(self, tokenizer_dir: str):
|
98 |
+
self.tokenizer.save_pretrained(tokenizer_dir)
|
99 |
+
logger.info(f"Tokenizer saved to {tokenizer_dir}")
|
100 |
+
|
101 |
+
def load_tokenizer(self, tokenizer_dir: str):
|
102 |
+
self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_dir)
|
103 |
+
logger.info(f"Tokenizer loaded from {tokenizer_dir}")
|
104 |
+
|
105 |
+
def estimate_total_pairs(self, dialogues: List[dict]) -> int:
|
106 |
+
"""Estimate total number of training pairs including hard negatives."""
|
107 |
+
base_pairs = sum(
|
108 |
+
len([
|
109 |
+
1 for i in range(len(d.get('turns', [])) - 1)
|
110 |
+
if (d['turns'][i].get('speaker') == 'user' and
|
111 |
+
d['turns'][i+1].get('speaker') == 'assistant')
|
112 |
+
])
|
113 |
+
for d in dialogues
|
114 |
+
)
|
115 |
+
# Account for hard negatives
|
116 |
+
return base_pairs * (1 + self.neg_samples)
|
117 |
+
|
118 |
+
@staticmethod
|
119 |
+
def load_json_training_data(data_path: Union[str, Path], debug_samples: Optional[int] = None) -> List[dict]:
|
120 |
+
"""
|
121 |
+
Load training data from a JSON file.
|
122 |
+
|
123 |
+
Args:
|
124 |
+
data_path (Union[str, Path]): Path to the JSON file containing dialogues.
|
125 |
+
debug_samples (Optional[int]): Number of samples to load for debugging.
|
126 |
+
|
127 |
+
Returns:
|
128 |
+
List[dict]: List of dialogue dictionaries.
|
129 |
+
"""
|
130 |
+
logger.info(f"Loading training data from {data_path}...")
|
131 |
+
data_path = Path(data_path)
|
132 |
+
if not data_path.exists():
|
133 |
+
logger.error(f"Data file {data_path} does not exist.")
|
134 |
+
return []
|
135 |
+
|
136 |
+
with open(data_path, 'r', encoding='utf-8') as f:
|
137 |
+
dialogues = json.load(f)
|
138 |
+
|
139 |
+
if debug_samples is not None:
|
140 |
+
dialogues = dialogues[:debug_samples]
|
141 |
+
logger.info(f"Debug mode: Limited to {debug_samples} dialogues")
|
142 |
+
|
143 |
+
logger.info(f"Loaded {len(dialogues)} dialogues.")
|
144 |
+
return dialogues
|
145 |
+
|
146 |
+
def collect_responses(self, dialogues: List[dict]) -> List[str]:
|
147 |
+
"""Extract unique assistant responses from dialogues."""
|
148 |
+
response_set = set()
|
149 |
+
for dialogue in dialogues:
|
150 |
+
turns = dialogue.get('turns', [])
|
151 |
+
for turn in turns:
|
152 |
+
speaker = turn.get('speaker')
|
153 |
+
text = turn.get('text', '').strip()
|
154 |
+
if speaker == 'assistant' and text:
|
155 |
+
# Ensure we don't exclude valid shorter responses
|
156 |
+
if len(text) <= self.max_length:
|
157 |
+
response_set.add(text)
|
158 |
+
logger.info(f"Collected {len(response_set)} unique assistant responses from dialogues.")
|
159 |
+
return list(response_set)
|
160 |
+
|
161 |
+
def _extract_pairs_from_dialogue(self, dialogue: dict) -> List[Tuple[str, str]]:
|
162 |
+
"""Extract query-response pairs from a dialogue."""
|
163 |
+
pairs = []
|
164 |
+
turns = dialogue.get('turns', [])
|
165 |
+
|
166 |
+
for i in range(len(turns) - 1):
|
167 |
+
current_turn = turns[i]
|
168 |
+
next_turn = turns[i+1]
|
169 |
+
|
170 |
+
if (current_turn.get('speaker') == 'user' and
|
171 |
+
next_turn.get('speaker') == 'assistant' and
|
172 |
+
'text' in current_turn and
|
173 |
+
'text' in next_turn):
|
174 |
+
|
175 |
+
query = current_turn['text'].strip()
|
176 |
+
positive = next_turn['text'].strip()
|
177 |
+
pairs.append((query, positive))
|
178 |
+
|
179 |
+
return pairs
|
180 |
+
|
181 |
+
def _compute_and_index_response_embeddings(self):
|
182 |
+
"""
|
183 |
+
Computes embeddings for the response pool and adds them to the FAISS index.
|
184 |
+
"""
|
185 |
+
logger.info("Computing embeddings for the response pool...")
|
186 |
+
|
187 |
+
# Log the contents and types of response_pool
|
188 |
+
for idx, response in enumerate(self.response_pool[:5], 1): # Log first 5 responses
|
189 |
+
logger.debug(f"Response {idx}: {response} (Type: {type(response)})")
|
190 |
+
|
191 |
+
# Ensure all responses are strings
|
192 |
+
if not all(isinstance(response, str) for response in self.response_pool):
|
193 |
+
logger.error("All elements in response_pool must be strings.")
|
194 |
+
raise ValueError("Invalid data type in response_pool.")
|
195 |
+
|
196 |
+
# Proceed with tokenization
|
197 |
+
encoded_responses = self.tokenizer(
|
198 |
+
self.response_pool,
|
199 |
+
padding=True,
|
200 |
+
truncation=True,
|
201 |
+
max_length=self.max_length,
|
202 |
+
return_tensors='tf'
|
203 |
+
)
|
204 |
+
response_ids = encoded_responses['input_ids']
|
205 |
+
|
206 |
+
# Compute embeddings in batches
|
207 |
+
batch_size = getattr(self, 'embedding_batch_size', 64) # Default to 64 if not set
|
208 |
+
embeddings = []
|
209 |
+
for i in range(0, len(response_ids), batch_size):
|
210 |
+
batch_ids = response_ids[i:i+batch_size]
|
211 |
+
# Compute embeddings
|
212 |
+
batch_embeddings = self.encoder(batch_ids, training=False).numpy()
|
213 |
+
# Normalize embeddings if using inner product or cosine similarity
|
214 |
+
faiss.normalize_L2(batch_embeddings)
|
215 |
+
embeddings.append(batch_embeddings)
|
216 |
+
|
217 |
+
if embeddings:
|
218 |
+
embeddings = np.vstack(embeddings).astype(np.float32)
|
219 |
+
# Add embeddings to FAISS index
|
220 |
+
logger.info(f"Adding {len(embeddings)} response embeddings to FAISS index...")
|
221 |
+
self.index.add(embeddings)
|
222 |
+
logger.info("Response embeddings added to FAISS index.")
|
223 |
+
else:
|
224 |
+
logger.warning("No embeddings to add to FAISS index.")
|
225 |
+
|
226 |
+
# **Sanity Check:** Verify the number of embeddings in FAISS index
|
227 |
+
logger.info(f"Total embeddings in FAISS index after addition: {self.index.ntotal}")
|
228 |
+
|
229 |
+
def _find_hard_negatives_batch(self, queries: List[str], positives: List[str]) -> List[List[str]]:
|
230 |
+
"""Find hard negatives for a batch of queries with error handling and retries."""
|
231 |
+
retry_count = 0
|
232 |
+
total_responses = len(self.response_pool)
|
233 |
+
|
234 |
+
# Set k to be neg_samples + additional candidates to improve negative selection
|
235 |
+
k = self.neg_samples + 0
|
236 |
+
|
237 |
+
while retry_count < self.max_retries:
|
238 |
+
try:
|
239 |
+
# Compute embeddings in sub-batches to manage memory
|
240 |
+
batch_size = 128 # Example sub-batch size; adjust as needed
|
241 |
+
query_embeddings = []
|
242 |
+
for i in range(0, len(queries), batch_size):
|
243 |
+
sub_queries = queries[i:i + batch_size]
|
244 |
+
sub_embeddings = np.vstack([
|
245 |
+
self.query_embeddings_cache[q] for q in sub_queries
|
246 |
+
]).astype(np.float32)
|
247 |
+
faiss.normalize_L2(sub_embeddings)
|
248 |
+
query_embeddings.append(sub_embeddings)
|
249 |
+
query_embeddings = np.vstack(query_embeddings)
|
250 |
+
|
251 |
+
# Ensure contiguous memory layout
|
252 |
+
query_embeddings = np.ascontiguousarray(query_embeddings)
|
253 |
+
|
254 |
+
# Perform FAISS search on CPU
|
255 |
+
distances, indices = self.index.search(query_embeddings, k)
|
256 |
+
|
257 |
+
all_negatives = []
|
258 |
+
for query_indices, query, positive in zip(indices, queries, positives):
|
259 |
+
negatives = []
|
260 |
+
positive_strip = positive.strip()
|
261 |
+
seen = {positive_strip}
|
262 |
+
|
263 |
+
for idx in query_indices:
|
264 |
+
if idx >= 0 and idx < total_responses:
|
265 |
+
candidate = self.response_pool[idx].strip()
|
266 |
+
if candidate and candidate not in seen:
|
267 |
+
seen.add(candidate)
|
268 |
+
negatives.append(candidate)
|
269 |
+
if len(negatives) >= self.neg_samples:
|
270 |
+
break
|
271 |
+
|
272 |
+
# If not enough negatives are found, pad with a special token
|
273 |
+
while len(negatives) < self.neg_samples:
|
274 |
+
negatives.append("<EMPTY_NEGATIVE>") # Use a special token
|
275 |
+
|
276 |
+
all_negatives.append(negatives)
|
277 |
+
|
278 |
+
return all_negatives
|
279 |
+
|
280 |
+
except KeyError as ke:
|
281 |
+
retry_count += 1
|
282 |
+
logger.warning(f"Hard negative search attempt {retry_count} failed due to missing embeddings: {ke}")
|
283 |
+
if retry_count == self.max_retries:
|
284 |
+
logger.error("Max retries reached for hard negative search due to missing embeddings.")
|
285 |
+
return [["<EMPTY_NEGATIVE>"] * self.neg_samples for _ in queries]
|
286 |
+
# Perform memory cleanup
|
287 |
+
gc.collect()
|
288 |
+
if tf.config.list_physical_devices('GPU'):
|
289 |
+
tf.keras.backend.clear_session()
|
290 |
+
except Exception as e:
|
291 |
+
retry_count += 1
|
292 |
+
logger.warning(f"Hard negative search attempt {retry_count} failed: {e}")
|
293 |
+
if retry_count == self.max_retries:
|
294 |
+
logger.error("Max retries reached for hard negative search.")
|
295 |
+
return [["<EMPTY_NEGATIVE>"] * self.neg_samples for _ in queries]
|
296 |
+
# Perform memory cleanup
|
297 |
+
gc.collect()
|
298 |
+
if tf.config.list_physical_devices('GPU'):
|
299 |
+
tf.keras.backend.clear_session()
|
300 |
+
|
301 |
+
def _tokenize_and_encode(self, queries: List[str], positives: List[str], negatives: List[List[str]]) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
|
302 |
+
"""
|
303 |
+
Tokenize and encode the queries, positives, and negatives.
|
304 |
+
Returns:
|
305 |
+
query_ids: [batch_size, max_length]
|
306 |
+
positive_ids: [batch_size, max_length]
|
307 |
+
negative_ids: [batch_size, neg_samples, max_length]
|
308 |
+
"""
|
309 |
+
# Tokenize queries
|
310 |
+
q_enc = self.tokenizer(
|
311 |
+
queries,
|
312 |
+
padding="max_length",
|
313 |
+
truncation=True,
|
314 |
+
max_length=self.max_length,
|
315 |
+
return_tensors="np"
|
316 |
+
)
|
317 |
+
# Tokenize positives
|
318 |
+
p_enc = self.tokenizer(
|
319 |
+
positives,
|
320 |
+
padding="max_length",
|
321 |
+
truncation=True,
|
322 |
+
max_length=self.max_length,
|
323 |
+
return_tensors="np"
|
324 |
+
)
|
325 |
+
# Tokenize negatives
|
326 |
+
# Flatten negatives
|
327 |
+
flattened_negatives = [neg for sublist in negatives for neg in sublist]
|
328 |
+
if len(flattened_negatives) == 0:
|
329 |
+
# No negatives at all: return a zero array
|
330 |
+
n_ids = np.zeros((len(queries), self.neg_samples, self.max_length), dtype=np.int32)
|
331 |
+
else:
|
332 |
+
n_enc = self.tokenizer(
|
333 |
+
flattened_negatives,
|
334 |
+
padding="max_length",
|
335 |
+
truncation=True,
|
336 |
+
max_length=self.max_length,
|
337 |
+
return_tensors="np"
|
338 |
+
)
|
339 |
+
n_input_ids = n_enc["input_ids"]
|
340 |
+
|
341 |
+
# Reshape to [batch_size, neg_samples, max_length]
|
342 |
+
batch_size = len(queries)
|
343 |
+
n_ids = n_input_ids.reshape(batch_size, self.neg_samples, self.max_length)
|
344 |
+
|
345 |
+
# Convert to int32
|
346 |
+
query_ids = q_enc["input_ids"].astype(np.int32)
|
347 |
+
positive_ids = p_enc["input_ids"].astype(np.int32)
|
348 |
+
negative_ids = n_ids.astype(np.int32)
|
349 |
+
|
350 |
+
return query_ids, positive_ids, negative_ids
|
351 |
+
|
352 |
+
def prepare_and_save_data(self, dialogues: List[dict], tfrecord_file_path: str, batch_size: int = 32):
|
353 |
+
"""Processes dialogues in batches and saves to a TFRecord file."""
|
354 |
+
with tf.io.TFRecordWriter(tfrecord_file_path) as writer:
|
355 |
+
total_dialogues = len(dialogues)
|
356 |
+
logger.debug(f"Total dialogues to process: {total_dialogues}")
|
357 |
+
|
358 |
+
with tqdm(total=total_dialogues, desc="Processing Dialogues", unit="dialogue") as pbar:
|
359 |
+
for i in range(0, total_dialogues, batch_size):
|
360 |
+
batch_dialogues = dialogues[i:i+batch_size]
|
361 |
+
# Process each batch_dialogues
|
362 |
+
# Extract pairs, find negatives, tokenize, and serialize
|
363 |
+
# Example:
|
364 |
+
for dialogue in batch_dialogues:
|
365 |
+
pairs = self._extract_pairs_from_dialogue(dialogue)
|
366 |
+
queries = []
|
367 |
+
positives = []
|
368 |
+
|
369 |
+
for query, positive in pairs:
|
370 |
+
queries.append(query)
|
371 |
+
positives.append(positive)
|
372 |
+
|
373 |
+
if queries:
|
374 |
+
# **Compute and cache query embeddings before searching**
|
375 |
+
self._compute_embeddings(queries)
|
376 |
+
|
377 |
+
# Find hard negatives
|
378 |
+
hard_negatives = self._find_hard_negatives_batch(queries, positives)
|
379 |
+
|
380 |
+
for idx, negatives in enumerate(hard_negatives[:5]): # Log first 5 examples
|
381 |
+
logger.debug(f"Query: {queries[idx]}")
|
382 |
+
logger.debug(f"Positive: {positives[idx]}")
|
383 |
+
logger.debug(f"Hard Negatives: {negatives}")
|
384 |
+
# Tokenize and encode
|
385 |
+
query_ids, positive_ids, negative_ids = self._tokenize_and_encode(queries, positives, hard_negatives)
|
386 |
+
|
387 |
+
# Serialize each example and write to TFRecord
|
388 |
+
for q_id, p_id, n_id in zip(query_ids, positive_ids, negative_ids):
|
389 |
+
feature = {
|
390 |
+
'query_ids': tf.train.Feature(int64_list=tf.train.Int64List(value=q_id)),
|
391 |
+
'positive_ids': tf.train.Feature(int64_list=tf.train.Int64List(value=p_id)),
|
392 |
+
'negative_ids': tf.train.Feature(int64_list=tf.train.Int64List(value=n_id.flatten())),
|
393 |
+
}
|
394 |
+
example = tf.train.Example(features=tf.train.Features(feature=feature))
|
395 |
+
writer.write(example.SerializeToString())
|
396 |
+
|
397 |
+
pbar.update(len(batch_dialogues))
|
398 |
+
logger.info(f"Data preparation complete. TFRecord saved at {tfrecord_file_path}")
|
399 |
+
|
400 |
+
def _tokenize_negatives_tf(self, negatives):
|
401 |
+
"""Tokenizes negatives using tf.py_function."""
|
402 |
+
# Handle the case where negatives is an empty tensor
|
403 |
+
if tf.size(negatives) == 0:
|
404 |
+
return tf.zeros([0, self.neg_samples, self.max_length], dtype=tf.int32)
|
405 |
+
|
406 |
+
# Convert EagerTensor to a list of strings
|
407 |
+
negatives_list = []
|
408 |
+
for neg_list in negatives.numpy():
|
409 |
+
decoded_negs = [neg.decode("utf-8") for neg in neg_list if neg] # Filter out empty strings
|
410 |
+
negatives_list.append(decoded_negs)
|
411 |
+
|
412 |
+
# Flatten the list of lists
|
413 |
+
flattened_negatives = [neg for sublist in negatives_list for neg in sublist]
|
414 |
+
|
415 |
+
# Tokenize the flattened negatives
|
416 |
+
if flattened_negatives:
|
417 |
+
n_tokens = self.tokenizer(
|
418 |
+
flattened_negatives,
|
419 |
+
padding='max_length',
|
420 |
+
truncation=True,
|
421 |
+
max_length=self.max_length,
|
422 |
+
return_tensors='tf'
|
423 |
+
)
|
424 |
+
# Reshape the tokens
|
425 |
+
n_tokens_reshaped = tf.reshape(n_tokens['input_ids'], [-1, self.neg_samples, self.max_length])
|
426 |
+
return n_tokens_reshaped
|
427 |
+
else:
|
428 |
+
return tf.zeros([0, self.neg_samples, self.max_length], dtype=tf.int32)
|
429 |
+
|
430 |
+
def _compute_embeddings(self, queries: List[str]) -> None:
|
431 |
+
new_queries = [q for q in queries if q not in self.query_embeddings_cache]
|
432 |
+
if not new_queries:
|
433 |
+
return # All queries already cached
|
434 |
+
|
435 |
+
# Compute embeddings for new queries
|
436 |
+
new_embeddings = []
|
437 |
+
for i in range(0, len(new_queries), self.embedding_batch_size):
|
438 |
+
batch_queries = new_queries[i:i + self.embedding_batch_size]
|
439 |
+
encoded = self.tokenizer(
|
440 |
+
batch_queries,
|
441 |
+
padding=True,
|
442 |
+
truncation=True,
|
443 |
+
max_length=self.max_length,
|
444 |
+
return_tensors='tf'
|
445 |
+
)
|
446 |
+
batch_embeddings = self.encoder(encoded['input_ids'], training=False).numpy()
|
447 |
+
faiss.normalize_L2(batch_embeddings)
|
448 |
+
new_embeddings.extend(batch_embeddings)
|
449 |
+
|
450 |
+
# Update the cache
|
451 |
+
for query, emb in zip(new_queries, new_embeddings):
|
452 |
+
self.query_embeddings_cache[query] = emb
|
453 |
+
|
454 |
+
def data_generator(self, dialogues: List[dict]) -> Generator[Tuple[str, str, List[str]], None, None]:
|
455 |
+
"""
|
456 |
+
Generates training examples: (query, positive, hard_negatives).
|
457 |
+
Wrapped the outer loop with tqdm for progress tracking.
|
458 |
+
"""
|
459 |
+
total_dialogues = len(dialogues)
|
460 |
+
logger.debug(f"Total dialogues to process: {total_dialogues}")
|
461 |
+
|
462 |
+
# Initialize tqdm progress bar
|
463 |
+
with tqdm(total=total_dialogues, desc="Processing Dialogues", unit="dialogue") as pbar:
|
464 |
+
for dialogue in dialogues:
|
465 |
+
pairs = self._extract_pairs_from_dialogue(dialogue)
|
466 |
+
for query, positive in pairs:
|
467 |
+
# Ensure embeddings are computed, find hard negatives, etc.
|
468 |
+
self._compute_embeddings([query])
|
469 |
+
hard_negatives = self._find_hard_negatives_batch([query], [positive])[0]
|
470 |
+
yield (query, positive, hard_negatives)
|
471 |
+
pbar.update(1)
|
472 |
+
|
473 |
+
def _prepare_batch(self, queries: tf.Tensor, positives: tf.Tensor, negatives: tf.Tensor) -> Tuple[tf.Tensor, tf.Tensor]:
|
474 |
+
"""Prepares a batch of data for training."""
|
475 |
+
|
476 |
+
# Convert EagerTensors to lists of strings
|
477 |
+
queries_list = [query.decode("utf-8") for query in queries.numpy()]
|
478 |
+
positives_list = [pos.decode("utf-8") for pos in positives.numpy()]
|
479 |
+
|
480 |
+
# Tokenize queries and positives
|
481 |
+
q_tokens = self.tokenizer(queries_list, padding='max_length', truncation=True, max_length=self.max_length, return_tensors='tf')
|
482 |
+
p_tokens = self.tokenizer(positives_list, padding='max_length', truncation=True, max_length=self.max_length, return_tensors='tf')
|
483 |
+
|
484 |
+
# Decode negatives and ensure they are lists of strings
|
485 |
+
negatives_list = []
|
486 |
+
for neg_list in negatives.numpy():
|
487 |
+
decoded_negs = [neg.decode("utf-8") for neg in neg_list if neg] # Filter out empty strings
|
488 |
+
negatives_list.append(decoded_negs)
|
489 |
+
|
490 |
+
# Flatten negatives for tokenization if there are any valid negatives
|
491 |
+
flattened_negatives = [neg for sublist in negatives_list for neg in sublist if neg]
|
492 |
+
|
493 |
+
# Tokenize negatives if there are any
|
494 |
+
n_tokens_reshaped = None
|
495 |
+
if flattened_negatives:
|
496 |
+
n_tokens = self.tokenizer(flattened_negatives, padding='max_length', truncation=True, max_length=self.max_length, return_tensors='tf')
|
497 |
+
|
498 |
+
# Reshape n_tokens to match the expected shape based on the number of negatives per query
|
499 |
+
# This part may need adjustment if the number of negatives varies per query
|
500 |
+
n_tokens_reshaped = tf.reshape(n_tokens['input_ids'], [len(queries_list), -1, self.max_length])
|
501 |
+
else:
|
502 |
+
# Create a placeholder tensor for the case where there are no negatives
|
503 |
+
n_tokens_reshaped = tf.zeros([len(queries_list), 0, self.max_length], dtype=tf.int32)
|
504 |
+
|
505 |
+
# Ensure n_tokens_reshaped has a consistent shape even when there are no negatives
|
506 |
+
# Adjust shape to [batch_size, num_neg_samples, max_length]
|
507 |
+
if n_tokens_reshaped.shape[1] != self.neg_samples:
|
508 |
+
# Pad or truncate the second dimension to match neg_samples
|
509 |
+
padding = tf.zeros([len(queries_list), tf.maximum(0, self.neg_samples - n_tokens_reshaped.shape[1]), self.max_length], dtype=tf.int32)
|
510 |
+
n_tokens_reshaped = tf.concat([n_tokens_reshaped, padding], axis=1)
|
511 |
+
n_tokens_reshaped = n_tokens_reshaped[:, :self.neg_samples, :]
|
512 |
+
|
513 |
+
# Concatenate the positive and negative examples along the 'neg_samples' dimension
|
514 |
+
combined_p_n_tokens = tf.concat([tf.expand_dims(p_tokens['input_ids'], axis=1), n_tokens_reshaped], axis=1)
|
515 |
+
|
516 |
+
return q_tokens['input_ids'], combined_p_n_tokens
|
517 |
+
|
518 |
+
def get_tf_dataset(self, dialogues: List[dict], batch_size: int) -> tf.data.Dataset:
|
519 |
+
"""
|
520 |
+
Creates a tf.data.Dataset for streaming training that yields
|
521 |
+
(input_ids_query, input_ids_positive, input_ids_negatives).
|
522 |
+
"""
|
523 |
+
# 1) Start with a generator dataset
|
524 |
+
dataset = tf.data.Dataset.from_generator(
|
525 |
+
lambda: self.data_generator(dialogues),
|
526 |
+
output_signature=(
|
527 |
+
tf.TensorSpec(shape=(), dtype=tf.string), # Query (single string)
|
528 |
+
tf.TensorSpec(shape=(), dtype=tf.string), # Positive (single string)
|
529 |
+
tf.TensorSpec(shape=(self.neg_samples,), dtype=tf.string) # Hard Negatives (list of strings)
|
530 |
+
)
|
531 |
+
)
|
532 |
+
|
533 |
+
# 2) Batch the raw strings
|
534 |
+
dataset = dataset.batch(batch_size, drop_remainder=True)
|
535 |
+
|
536 |
+
# 3) Map them through a tokenize step using `tf.py_function`
|
537 |
+
dataset = dataset.map(
|
538 |
+
lambda q, p, n: self._tokenize_triple(q, p, n),
|
539 |
+
num_parallel_calls=1 #tf.data.AUTOTUNE
|
540 |
+
)
|
541 |
+
|
542 |
+
dataset = dataset.prefetch(tf.data.AUTOTUNE)
|
543 |
+
return dataset
|
544 |
+
# def get_tf_dataset(self, dialogues: List[dict], batch_size: int) -> tf.data.Dataset:
|
545 |
+
# """
|
546 |
+
# Creates a tf.data.Dataset for streaming training that yields
|
547 |
+
# (input_ids_query, input_ids_positive, input_ids_negatives).
|
548 |
+
# """
|
549 |
+
# # 1) Start with a generator dataset
|
550 |
+
# dataset = tf.data.Dataset.from_generator(
|
551 |
+
# lambda: self.data_generator(dialogues),
|
552 |
+
# output_signature=(
|
553 |
+
# tf.TensorSpec(shape=(), dtype=tf.string), # Query (single string)
|
554 |
+
# tf.TensorSpec(shape=(), dtype=tf.string), # Positive (single string)
|
555 |
+
# tf.TensorSpec(shape=(None,), dtype=tf.string) # Hard Negatives (list of strings)
|
556 |
+
# )
|
557 |
+
# )
|
558 |
+
|
559 |
+
# # 2) Batch the raw strings
|
560 |
+
# dataset = dataset.batch(batch_size)
|
561 |
+
|
562 |
+
# # 3) Now map them through a tokenize step (via py_function)
|
563 |
+
# dataset = dataset.map(
|
564 |
+
# lambda q, p, n: self._tokenize_triple(q, p, n),
|
565 |
+
# num_parallel_calls=1 #tf.data.AUTOTUNE
|
566 |
+
# )
|
567 |
+
|
568 |
+
# dataset = dataset.prefetch(tf.data.AUTOTUNE)
|
569 |
+
# return dataset
|
570 |
+
|
571 |
+
def _tokenize_triple(
|
572 |
+
self,
|
573 |
+
q: tf.Tensor,
|
574 |
+
p: tf.Tensor,
|
575 |
+
n: tf.Tensor
|
576 |
+
) -> Tuple[tf.Tensor, tf.Tensor, tf.Tensor]:
|
577 |
+
"""
|
578 |
+
Wraps a Python function via tf.py_function to convert tf.Tensors of strings
|
579 |
+
-> Python lists of strings -> HF tokenizer -> Tensors of IDs.
|
580 |
+
|
581 |
+
q is shape [batch_size], p is shape [batch_size],
|
582 |
+
n is shape [batch_size, neg_samples] (i.e., each row is a list of negatives).
|
583 |
+
"""
|
584 |
+
# Use tf.py_function with limited parallelism
|
585 |
+
q_ids, p_ids, n_ids = tf.py_function(
|
586 |
+
func=self._tokenize_triple_py,
|
587 |
+
inp=[q, p, n, tf.constant(self.max_length), tf.constant(self.neg_samples)],
|
588 |
+
Tout=[tf.int32, tf.int32, tf.int32]
|
589 |
+
)
|
590 |
+
|
591 |
+
# Manually set shape information
|
592 |
+
q_ids.set_shape([None, self.max_length]) # [batch_size, max_length]
|
593 |
+
p_ids.set_shape([None, self.max_length]) # [batch_size, max_length]
|
594 |
+
n_ids.set_shape([None, self.neg_samples, self.max_length]) # [batch_size, neg_samples, max_length]
|
595 |
+
|
596 |
+
return q_ids, p_ids, n_ids
|
597 |
+
|
598 |
+
def _tokenize_triple_py(
|
599 |
+
self,
|
600 |
+
q: tf.Tensor,
|
601 |
+
p: tf.Tensor,
|
602 |
+
n: tf.Tensor,
|
603 |
+
max_len: tf.Tensor,
|
604 |
+
neg_samples: tf.Tensor
|
605 |
+
) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
|
606 |
+
"""
|
607 |
+
Python function that:
|
608 |
+
- Decodes each tf.string Tensor to a Python list of strings
|
609 |
+
- Calls the HF tokenizer
|
610 |
+
- Reshapes negatives
|
611 |
+
- Returns np.array of int32s for (q_ids, p_ids, n_ids).
|
612 |
+
|
613 |
+
q: shape [batch_size], p: shape [batch_size]
|
614 |
+
n: shape [batch_size, neg_samples]
|
615 |
+
max_len: scalar int
|
616 |
+
neg_samples: scalar int
|
617 |
+
"""
|
618 |
+
max_len = int(max_len.numpy()) # Convert to Python int
|
619 |
+
neg_samples = int(neg_samples.numpy())
|
620 |
+
|
621 |
+
# 1) Convert Tensors -> Python lists of strings
|
622 |
+
q_list = [q_i.decode("utf-8") for q_i in q.numpy()] # shape [batch_size]
|
623 |
+
p_list = [p_i.decode("utf-8") for p_i in p.numpy()] # shape [batch_size]
|
624 |
+
|
625 |
+
# shape [batch_size, neg_samples], decode each row
|
626 |
+
n_list = []
|
627 |
+
for row in n.numpy():
|
628 |
+
# row is shape [neg_samples], each is a tf.string
|
629 |
+
decoded = [neg.decode("utf-8") for neg in row]
|
630 |
+
n_list.append(decoded)
|
631 |
+
|
632 |
+
# 2) Tokenize queries & positives
|
633 |
+
q_enc = self.tokenizer(
|
634 |
+
q_list,
|
635 |
+
padding="max_length",
|
636 |
+
truncation=True,
|
637 |
+
max_length=max_len,
|
638 |
+
return_tensors="np"
|
639 |
+
)
|
640 |
+
p_enc = self.tokenizer(
|
641 |
+
p_list,
|
642 |
+
padding="max_length",
|
643 |
+
truncation=True,
|
644 |
+
max_length=max_len,
|
645 |
+
return_tensors="np"
|
646 |
+
)
|
647 |
+
|
648 |
+
# 3) Tokenize negatives
|
649 |
+
# Flatten [batch_size, neg_samples] -> single list
|
650 |
+
flattened_negatives = [neg for row in n_list for neg in row]
|
651 |
+
if len(flattened_negatives) == 0:
|
652 |
+
# No negatives at all: return a zero array
|
653 |
+
n_ids = np.zeros((len(q_list), neg_samples, max_len), dtype=np.int32)
|
654 |
+
else:
|
655 |
+
n_enc = self.tokenizer(
|
656 |
+
flattened_negatives,
|
657 |
+
padding="max_length",
|
658 |
+
truncation=True,
|
659 |
+
max_length=max_len,
|
660 |
+
return_tensors="np"
|
661 |
+
)
|
662 |
+
# shape [batch_size * neg_samples, max_len]
|
663 |
+
n_input_ids = n_enc["input_ids"]
|
664 |
+
|
665 |
+
# We want to reshape to [batch_size, neg_samples, max_len]
|
666 |
+
# Handle cases where there might be fewer negatives
|
667 |
+
batch_size = len(q_list)
|
668 |
+
n_ids_list = []
|
669 |
+
for i in range(batch_size):
|
670 |
+
start_idx = i * neg_samples
|
671 |
+
end_idx = start_idx + neg_samples
|
672 |
+
row_negs = n_input_ids[start_idx:end_idx]
|
673 |
+
|
674 |
+
# If fewer negatives, pad with zeros
|
675 |
+
if row_negs.shape[0] < neg_samples:
|
676 |
+
deficit = neg_samples - row_negs.shape[0]
|
677 |
+
pad_arr = np.zeros((deficit, max_len), dtype=np.int32)
|
678 |
+
row_negs = np.concatenate([row_negs, pad_arr], axis=0)
|
679 |
+
|
680 |
+
n_ids_list.append(row_negs)
|
681 |
+
|
682 |
+
# stack them -> shape [batch_size, neg_samples, max_len]
|
683 |
+
n_ids = np.stack(n_ids_list, axis=0)
|
684 |
+
|
685 |
+
# 4) Return as np.int32 arrays
|
686 |
+
q_ids = q_enc["input_ids"].astype(np.int32) # shape [batch_size, max_len]
|
687 |
+
p_ids = p_enc["input_ids"].astype(np.int32) # shape [batch_size, max_len]
|
688 |
+
n_ids = n_ids.astype(np.int32) # shape [batch_size, neg_samples, max_len]
|
689 |
+
|
690 |
+
return q_ids, p_ids, n_ids
|
691 |
+
|
692 |
+
|
693 |
+
|
694 |
+
# def _find_hard_negatives_batch(self, queries: List[str], positives: List[str]) -> List[List[str]]:
|
695 |
+
# """Find hard negatives for a batch of queries with error handling and retries."""
|
696 |
+
# retry_count = 0
|
697 |
+
# total_responses = len(self.response_pool)
|
698 |
+
|
699 |
+
# while retry_count < self.max_retries:
|
700 |
+
# try:
|
701 |
+
# query_embeddings = np.vstack([
|
702 |
+
# self.query_embeddings_cache[q] for q in queries
|
703 |
+
# ]).astype(np.float32)
|
704 |
+
|
705 |
+
# query_embeddings = np.ascontiguousarray(query_embeddings)
|
706 |
+
# faiss.normalize_L2(query_embeddings)
|
707 |
+
|
708 |
+
# k = 1 # TODO: try higher k for better results
|
709 |
+
# #logger.debug(f"Searching with k={k} among {total_responses} responses")
|
710 |
+
|
711 |
+
# distances, indices = self.index.search(query_embeddings, k)
|
712 |
+
|
713 |
+
# all_negatives = []
|
714 |
+
# for query_indices, query, positive in zip(indices, queries, positives):
|
715 |
+
# negatives = []
|
716 |
+
# positive_strip = positive.strip()
|
717 |
+
# seen = {positive_strip}
|
718 |
+
|
719 |
+
# for idx in query_indices:
|
720 |
+
# if idx >= 0 and idx < total_responses:
|
721 |
+
# candidate = self.response_pool[idx].strip()
|
722 |
+
# if candidate and candidate not in seen:
|
723 |
+
# seen.add(candidate)
|
724 |
+
# negatives.append(candidate)
|
725 |
+
# if len(negatives) >= self.neg_samples:
|
726 |
+
# break
|
727 |
+
|
728 |
+
# # Pad with a special empty negative if necessary
|
729 |
+
# while len(negatives) < self.neg_samples:
|
730 |
+
# negatives.append("<EMPTY_NEGATIVE>") # Use a special token
|
731 |
+
|
732 |
+
# all_negatives.append(negatives)
|
733 |
+
|
734 |
+
# return all_negatives
|