JoeArmani
commited on
Commit
·
300fe5d
1
Parent(s):
d2df8da
updates through 4th iteration
Browse files- .gitignore +4 -0
- augmented_combined_dataset.json +0 -0
- back_translator.py +53 -22
- chatbot.py +261 -0
- chatbot2.py +839 -0
- chatbot3.py +824 -0
- chatbot4.py +1291 -0
- dialogue_augmenter.py +329 -171
- main.py +3 -3
- paraphraser.py +14 -3
- pipeline_config.py +0 -1
- processing_pipeline.py +61 -19
- quality_metrics.py +10 -92
- readme.md +4 -2
- requirements.txt +12 -10
- response_quality_checker.py +164 -0
- run_model.py +162 -0
- run_model2.py +340 -0
- run_model3.py +434 -0
- run_model4.py +237 -0
- setup.py +61 -6
- test_trained_model.py +0 -0
.gitignore
CHANGED
@@ -159,3 +159,7 @@ datasets/*
|
|
159 |
|
160 |
processed_outputs/*
|
161 |
!processed_outputs/.gitkeep
|
|
|
|
|
|
|
|
|
|
159 |
|
160 |
processed_outputs/*
|
161 |
!processed_outputs/.gitkeep
|
162 |
+
|
163 |
+
chatbot_training/
|
164 |
+
checkpoints/
|
165 |
+
.DS_Store
|
augmented_combined_dataset.json
DELETED
The diff for this file is too large to render.
See raw diff
|
|
back_translator.py
CHANGED
@@ -3,6 +3,8 @@ from transformers import (
|
|
3 |
MarianTokenizer,
|
4 |
)
|
5 |
|
|
|
|
|
6 |
class BackTranslator:
|
7 |
"""
|
8 |
Perform Back-translation with pivot language. English -> German -> Spanish -> English
|
@@ -20,7 +22,7 @@ class BackTranslator:
|
|
20 |
self.tokenizer_pivot_forward = MarianTokenizer.from_pretrained(pivot_forward_model_name)
|
21 |
self.model_pivot_forward = MarianMTModel.from_pretrained(pivot_forward_model_name)
|
22 |
|
23 |
-
# Pivot translation
|
24 |
pivot_backward_model_name = f'Helsinki-NLP/opus-mt-{pivot_lang}-{target_lang}'
|
25 |
self.tokenizer_pivot_backward = MarianTokenizer.from_pretrained(pivot_backward_model_name)
|
26 |
self.model_pivot_backward = MarianMTModel.from_pretrained(pivot_backward_model_name)
|
@@ -29,28 +31,57 @@ class BackTranslator:
|
|
29 |
backward_model_name = f'Helsinki-NLP/opus-mt-{target_lang}-{source_lang}'
|
30 |
self.tokenizer_backward = MarianTokenizer.from_pretrained(backward_model_name)
|
31 |
self.model_backward = MarianMTModel.from_pretrained(backward_model_name)
|
|
|
|
|
|
|
|
|
|
|
32 |
|
33 |
-
def back_translate(self, text):
|
34 |
-
|
35 |
-
|
36 |
-
|
37 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
38 |
|
39 |
-
|
40 |
-
|
41 |
-
|
42 |
-
|
43 |
-
|
44 |
-
generated_pivot = self.model_pivot_forward.generate(**encoded_pivot)
|
45 |
-
pivot_text = self.tokenizer_pivot_forward.batch_decode(generated_pivot, skip_special_tokens=True)[0]
|
46 |
|
47 |
-
|
48 |
-
|
49 |
-
|
50 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
51 |
|
52 |
-
|
53 |
-
|
54 |
-
|
55 |
-
|
56 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
3 |
MarianTokenizer,
|
4 |
)
|
5 |
|
6 |
+
# Retained for reference but removed from the final code.
|
7 |
+
# This method did not seem helpful for this retrieval-based chatbot.
|
8 |
class BackTranslator:
|
9 |
"""
|
10 |
Perform Back-translation with pivot language. English -> German -> Spanish -> English
|
|
|
22 |
self.tokenizer_pivot_forward = MarianTokenizer.from_pretrained(pivot_forward_model_name)
|
23 |
self.model_pivot_forward = MarianMTModel.from_pretrained(pivot_forward_model_name)
|
24 |
|
25 |
+
# Pivot translation (German to Spanish)
|
26 |
pivot_backward_model_name = f'Helsinki-NLP/opus-mt-{pivot_lang}-{target_lang}'
|
27 |
self.tokenizer_pivot_backward = MarianTokenizer.from_pretrained(pivot_backward_model_name)
|
28 |
self.model_pivot_backward = MarianMTModel.from_pretrained(pivot_backward_model_name)
|
|
|
31 |
backward_model_name = f'Helsinki-NLP/opus-mt-{target_lang}-{source_lang}'
|
32 |
self.tokenizer_backward = MarianTokenizer.from_pretrained(backward_model_name)
|
33 |
self.model_backward = MarianMTModel.from_pretrained(backward_model_name)
|
34 |
+
|
35 |
+
# Set models to eval mode
|
36 |
+
self.model_pivot_forward.eval()
|
37 |
+
self.model_pivot_backward.eval()
|
38 |
+
self.model_backward.eval()
|
39 |
|
40 |
+
def back_translate(self, text, device=None):
|
41 |
+
try:
|
42 |
+
# Move models to device if specified
|
43 |
+
if device is not None:
|
44 |
+
self.model_pivot_forward = self.model_pivot_forward.to(device)
|
45 |
+
self.model_pivot_backward = self.model_pivot_backward.to(device)
|
46 |
+
self.model_backward = self.model_backward.to(device)
|
47 |
+
|
48 |
+
# Forward translation (English to German)
|
49 |
+
encoded_pivot = self.tokenizer_pivot_forward([text], padding=True,
|
50 |
+
truncation=True, return_tensors='pt')
|
51 |
+
if device is not None:
|
52 |
+
encoded_pivot = {k: v.to(device) for k, v in encoded_pivot.items()}
|
53 |
|
54 |
+
generated_pivot = self.model_pivot_forward.generate(**encoded_pivot)
|
55 |
+
if device is not None:
|
56 |
+
generated_pivot = generated_pivot.cpu()
|
57 |
+
pivot_text = self.tokenizer_pivot_forward.batch_decode(generated_pivot,
|
58 |
+
skip_special_tokens=True)[0]
|
|
|
|
|
59 |
|
60 |
+
# Pivot translation (German to Spanish)
|
61 |
+
encoded_back_pivot = self.tokenizer_pivot_backward([pivot_text], padding=True,
|
62 |
+
truncation=True, return_tensors='pt')
|
63 |
+
if device is not None:
|
64 |
+
encoded_back_pivot = {k: v.to(device) for k, v in encoded_back_pivot.items()}
|
65 |
+
|
66 |
+
retranslated_pivot = self.model_pivot_backward.generate(**encoded_back_pivot)
|
67 |
+
if device is not None:
|
68 |
+
retranslated_pivot = retranslated_pivot.cpu()
|
69 |
+
tgt_text_back = self.tokenizer_pivot_backward.batch_decode(retranslated_pivot,
|
70 |
+
skip_special_tokens=True)[0]
|
71 |
|
72 |
+
# Backward translation (Spanish to English)
|
73 |
+
encoded_back = self.tokenizer_backward([tgt_text_back], padding=True,
|
74 |
+
truncation=True, return_tensors='pt')
|
75 |
+
if device is not None:
|
76 |
+
encoded_back = {k: v.to(device) for k, v in encoded_back.items()}
|
77 |
+
|
78 |
+
retranslated = self.model_backward.generate(**encoded_back)
|
79 |
+
if device is not None:
|
80 |
+
retranslated = retranslated.cpu()
|
81 |
+
src_text = self.tokenizer_backward.batch_decode(retranslated,
|
82 |
+
skip_special_tokens=True)[0]
|
83 |
+
|
84 |
+
return src_text
|
85 |
+
except Exception as e:
|
86 |
+
print(f"Error in back translation: {e}")
|
87 |
+
return text
|
chatbot.py
ADDED
@@ -0,0 +1,261 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import tensorflow as tf
|
3 |
+
import keras
|
4 |
+
print(tf.__version__)
|
5 |
+
print(keras.__version__)
|
6 |
+
import spacy
|
7 |
+
import random
|
8 |
+
from tqdm import trange
|
9 |
+
|
10 |
+
class RetrievalChatbot:
|
11 |
+
def __init__(
|
12 |
+
self,
|
13 |
+
vocab_size: int = 10000,
|
14 |
+
max_sequence_length: int = 80,
|
15 |
+
embedding_dim: int = 256,
|
16 |
+
lstm_units: int = 256,
|
17 |
+
num_attention_heads: int = 8,
|
18 |
+
margin: float = 0.3
|
19 |
+
):
|
20 |
+
self.vocab_size = vocab_size
|
21 |
+
self.max_sequence_length = max_sequence_length
|
22 |
+
self.embedding_dim = embedding_dim
|
23 |
+
self.lstm_units = lstm_units
|
24 |
+
self.num_attention_heads = num_attention_heads
|
25 |
+
self.margin = margin
|
26 |
+
|
27 |
+
self.nlp = spacy.load('en_core_web_md')
|
28 |
+
self.tokenizer = tf.keras.preprocessing.text.Tokenizer(
|
29 |
+
num_words=vocab_size,
|
30 |
+
oov_token="<OOV>"
|
31 |
+
)
|
32 |
+
|
33 |
+
self.query_encoder_model, self.response_encoder_model = self._build_encoders()
|
34 |
+
|
35 |
+
def _positional_encoding(self, position: int, d_model: int) -> tf.Tensor:
|
36 |
+
angles = np.arange(position)[:, np.newaxis] / np.power(
|
37 |
+
10000,
|
38 |
+
(2 * (np.arange(d_model)[np.newaxis, :] // 2)) / d_model
|
39 |
+
)
|
40 |
+
sines = np.sin(angles[:, 0::2])
|
41 |
+
cosines = np.cos(angles[:, 1::2])
|
42 |
+
pos_encoding = np.concatenate([sines, cosines], axis=-1)
|
43 |
+
pos_encoding = pos_encoding[np.newaxis, ...]
|
44 |
+
return tf.cast(pos_encoding, dtype=tf.float32)
|
45 |
+
|
46 |
+
def _build_single_encoder(self, name_prefix: str):
|
47 |
+
input_layer = tf.keras.Input(shape=(self.max_sequence_length,), name=f"{name_prefix}_input")
|
48 |
+
embedding = tf.keras.layers.Embedding(
|
49 |
+
self.vocab_size,
|
50 |
+
self.embedding_dim,
|
51 |
+
mask_zero=True,
|
52 |
+
name=f"{name_prefix}_embedding"
|
53 |
+
)(input_layer)
|
54 |
+
|
55 |
+
pos_encoding = self._positional_encoding(self.max_sequence_length, self.embedding_dim)
|
56 |
+
x = embedding + pos_encoding
|
57 |
+
|
58 |
+
# # Multi-head attention
|
59 |
+
# attention_output = tf.keras.layers.MultiHeadAttention(
|
60 |
+
# num_heads=self.num_attention_heads,
|
61 |
+
# key_dim=self.embedding_dim // self.num_attention_heads
|
62 |
+
# )(x, x)
|
63 |
+
# x = tf.keras.layers.LayerNormalization()(x + attention_output)
|
64 |
+
|
65 |
+
for i in range(2):
|
66 |
+
lstm_out = tf.keras.layers.LSTM(
|
67 |
+
self.lstm_units,
|
68 |
+
return_sequences=True,
|
69 |
+
kernel_regularizer=tf.keras.regularizers.l2(0.01),
|
70 |
+
name=f"{name_prefix}_lstm_{i}"
|
71 |
+
)(x)
|
72 |
+
x = tf.keras.layers.LayerNormalization()(x + lstm_out)
|
73 |
+
|
74 |
+
encoder_output = tf.keras.layers.LSTM(
|
75 |
+
self.lstm_units,
|
76 |
+
name=f"{name_prefix}_final_lstm"
|
77 |
+
)(x)
|
78 |
+
encoder_output = tf.keras.layers.Dropout(0.2)(encoder_output)
|
79 |
+
encoder_output = tf.keras.layers.Lambda(lambda t: tf.nn.l2_normalize(t, axis=1))(encoder_output)
|
80 |
+
|
81 |
+
return tf.keras.Model(input_layer, encoder_output, name=f"{name_prefix}_encoder")
|
82 |
+
|
83 |
+
def _build_encoders(self):
|
84 |
+
query_encoder = self._build_single_encoder("query")
|
85 |
+
response_encoder = self._build_single_encoder("response")
|
86 |
+
return query_encoder, response_encoder
|
87 |
+
|
88 |
+
def _spacy_similarity(self, text1: str, text2: str) -> float:
|
89 |
+
doc1 = self.nlp(text1)
|
90 |
+
doc2 = self.nlp(text2)
|
91 |
+
print('doc1:', doc1)
|
92 |
+
print('doc2:', doc2)
|
93 |
+
print('doc1.similarity(doc2):', doc1.similarity(doc2))
|
94 |
+
return doc1.similarity(doc2)
|
95 |
+
|
96 |
+
def prepare_dataset(self, dialogues: list, neg_samples_per_pos=3):
|
97 |
+
# Create triplets: (query, positive, negative)
|
98 |
+
response_pool = [
|
99 |
+
turn['text'] for d in dialogues for turn in d['turns'] if turn['speaker'] == 'assistant'
|
100 |
+
]
|
101 |
+
queries, positives, negatives = [], [], []
|
102 |
+
|
103 |
+
for dialogue in dialogues:
|
104 |
+
turns = dialogue['turns']
|
105 |
+
for i in range(0, len(turns)-1):
|
106 |
+
if turns[i]['speaker'] == 'user' and turns[i+1]['speaker'] == 'assistant':
|
107 |
+
q = turns[i]['text']
|
108 |
+
p = turns[i+1]['text']
|
109 |
+
|
110 |
+
# Find negatives using spaCy similarity
|
111 |
+
neg_candidates = []
|
112 |
+
attempts = 0
|
113 |
+
while len(neg_candidates) < neg_samples_per_pos and attempts < 200:
|
114 |
+
cand = random.choice(response_pool)
|
115 |
+
if cand != p:
|
116 |
+
sim = self._spacy_similarity(cand, p)
|
117 |
+
# Choose thresholds that produce hard negatives
|
118 |
+
if 0.4 < sim < 0.9:
|
119 |
+
neg_candidates.append(cand)
|
120 |
+
attempts += 1
|
121 |
+
|
122 |
+
if len(neg_candidates) == neg_samples_per_pos:
|
123 |
+
for neg in neg_candidates:
|
124 |
+
queries.append(q)
|
125 |
+
positives.append(p)
|
126 |
+
negatives.append(neg)
|
127 |
+
|
128 |
+
# Fit tokenizer
|
129 |
+
all_text = queries + positives + negatives
|
130 |
+
self.tokenizer.fit_on_texts(all_text)
|
131 |
+
|
132 |
+
def seq_pad(txts):
|
133 |
+
seq = self.tokenizer.texts_to_sequences(txts)
|
134 |
+
return tf.keras.preprocessing.sequence.pad_sequences(seq, maxlen=self.max_sequence_length, padding='post')
|
135 |
+
|
136 |
+
q_pad = seq_pad(queries)
|
137 |
+
p_pad = seq_pad(positives)
|
138 |
+
n_pad = seq_pad(negatives)
|
139 |
+
|
140 |
+
return q_pad, p_pad, n_pad
|
141 |
+
|
142 |
+
def triplet_loss(self, q_emb, p_emb, n_emb):
|
143 |
+
pos_dist = tf.reduce_sum(tf.square(q_emb - p_emb), axis=1)
|
144 |
+
neg_dist = tf.reduce_sum(tf.square(q_emb - n_emb), axis=1)
|
145 |
+
loss = tf.maximum(0.0, self.margin + pos_dist - neg_dist)
|
146 |
+
return tf.reduce_mean(loss)
|
147 |
+
|
148 |
+
def train_with_triplet_loss(
|
149 |
+
self, q_pad, p_pad, n_pad,
|
150 |
+
epochs=3,
|
151 |
+
batch_size=16,
|
152 |
+
validation_split=0.2,
|
153 |
+
early_stopping_patience=3,
|
154 |
+
use_tqdm=True
|
155 |
+
):
|
156 |
+
train_losses = []
|
157 |
+
val_losses = []
|
158 |
+
|
159 |
+
total_samples = len(q_pad)
|
160 |
+
idxs = np.arange(total_samples)
|
161 |
+
np.random.shuffle(idxs)
|
162 |
+
train_size = int((1 - validation_split) * total_samples)
|
163 |
+
|
164 |
+
train_idxs = idxs[:train_size]
|
165 |
+
val_idxs = idxs[train_size:]
|
166 |
+
|
167 |
+
q_train, p_train, n_train = q_pad[train_idxs], p_pad[train_idxs], n_pad[train_idxs]
|
168 |
+
q_val, p_val, n_val = q_pad[val_idxs], p_pad[val_idxs], n_pad[val_idxs]
|
169 |
+
|
170 |
+
optimizer = tf.keras.optimizers.Adam(learning_rate=0.001)
|
171 |
+
best_val_loss = float('inf')
|
172 |
+
wait = 0
|
173 |
+
|
174 |
+
for epoch in range(epochs):
|
175 |
+
# Shuffle training data each epoch
|
176 |
+
perm = np.random.permutation(len(q_train))
|
177 |
+
q_train, p_train, n_train = q_train[perm], p_train[perm], n_train[perm]
|
178 |
+
|
179 |
+
num_batches = len(q_train) // batch_size
|
180 |
+
epoch_train_loss = 0.0
|
181 |
+
|
182 |
+
batch_iter = range(num_batches)
|
183 |
+
if use_tqdm:
|
184 |
+
batch_iter = trange(num_batches, desc=f"Epoch {epoch+1}/{epochs}")
|
185 |
+
|
186 |
+
for i in batch_iter:
|
187 |
+
q_batch = q_train[i*batch_size:(i+1)*batch_size]
|
188 |
+
p_batch = p_train[i*batch_size:(i+1)*batch_size]
|
189 |
+
n_batch = n_train[i*batch_size:(i+1)*batch_size]
|
190 |
+
|
191 |
+
with tf.GradientTape() as tape:
|
192 |
+
q_emb = self.query_encoder_model(q_batch, training=True)
|
193 |
+
p_emb = self.response_encoder_model(p_batch, training=True)
|
194 |
+
n_emb = self.response_encoder_model(n_batch, training=True)
|
195 |
+
loss = self.triplet_loss(q_emb, p_emb, n_emb)
|
196 |
+
|
197 |
+
grads = tape.gradient(
|
198 |
+
loss,
|
199 |
+
self.query_encoder_model.trainable_variables +
|
200 |
+
self.response_encoder_model.trainable_variables
|
201 |
+
)
|
202 |
+
optimizer.apply_gradients(zip(
|
203 |
+
grads,
|
204 |
+
self.query_encoder_model.trainable_variables +
|
205 |
+
self.response_encoder_model.trainable_variables
|
206 |
+
))
|
207 |
+
epoch_train_loss += loss.numpy()
|
208 |
+
|
209 |
+
epoch_train_loss /= num_batches
|
210 |
+
|
211 |
+
# Validation loss
|
212 |
+
val_batches = len(q_val) // batch_size
|
213 |
+
epoch_val_loss = 0.0
|
214 |
+
for i in range(val_batches):
|
215 |
+
q_batch = q_val[i*batch_size:(i+1)*batch_size]
|
216 |
+
p_batch = p_val[i*batch_size:(i+1)*batch_size]
|
217 |
+
n_batch = n_val[i*batch_size:(i+1)*batch_size]
|
218 |
+
|
219 |
+
q_emb = self.query_encoder_model(q_batch, training=False)
|
220 |
+
p_emb = self.response_encoder_model(p_batch, training=False)
|
221 |
+
n_emb = self.response_encoder_model(n_batch, training=False)
|
222 |
+
v_loss = self.triplet_loss(q_emb, p_emb, n_emb)
|
223 |
+
epoch_val_loss += v_loss.numpy()
|
224 |
+
|
225 |
+
if val_batches > 0:
|
226 |
+
epoch_val_loss /= val_batches
|
227 |
+
|
228 |
+
train_losses.append(epoch_train_loss)
|
229 |
+
val_losses.append(epoch_val_loss)
|
230 |
+
|
231 |
+
print(f"Epoch {epoch+1}/{epochs}, Train Loss: {epoch_train_loss:.4f}, Val Loss: {epoch_val_loss:.4f}")
|
232 |
+
|
233 |
+
# Early Stopping logic
|
234 |
+
if epoch_val_loss < best_val_loss:
|
235 |
+
best_val_loss = epoch_val_loss
|
236 |
+
wait = 0
|
237 |
+
# (Optional) Save best weights
|
238 |
+
else:
|
239 |
+
wait += 1
|
240 |
+
if wait >= early_stopping_patience:
|
241 |
+
print("Early stopping triggered.")
|
242 |
+
break
|
243 |
+
|
244 |
+
return train_losses, val_losses
|
245 |
+
|
246 |
+
def encode_texts(self, texts, is_query=True):
|
247 |
+
seq = self.tokenizer.texts_to_sequences(texts)
|
248 |
+
pad_seq = tf.keras.preprocessing.sequence.pad_sequences(seq, maxlen=self.max_sequence_length, padding='post')
|
249 |
+
if is_query:
|
250 |
+
return self.query_encoder_model(pad_seq, training=False)
|
251 |
+
else:
|
252 |
+
return self.response_encoder_model(pad_seq, training=False)
|
253 |
+
|
254 |
+
def retrieve_top_n(self, query: str, candidates: list, top_n=5):
|
255 |
+
q_emb = self.encode_texts([query], is_query=True) # shape (1, d)
|
256 |
+
c_emb = self.encode_texts(candidates, is_query=False) # shape (num_cand, d)
|
257 |
+
sim = tf.matmul(q_emb, c_emb, transpose_b=True).numpy()[0] # dot product similarity
|
258 |
+
top_indices = np.argsort(sim)[::-1][:top_n]
|
259 |
+
return [(candidates[i], sim[i]) for i in top_indices]
|
260 |
+
|
261 |
+
|
chatbot2.py
ADDED
@@ -0,0 +1,839 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import tensorflow as tf
|
3 |
+
import spacy
|
4 |
+
import random
|
5 |
+
from typing import List, Tuple, Dict, Optional, Union
|
6 |
+
from dataclasses import dataclass
|
7 |
+
from tqdm import tqdm
|
8 |
+
import logging
|
9 |
+
from pathlib import Path
|
10 |
+
import json
|
11 |
+
|
12 |
+
# Configure logging
|
13 |
+
logging.basicConfig(
|
14 |
+
level=logging.INFO,
|
15 |
+
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
|
16 |
+
)
|
17 |
+
logger = logging.getLogger(__name__)
|
18 |
+
|
19 |
+
@dataclass
|
20 |
+
class ChatbotConfig:
|
21 |
+
"""Configuration for the retrieval chatbot."""
|
22 |
+
vocab_size: int = 10000
|
23 |
+
max_sequence_length: int = 512
|
24 |
+
embedding_dim: int = 256
|
25 |
+
encoder_units: int = 256
|
26 |
+
num_attention_heads: int = 8
|
27 |
+
dropout_rate: float = 0.2
|
28 |
+
l2_reg_weight: float = 0.001
|
29 |
+
margin: float = 0.3
|
30 |
+
learning_rate: float = 0.001
|
31 |
+
min_text_length: int = 3 # Reduced from 10 to allow shorter responses
|
32 |
+
max_context_turns: int = 5
|
33 |
+
warmup_steps: int = 200
|
34 |
+
spacy_model: str = 'en_core_web_md'
|
35 |
+
|
36 |
+
def to_dict(self) -> dict:
|
37 |
+
"""Convert config to dictionary."""
|
38 |
+
return {k: str(v) if isinstance(v, Path) else v
|
39 |
+
for k, v in self.__dict__.items()}
|
40 |
+
|
41 |
+
@classmethod
|
42 |
+
def from_dict(cls, config_dict: dict) -> 'ChatbotConfig':
|
43 |
+
"""Create config from dictionary."""
|
44 |
+
return cls(**{k: v for k, v in config_dict.items()
|
45 |
+
if k in cls.__dataclass_fields__})
|
46 |
+
|
47 |
+
class TransformerBlock(tf.keras.layers.Layer):
|
48 |
+
"""Custom Transformer block with pre-layer normalization."""
|
49 |
+
def __init__(
|
50 |
+
self,
|
51 |
+
embed_dim: int,
|
52 |
+
num_heads: int,
|
53 |
+
ff_dim: int,
|
54 |
+
dropout: float = 0.1,
|
55 |
+
**kwargs
|
56 |
+
):
|
57 |
+
super().__init__(**kwargs)
|
58 |
+
self.embed_dim = embed_dim
|
59 |
+
self.num_heads = num_heads
|
60 |
+
self.ff_dim = ff_dim
|
61 |
+
self.dropout = dropout
|
62 |
+
|
63 |
+
self.attention = tf.keras.layers.MultiHeadAttention(
|
64 |
+
num_heads=num_heads,
|
65 |
+
key_dim=embed_dim // num_heads,
|
66 |
+
dropout=dropout
|
67 |
+
)
|
68 |
+
self.ffn = tf.keras.Sequential([
|
69 |
+
tf.keras.layers.Dense(ff_dim, activation="gelu"),
|
70 |
+
tf.keras.layers.Dense(embed_dim),
|
71 |
+
])
|
72 |
+
|
73 |
+
self.layernorm1 = tf.keras.layers.LayerNormalization(epsilon=1e-6)
|
74 |
+
self.layernorm2 = tf.keras.layers.LayerNormalization(epsilon=1e-6)
|
75 |
+
self.dropout1 = tf.keras.layers.Dropout(dropout)
|
76 |
+
self.dropout2 = tf.keras.layers.Dropout(dropout)
|
77 |
+
|
78 |
+
def call(self, inputs: tf.Tensor, training: bool, mask: Optional[tf.Tensor] = None) -> tf.Tensor:
|
79 |
+
# Pre-layer normalization
|
80 |
+
norm_inputs = self.layernorm1(inputs)
|
81 |
+
|
82 |
+
# Self-attention
|
83 |
+
attention_output = self.attention(
|
84 |
+
query=norm_inputs,
|
85 |
+
value=norm_inputs,
|
86 |
+
key=norm_inputs,
|
87 |
+
attention_mask=mask,
|
88 |
+
training=training
|
89 |
+
)
|
90 |
+
attention_output = self.dropout1(attention_output, training=training)
|
91 |
+
attention_output = inputs + attention_output
|
92 |
+
|
93 |
+
# Feed-forward network
|
94 |
+
norm_attention = self.layernorm2(attention_output)
|
95 |
+
ffn_output = self.ffn(norm_attention)
|
96 |
+
ffn_output = self.dropout2(ffn_output, training=training)
|
97 |
+
|
98 |
+
return attention_output + ffn_output
|
99 |
+
|
100 |
+
def get_config(self) -> dict:
|
101 |
+
config = super().get_config()
|
102 |
+
config.update({
|
103 |
+
"embed_dim": self.embed_dim,
|
104 |
+
"num_heads": self.num_heads,
|
105 |
+
"ff_dim": self.ff_dim,
|
106 |
+
"dropout": self.dropout,
|
107 |
+
})
|
108 |
+
return config
|
109 |
+
|
110 |
+
class EncoderModel(tf.keras.Model):
|
111 |
+
"""Dual encoder model with shared weights option."""
|
112 |
+
def __init__(
|
113 |
+
self,
|
114 |
+
config: ChatbotConfig,
|
115 |
+
name: str = "encoder",
|
116 |
+
shared_weights: bool = False,
|
117 |
+
**kwargs
|
118 |
+
):
|
119 |
+
super().__init__(name=name, **kwargs)
|
120 |
+
self.config = config
|
121 |
+
self.shared_weights = shared_weights
|
122 |
+
|
123 |
+
# Input embedding layer
|
124 |
+
self.embedding = tf.keras.layers.Embedding(
|
125 |
+
config.vocab_size,
|
126 |
+
config.embedding_dim,
|
127 |
+
mask_zero=True,
|
128 |
+
name=f"{name}_embedding"
|
129 |
+
)
|
130 |
+
|
131 |
+
# Positional encoding
|
132 |
+
self.pos_encoding = self._get_positional_encoding()
|
133 |
+
|
134 |
+
# Transformer blocks
|
135 |
+
self.transformer_blocks = [
|
136 |
+
TransformerBlock(
|
137 |
+
config.embedding_dim,
|
138 |
+
config.num_attention_heads,
|
139 |
+
config.encoder_units * 4,
|
140 |
+
config.dropout_rate,
|
141 |
+
name=f"{name}_transformer_{i}"
|
142 |
+
) for i in range(3)
|
143 |
+
]
|
144 |
+
|
145 |
+
# Final LSTM layer
|
146 |
+
self.final_lstm = tf.keras.layers.LSTM(
|
147 |
+
config.encoder_units,
|
148 |
+
kernel_regularizer=tf.keras.regularizers.l2(config.l2_reg_weight),
|
149 |
+
name=f"{name}_final_lstm"
|
150 |
+
)
|
151 |
+
|
152 |
+
self.dropout = tf.keras.layers.Dropout(config.dropout_rate)
|
153 |
+
self.normalize = tf.keras.layers.Lambda(
|
154 |
+
lambda x: tf.nn.l2_normalize(x, axis=1)
|
155 |
+
)
|
156 |
+
|
157 |
+
def _get_positional_encoding(self) -> tf.Tensor:
|
158 |
+
"""Generate positional encoding matrix."""
|
159 |
+
pos = np.arange(self.config.max_sequence_length)[:, np.newaxis]
|
160 |
+
i = np.arange(self.config.embedding_dim)[np.newaxis, :]
|
161 |
+
angle = pos / np.power(10000, (2 * (i // 2)) / self.config.embedding_dim)
|
162 |
+
|
163 |
+
pos_encoding = np.zeros_like(angle)
|
164 |
+
pos_encoding[:, 0::2] = np.sin(angle[:, 0::2])
|
165 |
+
pos_encoding[:, 1::2] = np.cos(angle[:, 1::2])
|
166 |
+
|
167 |
+
return tf.cast(pos_encoding[np.newaxis, ...], dtype=tf.float32)
|
168 |
+
|
169 |
+
def call(self, inputs: tf.Tensor, training: bool = False) -> tf.Tensor:
|
170 |
+
# Get input mask
|
171 |
+
mask = self.embedding.compute_mask(inputs)
|
172 |
+
mask = mask[:, tf.newaxis, tf.newaxis, :] # Add attention dims
|
173 |
+
|
174 |
+
# Embedding + positional encoding
|
175 |
+
x = self.embedding(inputs)
|
176 |
+
x = x + self.pos_encoding
|
177 |
+
|
178 |
+
# Apply transformer blocks
|
179 |
+
for transformer_block in self.transformer_blocks:
|
180 |
+
x = transformer_block(x, training=training, mask=mask)
|
181 |
+
|
182 |
+
# Final processing
|
183 |
+
x = self.final_lstm(x)
|
184 |
+
x = self.dropout(x, training=training)
|
185 |
+
return self.normalize(x)
|
186 |
+
|
187 |
+
class RetrievalChatbot:
|
188 |
+
"""Professional implementation of a retrieval-based chatbot."""
|
189 |
+
def __init__(self, config: ChatbotConfig):
|
190 |
+
self.config = config
|
191 |
+
self.nlp = spacy.load(config.spacy_model)
|
192 |
+
|
193 |
+
# Initialize tokenizer
|
194 |
+
self.tokenizer = tf.keras.preprocessing.text.Tokenizer(
|
195 |
+
num_words=config.vocab_size,
|
196 |
+
oov_token="<OOV>"
|
197 |
+
)
|
198 |
+
|
199 |
+
# Special tokens
|
200 |
+
self.special_tokens = {
|
201 |
+
"user": "<USER>",
|
202 |
+
"assistant": "<ASSISTANT>",
|
203 |
+
"context": "<CONTEXT>",
|
204 |
+
"sep": "<SEP>"
|
205 |
+
}
|
206 |
+
|
207 |
+
# Build models
|
208 |
+
self._build_models()
|
209 |
+
|
210 |
+
# Training history
|
211 |
+
self.history = {
|
212 |
+
"train_loss": [],
|
213 |
+
"val_loss": [],
|
214 |
+
"train_metrics": {},
|
215 |
+
"val_metrics": {}
|
216 |
+
}
|
217 |
+
|
218 |
+
# Initialize similarity cache
|
219 |
+
self.similarity_cache = {}
|
220 |
+
|
221 |
+
def _build_models(self):
|
222 |
+
"""Initialize the encoder models."""
|
223 |
+
# Query encoder
|
224 |
+
self.query_encoder = EncoderModel(
|
225 |
+
self.config,
|
226 |
+
name="query_encoder",
|
227 |
+
shared_weights=False
|
228 |
+
)
|
229 |
+
|
230 |
+
# Response encoder (can share weights with query encoder)
|
231 |
+
self.response_encoder = EncoderModel(
|
232 |
+
self.config,
|
233 |
+
name="response_encoder",
|
234 |
+
shared_weights=False
|
235 |
+
)
|
236 |
+
|
237 |
+
def save_models(self, save_dir: Union[str, Path]):
|
238 |
+
"""Save models and configuration."""
|
239 |
+
save_dir = Path(save_dir)
|
240 |
+
save_dir.mkdir(parents=True, exist_ok=True)
|
241 |
+
|
242 |
+
# Save config
|
243 |
+
with open(save_dir / "config.json", "w") as f:
|
244 |
+
json.dump(self.config.to_dict(), f, indent=2)
|
245 |
+
|
246 |
+
# Save models with proper extension
|
247 |
+
self.query_encoder.save(save_dir / "query_encoder.keras")
|
248 |
+
self.response_encoder.save(save_dir / "response_encoder.keras")
|
249 |
+
|
250 |
+
# Save tokenizer config
|
251 |
+
tokenizer_config = {
|
252 |
+
"word_index": self.tokenizer.word_index,
|
253 |
+
"word_counts": self.tokenizer.word_counts,
|
254 |
+
"document_count": self.tokenizer.document_count,
|
255 |
+
"index_docs": self.tokenizer.index_docs,
|
256 |
+
"index_word": self.tokenizer.index_word
|
257 |
+
}
|
258 |
+
with open(save_dir / "tokenizer_config.json", "w") as f:
|
259 |
+
json.dump(tokenizer_config, f)
|
260 |
+
|
261 |
+
@classmethod
|
262 |
+
def load_models(cls, load_dir: Union[str, Path]) -> 'RetrievalChatbot':
|
263 |
+
"""Load saved models and configuration."""
|
264 |
+
load_dir = Path(load_dir)
|
265 |
+
|
266 |
+
# Load config
|
267 |
+
with open(load_dir / "config.json", "r") as f:
|
268 |
+
config = ChatbotConfig.from_dict(json.load(f))
|
269 |
+
|
270 |
+
# Initialize chatbot
|
271 |
+
chatbot = cls(config)
|
272 |
+
|
273 |
+
# Load models with proper extension
|
274 |
+
chatbot.query_encoder = tf.keras.models.load_model(
|
275 |
+
load_dir / "query_encoder.keras",
|
276 |
+
custom_objects={"TransformerBlock": TransformerBlock}
|
277 |
+
)
|
278 |
+
chatbot.response_encoder = tf.keras.models.load_model(
|
279 |
+
load_dir / "response_encoder.keras",
|
280 |
+
custom_objects={"TransformerBlock": TransformerBlock}
|
281 |
+
)
|
282 |
+
|
283 |
+
# Load tokenizer config
|
284 |
+
with open(load_dir / "tokenizer_config.json", "r") as f:
|
285 |
+
tokenizer_config = json.load(f)
|
286 |
+
|
287 |
+
chatbot.tokenizer = tf.keras.preprocessing.text.Tokenizer(
|
288 |
+
num_words=config.vocab_size,
|
289 |
+
oov_token="<OOV>"
|
290 |
+
)
|
291 |
+
chatbot.tokenizer.word_index = tokenizer_config["word_index"]
|
292 |
+
chatbot.tokenizer.word_counts = tokenizer_config["word_counts"]
|
293 |
+
chatbot.tokenizer.document_count = tokenizer_config["document_count"]
|
294 |
+
chatbot.tokenizer.index_docs = tokenizer_config["index_docs"]
|
295 |
+
chatbot.tokenizer.index_word = tokenizer_config["index_word"]
|
296 |
+
|
297 |
+
return chatbot
|
298 |
+
|
299 |
+
def _improved_spacy_similarity(self, text1: str, text2: str) -> float:
|
300 |
+
"""Calculate semantic similarity between texts with preprocessing."""
|
301 |
+
def preprocess(text: str) -> str:
|
302 |
+
# Basic cleaning
|
303 |
+
text = ' '.join(text.split())
|
304 |
+
return text if text.strip() else "empty_document"
|
305 |
+
|
306 |
+
# Get cache key
|
307 |
+
cache_key = f"{hash(text1)}_{hash(text2)}"
|
308 |
+
if cache_key in self.similarity_cache:
|
309 |
+
return self.similarity_cache[cache_key]
|
310 |
+
|
311 |
+
# Process texts
|
312 |
+
text1, text2 = preprocess(text1), preprocess(text2)
|
313 |
+
doc1, doc2 = self.nlp(text1), self.nlp(text2)
|
314 |
+
|
315 |
+
# Calculate similarity
|
316 |
+
if doc1.has_vector and doc2.has_vector:
|
317 |
+
sim = doc1.similarity(doc2)
|
318 |
+
else:
|
319 |
+
# Fallback to token overlap similarity
|
320 |
+
tokens1 = {t.lower_ for t in doc1 if not t.is_stop and not t.is_punct}
|
321 |
+
tokens2 = {t.lower_ for t in doc2 if not t.is_stop and not t.is_punct}
|
322 |
+
intersection = len(tokens1.intersection(tokens2))
|
323 |
+
union = len(tokens1.union(tokens2))
|
324 |
+
sim = intersection / union if union > 0 else 0.0
|
325 |
+
|
326 |
+
# Cache result
|
327 |
+
self.similarity_cache[cache_key] = sim
|
328 |
+
return sim
|
329 |
+
|
330 |
+
def _smart_negative_sampling(
|
331 |
+
self,
|
332 |
+
positive: str,
|
333 |
+
response_pool: List[str],
|
334 |
+
n_samples: int,
|
335 |
+
max_attempts: int = 200,
|
336 |
+
similarity_bounds: Tuple[float, float] = (0.3, 0.8),
|
337 |
+
batch_size: int = 10
|
338 |
+
) -> List[str]:
|
339 |
+
"""Smart negative sampling with similarity bounds and batching."""
|
340 |
+
candidates = []
|
341 |
+
seen = set()
|
342 |
+
attempts = 0
|
343 |
+
|
344 |
+
while len(candidates) < n_samples and attempts < max_attempts:
|
345 |
+
# Batch process candidates
|
346 |
+
batch = random.sample(
|
347 |
+
response_pool,
|
348 |
+
min(batch_size, max_attempts - attempts)
|
349 |
+
)
|
350 |
+
|
351 |
+
for candidate in batch:
|
352 |
+
if candidate != positive and candidate not in seen:
|
353 |
+
seen.add(candidate)
|
354 |
+
sim = self._improved_spacy_similarity(candidate, positive)
|
355 |
+
|
356 |
+
# Check similarity bounds
|
357 |
+
if similarity_bounds[0] < sim < similarity_bounds[1]:
|
358 |
+
candidates.append(candidate)
|
359 |
+
if len(candidates) == n_samples:
|
360 |
+
break
|
361 |
+
|
362 |
+
attempts += len(batch)
|
363 |
+
|
364 |
+
return candidates
|
365 |
+
|
366 |
+
def train(
|
367 |
+
self,
|
368 |
+
q_pad: tf.Tensor,
|
369 |
+
p_pad: tf.Tensor,
|
370 |
+
n_pad: tf.Tensor,
|
371 |
+
epochs: int = 3,
|
372 |
+
batch_size: int = 32,
|
373 |
+
validation_split: float = 0.2,
|
374 |
+
checkpoint_dir: Optional[Union[str, Path]] = None
|
375 |
+
):
|
376 |
+
"""Train the model with improved training loop."""
|
377 |
+
# Setup training
|
378 |
+
total_samples = len(q_pad)
|
379 |
+
train_size = int((1 - validation_split) * total_samples)
|
380 |
+
|
381 |
+
# Split data
|
382 |
+
indices = np.random.permutation(total_samples)
|
383 |
+
train_idx, val_idx = indices[:train_size], indices[train_size:]
|
384 |
+
|
385 |
+
train_data = (q_pad[train_idx], p_pad[train_idx], n_pad[train_idx])
|
386 |
+
val_data = (q_pad[val_idx], p_pad[val_idx], n_pad[val_idx])
|
387 |
+
|
388 |
+
# Setup optimizer with learning rate schedule
|
389 |
+
steps_per_epoch = train_size // batch_size
|
390 |
+
total_steps = steps_per_epoch * epochs
|
391 |
+
|
392 |
+
lr_schedule = self._get_lr_schedule(
|
393 |
+
total_steps,
|
394 |
+
self.config.learning_rate,
|
395 |
+
self.config.warmup_steps
|
396 |
+
)
|
397 |
+
|
398 |
+
optimizer = tf.keras.optimizers.Adam(lr_schedule)
|
399 |
+
|
400 |
+
# Setup checkpointing
|
401 |
+
if checkpoint_dir:
|
402 |
+
checkpoint_dir = Path(checkpoint_dir)
|
403 |
+
checkpoint_dir.mkdir(parents=True, exist_ok=True)
|
404 |
+
|
405 |
+
# Setup checkpoint callback with correct file format
|
406 |
+
checkpoint_template = str(checkpoint_dir / "model_epoch_{epoch:04d}.weights.h5")
|
407 |
+
checkpoint_callback = tf.keras.callbacks.ModelCheckpoint(
|
408 |
+
checkpoint_template,
|
409 |
+
save_weights_only=True,
|
410 |
+
save_best_only=True,
|
411 |
+
monitor='val_loss',
|
412 |
+
mode='min',
|
413 |
+
verbose=1
|
414 |
+
)
|
415 |
+
|
416 |
+
# Training loop
|
417 |
+
best_val_loss = float('inf')
|
418 |
+
patience = 5
|
419 |
+
wait = 0
|
420 |
+
|
421 |
+
for epoch in range(epochs):
|
422 |
+
# Training
|
423 |
+
train_loss = self._train_epoch(
|
424 |
+
train_data,
|
425 |
+
optimizer,
|
426 |
+
batch_size,
|
427 |
+
training=True
|
428 |
+
)
|
429 |
+
|
430 |
+
# Validation
|
431 |
+
val_loss = self._train_epoch(
|
432 |
+
val_data,
|
433 |
+
optimizer,
|
434 |
+
batch_size,
|
435 |
+
training=False
|
436 |
+
)
|
437 |
+
|
438 |
+
# Update history
|
439 |
+
self.history['train_loss'].append(train_loss)
|
440 |
+
self.history['val_loss'].append(val_loss)
|
441 |
+
|
442 |
+
logger.info(
|
443 |
+
f"Epoch {epoch + 1}/{epochs} - "
|
444 |
+
f"train_loss: {train_loss:.4f} - "
|
445 |
+
f"val_loss: {val_loss:.4f}"
|
446 |
+
)
|
447 |
+
|
448 |
+
# Early stopping
|
449 |
+
if val_loss < best_val_loss:
|
450 |
+
best_val_loss = val_loss
|
451 |
+
wait = 0
|
452 |
+
if checkpoint_dir:
|
453 |
+
self.save_models(checkpoint_dir / f"best_model")
|
454 |
+
else:
|
455 |
+
wait += 1
|
456 |
+
if wait >= patience:
|
457 |
+
logger.info("Early stopping triggered")
|
458 |
+
break
|
459 |
+
|
460 |
+
def _get_lr_schedule(
|
461 |
+
self,
|
462 |
+
total_steps: int,
|
463 |
+
peak_lr: float,
|
464 |
+
warmup_steps: int
|
465 |
+
) -> tf.keras.optimizers.schedules.LearningRateSchedule:
|
466 |
+
"""Enhanced learning rate schedule with better error handling and logging."""
|
467 |
+
class CustomSchedule(tf.keras.optimizers.schedules.LearningRateSchedule):
|
468 |
+
def __init__(
|
469 |
+
self,
|
470 |
+
total_steps: int,
|
471 |
+
peak_lr: float,
|
472 |
+
warmup_steps: int
|
473 |
+
):
|
474 |
+
super().__init__()
|
475 |
+
self.total_steps = tf.cast(total_steps, tf.float32)
|
476 |
+
self.peak_lr = tf.cast(peak_lr, tf.float32)
|
477 |
+
self.warmup_steps = tf.cast(max(1, warmup_steps), tf.float32) # Prevent 0
|
478 |
+
|
479 |
+
# Calculate and store constants
|
480 |
+
self.initial_lr = self.peak_lr * 0.1 # Start at 10% of peak
|
481 |
+
self.min_lr = self.peak_lr * 0.01 # Minimum 1% of peak
|
482 |
+
|
483 |
+
logger.info(f"Learning rate schedule initialized:")
|
484 |
+
logger.info(f" Initial LR: {float(self.initial_lr):.6f}")
|
485 |
+
logger.info(f" Peak LR: {float(self.peak_lr):.6f}")
|
486 |
+
logger.info(f" Min LR: {float(self.min_lr):.6f}")
|
487 |
+
logger.info(f" Warmup steps: {int(self.warmup_steps)}")
|
488 |
+
logger.info(f" Total steps: {int(self.total_steps)}")
|
489 |
+
|
490 |
+
def __call__(self, step):
|
491 |
+
step = tf.cast(step, tf.float32)
|
492 |
+
|
493 |
+
# Warmup phase
|
494 |
+
warmup_factor = tf.minimum(1.0, step / self.warmup_steps)
|
495 |
+
warmup_lr = self.initial_lr + (self.peak_lr - self.initial_lr) * warmup_factor
|
496 |
+
|
497 |
+
# Decay phase
|
498 |
+
decay_steps = tf.maximum(1.0, self.total_steps - self.warmup_steps)
|
499 |
+
decay_factor = (step - self.warmup_steps) / decay_steps
|
500 |
+
decay_factor = tf.minimum(tf.maximum(0.0, decay_factor), 1.0) # Clip to [0,1]
|
501 |
+
|
502 |
+
cosine_decay = 0.5 * (1.0 + tf.cos(tf.constant(np.pi) * decay_factor))
|
503 |
+
decay_lr = self.min_lr + (self.peak_lr - self.min_lr) * cosine_decay
|
504 |
+
|
505 |
+
# Choose between warmup and decay
|
506 |
+
final_lr = tf.where(step < self.warmup_steps, warmup_lr, decay_lr)
|
507 |
+
|
508 |
+
# Ensure learning rate is valid
|
509 |
+
final_lr = tf.maximum(self.min_lr, final_lr)
|
510 |
+
final_lr = tf.where(tf.math.is_finite(final_lr), final_lr, self.min_lr)
|
511 |
+
|
512 |
+
return final_lr
|
513 |
+
|
514 |
+
def get_config(self):
|
515 |
+
return {
|
516 |
+
"total_steps": self.total_steps,
|
517 |
+
"peak_lr": self.peak_lr,
|
518 |
+
"warmup_steps": self.warmup_steps,
|
519 |
+
}
|
520 |
+
|
521 |
+
return CustomSchedule(total_steps, peak_lr, warmup_steps)
|
522 |
+
|
523 |
+
@tf.function
|
524 |
+
def _train_step(
|
525 |
+
self,
|
526 |
+
q_batch: tf.Tensor,
|
527 |
+
p_batch: tf.Tensor,
|
528 |
+
n_batch: tf.Tensor,
|
529 |
+
optimizer: tf.keras.optimizers.Optimizer,
|
530 |
+
training: bool = True
|
531 |
+
) -> tf.Tensor:
|
532 |
+
"""Single training step with triplet loss."""
|
533 |
+
with tf.GradientTape() as tape:
|
534 |
+
# Get embeddings
|
535 |
+
q_emb = self.query_encoder(q_batch, training=training)
|
536 |
+
p_emb = self.response_encoder(p_batch, training=training)
|
537 |
+
n_emb = self.response_encoder(n_batch, training=training)
|
538 |
+
|
539 |
+
# Calculate triplet loss
|
540 |
+
pos_dist = tf.reduce_sum(tf.square(q_emb - p_emb), axis=1)
|
541 |
+
neg_dist = tf.reduce_sum(tf.square(q_emb - n_emb), axis=1)
|
542 |
+
|
543 |
+
loss = tf.maximum(0.0, self.config.margin + pos_dist - neg_dist)
|
544 |
+
loss = tf.reduce_mean(loss)
|
545 |
+
|
546 |
+
if training:
|
547 |
+
# Apply gradients
|
548 |
+
gradients = tape.gradient(
|
549 |
+
loss,
|
550 |
+
self.query_encoder.trainable_variables +
|
551 |
+
self.response_encoder.trainable_variables
|
552 |
+
)
|
553 |
+
optimizer.apply_gradients(zip(
|
554 |
+
gradients,
|
555 |
+
self.query_encoder.trainable_variables +
|
556 |
+
self.response_encoder.trainable_variables
|
557 |
+
))
|
558 |
+
|
559 |
+
return loss
|
560 |
+
|
561 |
+
def _train_epoch(
|
562 |
+
self,
|
563 |
+
data: Tuple[tf.Tensor, tf.Tensor, tf.Tensor],
|
564 |
+
optimizer: tf.keras.optimizers.Optimizer,
|
565 |
+
batch_size: int,
|
566 |
+
training: bool = True
|
567 |
+
) -> float:
|
568 |
+
"""Train for one epoch with enhanced logging and progress tracking."""
|
569 |
+
q_data, p_data, n_data = data
|
570 |
+
total_loss = 0
|
571 |
+
num_batches = len(q_data) // batch_size
|
572 |
+
|
573 |
+
# Log current learning rate at start of epoch
|
574 |
+
if training:
|
575 |
+
if hasattr(optimizer.learning_rate, '__call__'):
|
576 |
+
current_lr = optimizer.learning_rate(optimizer.iterations)
|
577 |
+
else:
|
578 |
+
current_lr = optimizer.learning_rate
|
579 |
+
logger.info(f"Current learning rate: {float(current_lr):.6f}")
|
580 |
+
|
581 |
+
# Shuffle data
|
582 |
+
indices = np.random.permutation(len(q_data))
|
583 |
+
q_data = q_data[indices]
|
584 |
+
p_data = p_data[indices]
|
585 |
+
n_data = n_data[indices]
|
586 |
+
|
587 |
+
# Create progress bar
|
588 |
+
mode = "Training" if training else "Validation"
|
589 |
+
pbar = tqdm(
|
590 |
+
total=num_batches,
|
591 |
+
desc=f"{mode} batches",
|
592 |
+
unit="batch",
|
593 |
+
dynamic_ncols=True # Automatically adjust width
|
594 |
+
)
|
595 |
+
|
596 |
+
# Process batches
|
597 |
+
for i in range(num_batches):
|
598 |
+
start_idx = i * batch_size
|
599 |
+
end_idx = start_idx + batch_size
|
600 |
+
|
601 |
+
batch_loss = self._train_step(
|
602 |
+
q_data[start_idx:end_idx],
|
603 |
+
p_data[start_idx:end_idx],
|
604 |
+
n_data[start_idx:end_idx],
|
605 |
+
optimizer,
|
606 |
+
training
|
607 |
+
)
|
608 |
+
total_loss += batch_loss
|
609 |
+
|
610 |
+
# Update progress bar with current loss
|
611 |
+
avg_loss = total_loss / (i + 1)
|
612 |
+
pbar.set_postfix({
|
613 |
+
'loss': f'{avg_loss:.4f}',
|
614 |
+
'lr': f'{float(current_lr):.6f}' if training else 'N/A'
|
615 |
+
})
|
616 |
+
pbar.update(1)
|
617 |
+
|
618 |
+
pbar.close()
|
619 |
+
return total_loss / num_batches if num_batches > 0 else 0
|
620 |
+
|
621 |
+
def _prepare_sequences(
|
622 |
+
self,
|
623 |
+
queries: List[str],
|
624 |
+
positives: List[str],
|
625 |
+
negatives: List[str]
|
626 |
+
) -> Tuple[tf.Tensor, tf.Tensor, tf.Tensor]:
|
627 |
+
"""Enhanced sequence preparation with logging and text preprocessing."""
|
628 |
+
logger.info("Preparing sequences...")
|
629 |
+
|
630 |
+
# Text cleaning function from old version
|
631 |
+
def clean_text(text: str) -> str:
|
632 |
+
# Remove excessive whitespace
|
633 |
+
text = ' '.join(text.split())
|
634 |
+
# Remove very long repetitive sequences
|
635 |
+
if len(text) > 500: # Add length limit
|
636 |
+
text = ' '.join(dict.fromkeys(text.split()))
|
637 |
+
return text
|
638 |
+
|
639 |
+
# Process texts with special tokens and cleaning
|
640 |
+
queries = [f"{self.special_tokens['user']} {clean_text(q)}" for q in queries]
|
641 |
+
positives = [f"{self.special_tokens['assistant']} {clean_text(p)}" for p in positives]
|
642 |
+
negatives = [f"{self.special_tokens['assistant']} {clean_text(n)}" for n in negatives]
|
643 |
+
|
644 |
+
# Fit tokenizer and log vocabulary statistics
|
645 |
+
all_texts = queries + positives + negatives
|
646 |
+
self.tokenizer.fit_on_texts(all_texts)
|
647 |
+
|
648 |
+
# Log vocabulary statistics
|
649 |
+
vocab_size = len(self.tokenizer.word_index)
|
650 |
+
logger.info(f"Vocabulary statistics:")
|
651 |
+
logger.info(f" Total unique tokens: {vocab_size}")
|
652 |
+
logger.info(f" Vocab limit: {self.config.vocab_size}")
|
653 |
+
|
654 |
+
# Log most common tokens
|
655 |
+
word_freq = sorted(
|
656 |
+
self.tokenizer.word_counts.items(),
|
657 |
+
key=lambda x: x[1],
|
658 |
+
reverse=True
|
659 |
+
)[:10]
|
660 |
+
logger.info("Most common tokens:")
|
661 |
+
for word, freq in word_freq:
|
662 |
+
logger.info(f" {word}: {freq}")
|
663 |
+
|
664 |
+
# Padding function from old version
|
665 |
+
def pad_sequences(texts: List[str]) -> tf.Tensor:
|
666 |
+
sequences = self.tokenizer.texts_to_sequences(texts)
|
667 |
+
return tf.keras.preprocessing.sequence.pad_sequences(
|
668 |
+
sequences,
|
669 |
+
maxlen=self.config.max_sequence_length,
|
670 |
+
padding='post',
|
671 |
+
truncating='post'
|
672 |
+
)
|
673 |
+
|
674 |
+
# Return padded sequences
|
675 |
+
return (
|
676 |
+
pad_sequences(queries),
|
677 |
+
pad_sequences(positives),
|
678 |
+
pad_sequences(negatives)
|
679 |
+
)
|
680 |
+
|
681 |
+
def prepare_dataset(
|
682 |
+
self,
|
683 |
+
dialogues: List[dict],
|
684 |
+
neg_samples_per_pos: int = 3,
|
685 |
+
debug_samples: Optional[int] = None
|
686 |
+
) -> Tuple[tf.Tensor, tf.Tensor, tf.Tensor]:
|
687 |
+
"""Prepare dataset with enhanced logging and statistics."""
|
688 |
+
logger.info("Preparing dataset...")
|
689 |
+
|
690 |
+
# Log dataset statistics
|
691 |
+
total_dialogues = len(dialogues)
|
692 |
+
total_turns = sum(len(d['turns']) for d in dialogues)
|
693 |
+
avg_turns = total_turns / total_dialogues
|
694 |
+
|
695 |
+
logger.info(f"Dataset statistics:")
|
696 |
+
logger.info(f" Total dialogues: {total_dialogues}")
|
697 |
+
logger.info(f" Total turns: {total_turns}")
|
698 |
+
logger.info(f" Average turns per dialogue: {avg_turns:.2f}")
|
699 |
+
|
700 |
+
# Extract and filter responses with logging
|
701 |
+
response_pool = []
|
702 |
+
skipped_short = 0
|
703 |
+
skipped_long = 0
|
704 |
+
|
705 |
+
for d in dialogues:
|
706 |
+
for turn in d['turns']:
|
707 |
+
if turn['speaker'] == 'assistant':
|
708 |
+
text = turn['text'].strip()
|
709 |
+
length = len(text.split())
|
710 |
+
if length < self.config.min_text_length:
|
711 |
+
skipped_short += 1
|
712 |
+
continue
|
713 |
+
if length > self.config.max_sequence_length:
|
714 |
+
skipped_long += 1
|
715 |
+
continue
|
716 |
+
response_pool.append(text)
|
717 |
+
|
718 |
+
logger.info(f"Response pool statistics:")
|
719 |
+
logger.info(f" Total responses: {len(response_pool)}")
|
720 |
+
logger.info(f" Skipped (too short): {skipped_short}")
|
721 |
+
logger.info(f" Skipped (too long): {skipped_long}")
|
722 |
+
|
723 |
+
# Process dialogues and create training examples
|
724 |
+
queries, positives, negatives = [], [], []
|
725 |
+
|
726 |
+
for dialogue in tqdm(dialogues, desc="Processing dialogues"):
|
727 |
+
turns = dialogue['turns']
|
728 |
+
for i in range(len(turns) - 1):
|
729 |
+
if turns[i]['speaker'] == 'user' and turns[i+1]['speaker'] == 'assistant':
|
730 |
+
query = turns[i]['text'].strip()
|
731 |
+
positive = turns[i+1]['text'].strip()
|
732 |
+
|
733 |
+
# Skip short texts
|
734 |
+
if (len(query.split()) < self.config.min_text_length or
|
735 |
+
len(positive.split()) < self.config.min_text_length): # Fixed
|
736 |
+
continue
|
737 |
+
|
738 |
+
# Get negative samples
|
739 |
+
neg_samples = self._smart_negative_sampling(
|
740 |
+
positive,
|
741 |
+
response_pool,
|
742 |
+
neg_samples_per_pos
|
743 |
+
)
|
744 |
+
|
745 |
+
if len(neg_samples) == neg_samples_per_pos:
|
746 |
+
for neg in neg_samples:
|
747 |
+
queries.append(query)
|
748 |
+
positives.append(positive)
|
749 |
+
negatives.append(neg)
|
750 |
+
|
751 |
+
# Log final dataset statistics
|
752 |
+
logger.info(f"Final dataset statistics:")
|
753 |
+
logger.info(f" Training examples: {len(queries)}")
|
754 |
+
logger.info(f" Unique queries: {len(set(queries))}")
|
755 |
+
logger.info(f" Unique responses: {len(set(positives))}")
|
756 |
+
|
757 |
+
return self._prepare_sequences(queries, positives, negatives)
|
758 |
+
|
759 |
+
def encode_query(self, query: str, context: Optional[List[Tuple[str, str]]] = None) -> tf.Tensor:
|
760 |
+
"""Encode a query with optional conversation context."""
|
761 |
+
# Prepare query with context
|
762 |
+
if context:
|
763 |
+
context_str = ' '.join([
|
764 |
+
f"{self.special_tokens['user']} {q} "
|
765 |
+
f"{self.special_tokens['assistant']} {r}"
|
766 |
+
for q, r in context[-self.config.max_context_turns:]
|
767 |
+
])
|
768 |
+
query = f"{context_str} {self.special_tokens['user']} {query}"
|
769 |
+
else:
|
770 |
+
query = f"{self.special_tokens['user']} {query}"
|
771 |
+
|
772 |
+
# Tokenize and pad
|
773 |
+
seq = self.tokenizer.texts_to_sequences([query])
|
774 |
+
padded_seq = tf.keras.preprocessing.sequence.pad_sequences(
|
775 |
+
seq,
|
776 |
+
maxlen=self.config.max_sequence_length,
|
777 |
+
padding='post',
|
778 |
+
truncating='post'
|
779 |
+
)
|
780 |
+
|
781 |
+
return self.query_encoder(padded_seq, training=False)
|
782 |
+
|
783 |
+
def encode_responses(self, responses: List[str]) -> tf.Tensor:
|
784 |
+
"""Encode a batch of responses."""
|
785 |
+
# Prepare responses
|
786 |
+
responses = [
|
787 |
+
f"{self.special_tokens['assistant']} {r}"
|
788 |
+
for r in responses
|
789 |
+
]
|
790 |
+
|
791 |
+
# Tokenize and pad
|
792 |
+
sequences = self.tokenizer.texts_to_sequences(responses)
|
793 |
+
padded_sequences = tf.keras.preprocessing.sequence.pad_sequences(
|
794 |
+
sequences,
|
795 |
+
maxlen=self.config.max_sequence_length,
|
796 |
+
padding='post',
|
797 |
+
truncating='post'
|
798 |
+
)
|
799 |
+
|
800 |
+
return self.response_encoder(padded_sequences, training=False)
|
801 |
+
|
802 |
+
def retrieve_responses(
|
803 |
+
self,
|
804 |
+
query: str,
|
805 |
+
candidates: List[str],
|
806 |
+
context: Optional[List[Tuple[str, str]]] = None,
|
807 |
+
top_k: int = 5
|
808 |
+
) -> List[Tuple[str, float]]:
|
809 |
+
"""Retrieve top-k responses for a query."""
|
810 |
+
# Encode query and candidates
|
811 |
+
q_emb = self.encode_query(query, context)
|
812 |
+
c_emb = self.encode_responses(candidates)
|
813 |
+
|
814 |
+
# Calculate similarities
|
815 |
+
similarities = tf.matmul(q_emb, c_emb, transpose_b=True).numpy()[0]
|
816 |
+
|
817 |
+
# Get top-k responses
|
818 |
+
top_indices = np.argsort(similarities)[::-1][:top_k]
|
819 |
+
|
820 |
+
return [(candidates[i], similarities[i]) for i in top_indices]
|
821 |
+
|
822 |
+
def chat(
|
823 |
+
self,
|
824 |
+
query: str,
|
825 |
+
response_pool: List[str],
|
826 |
+
conversation_history: Optional[List[Tuple[str, str]]] = None,
|
827 |
+
top_k: int = 5
|
828 |
+
) -> Tuple[str, List[Tuple[str, float]]]:
|
829 |
+
"""Interactive chat with response selection."""
|
830 |
+
# Get responses with scores
|
831 |
+
responses = self.retrieve_responses(
|
832 |
+
query,
|
833 |
+
response_pool,
|
834 |
+
conversation_history,
|
835 |
+
top_k
|
836 |
+
)
|
837 |
+
|
838 |
+
# Return best response and all candidates with scores
|
839 |
+
return responses[0][0], responses
|
chatbot3.py
ADDED
@@ -0,0 +1,824 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from transformers import TFAutoModel, AutoTokenizer
|
2 |
+
import tensorflow as tf
|
3 |
+
import numpy as np
|
4 |
+
from typing import List, Tuple, Dict, Optional, Union
|
5 |
+
from dataclasses import dataclass
|
6 |
+
import logging
|
7 |
+
import spacy
|
8 |
+
import random
|
9 |
+
import json
|
10 |
+
from tqdm import tqdm
|
11 |
+
from pathlib import Path
|
12 |
+
|
13 |
+
# Configure logging
|
14 |
+
logging.basicConfig(
|
15 |
+
level=logging.INFO,
|
16 |
+
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
|
17 |
+
)
|
18 |
+
logger = logging.getLogger(__name__)
|
19 |
+
|
20 |
+
@dataclass
|
21 |
+
class ChatbotConfig:
|
22 |
+
"""Enhanced configuration with pretrained model settings."""
|
23 |
+
vocab_size: int = 10000
|
24 |
+
max_sequence_length: int = 512
|
25 |
+
embedding_dim: int = 768 # Match DistilBERT's dimension
|
26 |
+
encoder_units: int = 256
|
27 |
+
num_attention_heads: int = 8
|
28 |
+
dropout_rate: float = 0.2
|
29 |
+
l2_reg_weight: float = 0.001
|
30 |
+
margin: float = 0.3
|
31 |
+
learning_rate: float = 0.001
|
32 |
+
min_text_length: int = 3
|
33 |
+
max_context_turns: int = 5
|
34 |
+
warmup_steps: int = 200
|
35 |
+
pretrained_model: str = 'distilbert-base-uncased'
|
36 |
+
freeze_embeddings: bool = True
|
37 |
+
spacy_model: str = 'en_core_web_md'
|
38 |
+
|
39 |
+
def to_dict(self) -> dict:
|
40 |
+
"""Convert config to dictionary."""
|
41 |
+
return {k: str(v) if isinstance(v, Path) else v
|
42 |
+
for k, v in self.__dict__.items()}
|
43 |
+
|
44 |
+
@classmethod
|
45 |
+
def from_dict(cls, config_dict: dict) -> 'ChatbotConfig':
|
46 |
+
"""Create config from dictionary."""
|
47 |
+
return cls(**{k: v for k, v in config_dict.items()
|
48 |
+
if k in cls.__dataclass_fields__})
|
49 |
+
|
50 |
+
class TransformerBlock(tf.keras.layers.Layer):
|
51 |
+
"""Custom Transformer block with pre-layer normalization."""
|
52 |
+
def __init__(
|
53 |
+
self,
|
54 |
+
embed_dim: int,
|
55 |
+
num_heads: int,
|
56 |
+
ff_dim: int,
|
57 |
+
dropout: float = 0.1,
|
58 |
+
**kwargs
|
59 |
+
):
|
60 |
+
super().__init__(**kwargs)
|
61 |
+
self.embed_dim = embed_dim
|
62 |
+
self.num_heads = num_heads
|
63 |
+
self.ff_dim = ff_dim
|
64 |
+
self.dropout = dropout
|
65 |
+
|
66 |
+
self.attention = tf.keras.layers.MultiHeadAttention(
|
67 |
+
num_heads=num_heads,
|
68 |
+
key_dim=embed_dim // num_heads,
|
69 |
+
dropout=dropout
|
70 |
+
)
|
71 |
+
self.ffn = tf.keras.Sequential([
|
72 |
+
tf.keras.layers.Dense(ff_dim, activation="gelu"),
|
73 |
+
tf.keras.layers.Dense(embed_dim),
|
74 |
+
])
|
75 |
+
|
76 |
+
self.layernorm1 = tf.keras.layers.LayerNormalization(epsilon=1e-6)
|
77 |
+
self.layernorm2 = tf.keras.layers.LayerNormalization(epsilon=1e-6)
|
78 |
+
self.dropout1 = tf.keras.layers.Dropout(dropout)
|
79 |
+
self.dropout2 = tf.keras.layers.Dropout(dropout)
|
80 |
+
|
81 |
+
def call(self, inputs: tf.Tensor, training: bool, mask: Optional[tf.Tensor] = None) -> tf.Tensor:
|
82 |
+
# Pre-layer normalization
|
83 |
+
norm_inputs = self.layernorm1(inputs)
|
84 |
+
|
85 |
+
# Self-attention
|
86 |
+
attention_output = self.attention(
|
87 |
+
query=norm_inputs,
|
88 |
+
value=norm_inputs,
|
89 |
+
key=norm_inputs,
|
90 |
+
attention_mask=mask,
|
91 |
+
training=training
|
92 |
+
)
|
93 |
+
attention_output = self.dropout1(attention_output, training=training)
|
94 |
+
attention_output = inputs + attention_output
|
95 |
+
|
96 |
+
# Feed-forward network
|
97 |
+
norm_attention = self.layernorm2(attention_output)
|
98 |
+
ffn_output = self.ffn(norm_attention)
|
99 |
+
ffn_output = self.dropout2(ffn_output, training=training)
|
100 |
+
|
101 |
+
return attention_output + ffn_output
|
102 |
+
|
103 |
+
def get_config(self) -> dict:
|
104 |
+
config = super().get_config()
|
105 |
+
config.update({
|
106 |
+
"embed_dim": self.embed_dim,
|
107 |
+
"num_heads": self.num_heads,
|
108 |
+
"ff_dim": self.ff_dim,
|
109 |
+
"dropout": self.dropout,
|
110 |
+
})
|
111 |
+
return config
|
112 |
+
|
113 |
+
class EncoderModel(tf.keras.Model):
|
114 |
+
"""Dual encoder model with pretrained embeddings."""
|
115 |
+
def __init__(
|
116 |
+
self,
|
117 |
+
config: ChatbotConfig,
|
118 |
+
name: str = "encoder",
|
119 |
+
shared_weights: bool = False,
|
120 |
+
**kwargs
|
121 |
+
):
|
122 |
+
super().__init__(name=name, **kwargs)
|
123 |
+
self.config = config
|
124 |
+
self.shared_weights = shared_weights
|
125 |
+
|
126 |
+
# Load pretrained model and tokenizer
|
127 |
+
self.pretrained = TFAutoModel.from_pretrained(config.pretrained_model)
|
128 |
+
|
129 |
+
# Freeze pretrained weights if specified
|
130 |
+
if config.freeze_embeddings:
|
131 |
+
self.pretrained.trainable = False
|
132 |
+
|
133 |
+
# Transformer blocks for additional processing
|
134 |
+
self.transformer_blocks = [
|
135 |
+
TransformerBlock(
|
136 |
+
config.embedding_dim,
|
137 |
+
config.num_attention_heads,
|
138 |
+
config.encoder_units * 4,
|
139 |
+
config.dropout_rate,
|
140 |
+
name=f"{name}_transformer_{i}"
|
141 |
+
) for i in range(2) # Reduced number of blocks since we're using pretrained
|
142 |
+
]
|
143 |
+
|
144 |
+
# Final LSTM layer
|
145 |
+
self.final_lstm = tf.keras.layers.LSTM(
|
146 |
+
config.encoder_units,
|
147 |
+
kernel_regularizer=tf.keras.regularizers.l2(config.l2_reg_weight),
|
148 |
+
name=f"{name}_final_lstm"
|
149 |
+
)
|
150 |
+
|
151 |
+
self.dropout = tf.keras.layers.Dropout(config.dropout_rate)
|
152 |
+
self.normalize = tf.keras.layers.Lambda(
|
153 |
+
lambda x: tf.nn.l2_normalize(x, axis=1)
|
154 |
+
)
|
155 |
+
|
156 |
+
def call(self, inputs: tf.Tensor, training: bool = False) -> tf.Tensor:
|
157 |
+
# Get pretrained embeddings
|
158 |
+
pretrained_outputs = self.pretrained(inputs, training=training)
|
159 |
+
x = pretrained_outputs.last_hidden_state
|
160 |
+
|
161 |
+
# Get attention mask from input
|
162 |
+
attention_mask = tf.cast(tf.not_equal(inputs, 0), tf.float32)
|
163 |
+
attention_mask = attention_mask[:, tf.newaxis, tf.newaxis, :]
|
164 |
+
|
165 |
+
# Apply transformer blocks
|
166 |
+
for transformer_block in self.transformer_blocks:
|
167 |
+
x = transformer_block(x, training=training, mask=attention_mask)
|
168 |
+
|
169 |
+
# Final processing
|
170 |
+
x = self.final_lstm(x)
|
171 |
+
x = self.dropout(x, training=training)
|
172 |
+
return self.normalize(x)
|
173 |
+
|
174 |
+
class RetrievalChatbot:
|
175 |
+
"""Modified chatbot using pretrained embeddings with full functionality."""
|
176 |
+
def __init__(self, config: ChatbotConfig):
|
177 |
+
self.config = config
|
178 |
+
self.nlp = spacy.load(config.spacy_model)
|
179 |
+
|
180 |
+
# Use HuggingFace tokenizer instead of Keras
|
181 |
+
self.tokenizer = AutoTokenizer.from_pretrained(config.pretrained_model)
|
182 |
+
|
183 |
+
# Special tokens
|
184 |
+
self.special_tokens = {
|
185 |
+
"user": "<USER>",
|
186 |
+
"assistant": "<ASSISTANT>",
|
187 |
+
"context": "<CONTEXT>",
|
188 |
+
"sep": "<SEP>"
|
189 |
+
}
|
190 |
+
|
191 |
+
# Add special tokens to tokenizer
|
192 |
+
self.tokenizer.add_special_tokens(
|
193 |
+
{'additional_special_tokens': list(self.special_tokens.values())}
|
194 |
+
)
|
195 |
+
|
196 |
+
# Build models
|
197 |
+
self._build_models()
|
198 |
+
|
199 |
+
# Initialize training tracking
|
200 |
+
self.history = {
|
201 |
+
"train_loss": [],
|
202 |
+
"val_loss": [],
|
203 |
+
"train_metrics": {},
|
204 |
+
"val_metrics": {}
|
205 |
+
}
|
206 |
+
|
207 |
+
self.similarity_cache = {}
|
208 |
+
|
209 |
+
def _build_models(self):
|
210 |
+
"""Initialize the encoder models."""
|
211 |
+
# Query encoder
|
212 |
+
self.query_encoder = EncoderModel(
|
213 |
+
self.config,
|
214 |
+
name="query_encoder",
|
215 |
+
shared_weights=False
|
216 |
+
)
|
217 |
+
|
218 |
+
# Response encoder (can share weights with query encoder)
|
219 |
+
self.response_encoder = EncoderModel(
|
220 |
+
self.config,
|
221 |
+
name="response_encoder",
|
222 |
+
shared_weights=False
|
223 |
+
)
|
224 |
+
|
225 |
+
# Resize token embeddings to match the tokenizer's vocab size
|
226 |
+
new_vocab_size = len(self.tokenizer)
|
227 |
+
self.query_encoder.pretrained.resize_token_embeddings(new_vocab_size)
|
228 |
+
self.response_encoder.pretrained.resize_token_embeddings(new_vocab_size)
|
229 |
+
|
230 |
+
def save_models(self, save_dir: Union[str, Path]):
|
231 |
+
"""Save models and configuration."""
|
232 |
+
save_dir = Path(save_dir)
|
233 |
+
save_dir.mkdir(parents=True, exist_ok=True)
|
234 |
+
|
235 |
+
# Save config
|
236 |
+
with open(save_dir / "config.json", "w") as f:
|
237 |
+
json.dump(self.config.to_dict(), f, indent=2)
|
238 |
+
|
239 |
+
# Save models
|
240 |
+
self.query_encoder.pretrained.save_pretrained(save_dir / "query_encoder")
|
241 |
+
self.response_encoder.pretrained.save_pretrained(save_dir / "response_encoder")
|
242 |
+
|
243 |
+
# Save tokenizer
|
244 |
+
self.tokenizer.save_pretrained(save_dir / "tokenizer")
|
245 |
+
|
246 |
+
@classmethod
|
247 |
+
def load_models(cls, load_dir: Union[str, Path]) -> 'RetrievalChatbot':
|
248 |
+
"""Load saved models and configuration."""
|
249 |
+
load_dir = Path(load_dir)
|
250 |
+
|
251 |
+
# Load config
|
252 |
+
with open(load_dir / "config.json", "r") as f:
|
253 |
+
config = ChatbotConfig.from_dict(json.load(f))
|
254 |
+
|
255 |
+
# Initialize chatbot
|
256 |
+
chatbot = cls(config)
|
257 |
+
|
258 |
+
# Load models
|
259 |
+
chatbot.query_encoder.pretrained = TFAutoModel.from_pretrained(
|
260 |
+
load_dir / "query_encoder",
|
261 |
+
config=config
|
262 |
+
)
|
263 |
+
chatbot.response_encoder.pretrained = TFAutoModel.from_pretrained(
|
264 |
+
load_dir / "response_encoder",
|
265 |
+
config=config
|
266 |
+
)
|
267 |
+
|
268 |
+
# Load tokenizer
|
269 |
+
chatbot.tokenizer = AutoTokenizer.from_pretrained(load_dir / "tokenizer")
|
270 |
+
|
271 |
+
return chatbot
|
272 |
+
|
273 |
+
def _improved_spacy_similarity(self, text1: str, text2: str) -> float:
|
274 |
+
"""Calculate semantic similarity between texts with preprocessing."""
|
275 |
+
def preprocess(text: str) -> str:
|
276 |
+
# Basic cleaning
|
277 |
+
text = ' '.join(text.split())
|
278 |
+
return text if text.strip() else "empty_document"
|
279 |
+
|
280 |
+
# Get cache key
|
281 |
+
cache_key = f"{hash(text1)}_{hash(text2)}"
|
282 |
+
if cache_key in self.similarity_cache:
|
283 |
+
return self.similarity_cache[cache_key]
|
284 |
+
|
285 |
+
# Process texts
|
286 |
+
text1, text2 = preprocess(text1), preprocess(text2)
|
287 |
+
doc1, doc2 = self.nlp(text1), self.nlp(text2)
|
288 |
+
|
289 |
+
# Calculate similarity
|
290 |
+
if doc1.has_vector and doc2.has_vector:
|
291 |
+
sim = doc1.similarity(doc2)
|
292 |
+
else:
|
293 |
+
# Fallback to token overlap similarity
|
294 |
+
tokens1 = {t.lower_ for t in doc1 if not t.is_stop and not t.is_punct}
|
295 |
+
tokens2 = {t.lower_ for t in doc2 if not t.is_stop and not t.is_punct}
|
296 |
+
intersection = len(tokens1.intersection(tokens2))
|
297 |
+
union = len(tokens1.union(tokens2))
|
298 |
+
sim = intersection / union if union > 0 else 0.0
|
299 |
+
|
300 |
+
# Cache result
|
301 |
+
self.similarity_cache[cache_key] = sim
|
302 |
+
return sim
|
303 |
+
|
304 |
+
def prepare_dataset(
|
305 |
+
self,
|
306 |
+
dialogues: List[dict],
|
307 |
+
neg_samples_per_pos: int = 3,
|
308 |
+
debug_samples: Optional[int] = None
|
309 |
+
) -> Tuple[tf.Tensor, tf.Tensor, tf.Tensor]:
|
310 |
+
"""Prepare dataset with enhanced logging and statistics."""
|
311 |
+
logger.info("Preparing dataset...")
|
312 |
+
|
313 |
+
# Apply debug_samples limit if specified
|
314 |
+
if debug_samples is not None:
|
315 |
+
dialogues = dialogues[:debug_samples]
|
316 |
+
logger.info(f"Debug mode: Limited to {debug_samples} dialogues")
|
317 |
+
|
318 |
+
# Log dataset statistics
|
319 |
+
total_dialogues = len(dialogues)
|
320 |
+
total_turns = sum(len(d['turns']) for d in dialogues)
|
321 |
+
avg_turns = total_turns / total_dialogues if total_dialogues > 0 else 0
|
322 |
+
|
323 |
+
logger.info(f"Dataset statistics:")
|
324 |
+
logger.info(f" Total dialogues: {total_dialogues}")
|
325 |
+
logger.info(f" Total turns: {total_turns}")
|
326 |
+
logger.info(f" Average turns per dialogue: {avg_turns:.2f}")
|
327 |
+
|
328 |
+
# Extract and filter responses with logging
|
329 |
+
response_pool = []
|
330 |
+
skipped_short = 0
|
331 |
+
skipped_long = 0
|
332 |
+
|
333 |
+
for d in dialogues:
|
334 |
+
for turn in d['turns']:
|
335 |
+
if turn.get('speaker') == 'assistant' and 'text' in turn:
|
336 |
+
text = turn['text'].strip()
|
337 |
+
length = len(text.split())
|
338 |
+
if length < self.config.min_text_length:
|
339 |
+
skipped_short += 1
|
340 |
+
continue
|
341 |
+
if length > self.config.max_sequence_length:
|
342 |
+
skipped_long += 1
|
343 |
+
continue
|
344 |
+
response_pool.append(text)
|
345 |
+
|
346 |
+
logger.info(f"Response pool statistics:")
|
347 |
+
logger.info(f" Total responses: {len(response_pool)}")
|
348 |
+
logger.info(f" Skipped (too short): {skipped_short}")
|
349 |
+
logger.info(f" Skipped (too long): {skipped_long}")
|
350 |
+
|
351 |
+
# Process dialogues and create training examples
|
352 |
+
queries, positives, negatives = [], [], []
|
353 |
+
|
354 |
+
for dialogue in tqdm(dialogues, desc="Processing dialogues"):
|
355 |
+
turns = dialogue.get('turns', [])
|
356 |
+
for i in range(len(turns) - 1):
|
357 |
+
current_turn = turns[i]
|
358 |
+
next_turn = turns[i+1]
|
359 |
+
|
360 |
+
if (current_turn.get('speaker') == 'user' and
|
361 |
+
next_turn.get('speaker') == 'assistant' and
|
362 |
+
'text' in current_turn and
|
363 |
+
'text' in next_turn):
|
364 |
+
|
365 |
+
query = current_turn['text'].strip()
|
366 |
+
positive = next_turn['text'].strip()
|
367 |
+
|
368 |
+
# Skip short texts
|
369 |
+
if (len(query.split()) < self.config.min_text_length or
|
370 |
+
len(positive.split()) < self.config.min_text_length):
|
371 |
+
continue
|
372 |
+
|
373 |
+
# Get negative samples
|
374 |
+
neg_samples = self._smart_negative_sampling(
|
375 |
+
positive,
|
376 |
+
response_pool,
|
377 |
+
neg_samples_per_pos
|
378 |
+
)
|
379 |
+
|
380 |
+
if len(neg_samples) == neg_samples_per_pos:
|
381 |
+
for neg in neg_samples:
|
382 |
+
queries.append(query)
|
383 |
+
positives.append(positive)
|
384 |
+
negatives.append(neg)
|
385 |
+
else:
|
386 |
+
logger.warning(f"Insufficient negative samples for positive response: '{positive}'")
|
387 |
+
|
388 |
+
# Log final dataset statistics
|
389 |
+
logger.info(f"Final dataset statistics:")
|
390 |
+
logger.info(f" Training examples: {len(queries)}")
|
391 |
+
logger.info(f" Unique queries: {len(set(queries))}")
|
392 |
+
logger.info(f" Unique responses: {len(set(positives))}")
|
393 |
+
|
394 |
+
return self._prepare_sequences(queries, positives, negatives)
|
395 |
+
|
396 |
+
def _smart_negative_sampling(
|
397 |
+
self,
|
398 |
+
positive: str,
|
399 |
+
response_pool: List[str],
|
400 |
+
n_samples: int,
|
401 |
+
max_attempts: int = 200,
|
402 |
+
similarity_bounds: Tuple[float, float] = (0.2, 0.9),
|
403 |
+
batch_size: int = 10
|
404 |
+
) -> List[str]:
|
405 |
+
"""Smart negative sampling with similarity bounds and fallback strategies."""
|
406 |
+
candidates = []
|
407 |
+
seen = set()
|
408 |
+
attempts = 0
|
409 |
+
|
410 |
+
while len(candidates) < n_samples and attempts < max_attempts:
|
411 |
+
remaining = min(batch_size, len(response_pool) - len(seen), max_attempts - attempts)
|
412 |
+
if remaining <= 0:
|
413 |
+
break
|
414 |
+
batch = random.sample(
|
415 |
+
[r for r in response_pool if r not in seen and r != positive],
|
416 |
+
remaining
|
417 |
+
)
|
418 |
+
|
419 |
+
for candidate in batch:
|
420 |
+
seen.add(candidate)
|
421 |
+
sim = self._improved_spacy_similarity(candidate, positive)
|
422 |
+
|
423 |
+
if similarity_bounds[0] < sim < similarity_bounds[1]:
|
424 |
+
candidates.append(candidate)
|
425 |
+
if len(candidates) == n_samples:
|
426 |
+
break
|
427 |
+
|
428 |
+
attempts += len(batch)
|
429 |
+
|
430 |
+
if len(candidates) < n_samples:
|
431 |
+
logger.warning(f"Only found {len(candidates)} negative samples for positive response: '{positive}'")
|
432 |
+
# Fallback to random negatives without similarity constraints
|
433 |
+
fallback_needed = n_samples - len(candidates)
|
434 |
+
available_negatives = [r for r in response_pool if r != positive and r not in seen]
|
435 |
+
if available_negatives:
|
436 |
+
additional_negatives = random.sample(
|
437 |
+
available_negatives,
|
438 |
+
min(fallback_needed, len(available_negatives))
|
439 |
+
)
|
440 |
+
candidates.extend(additional_negatives)
|
441 |
+
|
442 |
+
return candidates
|
443 |
+
|
444 |
+
def _prepare_sequences(
|
445 |
+
self,
|
446 |
+
queries: List[str],
|
447 |
+
positives: List[str],
|
448 |
+
negatives: List[str]
|
449 |
+
) -> Tuple[tf.Tensor, tf.Tensor, tf.Tensor]:
|
450 |
+
"""Modified sequence preparation for pretrained tokenizer."""
|
451 |
+
logger.info("Preparing sequences...")
|
452 |
+
|
453 |
+
# Process texts with special tokens
|
454 |
+
queries = [f"{self.special_tokens['user']} {q}" for q in queries]
|
455 |
+
positives = [f"{self.special_tokens['assistant']} {p}" for p in positives]
|
456 |
+
negatives = [f"{self.special_tokens['assistant']} {n}" for n in negatives]
|
457 |
+
|
458 |
+
# Tokenize using HuggingFace tokenizer
|
459 |
+
def encode_batch(texts: List[str]) -> tf.Tensor:
|
460 |
+
# HuggingFace tokenizer returns TensorFlow tensors when return_tensors='tf'
|
461 |
+
encodings = self.tokenizer(
|
462 |
+
texts,
|
463 |
+
padding='max_length',
|
464 |
+
truncation=True,
|
465 |
+
max_length=self.config.max_sequence_length,
|
466 |
+
return_tensors='tf'
|
467 |
+
)
|
468 |
+
return encodings['input_ids']
|
469 |
+
|
470 |
+
# Encode all sequences
|
471 |
+
q_tensor = encode_batch(queries)
|
472 |
+
p_tensor = encode_batch(positives)
|
473 |
+
n_tensor = encode_batch(negatives)
|
474 |
+
|
475 |
+
# Log statistics about encoded sequences
|
476 |
+
logger.info("Sequence statistics:")
|
477 |
+
logger.info(f" Query sequence shape: {q_tensor.shape}")
|
478 |
+
logger.info(f" Positive response sequence shape: {p_tensor.shape}")
|
479 |
+
logger.info(f" Negative response sequence shape: {n_tensor.shape}")
|
480 |
+
|
481 |
+
return q_tensor, p_tensor, n_tensor
|
482 |
+
|
483 |
+
def train(
|
484 |
+
self,
|
485 |
+
q_pad: tf.Tensor,
|
486 |
+
p_pad: tf.Tensor,
|
487 |
+
n_pad: tf.Tensor,
|
488 |
+
epochs: int = 3,
|
489 |
+
batch_size: int = 32,
|
490 |
+
validation_split: float = 0.2,
|
491 |
+
checkpoint_dir: Optional[Union[str, Path]] = None
|
492 |
+
):
|
493 |
+
"""Train the model with improved training loop."""
|
494 |
+
# Setup training
|
495 |
+
total_samples = tf.shape(q_pad)[0]
|
496 |
+
train_size = int((1 - validation_split) * total_samples.numpy())
|
497 |
+
|
498 |
+
# Shuffle and split data
|
499 |
+
indices = tf.random.shuffle(tf.range(start=0, limit=total_samples, dtype=tf.int32))
|
500 |
+
train_idx = indices[:train_size]
|
501 |
+
val_idx = indices[train_size:]
|
502 |
+
|
503 |
+
# Split data using TF indexing
|
504 |
+
train_data = (
|
505 |
+
tf.gather(q_pad, train_idx),
|
506 |
+
tf.gather(p_pad, train_idx),
|
507 |
+
tf.gather(n_pad, train_idx)
|
508 |
+
)
|
509 |
+
val_data = (
|
510 |
+
tf.gather(q_pad, val_idx),
|
511 |
+
tf.gather(p_pad, val_idx),
|
512 |
+
tf.gather(n_pad, val_idx)
|
513 |
+
)
|
514 |
+
|
515 |
+
# Setup optimizer with learning rate schedule
|
516 |
+
steps_per_epoch = train_size // batch_size
|
517 |
+
total_steps = steps_per_epoch * epochs
|
518 |
+
|
519 |
+
lr_schedule = self._get_lr_schedule(
|
520 |
+
total_steps,
|
521 |
+
self.config.learning_rate,
|
522 |
+
self.config.warmup_steps
|
523 |
+
)
|
524 |
+
|
525 |
+
optimizer = tf.keras.optimizers.Adam(learning_rate=lr_schedule)
|
526 |
+
|
527 |
+
# Setup checkpointing
|
528 |
+
if checkpoint_dir:
|
529 |
+
checkpoint_dir = Path(checkpoint_dir)
|
530 |
+
checkpoint_dir.mkdir(parents=True, exist_ok=True)
|
531 |
+
|
532 |
+
# Setup checkpoint callback with correct file format
|
533 |
+
checkpoint_template = str(checkpoint_dir / "model_epoch_{epoch:04d}.weights.h5")
|
534 |
+
checkpoint_callback = tf.keras.callbacks.ModelCheckpoint(
|
535 |
+
checkpoint_template,
|
536 |
+
save_weights_only=True,
|
537 |
+
save_best_only=True,
|
538 |
+
monitor='val_loss',
|
539 |
+
mode='min',
|
540 |
+
verbose=1
|
541 |
+
)
|
542 |
+
|
543 |
+
# Training loop
|
544 |
+
best_val_loss = float('inf')
|
545 |
+
patience = 5
|
546 |
+
wait = 0
|
547 |
+
|
548 |
+
for epoch in range(epochs):
|
549 |
+
# Training
|
550 |
+
train_loss = self._train_epoch(
|
551 |
+
train_data,
|
552 |
+
optimizer,
|
553 |
+
batch_size,
|
554 |
+
training=True
|
555 |
+
)
|
556 |
+
|
557 |
+
# Validation
|
558 |
+
val_loss = self._train_epoch(
|
559 |
+
val_data,
|
560 |
+
optimizer,
|
561 |
+
batch_size,
|
562 |
+
training=False
|
563 |
+
)
|
564 |
+
|
565 |
+
# Update history
|
566 |
+
self.history['train_loss'].append(train_loss)
|
567 |
+
self.history['val_loss'].append(val_loss)
|
568 |
+
|
569 |
+
logger.info(
|
570 |
+
f"Epoch {epoch + 1}/{epochs} - "
|
571 |
+
f"train_loss: {train_loss:.4f} - "
|
572 |
+
f"val_loss: {val_loss:.4f}"
|
573 |
+
)
|
574 |
+
|
575 |
+
# Early stopping
|
576 |
+
if val_loss < best_val_loss:
|
577 |
+
best_val_loss = val_loss
|
578 |
+
wait = 0
|
579 |
+
if checkpoint_dir:
|
580 |
+
self.save_models(checkpoint_dir / f"best_model")
|
581 |
+
else:
|
582 |
+
wait += 1
|
583 |
+
if wait >= patience:
|
584 |
+
logger.info("Early stopping triggered")
|
585 |
+
break
|
586 |
+
|
587 |
+
def _train_epoch(
|
588 |
+
self,
|
589 |
+
data: Tuple[tf.Tensor, tf.Tensor, tf.Tensor],
|
590 |
+
optimizer: tf.keras.optimizers.Optimizer,
|
591 |
+
batch_size: int,
|
592 |
+
training: bool = True
|
593 |
+
) -> float:
|
594 |
+
"""Train for one epoch with enhanced logging and progress tracking."""
|
595 |
+
q_data, p_data, n_data = data
|
596 |
+
total_loss = 0.0
|
597 |
+
num_batches = tf.shape(q_data)[0] // batch_size
|
598 |
+
|
599 |
+
# Log current learning rate at start of epoch
|
600 |
+
if training:
|
601 |
+
if hasattr(optimizer.learning_rate, '__call__'):
|
602 |
+
current_lr = optimizer.learning_rate(optimizer.iterations)
|
603 |
+
else:
|
604 |
+
current_lr = optimizer.learning_rate
|
605 |
+
logger.info(f"Current learning rate: {float(current_lr):.6f}")
|
606 |
+
|
607 |
+
# Create progress bar
|
608 |
+
mode = "Training" if training else "Validation"
|
609 |
+
pbar = tqdm(
|
610 |
+
total=num_batches.numpy(),
|
611 |
+
desc=f"{mode} batches",
|
612 |
+
unit="batch",
|
613 |
+
dynamic_ncols=True
|
614 |
+
)
|
615 |
+
|
616 |
+
# Process batches
|
617 |
+
for i in range(num_batches):
|
618 |
+
start_idx = i * batch_size
|
619 |
+
end_idx = start_idx + batch_size
|
620 |
+
|
621 |
+
batch_loss = self._train_step(
|
622 |
+
q_data[start_idx:end_idx],
|
623 |
+
p_data[start_idx:end_idx],
|
624 |
+
n_data[start_idx:end_idx],
|
625 |
+
optimizer,
|
626 |
+
training
|
627 |
+
)
|
628 |
+
total_loss += batch_loss.numpy()
|
629 |
+
|
630 |
+
# Update progress bar with current loss
|
631 |
+
avg_loss = total_loss / (i + 1)
|
632 |
+
pbar.set_postfix({
|
633 |
+
'loss': f'{avg_loss:.4f}',
|
634 |
+
'lr': f'{float(current_lr):.6f}' if training else 'N/A'
|
635 |
+
})
|
636 |
+
pbar.update(1)
|
637 |
+
|
638 |
+
pbar.close()
|
639 |
+
return total_loss / num_batches.numpy() if num_batches > 0 else 0.0
|
640 |
+
|
641 |
+
@tf.function
|
642 |
+
def _train_step(
|
643 |
+
self,
|
644 |
+
q_batch: tf.Tensor,
|
645 |
+
p_batch: tf.Tensor,
|
646 |
+
n_batch: tf.Tensor,
|
647 |
+
optimizer: tf.keras.optimizers.Optimizer,
|
648 |
+
training: bool = True
|
649 |
+
) -> tf.Tensor:
|
650 |
+
"""Single training step with triplet loss."""
|
651 |
+
with tf.GradientTape() as tape:
|
652 |
+
# Get embeddings
|
653 |
+
q_emb = self.query_encoder(q_batch, training=training)
|
654 |
+
p_emb = self.response_encoder(p_batch, training=training)
|
655 |
+
n_emb = self.response_encoder(n_batch, training=training)
|
656 |
+
|
657 |
+
# Calculate triplet loss
|
658 |
+
pos_dist = tf.reduce_sum(tf.square(q_emb - p_emb), axis=1)
|
659 |
+
neg_dist = tf.reduce_sum(tf.square(q_emb - n_emb), axis=1)
|
660 |
+
|
661 |
+
loss = tf.maximum(0.0, self.config.margin + pos_dist - neg_dist)
|
662 |
+
loss = tf.reduce_mean(loss)
|
663 |
+
|
664 |
+
if training:
|
665 |
+
# Apply gradients
|
666 |
+
gradients = tape.gradient(
|
667 |
+
loss,
|
668 |
+
self.query_encoder.trainable_variables +
|
669 |
+
self.response_encoder.trainable_variables
|
670 |
+
)
|
671 |
+
optimizer.apply_gradients(zip(
|
672 |
+
gradients,
|
673 |
+
self.query_encoder.trainable_variables +
|
674 |
+
self.response_encoder.trainable_variables
|
675 |
+
))
|
676 |
+
|
677 |
+
return loss
|
678 |
+
|
679 |
+
def _get_lr_schedule(
|
680 |
+
self,
|
681 |
+
total_steps: int,
|
682 |
+
peak_lr: float,
|
683 |
+
warmup_steps: int
|
684 |
+
) -> tf.keras.optimizers.schedules.LearningRateSchedule:
|
685 |
+
"""Enhanced learning rate schedule with better error handling and logging."""
|
686 |
+
class CustomSchedule(tf.keras.optimizers.schedules.LearningRateSchedule):
|
687 |
+
def __init__(
|
688 |
+
self,
|
689 |
+
total_steps: int,
|
690 |
+
peak_lr: float,
|
691 |
+
warmup_steps: int
|
692 |
+
):
|
693 |
+
super().__init__()
|
694 |
+
self.total_steps = tf.cast(total_steps, tf.float32)
|
695 |
+
self.peak_lr = tf.cast(peak_lr, tf.float32)
|
696 |
+
self.warmup_steps = tf.cast(max(1, warmup_steps), tf.float32) # Prevent 0
|
697 |
+
|
698 |
+
# Calculate and store constants
|
699 |
+
self.initial_lr = self.peak_lr * 0.1 # Start at 10% of peak
|
700 |
+
self.min_lr = self.peak_lr * 0.01 # Minimum 1% of peak
|
701 |
+
|
702 |
+
logger.info(f"Learning rate schedule initialized:")
|
703 |
+
logger.info(f" Initial LR: {float(self.initial_lr):.6f}")
|
704 |
+
logger.info(f" Peak LR: {float(self.peak_lr):.6f}")
|
705 |
+
logger.info(f" Min LR: {float(self.min_lr):.6f}")
|
706 |
+
logger.info(f" Warmup steps: {int(self.warmup_steps)}")
|
707 |
+
logger.info(f" Total steps: {int(self.total_steps)}")
|
708 |
+
|
709 |
+
def __call__(self, step):
|
710 |
+
step = tf.cast(step, tf.float32)
|
711 |
+
|
712 |
+
# Warmup phase
|
713 |
+
warmup_factor = tf.minimum(1.0, step / self.warmup_steps)
|
714 |
+
warmup_lr = self.initial_lr + (self.peak_lr - self.initial_lr) * warmup_factor
|
715 |
+
|
716 |
+
# Decay phase
|
717 |
+
decay_steps = tf.maximum(1.0, self.total_steps - self.warmup_steps)
|
718 |
+
decay_factor = (step - self.warmup_steps) / decay_steps
|
719 |
+
decay_factor = tf.minimum(tf.maximum(0.0, decay_factor), 1.0) # Clip to [0,1]
|
720 |
+
|
721 |
+
cosine_decay = 0.5 * (1.0 + tf.cos(np.pi * decay_factor))
|
722 |
+
decay_lr = self.min_lr + (self.peak_lr - self.min_lr) * cosine_decay
|
723 |
+
|
724 |
+
# Choose between warmup and decay
|
725 |
+
final_lr = tf.where(step < self.warmup_steps, warmup_lr, decay_lr)
|
726 |
+
|
727 |
+
# Ensure learning rate is valid
|
728 |
+
final_lr = tf.maximum(self.min_lr, final_lr)
|
729 |
+
final_lr = tf.where(tf.math.is_finite(final_lr), final_lr, self.min_lr)
|
730 |
+
|
731 |
+
return final_lr
|
732 |
+
|
733 |
+
def get_config(self):
|
734 |
+
return {
|
735 |
+
"total_steps": self.total_steps,
|
736 |
+
"peak_lr": self.peak_lr,
|
737 |
+
"warmup_steps": self.warmup_steps,
|
738 |
+
}
|
739 |
+
|
740 |
+
return CustomSchedule(total_steps, peak_lr, warmup_steps)
|
741 |
+
|
742 |
+
def encode_query(self, query: str, context: Optional[List[Tuple[str, str]]] = None) -> tf.Tensor:
|
743 |
+
"""Encode a query with optional conversation context."""
|
744 |
+
# Prepare query with context
|
745 |
+
if context:
|
746 |
+
context_str = ' '.join([
|
747 |
+
f"{self.special_tokens['user']} {q} "
|
748 |
+
f"{self.special_tokens['assistant']} {r}"
|
749 |
+
for q, r in context[-self.config.max_context_turns:]
|
750 |
+
])
|
751 |
+
query = f"{context_str} {self.special_tokens['user']} {query}"
|
752 |
+
else:
|
753 |
+
query = f"{self.special_tokens['user']} {query}"
|
754 |
+
|
755 |
+
# Tokenize and pad using TensorFlow tensors
|
756 |
+
encodings = self.tokenizer(
|
757 |
+
[query],
|
758 |
+
padding='max_length',
|
759 |
+
truncation=True,
|
760 |
+
max_length=self.config.max_sequence_length,
|
761 |
+
return_tensors='tf'
|
762 |
+
)
|
763 |
+
input_ids = encodings['input_ids']
|
764 |
+
|
765 |
+
return self.query_encoder(input_ids, training=False)
|
766 |
+
|
767 |
+
def encode_responses(self, responses: List[str]) -> tf.Tensor:
|
768 |
+
"""Encode a batch of responses."""
|
769 |
+
# Prepare responses
|
770 |
+
responses = [
|
771 |
+
f"{self.special_tokens['assistant']} {r}"
|
772 |
+
for r in responses
|
773 |
+
]
|
774 |
+
|
775 |
+
# Tokenize and pad using TensorFlow tensors
|
776 |
+
encodings = self.tokenizer(
|
777 |
+
responses,
|
778 |
+
padding='max_length',
|
779 |
+
truncation=True,
|
780 |
+
max_length=self.config.max_sequence_length,
|
781 |
+
return_tensors='tf'
|
782 |
+
)
|
783 |
+
input_ids = encodings['input_ids']
|
784 |
+
|
785 |
+
return self.response_encoder(input_ids, training=False)
|
786 |
+
|
787 |
+
def retrieve_responses(
|
788 |
+
self,
|
789 |
+
query: str,
|
790 |
+
candidates: List[str],
|
791 |
+
context: Optional[List[Tuple[str, str]]] = None,
|
792 |
+
top_k: int = 5
|
793 |
+
) -> List[Tuple[str, float]]:
|
794 |
+
"""Retrieve top-k responses for a query."""
|
795 |
+
# Encode query and candidates
|
796 |
+
q_emb = self.encode_query(query, context)
|
797 |
+
c_emb = self.encode_responses(candidates)
|
798 |
+
|
799 |
+
# Calculate similarities
|
800 |
+
similarities = tf.matmul(q_emb, c_emb, transpose_b=True).numpy()[0]
|
801 |
+
|
802 |
+
# Get top-k responses
|
803 |
+
top_indices = np.argsort(similarities)[::-1][:top_k]
|
804 |
+
|
805 |
+
return [(candidates[i], similarities[i]) for i in top_indices]
|
806 |
+
|
807 |
+
def chat(
|
808 |
+
self,
|
809 |
+
query: str,
|
810 |
+
response_pool: List[str],
|
811 |
+
conversation_history: Optional[List[Tuple[str, str]]] = None,
|
812 |
+
top_k: int = 5
|
813 |
+
) -> Tuple[str, List[Tuple[str, float]]]:
|
814 |
+
"""Interactive chat with response selection."""
|
815 |
+
# Get responses with scores
|
816 |
+
responses = self.retrieve_responses(
|
817 |
+
query,
|
818 |
+
response_pool,
|
819 |
+
conversation_history,
|
820 |
+
top_k
|
821 |
+
)
|
822 |
+
|
823 |
+
# Return best response and all candidates with scores
|
824 |
+
return responses[0][0], responses
|
chatbot4.py
ADDED
@@ -0,0 +1,1291 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from transformers import TFAutoModel, AutoTokenizer
|
2 |
+
import tensorflow as tf
|
3 |
+
import numpy as np
|
4 |
+
from typing import List, Tuple, Dict, Optional, Union, Any
|
5 |
+
from dataclasses import dataclass
|
6 |
+
import logging
|
7 |
+
import json
|
8 |
+
from tqdm import tqdm
|
9 |
+
from pathlib import Path
|
10 |
+
import faiss
|
11 |
+
from response_quality_checker import ResponseQualityChecker
|
12 |
+
|
13 |
+
policy = tf.keras.mixed_precision.Policy('mixed_float16')
|
14 |
+
tf.keras.mixed_precision.set_global_policy(policy)
|
15 |
+
|
16 |
+
# Configure logging
|
17 |
+
logging.basicConfig(
|
18 |
+
level=logging.DEBUG,
|
19 |
+
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
|
20 |
+
)
|
21 |
+
logger = logging.getLogger(__name__)
|
22 |
+
|
23 |
+
@dataclass
|
24 |
+
class ChatbotConfig:
|
25 |
+
"""Configuration for the RetrievalChatbot."""
|
26 |
+
vocab_size: int = 30526 # DistilBERT vocab size
|
27 |
+
max_sequence_length: int = 512
|
28 |
+
embedding_dim: int = 768 # Match DistilBERT's dimension
|
29 |
+
encoder_units: int = 256
|
30 |
+
num_attention_heads: int = 8
|
31 |
+
dropout_rate: float = 0.2
|
32 |
+
l2_reg_weight: float = 0.001
|
33 |
+
margin: float = 0.3
|
34 |
+
learning_rate: float = 0.001
|
35 |
+
min_text_length: int = 3
|
36 |
+
max_context_turns: int = 5
|
37 |
+
warmup_steps: int = 200
|
38 |
+
pretrained_model: str = 'distilbert-base-uncased'
|
39 |
+
dtype: str = 'float32'
|
40 |
+
freeze_embeddings: bool = False
|
41 |
+
# Additional configurations can be added here
|
42 |
+
|
43 |
+
def to_dict(self) -> dict:
|
44 |
+
"""Convert config to dictionary."""
|
45 |
+
return {k: str(v) if isinstance(v, Path) else v
|
46 |
+
for k, v in self.__dict__.items()}
|
47 |
+
|
48 |
+
@classmethod
|
49 |
+
def from_dict(cls, config_dict: dict) -> 'ChatbotConfig':
|
50 |
+
"""Create config from dictionary."""
|
51 |
+
return cls(**{k: v for k, v in config_dict.items()
|
52 |
+
if k in cls.__dataclass_fields__})
|
53 |
+
|
54 |
+
class EncoderModel(tf.keras.Model):
|
55 |
+
"""Dual encoder model with pretrained embeddings."""
|
56 |
+
def __init__(
|
57 |
+
self,
|
58 |
+
config: ChatbotConfig,
|
59 |
+
name: str = "encoder",
|
60 |
+
shared_weights: bool = False,
|
61 |
+
**kwargs
|
62 |
+
):
|
63 |
+
super().__init__(name=name, **kwargs)
|
64 |
+
self.config = config
|
65 |
+
self.shared_weights = shared_weights
|
66 |
+
|
67 |
+
# Load pretrained model
|
68 |
+
self.pretrained = TFAutoModel.from_pretrained(config.pretrained_model)
|
69 |
+
|
70 |
+
# Freeze pretrained weights if specified
|
71 |
+
self.pretrained.distilbert.embeddings.trainable = False
|
72 |
+
for i, layer_module in enumerate(self.pretrained.distilbert.transformer.layer):
|
73 |
+
if i < 3: # freeze first 2 layers
|
74 |
+
layer_module.trainable = False
|
75 |
+
else:
|
76 |
+
layer_module.trainable = True
|
77 |
+
|
78 |
+
# Pooling layer (Global Average Pooling)
|
79 |
+
self.pooler = tf.keras.layers.GlobalAveragePooling1D()
|
80 |
+
|
81 |
+
# Dropout and normalization
|
82 |
+
self.dropout = tf.keras.layers.Dropout(config.dropout_rate)
|
83 |
+
self.normalize = tf.keras.layers.Lambda(
|
84 |
+
lambda x: tf.nn.l2_normalize(x, axis=1)
|
85 |
+
)
|
86 |
+
|
87 |
+
def call(self, inputs: tf.Tensor, training: bool = False) -> tf.Tensor:
|
88 |
+
"""Forward pass."""
|
89 |
+
# Get pretrained embeddings
|
90 |
+
pretrained_outputs = self.pretrained(inputs, training=training)
|
91 |
+
x = pretrained_outputs.last_hidden_state # Shape: [batch_size, seq_len, embedding_dim]
|
92 |
+
|
93 |
+
# Apply pooling
|
94 |
+
x = self.pooler(x) # Shape: [batch_size, embedding_dim]
|
95 |
+
|
96 |
+
# Apply dropout
|
97 |
+
x = self.dropout(x, training=training)
|
98 |
+
|
99 |
+
# L2 normalization
|
100 |
+
x = self.normalize(x) # Shape: [batch_size, embedding_dim]
|
101 |
+
|
102 |
+
return x
|
103 |
+
|
104 |
+
def get_config(self) -> dict:
|
105 |
+
"""Return the config of the model."""
|
106 |
+
config = super().get_config()
|
107 |
+
config.update({
|
108 |
+
"config": self.config.to_dict(),
|
109 |
+
"shared_weights": self.shared_weights,
|
110 |
+
"name": self.name
|
111 |
+
})
|
112 |
+
return config
|
113 |
+
|
114 |
+
# class CustomLRSchedule(tf.keras.optimizers.schedules.LearningRateSchedule):
|
115 |
+
# def __init__(self, initial_lr, peak_lr, min_lr, warmup_steps, total_steps):
|
116 |
+
# super().__init__()
|
117 |
+
# self.initial_lr = initial_lr
|
118 |
+
# self.peak_lr = peak_lr
|
119 |
+
# self.min_lr = min_lr
|
120 |
+
# self.warmup_steps = min(warmup_steps, total_steps // 2) # Ensure warmup_steps <= total_steps
|
121 |
+
# self.total_steps = total_steps
|
122 |
+
|
123 |
+
# def __call__(self, step):
|
124 |
+
# if step < self.warmup_steps:
|
125 |
+
# # Linear warmup
|
126 |
+
# lr = self.initial_lr + (self.peak_lr - self.initial_lr) * (step / self.warmup_steps)
|
127 |
+
# else:
|
128 |
+
# # Linear decay
|
129 |
+
# decay_steps = self.total_steps - self.warmup_steps
|
130 |
+
# if decay_steps > 0:
|
131 |
+
# lr = self.peak_lr - (self.peak_lr - self.min_lr) * ((step - self.warmup_steps) / decay_steps)
|
132 |
+
# else:
|
133 |
+
# lr = self.peak_lr
|
134 |
+
# return lr
|
135 |
+
|
136 |
+
# def get_config(self):
|
137 |
+
# return {
|
138 |
+
# "initial_lr": self.initial_lr,
|
139 |
+
# "peak_lr": self.peak_lr,
|
140 |
+
# "min_lr": self.min_lr,
|
141 |
+
# "warmup_steps": self.warmup_steps,
|
142 |
+
# "total_steps": self.total_steps,
|
143 |
+
# }
|
144 |
+
|
145 |
+
class RetrievalChatbot:
|
146 |
+
"""Retrieval-based chatbot using pretrained embeddings and FAISS for similarity search."""
|
147 |
+
def __init__(self, config: ChatbotConfig, dialogues: List[dict] = []):
|
148 |
+
self.config = config
|
149 |
+
|
150 |
+
# Special tokens
|
151 |
+
self.special_tokens = {
|
152 |
+
"user": "<USER>",
|
153 |
+
"assistant": "<ASSISTANT>",
|
154 |
+
"context": "<CONTEXT>",
|
155 |
+
"sep": "<SEP>"
|
156 |
+
}
|
157 |
+
|
158 |
+
# Initialize tokenizer and add special tokens
|
159 |
+
self.tokenizer = AutoTokenizer.from_pretrained(config.pretrained_model)
|
160 |
+
self.tokenizer.add_special_tokens(
|
161 |
+
{'additional_special_tokens': list(self.special_tokens.values())}
|
162 |
+
)
|
163 |
+
|
164 |
+
# Build encoders
|
165 |
+
self._build_models()
|
166 |
+
|
167 |
+
# Initialize FAISS index
|
168 |
+
self._initialize_faiss()
|
169 |
+
|
170 |
+
# Precompute and index response embeddings
|
171 |
+
self._precompute_and_index_responses(dialogues)
|
172 |
+
|
173 |
+
# Initialize training history
|
174 |
+
self.history = {
|
175 |
+
"train_loss": [],
|
176 |
+
"val_loss": [],
|
177 |
+
"train_metrics": {},
|
178 |
+
"val_metrics": {}
|
179 |
+
}
|
180 |
+
|
181 |
+
def _build_models(self):
|
182 |
+
"""Initialize the shared encoder."""
|
183 |
+
logger.info("Building encoder model...")
|
184 |
+
|
185 |
+
# Shared encoder for both queries and responses
|
186 |
+
self.encoder = EncoderModel(
|
187 |
+
self.config,
|
188 |
+
name="shared_encoder",
|
189 |
+
)
|
190 |
+
|
191 |
+
# Resize token embeddings after adding special tokens
|
192 |
+
new_vocab_size = len(self.tokenizer)
|
193 |
+
self.encoder.pretrained.resize_token_embeddings(new_vocab_size)
|
194 |
+
logger.info(f"Token embeddings resized to: {new_vocab_size}")
|
195 |
+
|
196 |
+
# Inspect embeddings attributes for debugging
|
197 |
+
logger.info("Inspecting embeddings attributes:")
|
198 |
+
for attr in dir(self.encoder.pretrained.distilbert.embeddings):
|
199 |
+
if not attr.startswith('_'):
|
200 |
+
logger.info(f" {attr}")
|
201 |
+
|
202 |
+
# Verify embedding layers without accessing word_embeddings directly
|
203 |
+
embedding_dim = getattr(self.encoder.pretrained.distilbert.embeddings, 'embedding_dim', 'Unknown')
|
204 |
+
vocab_size = getattr(self.encoder.pretrained.distilbert.embeddings, 'input_dim', len(self.tokenizer))
|
205 |
+
logger.info(f"Encoder Embedding Dimension: {embedding_dim}")
|
206 |
+
logger.info(f"Encoder Embedding Vocabulary Size: {vocab_size}")
|
207 |
+
|
208 |
+
logger.info("Encoder model built and embeddings resized successfully.")
|
209 |
+
for var in self.encoder.pretrained.trainable_variables:
|
210 |
+
logger.info(f"{var.name}, {var.shape}")
|
211 |
+
|
212 |
+
def check_trainable_variables(self):
|
213 |
+
"""Logs the trainable variables in both encoders."""
|
214 |
+
logger.info("Checking trainable variables in shared_encoder:")
|
215 |
+
for var in self.encoder.pretrained.trainable_variables:
|
216 |
+
logger.info(f" {var.name}, shape: {var.shape}")
|
217 |
+
|
218 |
+
# logger.info("Checking trainable variables in response_encoder:")
|
219 |
+
# for var in self.response_encoder.pretrained.trainable_variables:
|
220 |
+
# logger.info(f" {var.name}, shape: {var.shape}")
|
221 |
+
|
222 |
+
def _initialize_faiss(self):
|
223 |
+
"""Initialize FAISS index based on available resources."""
|
224 |
+
logger.info("Initializing FAISS index...")
|
225 |
+
# Determine if GPU FAISS is available
|
226 |
+
try:
|
227 |
+
res = faiss.StandardGpuResources()
|
228 |
+
self.faiss_gpu = True
|
229 |
+
logger.info("FAISS GPU resources initialized.")
|
230 |
+
except Exception as e:
|
231 |
+
self.faiss_gpu = False
|
232 |
+
logger.info("FAISS GPU resources not available. Using FAISS CPU.")
|
233 |
+
|
234 |
+
# Initialize FAISS index for Inner Product (for cosine similarity)
|
235 |
+
if self.faiss_gpu:
|
236 |
+
self.index = faiss.IndexFlatIP(self.config.embedding_dim)
|
237 |
+
self.index = faiss.index_cpu_to_gpu(res, 0, self.index)
|
238 |
+
else:
|
239 |
+
self.index = faiss.IndexFlatIP(self.config.embedding_dim)
|
240 |
+
logger.info("FAISS index initialized.")
|
241 |
+
|
242 |
+
def verify_faiss_index(chatbot):
|
243 |
+
"""Verify that FAISS index matches the response pool."""
|
244 |
+
indexed_size = chatbot.index.ntotal
|
245 |
+
pool_size = len(chatbot.response_pool)
|
246 |
+
logger.info(f"FAISS index size: {indexed_size}")
|
247 |
+
logger.info(f"Response pool size: {pool_size}")
|
248 |
+
if indexed_size != pool_size:
|
249 |
+
logger.warning("Mismatch between FAISS index size and response pool size.")
|
250 |
+
else:
|
251 |
+
logger.info("FAISS index correctly matches the response pool.")
|
252 |
+
|
253 |
+
|
254 |
+
def _precompute_and_index_responses(self, dialogues: List[dict]):
|
255 |
+
"""Precompute embeddings for all responses and index them using FAISS."""
|
256 |
+
logger.info("Precomputing response embeddings and indexing with FAISS...")
|
257 |
+
|
258 |
+
# Use tqdm for collecting responses
|
259 |
+
responses = []
|
260 |
+
for dialogue in tqdm(dialogues, desc="Collecting assistant responses"):
|
261 |
+
turns = dialogue.get('turns', [])
|
262 |
+
for turn in turns:
|
263 |
+
if turn.get('speaker') == 'assistant' and 'text' in turn:
|
264 |
+
responses.append(turn['text'].strip())
|
265 |
+
|
266 |
+
# Remove duplicates
|
267 |
+
unique_responses = list(set(responses))
|
268 |
+
logger.info(f"Found {len(unique_responses)} unique responses.")
|
269 |
+
|
270 |
+
# Encode responses
|
271 |
+
response_embeddings = self.encode_responses(unique_responses)
|
272 |
+
response_embeddings = response_embeddings.numpy()
|
273 |
+
|
274 |
+
# Ensure float32
|
275 |
+
if response_embeddings.dtype != np.float32:
|
276 |
+
logger.info(f"Converting embeddings from {response_embeddings.dtype} to float32.")
|
277 |
+
response_embeddings = response_embeddings.astype('float32')
|
278 |
+
|
279 |
+
# Ensure the array is contiguous in memory
|
280 |
+
if not response_embeddings.flags['C_CONTIGUOUS']:
|
281 |
+
logger.info("Making embeddings contiguous in memory.")
|
282 |
+
response_embeddings = np.ascontiguousarray(response_embeddings)
|
283 |
+
|
284 |
+
# Normalize embeddings for cosine similarity
|
285 |
+
logger.info("Normalizing embeddings with FAISS.")
|
286 |
+
faiss.normalize_L2(response_embeddings)
|
287 |
+
|
288 |
+
# Add to FAISS index
|
289 |
+
logger.info("Adding embeddings to FAISS index...")
|
290 |
+
self.index.add(response_embeddings)
|
291 |
+
logger.info(f"Indexed {self.index.ntotal} responses.")
|
292 |
+
|
293 |
+
# Store responses and embeddings
|
294 |
+
self.response_pool = unique_responses
|
295 |
+
self.response_embeddings = response_embeddings
|
296 |
+
logger.info("Precomputation and indexing completed.")
|
297 |
+
|
298 |
+
def encode_responses(
|
299 |
+
self,
|
300 |
+
responses: List[str],
|
301 |
+
batch_size: int = 64
|
302 |
+
) -> tf.Tensor:
|
303 |
+
"""
|
304 |
+
Encodes a list of responses into embeddings, using chunked/batched processing
|
305 |
+
to avoid running out of memory when there are many responses.
|
306 |
+
|
307 |
+
Args:
|
308 |
+
responses (List[str]): The list of response texts to encode.
|
309 |
+
batch_size (int): How many responses to encode per chunk.
|
310 |
+
Adjust based on available GPU/CPU memory.
|
311 |
+
|
312 |
+
Returns:
|
313 |
+
tf.Tensor: Tensor of shape (N, emb_dim) with all response embeddings.
|
314 |
+
"""
|
315 |
+
logger.info(f"Encoding {len(responses)} responses in batches of size {batch_size}...")
|
316 |
+
|
317 |
+
# We'll accumulate embeddings in a list and concatenate at the end
|
318 |
+
all_embeddings = []
|
319 |
+
|
320 |
+
# Set up a progress bar
|
321 |
+
from tqdm import tqdm
|
322 |
+
pbar = tqdm(total=len(responses), desc="Encoding responses")
|
323 |
+
|
324 |
+
# Process the responses in chunks of 'batch_size'
|
325 |
+
for start_idx in range(0, len(responses), batch_size):
|
326 |
+
end_idx = start_idx + batch_size
|
327 |
+
batch_texts = responses[start_idx:end_idx]
|
328 |
+
|
329 |
+
# Tokenize the current batch
|
330 |
+
encodings = self.tokenizer(
|
331 |
+
batch_texts,
|
332 |
+
padding='max_length',
|
333 |
+
truncation=True,
|
334 |
+
max_length=self.config.max_sequence_length,
|
335 |
+
return_tensors='tf',
|
336 |
+
)
|
337 |
+
|
338 |
+
# Run the encoder forward pass
|
339 |
+
input_ids = encodings['input_ids']
|
340 |
+
embeddings_batch = self.encoder(input_ids, training=False)
|
341 |
+
|
342 |
+
# Cast to float32 if needed
|
343 |
+
if embeddings_batch.dtype != tf.float32:
|
344 |
+
embeddings_batch = tf.cast(embeddings_batch, tf.float32)
|
345 |
+
|
346 |
+
# Collect
|
347 |
+
all_embeddings.append(embeddings_batch)
|
348 |
+
|
349 |
+
# Update progress bar
|
350 |
+
pbar.update(len(batch_texts))
|
351 |
+
|
352 |
+
pbar.close()
|
353 |
+
|
354 |
+
# Concatenate all batch embeddings along axis=0
|
355 |
+
if len(all_embeddings) == 1:
|
356 |
+
# Only one batch
|
357 |
+
final_embeddings = all_embeddings[0]
|
358 |
+
else:
|
359 |
+
# Multiple batches, concatenate
|
360 |
+
final_embeddings = tf.concat(all_embeddings, axis=0)
|
361 |
+
|
362 |
+
logger.info(
|
363 |
+
f"Finished encoding {len(responses)} responses. "
|
364 |
+
f"Final shape: {final_embeddings.shape}"
|
365 |
+
)
|
366 |
+
return final_embeddings
|
367 |
+
|
368 |
+
def encode_query(self, query: str, context: Optional[List[Tuple[str, str]]] = None) -> tf.Tensor:
|
369 |
+
"""Encode a query with optional conversation context."""
|
370 |
+
# Prepare query with context
|
371 |
+
if context:
|
372 |
+
context_str = ' '.join([
|
373 |
+
f"{self.special_tokens['user']} {q} "
|
374 |
+
f"{self.special_tokens['assistant']} {r}"
|
375 |
+
for q, r in context[-self.config.max_context_turns:]
|
376 |
+
])
|
377 |
+
query = f"{context_str} {self.special_tokens['user']} {query}"
|
378 |
+
else:
|
379 |
+
query = f"{self.special_tokens['user']} {query}"
|
380 |
+
|
381 |
+
# Tokenize and encode
|
382 |
+
encodings = self.tokenizer(
|
383 |
+
[query],
|
384 |
+
padding='max_length',
|
385 |
+
truncation=True,
|
386 |
+
max_length=self.config.max_sequence_length,
|
387 |
+
return_tensors='tf'
|
388 |
+
)
|
389 |
+
input_ids = encodings['input_ids']
|
390 |
+
|
391 |
+
# Verify token IDs
|
392 |
+
max_id = tf.reduce_max(input_ids).numpy()
|
393 |
+
new_vocab_size = len(self.tokenizer)
|
394 |
+
logger.info(f"Maximum input_id: {max_id}, Vocab Size: {new_vocab_size}")
|
395 |
+
|
396 |
+
if max_id >= new_vocab_size:
|
397 |
+
logger.error(f"Token ID {max_id} exceeds the vocabulary size {new_vocab_size}.")
|
398 |
+
raise ValueError("Token ID exceeds vocabulary size.")
|
399 |
+
|
400 |
+
# Get embeddings from the shared encoder
|
401 |
+
return self.encoder(input_ids, training=False)
|
402 |
+
|
403 |
+
def retrieve_responses_faiss(self, query: str, top_k: int = 5) -> List[Tuple[str, float]]:
|
404 |
+
"""Retrieve top-k responses using FAISS."""
|
405 |
+
# Encode the query
|
406 |
+
q_emb = self.encode_query(query) # Shape: [1, embedding_dim]
|
407 |
+
q_emb_np = q_emb.numpy().astype('float32') # Ensure type matches FAISS requirements
|
408 |
+
|
409 |
+
# Normalize the query embedding for cosine similarity
|
410 |
+
faiss.normalize_L2(q_emb_np)
|
411 |
+
|
412 |
+
# Search the FAISS index
|
413 |
+
distances, indices = self.index.search(q_emb_np, top_k)
|
414 |
+
|
415 |
+
# Map indices to responses and distances to similarities
|
416 |
+
top_responses = []
|
417 |
+
for i, idx in enumerate(indices[0]):
|
418 |
+
if idx < len(self.response_pool):
|
419 |
+
top_responses.append((self.response_pool[idx], float(distances[0][i])))
|
420 |
+
else:
|
421 |
+
logger.warning(f"FAISS returned invalid index {idx}. Skipping.")
|
422 |
+
|
423 |
+
return top_responses
|
424 |
+
|
425 |
+
def save_models(self, save_dir: Union[str, Path]):
|
426 |
+
"""Save models and configuration."""
|
427 |
+
save_dir = Path(save_dir)
|
428 |
+
save_dir.mkdir(parents=True, exist_ok=True)
|
429 |
+
|
430 |
+
# Save config
|
431 |
+
with open(save_dir / "config.json", "w") as f:
|
432 |
+
json.dump(self.config.to_dict(), f, indent=2)
|
433 |
+
|
434 |
+
# Save models
|
435 |
+
self.encoder.pretrained.save_pretrained(save_dir / "shared_encoder")
|
436 |
+
|
437 |
+
# Save tokenizer
|
438 |
+
self.tokenizer.save_pretrained(save_dir / "tokenizer")
|
439 |
+
|
440 |
+
logger.info(f"Models and tokenizer saved to {save_dir}.")
|
441 |
+
|
442 |
+
@classmethod
|
443 |
+
def load_models(cls, load_dir: Union[str, Path]) -> 'RetrievalChatbot':
|
444 |
+
"""Load saved models and configuration."""
|
445 |
+
load_dir = Path(load_dir)
|
446 |
+
|
447 |
+
# Load config
|
448 |
+
with open(load_dir / "config.json", "r") as f:
|
449 |
+
config = ChatbotConfig.from_dict(json.load(f))
|
450 |
+
|
451 |
+
# Initialize chatbot
|
452 |
+
chatbot = cls(config)
|
453 |
+
|
454 |
+
# Load models
|
455 |
+
chatbot.encoder.pretrained = TFAutoModel.from_pretrained(
|
456 |
+
load_dir / "shared_encoder",
|
457 |
+
config=config
|
458 |
+
)
|
459 |
+
# chatbot.response_encoder.pretrained = TFAutoModel.from_pretrained(
|
460 |
+
# load_dir / "response_encoder",
|
461 |
+
# config=config
|
462 |
+
# )
|
463 |
+
|
464 |
+
# Load tokenizer
|
465 |
+
chatbot.tokenizer = AutoTokenizer.from_pretrained(load_dir / "tokenizer")
|
466 |
+
|
467 |
+
logger.info(f"Models and tokenizer loaded from {load_dir}.")
|
468 |
+
return chatbot
|
469 |
+
|
470 |
+
@staticmethod
|
471 |
+
def load_training_data(data_path: Union[str, Path], debug_samples: Optional[int] = None) -> List[dict]:
|
472 |
+
"""
|
473 |
+
Load training data from a JSON file.
|
474 |
+
|
475 |
+
Args:
|
476 |
+
data_path (Union[str, Path]): Path to the JSON file containing dialogues.
|
477 |
+
debug_samples (Optional[int]): Number of samples to load for debugging.
|
478 |
+
|
479 |
+
Returns:
|
480 |
+
List[dict]: List of dialogue dictionaries.
|
481 |
+
"""
|
482 |
+
logger.info(f"Loading training data from {data_path}...")
|
483 |
+
data_path = Path(data_path)
|
484 |
+
if not data_path.exists():
|
485 |
+
logger.error(f"Data file {data_path} does not exist.")
|
486 |
+
return []
|
487 |
+
|
488 |
+
with open(data_path, 'r', encoding='utf-8') as f:
|
489 |
+
dialogues = json.load(f)
|
490 |
+
|
491 |
+
if debug_samples is not None:
|
492 |
+
dialogues = dialogues[:debug_samples]
|
493 |
+
logger.info(f"Debug mode: Limited to {debug_samples} dialogues")
|
494 |
+
|
495 |
+
logger.info(f"Loaded {len(dialogues)} dialogues.")
|
496 |
+
return dialogues
|
497 |
+
|
498 |
+
def prepare_dataset(
|
499 |
+
self,
|
500 |
+
dialogues: List[dict],
|
501 |
+
debug_samples: int = None
|
502 |
+
) -> Tuple[tf.Tensor, tf.Tensor]:
|
503 |
+
"""
|
504 |
+
Prepares dataset for in-batch negatives:
|
505 |
+
Only returns (query, positive) pairs.
|
506 |
+
"""
|
507 |
+
logger.info("Preparing in-batch dataset...")
|
508 |
+
|
509 |
+
queries, positives = [], []
|
510 |
+
|
511 |
+
for dialogue in dialogues:
|
512 |
+
turns = dialogue.get('turns', [])
|
513 |
+
for i in range(len(turns) - 1):
|
514 |
+
current_turn = turns[i]
|
515 |
+
next_turn = turns[i+1]
|
516 |
+
|
517 |
+
if (current_turn.get('speaker') == 'user' and
|
518 |
+
next_turn.get('speaker') == 'assistant' and
|
519 |
+
'text' in current_turn and
|
520 |
+
'text' in next_turn):
|
521 |
+
|
522 |
+
query = current_turn['text'].strip()
|
523 |
+
positive = next_turn['text'].strip()
|
524 |
+
|
525 |
+
queries.append(query)
|
526 |
+
positives.append(positive)
|
527 |
+
|
528 |
+
# Optional debug slicing
|
529 |
+
if debug_samples is not None:
|
530 |
+
queries = queries[:debug_samples]
|
531 |
+
positives = positives[:debug_samples]
|
532 |
+
logger.info(f"Debug mode: limited to {debug_samples} pairs.")
|
533 |
+
|
534 |
+
logger.info(f"Prepared {len(queries)} (query, positive) pairs.")
|
535 |
+
|
536 |
+
# Tokenize queries
|
537 |
+
encoded_queries = self.tokenizer(
|
538 |
+
queries,
|
539 |
+
padding='max_length',
|
540 |
+
truncation=True,
|
541 |
+
max_length=self.config.max_sequence_length,
|
542 |
+
return_tensors='tf'
|
543 |
+
)
|
544 |
+
# Tokenize positives
|
545 |
+
encoded_positives = self.tokenizer(
|
546 |
+
positives,
|
547 |
+
padding='max_length',
|
548 |
+
truncation=True,
|
549 |
+
max_length=self.config.max_sequence_length,
|
550 |
+
return_tensors='tf'
|
551 |
+
)
|
552 |
+
|
553 |
+
q_tensor = encoded_queries['input_ids']
|
554 |
+
p_tensor = encoded_positives['input_ids']
|
555 |
+
|
556 |
+
logger.info("Tokenized and padded sequences for in-batch training.")
|
557 |
+
return q_tensor, p_tensor
|
558 |
+
|
559 |
+
def train(
|
560 |
+
self,
|
561 |
+
q_pad: tf.Tensor,
|
562 |
+
p_pad: tf.Tensor,
|
563 |
+
epochs: int,
|
564 |
+
batch_size: int,
|
565 |
+
validation_split: float,
|
566 |
+
checkpoint_dir: str,
|
567 |
+
use_lr_schedule: bool = True,
|
568 |
+
peak_lr: float = 2e-5,
|
569 |
+
warmup_steps_ratio: float = 0.1,
|
570 |
+
early_stopping_patience: int = 3,
|
571 |
+
min_delta: float = 1e-4
|
572 |
+
):
|
573 |
+
dataset_size = tf.shape(q_pad)[0].numpy()
|
574 |
+
val_size = int(dataset_size * validation_split)
|
575 |
+
train_size = dataset_size - val_size
|
576 |
+
|
577 |
+
logger.info(f"Total samples: {dataset_size}")
|
578 |
+
logger.info(f"Training samples: {train_size}")
|
579 |
+
logger.info(f"Validation samples: {val_size}")
|
580 |
+
|
581 |
+
steps_per_epoch = train_size // batch_size
|
582 |
+
if train_size % batch_size != 0:
|
583 |
+
steps_per_epoch += 1
|
584 |
+
total_steps = steps_per_epoch * epochs
|
585 |
+
logger.info(f"Total training steps (approx): {total_steps}")
|
586 |
+
|
587 |
+
# 1) Set up LR schedule or fixed LR
|
588 |
+
if use_lr_schedule:
|
589 |
+
warmup_steps = int(total_steps * warmup_steps_ratio)
|
590 |
+
lr_schedule = self._get_lr_schedule(
|
591 |
+
total_steps=total_steps,
|
592 |
+
peak_lr=peak_lr,
|
593 |
+
warmup_steps=warmup_steps
|
594 |
+
)
|
595 |
+
self.optimizer = tf.keras.optimizers.Adam(learning_rate=lr_schedule)
|
596 |
+
logger.info("Using custom learning rate schedule.")
|
597 |
+
else:
|
598 |
+
self.optimizer = tf.keras.optimizers.Adam(learning_rate=peak_lr)
|
599 |
+
logger.info("Using fixed learning rate.")
|
600 |
+
|
601 |
+
# 2) Prepare data splits
|
602 |
+
train_q = q_pad[:train_size]
|
603 |
+
train_p = p_pad[:train_size]
|
604 |
+
val_q = q_pad[train_size:]
|
605 |
+
val_p = p_pad[train_size:]
|
606 |
+
|
607 |
+
train_dataset = tf.data.Dataset.from_tensor_slices((train_q, train_p))
|
608 |
+
train_dataset = train_dataset.shuffle(buffer_size=4096).batch(batch_size)
|
609 |
+
|
610 |
+
val_dataset = tf.data.Dataset.from_tensor_slices((val_q, val_p))
|
611 |
+
val_dataset = val_dataset.batch(batch_size)
|
612 |
+
|
613 |
+
# 3) Checkpoint + manager
|
614 |
+
checkpoint = tf.train.Checkpoint(optimizer=self.optimizer, model=self.encoder)
|
615 |
+
manager = tf.train.CheckpointManager(checkpoint, checkpoint_dir, max_to_keep=3)
|
616 |
+
|
617 |
+
# 4) TensorBoard setup
|
618 |
+
import datetime
|
619 |
+
import os
|
620 |
+
from pathlib import Path
|
621 |
+
|
622 |
+
log_dir = Path(checkpoint_dir) / "tensorboard_logs"
|
623 |
+
log_dir.mkdir(parents=True, exist_ok=True)
|
624 |
+
|
625 |
+
current_time = datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
|
626 |
+
train_log_dir = str(log_dir / f"train_{current_time}")
|
627 |
+
val_log_dir = str(log_dir / f"val_{current_time}")
|
628 |
+
|
629 |
+
train_summary_writer = tf.summary.create_file_writer(train_log_dir)
|
630 |
+
val_summary_writer = tf.summary.create_file_writer(val_log_dir)
|
631 |
+
|
632 |
+
logger.info(f"TensorBoard logs will be saved in {log_dir}")
|
633 |
+
|
634 |
+
# 5) Early stopping
|
635 |
+
best_val_loss = float("inf")
|
636 |
+
epochs_no_improve = 0
|
637 |
+
|
638 |
+
logger.info("Beginning training loop...")
|
639 |
+
global_step = 0
|
640 |
+
|
641 |
+
from tqdm import tqdm
|
642 |
+
for epoch in range(1, epochs + 1):
|
643 |
+
logger.info(f"\n=== Epoch {epoch}/{epochs} ===")
|
644 |
+
epoch_loss_avg = tf.keras.metrics.Mean()
|
645 |
+
|
646 |
+
# Training loop
|
647 |
+
with tqdm(total=steps_per_epoch, desc=f"Training Epoch {epoch}") as pbar:
|
648 |
+
for (q_batch, p_batch) in train_dataset:
|
649 |
+
global_step += 1
|
650 |
+
|
651 |
+
# Train step
|
652 |
+
batch_loss = self._train_step(q_batch, p_batch)
|
653 |
+
epoch_loss_avg(batch_loss)
|
654 |
+
|
655 |
+
# Get current LR
|
656 |
+
if use_lr_schedule:
|
657 |
+
lr = self.optimizer.learning_rate
|
658 |
+
if isinstance(lr, tf.keras.optimizers.schedules.LearningRateSchedule):
|
659 |
+
# Get the current step
|
660 |
+
current_step = tf.cast(self.optimizer.iterations, tf.float32)
|
661 |
+
# Compute the current learning rate
|
662 |
+
current_lr = lr(current_step)
|
663 |
+
else:
|
664 |
+
# If learning_rate is not a schedule, use it directly
|
665 |
+
current_lr = lr
|
666 |
+
# Convert to float for logging
|
667 |
+
current_lr_value = float(current_lr.numpy())
|
668 |
+
else:
|
669 |
+
# If using fixed learning rate
|
670 |
+
current_lr_value = float(self.optimizer.learning_rate.numpy())
|
671 |
+
|
672 |
+
# Update tqdm
|
673 |
+
pbar.update(1)
|
674 |
+
pbar.set_postfix({
|
675 |
+
"loss": f"{batch_loss.numpy():.4f}",
|
676 |
+
"lr": f"{current_lr_value:.2e}"
|
677 |
+
})
|
678 |
+
|
679 |
+
# TensorBoard: log train metrics per step
|
680 |
+
with train_summary_writer.as_default():
|
681 |
+
tf.summary.scalar("loss", batch_loss, step=global_step)
|
682 |
+
tf.summary.scalar("learning_rate", current_lr_value, step=global_step)
|
683 |
+
|
684 |
+
# Validation
|
685 |
+
val_loss_avg = tf.keras.metrics.Mean()
|
686 |
+
for q_val, p_val in val_dataset:
|
687 |
+
q_enc = self.encoder(q_val, training=False)
|
688 |
+
p_enc = self.encoder(p_val, training=False)
|
689 |
+
sim_matrix = tf.matmul(q_enc, p_enc, transpose_b=True)
|
690 |
+
bs_val = tf.shape(q_enc)[0]
|
691 |
+
labels_val = tf.range(bs_val, dtype=tf.int32)
|
692 |
+
loss_val = tf.nn.sparse_softmax_cross_entropy_with_logits(
|
693 |
+
labels=labels_val,
|
694 |
+
logits=sim_matrix
|
695 |
+
)
|
696 |
+
val_loss_avg(tf.reduce_mean(loss_val))
|
697 |
+
|
698 |
+
train_loss = epoch_loss_avg.result().numpy()
|
699 |
+
val_loss = val_loss_avg.result().numpy()
|
700 |
+
|
701 |
+
logger.info(f"Epoch {epoch} Complete: Train Loss={train_loss:.4f}, Val Loss={val_loss:.4f}")
|
702 |
+
|
703 |
+
# TensorBoard: validation loss
|
704 |
+
with val_summary_writer.as_default():
|
705 |
+
tf.summary.scalar("val_loss", val_loss, step=epoch)
|
706 |
+
|
707 |
+
# Save checkpoint
|
708 |
+
manager.save()
|
709 |
+
|
710 |
+
# Update history
|
711 |
+
self.history['train_loss'].append(train_loss)
|
712 |
+
self.history['val_loss'].append(val_loss)
|
713 |
+
self.history.setdefault('learning_rate', []).append(float(current_lr_value))
|
714 |
+
|
715 |
+
# Early stopping
|
716 |
+
if val_loss < best_val_loss - min_delta:
|
717 |
+
best_val_loss = val_loss
|
718 |
+
epochs_no_improve = 0
|
719 |
+
logger.info(f"Validation loss improved to {val_loss:.4f}. Reset patience.")
|
720 |
+
else:
|
721 |
+
epochs_no_improve += 1
|
722 |
+
logger.info(f"No improvement this epoch. Patience: {epochs_no_improve}/{early_stopping_patience}")
|
723 |
+
if epochs_no_improve >= early_stopping_patience:
|
724 |
+
logger.info("Early stopping triggered.")
|
725 |
+
break
|
726 |
+
|
727 |
+
logger.info("In-batch training completed!")
|
728 |
+
|
729 |
+
@tf.function
|
730 |
+
def _train_step(self, q_batch, p_batch):
|
731 |
+
"""
|
732 |
+
Single training step using in-batch negatives.
|
733 |
+
q_batch: (batch_size, seq_len) int32 input_ids for queries
|
734 |
+
p_batch: (batch_size, seq_len) int32 input_ids for positives
|
735 |
+
"""
|
736 |
+
with tf.GradientTape() as tape:
|
737 |
+
# Encode queries and positives
|
738 |
+
q_enc = self.encoder(q_batch, training=True) # [B, emb_dim]
|
739 |
+
p_enc = self.encoder(p_batch, training=True) # [B, emb_dim]
|
740 |
+
|
741 |
+
# Compute similarity matrix: (B, B) = q_enc * p_enc^T
|
742 |
+
# If embeddings are L2-normalized, this is cosine similarity
|
743 |
+
sim_matrix = tf.matmul(q_enc, p_enc, transpose_b=True) # [B, B]
|
744 |
+
|
745 |
+
# Labels are just the diagonal indices
|
746 |
+
batch_size = tf.shape(q_enc)[0]
|
747 |
+
labels = tf.range(batch_size, dtype=tf.int32) # [0..B-1]
|
748 |
+
|
749 |
+
# Softmax cross-entropy
|
750 |
+
loss = tf.nn.sparse_softmax_cross_entropy_with_logits(
|
751 |
+
labels=labels,
|
752 |
+
logits=sim_matrix
|
753 |
+
)
|
754 |
+
loss = tf.reduce_mean(loss)
|
755 |
+
|
756 |
+
# Compute gradients for the pretrained DistilBERT variables only
|
757 |
+
train_vars = self.encoder.pretrained.trainable_variables
|
758 |
+
gradients = tape.gradient(loss, train_vars)
|
759 |
+
|
760 |
+
# Remove any None grads (in case some layers are frozen)
|
761 |
+
grads_and_vars = [(g, v) for g, v in zip(gradients, train_vars) if g is not None]
|
762 |
+
if grads_and_vars:
|
763 |
+
self.optimizer.apply_gradients(grads_and_vars)
|
764 |
+
|
765 |
+
return loss
|
766 |
+
|
767 |
+
def _prepare_sequences(
|
768 |
+
self,
|
769 |
+
queries: List[str],
|
770 |
+
positives: List[str],
|
771 |
+
negatives: List[str]
|
772 |
+
) -> Tuple[tf.Tensor, tf.Tensor, tf.Tensor]:
|
773 |
+
"""Prepare and tokenize sequences for training."""
|
774 |
+
logger.info("Preparing sequences for training...")
|
775 |
+
|
776 |
+
# Handle empty lists
|
777 |
+
if not queries:
|
778 |
+
logger.error("No queries to encode. Skipping sequence preparation.")
|
779 |
+
return tf.constant([]), tf.constant([]), tf.constant([])
|
780 |
+
|
781 |
+
# Process texts with special tokens
|
782 |
+
queries = [f"{self.special_tokens['user']} {q}" for q in queries]
|
783 |
+
positives = [f"{self.special_tokens['assistant']} {p}" for p in positives]
|
784 |
+
negatives = [f"{self.special_tokens['assistant']} {n}" for n in negatives]
|
785 |
+
|
786 |
+
# Tokenize using HuggingFace tokenizer
|
787 |
+
def encode_batch(texts: List[str]) -> tf.Tensor:
|
788 |
+
if not texts:
|
789 |
+
logger.error("Empty text list provided to tokenizer.")
|
790 |
+
return tf.constant([])
|
791 |
+
encodings = self.tokenizer(
|
792 |
+
texts,
|
793 |
+
padding='max_length',
|
794 |
+
truncation=True,
|
795 |
+
max_length=self.config.max_sequence_length,
|
796 |
+
return_tensors='tf'
|
797 |
+
)
|
798 |
+
return encodings['input_ids']
|
799 |
+
|
800 |
+
# Encode all sequences
|
801 |
+
q_tensor = encode_batch(queries)
|
802 |
+
p_tensor = encode_batch(positives)
|
803 |
+
n_tensor = encode_batch(negatives)
|
804 |
+
|
805 |
+
# Log statistics about encoded sequences
|
806 |
+
logger.info("Sequence statistics:")
|
807 |
+
logger.info(f" Query sequence shape: {q_tensor.shape}")
|
808 |
+
logger.info(f" Positive response sequence shape: {p_tensor.shape}")
|
809 |
+
logger.info(f" Negative response sequence shape: {n_tensor.shape}")
|
810 |
+
|
811 |
+
return q_tensor, p_tensor, n_tensor
|
812 |
+
|
813 |
+
def _get_lr_schedule(
|
814 |
+
self,
|
815 |
+
total_steps: int,
|
816 |
+
peak_lr: float,
|
817 |
+
warmup_steps: int
|
818 |
+
) -> tf.keras.optimizers.schedules.LearningRateSchedule:
|
819 |
+
"""Create a custom learning rate schedule with warmup and cosine decay."""
|
820 |
+
class CustomSchedule(tf.keras.optimizers.schedules.LearningRateSchedule):
|
821 |
+
def __init__(
|
822 |
+
self,
|
823 |
+
total_steps: int,
|
824 |
+
peak_lr: float,
|
825 |
+
warmup_steps: int
|
826 |
+
):
|
827 |
+
super().__init__()
|
828 |
+
self.total_steps = tf.cast(total_steps, tf.float32)
|
829 |
+
self.peak_lr = tf.cast(peak_lr, tf.float32)
|
830 |
+
|
831 |
+
# Adjust warmup_steps to not exceed half of total_steps
|
832 |
+
adjusted_warmup_steps = min(warmup_steps, max(1, total_steps // 10))
|
833 |
+
self.warmup_steps = tf.cast(adjusted_warmup_steps, tf.float32)
|
834 |
+
|
835 |
+
# Calculate and store constants
|
836 |
+
self.initial_lr = self.peak_lr * 0.1 # Start at 10% of peak
|
837 |
+
self.min_lr = self.peak_lr * 0.01 # Minimum 1% of peak
|
838 |
+
|
839 |
+
logger.info(f"Learning rate schedule initialized:")
|
840 |
+
logger.info(f" Initial LR: {float(self.initial_lr):.6f}")
|
841 |
+
logger.info(f" Peak LR: {float(self.peak_lr):.6f}")
|
842 |
+
logger.info(f" Min LR: {float(self.min_lr):.6f}")
|
843 |
+
logger.info(f" Warmup steps: {int(self.warmup_steps)}")
|
844 |
+
logger.info(f" Total steps: {int(self.total_steps)}")
|
845 |
+
|
846 |
+
def __call__(self, step):
|
847 |
+
step = tf.cast(step, tf.float32)
|
848 |
+
|
849 |
+
# Warmup phase
|
850 |
+
warmup_factor = tf.minimum(1.0, step / self.warmup_steps)
|
851 |
+
warmup_lr = self.initial_lr + (self.peak_lr - self.initial_lr) * warmup_factor
|
852 |
+
|
853 |
+
# Decay phase
|
854 |
+
decay_steps = tf.maximum(1.0, self.total_steps - self.warmup_steps)
|
855 |
+
decay_factor = (step - self.warmup_steps) / decay_steps
|
856 |
+
decay_factor = tf.minimum(tf.maximum(0.0, decay_factor), 1.0) # Clip to [0,1]
|
857 |
+
|
858 |
+
cosine_decay = 0.5 * (1.0 + tf.cos(np.pi * decay_factor))
|
859 |
+
decay_lr = self.min_lr + (self.peak_lr - self.min_lr) * cosine_decay
|
860 |
+
|
861 |
+
# Choose between warmup and decay
|
862 |
+
final_lr = tf.where(step < self.warmup_steps, warmup_lr, decay_lr)
|
863 |
+
|
864 |
+
# Ensure learning rate is valid
|
865 |
+
final_lr = tf.maximum(self.min_lr, final_lr)
|
866 |
+
final_lr = tf.where(tf.math.is_finite(final_lr), final_lr, self.min_lr)
|
867 |
+
|
868 |
+
return final_lr
|
869 |
+
|
870 |
+
def get_config(self):
|
871 |
+
return {
|
872 |
+
"total_steps": self.total_steps,
|
873 |
+
"peak_lr": self.peak_lr,
|
874 |
+
"warmup_steps": self.warmup_steps,
|
875 |
+
}
|
876 |
+
|
877 |
+
return CustomSchedule(total_steps, peak_lr, warmup_steps)
|
878 |
+
|
879 |
+
def _cosine_similarity(self, emb1: np.ndarray, emb2: np.ndarray) -> np.ndarray:
|
880 |
+
"""Compute cosine similarity between two numpy arrays."""
|
881 |
+
normalized_emb1 = emb1 / np.linalg.norm(emb1, axis=1, keepdims=True)
|
882 |
+
normalized_emb2 = emb2 / np.linalg.norm(emb2, axis=1, keepdims=True)
|
883 |
+
return np.dot(normalized_emb1, normalized_emb2.T)
|
884 |
+
|
885 |
+
def run_automatic_validation(
|
886 |
+
self,
|
887 |
+
quality_checker: 'ResponseQualityChecker',
|
888 |
+
num_examples: int = 5
|
889 |
+
) -> Dict[str, Any]:
|
890 |
+
"""
|
891 |
+
Run automatic validation with quality metrics using FAISS-based retrieval.
|
892 |
+
"""
|
893 |
+
logger.info("\n=== Running Automatic Validation ===")
|
894 |
+
|
895 |
+
test_queries = [
|
896 |
+
"Hello, how are you today?",
|
897 |
+
"What's the weather like?",
|
898 |
+
"Can you help me with a problem?",
|
899 |
+
"Tell me a joke",
|
900 |
+
"What time is it?",
|
901 |
+
"I need help with my homework",
|
902 |
+
"Where's a good place to eat?",
|
903 |
+
"What movies are playing?",
|
904 |
+
"How do I reset my password?",
|
905 |
+
"Can you recommend a book?"
|
906 |
+
]
|
907 |
+
|
908 |
+
test_queries = test_queries[:num_examples]
|
909 |
+
metrics_history = []
|
910 |
+
|
911 |
+
for i, query in enumerate(test_queries, 1):
|
912 |
+
logger.info(f"\nTest Case {i}:")
|
913 |
+
logger.info(f"Query: {query}")
|
914 |
+
|
915 |
+
# Get responses and scores using FAISS
|
916 |
+
responses = self.retrieve_responses_faiss(query, top_k=5)
|
917 |
+
|
918 |
+
# Check quality
|
919 |
+
quality_metrics = quality_checker.check_response_quality(query, responses)
|
920 |
+
metrics_history.append(quality_metrics)
|
921 |
+
|
922 |
+
# Log results
|
923 |
+
logger.info(f"Quality Metrics: {quality_metrics}")
|
924 |
+
logger.info("Top responses:")
|
925 |
+
for j, (response, score) in enumerate(responses[:3], 1):
|
926 |
+
logger.info(f"{j}. Score: {score:.4f}")
|
927 |
+
logger.info(f" Response: {response}")
|
928 |
+
if j == 1 and not quality_metrics.get('is_confident', False):
|
929 |
+
logger.info(" [Low Confidence - Would abstain from answering]")
|
930 |
+
|
931 |
+
# Calculate aggregate metrics
|
932 |
+
aggregate_metrics = {
|
933 |
+
'num_queries_tested': len(test_queries),
|
934 |
+
'avg_top_response_score': np.mean([m.get('top_score', 0) for m in metrics_history]),
|
935 |
+
'avg_diversity': np.mean([m.get('response_diversity', 0) for m in metrics_history]),
|
936 |
+
'avg_relevance': np.mean([m.get('query_response_relevance', 0) for m in metrics_history]),
|
937 |
+
'avg_length_score': np.mean([m.get('response_length_score', 0) for m in metrics_history]),
|
938 |
+
'avg_score_gap': np.mean([m.get('top_3_score_gap', 0) for m in metrics_history]),
|
939 |
+
'confidence_rate': np.mean([m.get('is_confident', False) for m in metrics_history]),
|
940 |
+
}
|
941 |
+
|
942 |
+
logger.info("\n=== Validation Summary ===")
|
943 |
+
for metric, value in aggregate_metrics.items():
|
944 |
+
logger.info(f"{metric}: {value:.4f}")
|
945 |
+
|
946 |
+
return aggregate_metrics
|
947 |
+
|
948 |
+
def chat(
|
949 |
+
self,
|
950 |
+
query: str,
|
951 |
+
conversation_history: Optional[List[Tuple[str, str]]] = None,
|
952 |
+
quality_checker: Optional['ResponseQualityChecker'] = None,
|
953 |
+
top_k: int = 5
|
954 |
+
) -> Tuple[str, List[Tuple[str, float]], Dict[str, Any]]:
|
955 |
+
"""
|
956 |
+
Interactive chat function with quality checking using FAISS-based retrieval.
|
957 |
+
|
958 |
+
Args:
|
959 |
+
query (str): The user's input query.
|
960 |
+
conversation_history (Optional[List[Tuple[str, str]]]): List of past (user, assistant) exchanges.
|
961 |
+
quality_checker (Optional['ResponseQualityChecker']): Quality checker instance.
|
962 |
+
top_k (int): Number of top responses to retrieve.
|
963 |
+
|
964 |
+
Returns:
|
965 |
+
Tuple[str, List[Tuple[str, float]], Dict[str, Any]]: (Response, Candidates, Quality Metrics)
|
966 |
+
"""
|
967 |
+
# Retrieve responses using FAISS
|
968 |
+
responses = self.retrieve_responses_faiss(query, top_k)
|
969 |
+
|
970 |
+
# If no quality checker provided, return the top response
|
971 |
+
if quality_checker is None:
|
972 |
+
return responses[0][0] if responses else "I'm sorry, I don't have an answer for that.", responses, {}
|
973 |
+
|
974 |
+
# Check quality
|
975 |
+
quality_metrics = quality_checker.check_response_quality(query, responses)
|
976 |
+
|
977 |
+
if quality_metrics.get('is_confident', False):
|
978 |
+
return responses[0][0], responses, quality_metrics
|
979 |
+
else:
|
980 |
+
uncertainty_response = (
|
981 |
+
"I apologize, but I don't feel confident providing an answer to that "
|
982 |
+
"question at the moment. Could you please rephrase or ask something else?"
|
983 |
+
)
|
984 |
+
return uncertainty_response, responses, quality_metrics
|
985 |
+
|
986 |
+
# TODO: consider removal
|
987 |
+
# def prepare_dataset(self, dialogues: List[dict], neg_samples_per_pos: int = 1, debug_samples: int = None) -> Tuple[tf.Tensor, tf.Tensor, tf.Tensor]:
|
988 |
+
# """Prepares the dataset for training."""
|
989 |
+
# logger.info("Preparing dataset...")
|
990 |
+
|
991 |
+
# # Extract (query, positive, negative) triples
|
992 |
+
# queries, positives, negatives = [], [], []
|
993 |
+
|
994 |
+
# for dialogue in dialogues:
|
995 |
+
# turns = dialogue.get('turns', [])
|
996 |
+
# for i in range(len(turns) - 1):
|
997 |
+
# current_turn = turns[i]
|
998 |
+
# next_turn = turns[i+1]
|
999 |
+
|
1000 |
+
# if (current_turn.get('speaker') == 'user' and
|
1001 |
+
# next_turn.get('speaker') == 'assistant' and
|
1002 |
+
# 'text' in current_turn and
|
1003 |
+
# 'text' in next_turn):
|
1004 |
+
|
1005 |
+
# query = current_turn['text'].strip()
|
1006 |
+
# positive = next_turn['text'].strip()
|
1007 |
+
|
1008 |
+
# # Generate hard negative samples
|
1009 |
+
# hard_negatives = self.hard_negative_sampling(positive, n_samples=neg_samples_per_pos)
|
1010 |
+
# for negative in hard_negatives:
|
1011 |
+
# negatives.append(negative)
|
1012 |
+
# queries.append(query)
|
1013 |
+
# positives.append(positive)
|
1014 |
+
|
1015 |
+
# logger.info(f"Prepared {len(queries)} training examples.")
|
1016 |
+
|
1017 |
+
# # Tokenize and pad sequences
|
1018 |
+
# encoded_queries = self.tokenizer(
|
1019 |
+
# queries,
|
1020 |
+
# padding='max_length',
|
1021 |
+
# truncation=True,
|
1022 |
+
# max_length=self.config.max_sequence_length,
|
1023 |
+
# return_tensors='tf'
|
1024 |
+
# )
|
1025 |
+
# encoded_positives = self.tokenizer(
|
1026 |
+
# positives,
|
1027 |
+
# padding='max_length',
|
1028 |
+
# truncation=True,
|
1029 |
+
# max_length=self.config.max_sequence_length,
|
1030 |
+
# return_tensors='tf'
|
1031 |
+
# )
|
1032 |
+
# encoded_negatives = self.tokenizer(
|
1033 |
+
# negatives,
|
1034 |
+
# padding='max_length',
|
1035 |
+
# truncation=True,
|
1036 |
+
# max_length=self.config.max_sequence_length,
|
1037 |
+
# return_tensors='tf'
|
1038 |
+
# )
|
1039 |
+
|
1040 |
+
# q_tensor = encoded_queries['input_ids']
|
1041 |
+
# p_tensor = encoded_positives['input_ids']
|
1042 |
+
# n_tensor = encoded_negatives['input_ids']
|
1043 |
+
|
1044 |
+
# logger.info(f"Tokenized and padded sequences.")
|
1045 |
+
|
1046 |
+
# return q_tensor, p_tensor, n_tensor
|
1047 |
+
|
1048 |
+
|
1049 |
+
# # TODO: consider removal
|
1050 |
+
# def hard_negative_sampling(self, positive_response, n_samples=1):
|
1051 |
+
# """Select hard negatives based on cosine similarity."""
|
1052 |
+
# try:
|
1053 |
+
# # Ensure we don't request more negatives than available
|
1054 |
+
# max_neg_samples = len(self.response_pool) - 1 # Exclude the positive response
|
1055 |
+
# n_samples = min(n_samples, max_neg_samples)
|
1056 |
+
|
1057 |
+
# if n_samples <= 0:
|
1058 |
+
# logger.error("Not enough responses to sample negatives.")
|
1059 |
+
# return []
|
1060 |
+
|
1061 |
+
# # Encode the positive response using the chatbot's encode_responses method
|
1062 |
+
# pos_emb = self.encode_responses([positive_response]).numpy()
|
1063 |
+
# faiss.normalize_L2(pos_emb)
|
1064 |
+
# #logger.info(f"Normalized positive embedding for response: {positive_response}")
|
1065 |
+
|
1066 |
+
# # Search for the top n_samples + 1 most similar responses (including the positive itself)
|
1067 |
+
# D, I = self.index.search(pos_emb, n_samples + 1)
|
1068 |
+
# #logger.info(f"FAISS search results: {I}")
|
1069 |
+
|
1070 |
+
# # Exclude the positive response itself (assuming it's indexed)
|
1071 |
+
# negatives = []
|
1072 |
+
# for i in range(n_samples):
|
1073 |
+
# idx = I[0][i + 1] # Skip the first one as it's the positive
|
1074 |
+
# if idx < len(self.response_pool):
|
1075 |
+
# negative_response = self.response_pool[idx]
|
1076 |
+
# negatives.append(negative_response)
|
1077 |
+
# logger.info(f"Selected negative: {negative_response}")
|
1078 |
+
# else:
|
1079 |
+
# logger.warning(f"Index {idx} out of range for response_pool with size {len(self.response_pool)}.")
|
1080 |
+
|
1081 |
+
# return negatives
|
1082 |
+
# except Exception as e:
|
1083 |
+
# logger.error(f"An error occurred during hard negative sampling: {e}")
|
1084 |
+
# return []
|
1085 |
+
|
1086 |
+
# def train(
|
1087 |
+
# self,
|
1088 |
+
# q_pad: tf.Tensor,
|
1089 |
+
# p_pad: tf.Tensor,
|
1090 |
+
# n_pad: tf.Tensor,
|
1091 |
+
# epochs: int,
|
1092 |
+
# batch_size: int,
|
1093 |
+
# validation_split: float,
|
1094 |
+
# checkpoint_dir: str,
|
1095 |
+
# callbacks: Optional[List[tf.keras.callbacks.Callback]] = None
|
1096 |
+
# ):
|
1097 |
+
# """
|
1098 |
+
# Train the chatbot model.
|
1099 |
+
|
1100 |
+
# Args:
|
1101 |
+
# q_pad (tf.Tensor): Padded query input_ids.
|
1102 |
+
# p_pad (tf.Tensor): Padded positive response input_ids.
|
1103 |
+
# n_pad (tf.Tensor): Padded negative response input_ids.
|
1104 |
+
# epochs (int): Number of training epochs.
|
1105 |
+
# batch_size (int): Training batch size.
|
1106 |
+
# validation_split (float): Fraction of data to use for validation.
|
1107 |
+
# checkpoint_dir (str): Directory to save model checkpoints.
|
1108 |
+
# callbacks (list, optional): List of Keras callbacks.
|
1109 |
+
# """
|
1110 |
+
# dataset_size = tf.shape(q_pad)[0].numpy()
|
1111 |
+
# val_size = int(dataset_size * validation_split)
|
1112 |
+
# train_size = dataset_size - val_size
|
1113 |
+
|
1114 |
+
# logger.info(f"Total samples: {dataset_size}")
|
1115 |
+
# logger.info(f"Training samples: {train_size}")
|
1116 |
+
# logger.info(f"Validation samples: {val_size}")
|
1117 |
+
|
1118 |
+
# # Calculate steps_per_epoch
|
1119 |
+
# steps_per_epoch = train_size // batch_size
|
1120 |
+
# if train_size % batch_size != 0:
|
1121 |
+
# steps_per_epoch += 1
|
1122 |
+
# total_steps = steps_per_epoch * epochs
|
1123 |
+
|
1124 |
+
# logger.info(f"Total training steps: {total_steps}")
|
1125 |
+
|
1126 |
+
# # Initialize learning rate schedule with adjusted warmup_steps
|
1127 |
+
# lr_schedule = self._get_lr_schedule(
|
1128 |
+
# total_steps=total_steps,
|
1129 |
+
# peak_lr=self.config.learning_rate,
|
1130 |
+
# warmup_steps=self.config.warmup_steps
|
1131 |
+
# )
|
1132 |
+
|
1133 |
+
# # callbacks = []
|
1134 |
+
# # if checkpoint_dir:
|
1135 |
+
# # checkpoint_dir = Path(checkpoint_dir)
|
1136 |
+
# # checkpoint_dir.mkdir(parents=True, exist_ok=True)
|
1137 |
+
|
1138 |
+
# # # Setup checkpoint callback with correct file format
|
1139 |
+
# # checkpoint_template = str(checkpoint_dir / "model_epoch_{epoch:04d}.weights.h5")
|
1140 |
+
# # checkpoint_callback = tf.keras.callbacks.ModelCheckpoint(
|
1141 |
+
# # checkpoint_template,
|
1142 |
+
# # save_weights_only=True,
|
1143 |
+
# # save_best_only=True,
|
1144 |
+
# # monitor='val_loss',
|
1145 |
+
# # mode='min',
|
1146 |
+
# # verbose=1
|
1147 |
+
# # )
|
1148 |
+
# # callbacks.append(checkpoint_callback)
|
1149 |
+
|
1150 |
+
# # # Early stopping callback
|
1151 |
+
# # early_stopping = tf.keras.callbacks.EarlyStopping(
|
1152 |
+
# # monitor='val_loss',
|
1153 |
+
# # patience=5,
|
1154 |
+
# # restore_best_weights=True,
|
1155 |
+
# # verbose=1
|
1156 |
+
# # )
|
1157 |
+
# # callbacks.append(early_stopping)
|
1158 |
+
|
1159 |
+
# # # TensorBoard callback
|
1160 |
+
# # tensorboard_callback = tf.keras.callbacks.TensorBoard(log_dir='./logs')
|
1161 |
+
# # callbacks.append(tensorboard_callback)
|
1162 |
+
|
1163 |
+
# # Update optimizer with the new learning rate schedule
|
1164 |
+
# self.optimizer = tf.keras.optimizers.Adam(learning_rate=lr_schedule)
|
1165 |
+
|
1166 |
+
# # Split the data
|
1167 |
+
# train_q = q_pad[:train_size]
|
1168 |
+
# train_p = p_pad[:train_size]
|
1169 |
+
# train_n = n_pad[:train_size]
|
1170 |
+
|
1171 |
+
# val_q = q_pad[train_size:]
|
1172 |
+
# val_p = p_pad[train_size:]
|
1173 |
+
# val_n = n_pad[train_size:]
|
1174 |
+
|
1175 |
+
# # Create TensorFlow datasets
|
1176 |
+
# train_dataset = tf.data.Dataset.from_tensor_slices((train_q, train_p, train_n))
|
1177 |
+
# train_dataset = train_dataset.shuffle(buffer_size=1000).batch(batch_size)
|
1178 |
+
|
1179 |
+
# val_dataset = tf.data.Dataset.from_tensor_slices((val_q, val_p, val_n))
|
1180 |
+
# val_dataset = val_dataset.batch(batch_size)
|
1181 |
+
|
1182 |
+
# # Log dataset sizes
|
1183 |
+
# logger.info(f"Training dataset batches: {len(list(train_dataset))}")
|
1184 |
+
# logger.info(f"Validation dataset batches: {len(list(val_dataset))}")
|
1185 |
+
|
1186 |
+
# # Create checkpoint manager
|
1187 |
+
# checkpoint = tf.train.Checkpoint(optimizer=self.optimizer, model=self.encoder)
|
1188 |
+
# manager = tf.train.CheckpointManager(checkpoint, checkpoint_dir, max_to_keep=3)
|
1189 |
+
|
1190 |
+
# for epoch in range(1, epochs + 1):
|
1191 |
+
# logger.info(f"Epoch {epoch}/{epochs}")
|
1192 |
+
# epoch_loss_avg = tf.keras.metrics.Mean()
|
1193 |
+
|
1194 |
+
# # Training loop
|
1195 |
+
# for q_batch, p_batch, n_batch in train_dataset:
|
1196 |
+
# batch_loss = self._train_step(q_batch, p_batch, n_batch)
|
1197 |
+
# epoch_loss_avg(batch_loss)
|
1198 |
+
|
1199 |
+
# # Validation loop
|
1200 |
+
# val_loss_avg = tf.keras.metrics.Mean()
|
1201 |
+
# try:
|
1202 |
+
# for q_val, p_val, n_val in val_dataset:
|
1203 |
+
# # Encode queries, positives, and negatives without training
|
1204 |
+
# q_enc = self.encoder(q_val, training=False)
|
1205 |
+
# p_enc = self.encoder(p_val, training=False)
|
1206 |
+
# n_enc = self.encoder(n_val, training=False)
|
1207 |
+
|
1208 |
+
# # Compute cosine similarities
|
1209 |
+
# pos_sim = tf.reduce_sum(tf.multiply(q_enc, p_enc), axis=1)
|
1210 |
+
# neg_sim = tf.reduce_sum(tf.multiply(q_enc, n_enc), axis=1)
|
1211 |
+
|
1212 |
+
# # Ensure similarities are float32
|
1213 |
+
# pos_sim = tf.cast(pos_sim, tf.float32)
|
1214 |
+
# neg_sim = tf.cast(neg_sim, tf.float32)
|
1215 |
+
|
1216 |
+
# # Compute loss with margin
|
1217 |
+
# margin = tf.cast(self.config.margin, tf.float32)
|
1218 |
+
# loss = tf.maximum(0.0, margin - pos_sim + neg_sim)
|
1219 |
+
|
1220 |
+
# val_loss_avg(tf.reduce_mean(loss))
|
1221 |
+
|
1222 |
+
# # Optional: Log individual batch validation loss
|
1223 |
+
# logger.debug(f"Batch Validation Loss: {tf.reduce_mean(loss).numpy():.6f}")
|
1224 |
+
|
1225 |
+
# train_loss = epoch_loss_avg.result().numpy()
|
1226 |
+
# val_loss = val_loss_avg.result().numpy()
|
1227 |
+
# logger.info(f"Epoch {epoch}: Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}")
|
1228 |
+
|
1229 |
+
# # Save checkpoint
|
1230 |
+
# manager.save()
|
1231 |
+
|
1232 |
+
# # Update history
|
1233 |
+
# self.history['train_loss'].append(train_loss)
|
1234 |
+
# self.history['val_loss'].append(val_loss)
|
1235 |
+
|
1236 |
+
# # Invoke callbacks if any
|
1237 |
+
# if callbacks:
|
1238 |
+
# for callback in callbacks:
|
1239 |
+
# callback.on_epoch_end(epoch, logs={'loss': train_loss, 'val_loss': val_loss})
|
1240 |
+
|
1241 |
+
# except tf.errors.OutOfRangeError:
|
1242 |
+
# logger.warning("Validation dataset is exhausted before expected.")
|
1243 |
+
# self.history['val_loss'].append(val_loss_avg.result().numpy())
|
1244 |
+
|
1245 |
+
# logger.info("Training completed.")
|
1246 |
+
|
1247 |
+
# @tf.function
|
1248 |
+
# def _train_step(self, q_batch, p_batch, n_batch):
|
1249 |
+
# """
|
1250 |
+
# Performs a single training step with query, positive, and negative batches.
|
1251 |
+
|
1252 |
+
# Args:
|
1253 |
+
# q_batch (tf.Tensor): Batch of query input_ids.
|
1254 |
+
# p_batch (tf.Tensor): Batch of positive response input_ids.
|
1255 |
+
# n_batch (tf.Tensor): Batch of negative response input_ids.
|
1256 |
+
|
1257 |
+
# Returns:
|
1258 |
+
# tf.Tensor: Mean loss for the batch.
|
1259 |
+
# """
|
1260 |
+
# with tf.GradientTape() as tape:
|
1261 |
+
# # Encode queries, positives, and negatives using the shared encoder
|
1262 |
+
# q_enc = self.encoder(q_batch, training=True) # Shape: (batch_size, embedding_dim)
|
1263 |
+
# p_enc = self.encoder(p_batch, training=True) # Shape: (batch_size, embedding_dim)
|
1264 |
+
# n_enc = self.encoder(n_batch, training=True) # Shape: (batch_size, embedding_dim)
|
1265 |
+
|
1266 |
+
# # Compute cosine similarities
|
1267 |
+
# pos_sim = tf.reduce_sum(tf.multiply(q_enc, p_enc), axis=1) # Shape: (batch_size,)
|
1268 |
+
# neg_sim = tf.reduce_sum(tf.multiply(q_enc, n_enc), axis=1) # Shape: (batch_size,)
|
1269 |
+
|
1270 |
+
# # Ensure similarities are float32
|
1271 |
+
# pos_sim = tf.cast(pos_sim, tf.float32)
|
1272 |
+
# neg_sim = tf.cast(neg_sim, tf.float32)
|
1273 |
+
|
1274 |
+
# # Compute loss with margin
|
1275 |
+
# margin = tf.cast(self.config.margin, tf.float32)
|
1276 |
+
# loss = tf.maximum(0.0, margin - pos_sim + neg_sim)
|
1277 |
+
|
1278 |
+
# # Compute gradients and update encoder weights
|
1279 |
+
# gradients = tape.gradient(loss, self.encoder.pretrained.trainable_variables)
|
1280 |
+
|
1281 |
+
# # Filter out None gradients (if any)
|
1282 |
+
# grads_and_vars = [
|
1283 |
+
# (g, v) for g, v in zip(gradients, self.encoder.pretrained.trainable_variables)
|
1284 |
+
# if g is not None
|
1285 |
+
# ]
|
1286 |
+
|
1287 |
+
# if grads_and_vars:
|
1288 |
+
# self.optimizer.apply_gradients(grads_and_vars)
|
1289 |
+
|
1290 |
+
# # Return mean loss
|
1291 |
+
# return tf.reduce_mean(loss)
|
dialogue_augmenter.py
CHANGED
@@ -3,11 +3,9 @@ import numpy as np
|
|
3 |
import torch
|
4 |
import tensorflow as tf
|
5 |
import tensorflow_hub as hub
|
6 |
-
import re
|
7 |
from pipeline_config import PipelineConfig
|
8 |
from quality_metrics import QualityMetrics
|
9 |
from paraphraser import Paraphraser
|
10 |
-
from back_translator import BackTranslator
|
11 |
import nlpaug.augmenter.word as naw
|
12 |
from concurrent.futures import ThreadPoolExecutor
|
13 |
from functools import lru_cache
|
@@ -29,9 +27,12 @@ class DialogueAugmenter:
|
|
29 |
print(f"Using device: {self.device}")
|
30 |
if self.use_gpu:
|
31 |
print(f"GPU Device: {torch.cuda.get_device_name(0)}")
|
|
|
32 |
|
33 |
-
# Load base models
|
34 |
self.quality_metrics = QualityMetrics(config)
|
|
|
|
|
|
|
35 |
self.use_model = hub.load('https://tfhub.dev/google/universal-sentence-encoder/4')
|
36 |
|
37 |
# Initialize augmentation models based on hardware
|
@@ -39,10 +40,6 @@ class DialogueAugmenter:
|
|
39 |
|
40 |
# Initialize caches
|
41 |
self.embedding_cache = {}
|
42 |
-
self.perplexity_cache = {}
|
43 |
-
|
44 |
-
# Compile regex patterns
|
45 |
-
self.spelling_pattern = re.compile(r'[a-zA-Z]{3,}')
|
46 |
|
47 |
# GPU memory management if available
|
48 |
if self.use_gpu:
|
@@ -57,25 +54,20 @@ class DialogueAugmenter:
|
|
57 |
def _initialize_augmentation_models(self):
|
58 |
"""Initialize augmentation models with appropriate device settings"""
|
59 |
# Advanced augmentation techniques
|
60 |
-
self.paraphraser = Paraphraser()
|
61 |
-
self.back_translator = BackTranslator()
|
62 |
-
|
63 |
if self.use_gpu:
|
64 |
-
# Move
|
65 |
self.paraphraser.model = self.paraphraser.model.to(self.device)
|
66 |
-
self.back_translator.model_pivot_forward = self.back_translator.model_pivot_forward.to(self.device)
|
67 |
-
self.back_translator.model_pivot_backward = self.back_translator.model_pivot_backward.to(self.device)
|
68 |
-
self.back_translator.model_backward = self.back_translator.model_backward.to(self.device)
|
69 |
|
70 |
# Basic augmentation techniques
|
71 |
self.word_augmenter = naw.SynonymAug(aug_src='wordnet')
|
72 |
-
self.spelling_augmenter = naw.SpellingAug()
|
73 |
|
74 |
self.augmenters = {
|
75 |
-
'advanced': [
|
|
|
|
|
76 |
'basic': [
|
77 |
('synonym', self.word_augmenter),
|
78 |
-
('spelling', self.spelling_augmenter)
|
79 |
]
|
80 |
}
|
81 |
|
@@ -103,52 +95,46 @@ class DialogueAugmenter:
|
|
103 |
|
104 |
def _quick_quality_check(self, variation: str, original: str) -> bool:
|
105 |
"""
|
106 |
-
|
107 |
"""
|
108 |
if self.config.debug:
|
109 |
print(f"\nQuick check for variation: {variation}")
|
110 |
|
111 |
-
# Stricter length check
|
112 |
orig_len = len(original.split())
|
113 |
var_len = len(variation.split())
|
114 |
-
|
115 |
-
# For very short texts (
|
116 |
if orig_len <= 3:
|
117 |
-
if var_len > orig_len * 3:
|
118 |
if self.config.debug:
|
119 |
print(f"Failed length check (short text): {var_len} vs {orig_len}")
|
120 |
return False
|
121 |
else:
|
122 |
-
if var_len > orig_len * 2:
|
123 |
if self.config.debug:
|
124 |
print(f"Failed length check (long text): {var_len} vs {orig_len}")
|
125 |
return False
|
126 |
-
|
127 |
-
#
|
128 |
stop_words = {'a', 'an', 'the', 'and', 'or', 'but', 'in', 'on', 'at', 'to', 'for', 'is', 'are', 'that', 'this', 'will', 'can'}
|
129 |
orig_words = set(w.lower() for w in original.split() if w.lower() not in stop_words)
|
130 |
var_words = set(w.lower() for w in variation.split() if w.lower() not in stop_words)
|
131 |
-
|
132 |
-
#
|
133 |
-
|
134 |
-
|
|
|
|
|
|
|
|
|
|
|
135 |
if self.config.debug:
|
136 |
-
print(
|
137 |
-
|
138 |
-
|
139 |
if self.config.debug:
|
140 |
print("Passed all quick checks")
|
141 |
return True
|
142 |
|
143 |
-
def _compute_metrics_parallel(self, original: str, candidates: List[str]) -> List[Dict[str, float]]:
|
144 |
-
"""Compute quality metrics for multiple candidates in parallel"""
|
145 |
-
with ThreadPoolExecutor(max_workers=4) as executor:
|
146 |
-
futures = [
|
147 |
-
executor.submit(self.quality_metrics.compute_metrics, original, candidate)
|
148 |
-
for candidate in candidates
|
149 |
-
]
|
150 |
-
return [future.result() for future in futures]
|
151 |
-
|
152 |
def _filter_variations_batch(self, variations: List[str], context: List[str], original_turn: str) -> List[str]:
|
153 |
"""
|
154 |
Filter variations using batched computations with detailed logging
|
@@ -162,12 +148,17 @@ class DialogueAugmenter:
|
|
162 |
print(f"Original turn: {original_turn}")
|
163 |
|
164 |
words = original_turn.split()
|
|
|
|
|
|
|
|
|
|
|
165 |
if len(words) < 3:
|
166 |
if self.config.debug:
|
167 |
print("Short text detected, using predefined variations")
|
168 |
short_text_variations = self._augment_short_text({'text': original_turn, 'speaker': ''})
|
169 |
return [var['text'] for var in short_text_variations]
|
170 |
-
|
171 |
# If this is the first turn (no context), be more lenient
|
172 |
if not context:
|
173 |
preliminary_filtered = variations
|
@@ -183,57 +174,85 @@ class DialogueAugmenter:
|
|
183 |
print(f"Passed quick check: {passed}")
|
184 |
if passed:
|
185 |
preliminary_filtered.append(var)
|
186 |
-
|
187 |
if self.config.debug:
|
188 |
print(f"Variations after quick check: {len(preliminary_filtered)}")
|
189 |
-
|
190 |
if not preliminary_filtered:
|
191 |
return []
|
192 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
193 |
# Only use last turn for coherence
|
194 |
recent_context = [context[-1]] if context else []
|
195 |
context_text = ' '.join(recent_context) if recent_context else ''
|
196 |
-
|
197 |
-
# Even more lenient thresholds
|
198 |
-
min_similarity = 0.1 # Further reduced
|
199 |
-
min_coherence = 0.05 # Further reduced
|
200 |
-
|
201 |
if context_text:
|
202 |
if self.config.debug:
|
203 |
print(f"\nContext text: {context_text}")
|
204 |
-
|
205 |
-
all_texts = [context_text] +
|
206 |
all_embeddings = self._compute_batch_embeddings(all_texts)
|
207 |
-
|
208 |
context_embedding = all_embeddings[0]
|
209 |
variation_embeddings = all_embeddings[1:]
|
210 |
-
|
211 |
# Vectorized similarity computation
|
212 |
context_similarities = cosine_similarity([context_embedding], variation_embeddings)[0]
|
213 |
-
|
214 |
# Response coherence check
|
215 |
if recent_context:
|
216 |
prev_embedding = self._compute_embedding(recent_context[-1])
|
217 |
response_coherence = cosine_similarity([prev_embedding], variation_embeddings)[0]
|
218 |
else:
|
219 |
response_coherence = np.ones_like(context_similarities)
|
220 |
-
|
221 |
-
# Combined scoring with detailed logging
|
222 |
filtered_variations = []
|
223 |
for i, (variation, sim, coh) in enumerate(zip(
|
224 |
-
|
225 |
-
# Use absolute values for scoring
|
226 |
combined_score = (
|
227 |
self.config.context_similarity_weight * abs(sim) +
|
228 |
self.config.response_coherence_weight * abs(coh)
|
229 |
)
|
230 |
-
|
231 |
if self.config.debug:
|
232 |
print(f"\nVariation: {variation}")
|
233 |
print(f"Context similarity: {sim:.3f}")
|
234 |
print(f"Response coherence: {coh:.3f}")
|
235 |
print(f"Combined score: {combined_score:.3f}")
|
236 |
-
|
237 |
# Accept if EITHER score is good enough
|
238 |
if (combined_score >= min_similarity or abs(coh) >= min_coherence):
|
239 |
filtered_variations.append(variation)
|
@@ -242,74 +261,71 @@ class DialogueAugmenter:
|
|
242 |
else:
|
243 |
if self.config.debug:
|
244 |
print("REJECTED")
|
245 |
-
|
246 |
# If we have enough variations, stop
|
247 |
if len(filtered_variations) >= self.config.max_variations_per_turn:
|
248 |
break
|
249 |
else:
|
250 |
-
filtered_variations =
|
251 |
-
|
252 |
if self.config.debug:
|
253 |
print(f"\nFinal filtered variations: {len(filtered_variations)}")
|
254 |
-
|
255 |
return filtered_variations
|
256 |
|
257 |
def _generate_variations_progressive(self, text: str, needed: int) -> List[str]:
|
258 |
"""
|
259 |
-
Generate variations progressively until we have enough good ones
|
|
|
260 |
"""
|
261 |
variations = set()
|
262 |
-
|
263 |
if self.config.debug:
|
264 |
print(f"\nAttempting to generate {needed} variations for text: {text}")
|
265 |
-
|
266 |
-
#
|
267 |
for augmenter in self.augmenters['advanced']:
|
268 |
if len(variations) >= needed:
|
269 |
break
|
270 |
-
|
271 |
try:
|
272 |
if isinstance(augmenter, Paraphraser):
|
273 |
if self.config.debug:
|
274 |
print("Trying paraphrase augmentation...")
|
275 |
-
new_vars = augmenter.paraphrase(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
276 |
if self.config.debug:
|
277 |
print(f"Paraphraser generated {len(new_vars)} variations")
|
278 |
-
|
279 |
-
if self.config.debug:
|
280 |
-
print("Trying back translation...")
|
281 |
-
new_vars = [augmenter.back_translate(text)]
|
282 |
-
if self.config.debug:
|
283 |
-
print(f"Back translator generated {len(new_vars)} variations")
|
284 |
-
|
285 |
valid_vars = [v for v in new_vars if v.strip() and v != text]
|
286 |
variations.update(valid_vars)
|
287 |
-
|
288 |
if self.config.debug:
|
289 |
print(f"Current unique variations: {len(variations)}")
|
290 |
-
|
291 |
except Exception as e:
|
292 |
print(f"Error in advanced augmentation: {str(e)}")
|
293 |
continue
|
294 |
-
|
295 |
# Try basic augmenters if needed
|
296 |
if len(variations) < needed:
|
297 |
if self.config.debug:
|
298 |
print("Not enough variations, trying basic augmenters...")
|
299 |
-
|
300 |
for aug_type, augmenter in self.augmenters['basic']:
|
301 |
if len(variations) >= needed:
|
302 |
break
|
303 |
-
|
304 |
try:
|
305 |
-
if aug_type == 'spelling' and self._is_technical_or_formal_text(text):
|
306 |
-
if self.config.debug:
|
307 |
-
print("Skipping spelling augmentation for technical text")
|
308 |
-
continue
|
309 |
-
|
310 |
if self.config.debug:
|
311 |
print(f"Trying {aug_type} augmentation...")
|
312 |
-
|
313 |
new_vars = augmenter.augment(text, n=2)
|
314 |
if isinstance(new_vars, list):
|
315 |
valid_vars = [v for v in new_vars if v.strip() and v != text]
|
@@ -317,21 +333,21 @@ class DialogueAugmenter:
|
|
317 |
else:
|
318 |
if new_vars.strip() and new_vars != text:
|
319 |
variations.add(new_vars)
|
320 |
-
|
321 |
if self.config.debug:
|
322 |
print(f"After {aug_type}, total variations: {len(variations)}")
|
323 |
-
|
324 |
except Exception as e:
|
325 |
print(f"Error in {aug_type} augmentation: {str(e)}")
|
326 |
continue
|
327 |
-
|
328 |
variations_list = list(variations)
|
329 |
-
|
330 |
if self.config.debug:
|
331 |
print(f"Final number of variations generated: {len(variations_list)}")
|
332 |
if not variations_list:
|
333 |
print("WARNING: No variations were generated!")
|
334 |
-
|
335 |
return variations_list
|
336 |
|
337 |
def augment_dialogue(self, dialogue: Dict) -> List[Dict]:
|
@@ -375,7 +391,8 @@ class DialogueAugmenter:
|
|
375 |
# Generate combinations with sampling
|
376 |
augmented_dialogues = self._generate_dialogue_combinations(
|
377 |
dialogue['dialogue_id'],
|
378 |
-
turn_variations
|
|
|
379 |
)
|
380 |
|
381 |
# Add original dialogue
|
@@ -392,47 +409,201 @@ class DialogueAugmenter:
|
|
392 |
|
393 |
return result
|
394 |
|
395 |
-
def
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
396 |
"""
|
397 |
-
|
|
|
|
|
398 |
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
399 |
augmented_dialogues = []
|
400 |
used_combinations = set()
|
401 |
-
|
402 |
-
def
|
403 |
if current_turns is None:
|
404 |
current_turns = []
|
405 |
|
406 |
-
if len(augmented_dialogues) >=
|
407 |
return
|
408 |
|
409 |
if turn_index == len(turn_variations):
|
|
|
410 |
dialogue_fingerprint = " | ".join(turn['text'] for turn in current_turns)
|
411 |
if dialogue_fingerprint not in used_combinations:
|
412 |
used_combinations.add(dialogue_fingerprint)
|
413 |
-
|
414 |
-
|
415 |
-
|
416 |
-
|
|
|
|
|
|
|
|
|
|
|
417 |
return
|
418 |
-
|
419 |
-
|
420 |
-
|
421 |
-
|
422 |
-
|
423 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
424 |
return
|
425 |
current_turns.append(variation)
|
426 |
-
|
427 |
current_turns.pop()
|
428 |
-
|
429 |
try:
|
430 |
-
|
431 |
except Exception as e:
|
432 |
print(f"Error in dialogue generation: {str(e)}")
|
433 |
return []
|
434 |
-
|
435 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
436 |
|
437 |
def _is_dialogue_duplicate(self, dialogue1: Dict, dialogue2: Dict) -> bool:
|
438 |
"""
|
@@ -445,11 +616,9 @@ class DialogueAugmenter:
|
|
445 |
def _augment_short_text(self, turn: Dict) -> List[Dict]:
|
446 |
"""
|
447 |
Special handling for very short texts with predefined variations.
|
448 |
-
|
449 |
-
|
450 |
-
|
451 |
-
Returns:
|
452 |
-
List[Dict]: List of variations for the short text
|
453 |
"""
|
454 |
text = turn['text']
|
455 |
common_variations = {
|
@@ -483,71 +652,60 @@ class DialogueAugmenter:
|
|
483 |
'Fantastic!', 'Amazing!', 'Terrific!'
|
484 |
]
|
485 |
}
|
486 |
-
|
487 |
-
# Try to find matching variations
|
488 |
text_lower = text.lower().rstrip('!.,?')
|
|
|
489 |
variations = []
|
490 |
-
|
491 |
-
# Check if text matches any of our predefined categories
|
492 |
for key, predefined_vars in common_variations.items():
|
493 |
if key in text_lower or text_lower in key:
|
494 |
variations.extend(predefined_vars)
|
495 |
|
496 |
-
# If no predefined variations found, generate simple variants
|
497 |
if not variations:
|
498 |
-
#
|
|
|
499 |
variations = [
|
500 |
-
|
501 |
-
|
502 |
-
|
503 |
]
|
504 |
|
505 |
# Add capitalization variations
|
506 |
-
variations.
|
507 |
-
|
508 |
-
if v.capitalize() not in variations
|
509 |
-
])
|
510 |
|
511 |
-
#
|
512 |
unique_variations = list(set(variations))
|
513 |
-
quality_variations = []
|
514 |
-
|
515 |
-
for var in unique_variations:
|
516 |
-
metrics = self.quality_metrics.compute_metrics(text, var)
|
517 |
-
quality_score = (
|
518 |
-
0.35 * metrics['semantic_similarity'] +
|
519 |
-
0.30 * (1.0 - metrics['perplexity'] / 100) +
|
520 |
-
0.15 * (1.0 - metrics['grammar_errors'] / 10) +
|
521 |
-
0.15 * metrics['content_preservation'] +
|
522 |
-
0.10 * metrics['type_token_ratio']
|
523 |
-
)
|
524 |
-
|
525 |
-
# More lenient quality threshold for short texts
|
526 |
-
if quality_score >= 0.5: # Lower threshold for short texts
|
527 |
-
quality_variations.append(var)
|
528 |
-
|
529 |
-
# Ensure we have at least some variations
|
530 |
-
if not quality_variations:
|
531 |
-
quality_variations = [text]
|
532 |
|
533 |
-
#
|
534 |
-
|
|
|
|
|
535 |
|
536 |
-
def
|
537 |
-
"""
|
538 |
-
|
539 |
-
|
540 |
-
|
541 |
-
|
542 |
-
|
543 |
-
|
544 |
-
|
545 |
-
|
546 |
-
|
547 |
-
|
|
|
|
|
|
|
|
|
|
|
548 |
|
549 |
-
|
550 |
-
|
551 |
-
|
|
|
|
|
|
|
|
|
|
|
552 |
|
553 |
-
return
|
|
|
3 |
import torch
|
4 |
import tensorflow as tf
|
5 |
import tensorflow_hub as hub
|
|
|
6 |
from pipeline_config import PipelineConfig
|
7 |
from quality_metrics import QualityMetrics
|
8 |
from paraphraser import Paraphraser
|
|
|
9 |
import nlpaug.augmenter.word as naw
|
10 |
from concurrent.futures import ThreadPoolExecutor
|
11 |
from functools import lru_cache
|
|
|
27 |
print(f"Using device: {self.device}")
|
28 |
if self.use_gpu:
|
29 |
print(f"GPU Device: {torch.cuda.get_device_name(0)}")
|
30 |
+
|
31 |
|
|
|
32 |
self.quality_metrics = QualityMetrics(config)
|
33 |
+
self.semantic_similarity_threshold = 0.75
|
34 |
+
|
35 |
+
# Load model
|
36 |
self.use_model = hub.load('https://tfhub.dev/google/universal-sentence-encoder/4')
|
37 |
|
38 |
# Initialize augmentation models based on hardware
|
|
|
40 |
|
41 |
# Initialize caches
|
42 |
self.embedding_cache = {}
|
|
|
|
|
|
|
|
|
43 |
|
44 |
# GPU memory management if available
|
45 |
if self.use_gpu:
|
|
|
54 |
def _initialize_augmentation_models(self):
|
55 |
"""Initialize augmentation models with appropriate device settings"""
|
56 |
# Advanced augmentation techniques
|
57 |
+
self.paraphraser = Paraphraser()
|
|
|
|
|
58 |
if self.use_gpu:
|
59 |
+
# Move model to GPU if available
|
60 |
self.paraphraser.model = self.paraphraser.model.to(self.device)
|
|
|
|
|
|
|
61 |
|
62 |
# Basic augmentation techniques
|
63 |
self.word_augmenter = naw.SynonymAug(aug_src='wordnet')
|
|
|
64 |
|
65 |
self.augmenters = {
|
66 |
+
'advanced': [
|
67 |
+
self.paraphraser,
|
68 |
+
],
|
69 |
'basic': [
|
70 |
('synonym', self.word_augmenter),
|
|
|
71 |
]
|
72 |
}
|
73 |
|
|
|
95 |
|
96 |
def _quick_quality_check(self, variation: str, original: str) -> bool:
|
97 |
"""
|
98 |
+
Preliminary quality check while maintaining reasonable pass rates
|
99 |
"""
|
100 |
if self.config.debug:
|
101 |
print(f"\nQuick check for variation: {variation}")
|
102 |
|
|
|
103 |
orig_len = len(original.split())
|
104 |
var_len = len(variation.split())
|
105 |
+
|
106 |
+
# For very short texts (<= 3 words), still allow more variation
|
107 |
if orig_len <= 3:
|
108 |
+
if var_len > orig_len * 3:
|
109 |
if self.config.debug:
|
110 |
print(f"Failed length check (short text): {var_len} vs {orig_len}")
|
111 |
return False
|
112 |
else:
|
113 |
+
if var_len > orig_len * 2:
|
114 |
if self.config.debug:
|
115 |
print(f"Failed length check (long text): {var_len} vs {orig_len}")
|
116 |
return False
|
117 |
+
|
118 |
+
# Adjust content overlap check based on length
|
119 |
stop_words = {'a', 'an', 'the', 'and', 'or', 'but', 'in', 'on', 'at', 'to', 'for', 'is', 'are', 'that', 'this', 'will', 'can'}
|
120 |
orig_words = set(w.lower() for w in original.split() if w.lower() not in stop_words)
|
121 |
var_words = set(w.lower() for w in variation.split() if w.lower() not in stop_words)
|
122 |
+
|
123 |
+
# If very short turn (less than 5 words), skip the content overlap check
|
124 |
+
if orig_len >= 5:
|
125 |
+
content_overlap = len(orig_words.intersection(var_words)) / len(orig_words) if orig_words else 0
|
126 |
+
if content_overlap < 0.2:
|
127 |
+
if self.config.debug:
|
128 |
+
print(f"Failed content check: overlap {content_overlap:.2f}")
|
129 |
+
return False
|
130 |
+
else:
|
131 |
if self.config.debug:
|
132 |
+
print("Short turn detected (<5 words), skipping content overlap check")
|
133 |
+
|
|
|
134 |
if self.config.debug:
|
135 |
print("Passed all quick checks")
|
136 |
return True
|
137 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
138 |
def _filter_variations_batch(self, variations: List[str], context: List[str], original_turn: str) -> List[str]:
|
139 |
"""
|
140 |
Filter variations using batched computations with detailed logging
|
|
|
148 |
print(f"Original turn: {original_turn}")
|
149 |
|
150 |
words = original_turn.split()
|
151 |
+
orig_len = len(words)
|
152 |
+
|
153 |
+
# If very short text, consider adjusting thresholds
|
154 |
+
is_very_short = orig_len < 5
|
155 |
+
|
156 |
if len(words) < 3:
|
157 |
if self.config.debug:
|
158 |
print("Short text detected, using predefined variations")
|
159 |
short_text_variations = self._augment_short_text({'text': original_turn, 'speaker': ''})
|
160 |
return [var['text'] for var in short_text_variations]
|
161 |
+
|
162 |
# If this is the first turn (no context), be more lenient
|
163 |
if not context:
|
164 |
preliminary_filtered = variations
|
|
|
174 |
print(f"Passed quick check: {passed}")
|
175 |
if passed:
|
176 |
preliminary_filtered.append(var)
|
177 |
+
|
178 |
if self.config.debug:
|
179 |
print(f"Variations after quick check: {len(preliminary_filtered)}")
|
180 |
+
|
181 |
if not preliminary_filtered:
|
182 |
return []
|
183 |
+
|
184 |
+
# Compute embeddings for original and variations
|
185 |
+
original_embedding = self._compute_embedding(original_turn)
|
186 |
+
variation_embeddings = self._compute_batch_embeddings(preliminary_filtered)
|
187 |
+
|
188 |
+
# Compute similarities
|
189 |
+
sims = cosine_similarity([original_embedding], variation_embeddings)[0]
|
190 |
+
|
191 |
+
# If very short turn, slightly lower the semantic similarity threshold
|
192 |
+
dynamic_sem_threshold = self.semantic_similarity_threshold
|
193 |
+
if is_very_short:
|
194 |
+
dynamic_sem_threshold = max(0.7, self.semantic_similarity_threshold - 0.05)
|
195 |
+
|
196 |
+
# Filter by semantic similarity threshold
|
197 |
+
refined_filtered = []
|
198 |
+
for var, sim in zip(preliminary_filtered, sims):
|
199 |
+
if sim >= dynamic_sem_threshold:
|
200 |
+
refined_filtered.append(var)
|
201 |
+
else:
|
202 |
+
if self.config.debug:
|
203 |
+
print(f"Variation '{var}' discarded due to low semantic similarity: {sim:.3f}")
|
204 |
+
|
205 |
+
if not refined_filtered:
|
206 |
+
return []
|
207 |
+
|
208 |
+
# Relax context coherence thresholds further if desired
|
209 |
+
# We already have min_similarity = 0.1, min_coherence = 0.05
|
210 |
+
# Let's lower them slightly more if the turn is very short:
|
211 |
+
if is_very_short:
|
212 |
+
min_similarity = 0.05
|
213 |
+
min_coherence = 0.02
|
214 |
+
else:
|
215 |
+
min_similarity = 0.1
|
216 |
+
min_coherence = 0.05
|
217 |
+
|
218 |
# Only use last turn for coherence
|
219 |
recent_context = [context[-1]] if context else []
|
220 |
context_text = ' '.join(recent_context) if recent_context else ''
|
221 |
+
|
|
|
|
|
|
|
|
|
222 |
if context_text:
|
223 |
if self.config.debug:
|
224 |
print(f"\nContext text: {context_text}")
|
225 |
+
|
226 |
+
all_texts = [context_text] + refined_filtered
|
227 |
all_embeddings = self._compute_batch_embeddings(all_texts)
|
228 |
+
|
229 |
context_embedding = all_embeddings[0]
|
230 |
variation_embeddings = all_embeddings[1:]
|
231 |
+
|
232 |
# Vectorized similarity computation
|
233 |
context_similarities = cosine_similarity([context_embedding], variation_embeddings)[0]
|
234 |
+
|
235 |
# Response coherence check
|
236 |
if recent_context:
|
237 |
prev_embedding = self._compute_embedding(recent_context[-1])
|
238 |
response_coherence = cosine_similarity([prev_embedding], variation_embeddings)[0]
|
239 |
else:
|
240 |
response_coherence = np.ones_like(context_similarities)
|
241 |
+
|
|
|
242 |
filtered_variations = []
|
243 |
for i, (variation, sim, coh) in enumerate(zip(
|
244 |
+
refined_filtered, context_similarities, response_coherence)):
|
|
|
245 |
combined_score = (
|
246 |
self.config.context_similarity_weight * abs(sim) +
|
247 |
self.config.response_coherence_weight * abs(coh)
|
248 |
)
|
249 |
+
|
250 |
if self.config.debug:
|
251 |
print(f"\nVariation: {variation}")
|
252 |
print(f"Context similarity: {sim:.3f}")
|
253 |
print(f"Response coherence: {coh:.3f}")
|
254 |
print(f"Combined score: {combined_score:.3f}")
|
255 |
+
|
256 |
# Accept if EITHER score is good enough
|
257 |
if (combined_score >= min_similarity or abs(coh) >= min_coherence):
|
258 |
filtered_variations.append(variation)
|
|
|
261 |
else:
|
262 |
if self.config.debug:
|
263 |
print("REJECTED")
|
264 |
+
|
265 |
# If we have enough variations, stop
|
266 |
if len(filtered_variations) >= self.config.max_variations_per_turn:
|
267 |
break
|
268 |
else:
|
269 |
+
filtered_variations = refined_filtered[:self.config.max_variations_per_turn]
|
270 |
+
|
271 |
if self.config.debug:
|
272 |
print(f"\nFinal filtered variations: {len(filtered_variations)}")
|
273 |
+
|
274 |
return filtered_variations
|
275 |
|
276 |
def _generate_variations_progressive(self, text: str, needed: int) -> List[str]:
|
277 |
"""
|
278 |
+
Generate variations progressively until we have enough good ones.
|
279 |
+
Adjust paraphraser parameters for closer paraphrases as needed.
|
280 |
"""
|
281 |
variations = set()
|
282 |
+
|
283 |
if self.config.debug:
|
284 |
print(f"\nAttempting to generate {needed} variations for text: {text}")
|
285 |
+
|
286 |
+
# Fine-tune paraphraser here if needed: fewer beams, less diversity already done
|
287 |
for augmenter in self.augmenters['advanced']:
|
288 |
if len(variations) >= needed:
|
289 |
break
|
290 |
+
|
291 |
try:
|
292 |
if isinstance(augmenter, Paraphraser):
|
293 |
if self.config.debug:
|
294 |
print("Trying paraphrase augmentation...")
|
295 |
+
new_vars = augmenter.paraphrase(
|
296 |
+
text,
|
297 |
+
num_return_sequences=needed-len(variations),
|
298 |
+
device=self.device if self.use_gpu else None,
|
299 |
+
num_beams=4, # even fewer beams for more faithful paraphrases
|
300 |
+
num_beam_groups=1,
|
301 |
+
diversity_penalty=0.0
|
302 |
+
)
|
303 |
if self.config.debug:
|
304 |
print(f"Paraphraser generated {len(new_vars)} variations")
|
305 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
306 |
valid_vars = [v for v in new_vars if v.strip() and v != text]
|
307 |
variations.update(valid_vars)
|
308 |
+
|
309 |
if self.config.debug:
|
310 |
print(f"Current unique variations: {len(variations)}")
|
311 |
+
|
312 |
except Exception as e:
|
313 |
print(f"Error in advanced augmentation: {str(e)}")
|
314 |
continue
|
315 |
+
|
316 |
# Try basic augmenters if needed
|
317 |
if len(variations) < needed:
|
318 |
if self.config.debug:
|
319 |
print("Not enough variations, trying basic augmenters...")
|
320 |
+
|
321 |
for aug_type, augmenter in self.augmenters['basic']:
|
322 |
if len(variations) >= needed:
|
323 |
break
|
324 |
+
|
325 |
try:
|
|
|
|
|
|
|
|
|
|
|
326 |
if self.config.debug:
|
327 |
print(f"Trying {aug_type} augmentation...")
|
328 |
+
|
329 |
new_vars = augmenter.augment(text, n=2)
|
330 |
if isinstance(new_vars, list):
|
331 |
valid_vars = [v for v in new_vars if v.strip() and v != text]
|
|
|
333 |
else:
|
334 |
if new_vars.strip() and new_vars != text:
|
335 |
variations.add(new_vars)
|
336 |
+
|
337 |
if self.config.debug:
|
338 |
print(f"After {aug_type}, total variations: {len(variations)}")
|
339 |
+
|
340 |
except Exception as e:
|
341 |
print(f"Error in {aug_type} augmentation: {str(e)}")
|
342 |
continue
|
343 |
+
|
344 |
variations_list = list(variations)
|
345 |
+
|
346 |
if self.config.debug:
|
347 |
print(f"Final number of variations generated: {len(variations_list)}")
|
348 |
if not variations_list:
|
349 |
print("WARNING: No variations were generated!")
|
350 |
+
|
351 |
return variations_list
|
352 |
|
353 |
def augment_dialogue(self, dialogue: Dict) -> List[Dict]:
|
|
|
391 |
# Generate combinations with sampling
|
392 |
augmented_dialogues = self._generate_dialogue_combinations(
|
393 |
dialogue['dialogue_id'],
|
394 |
+
turn_variations,
|
395 |
+
dialogue
|
396 |
)
|
397 |
|
398 |
# Add original dialogue
|
|
|
409 |
|
410 |
return result
|
411 |
|
412 |
+
def _variation_score(self, original: str, variation: str) -> float:
|
413 |
+
"""
|
414 |
+
Compute a single numeric score for a variation to guide selection.
|
415 |
+
You could use semantic similarity, content preservation, etc.
|
416 |
+
Higher is better.
|
417 |
+
"""
|
418 |
+
metrics = self.quality_metrics.compute_metrics(original, variation)
|
419 |
+
# Example: Primarily semantic similarity, with a slight boost for content preservation
|
420 |
+
# Adjust as needed.
|
421 |
+
score = metrics['semantic_similarity'] * 0.7 + metrics['content_preservation'] * 0.3
|
422 |
+
return score
|
423 |
+
|
424 |
+
def _dialogue_quality_score(self, dialogue: Dict, original_dialogue: Dict) -> float:
|
425 |
"""
|
426 |
+
Compute a quality score for the entire augmented dialogue.
|
427 |
+
For example, average semantic similarity of turns to the original turns.
|
428 |
+
This is done after the dialogue is formed.
|
429 |
"""
|
430 |
+
original_texts = [t['text'] for t in original_dialogue['turns']]
|
431 |
+
aug_texts = [t['text'] for t in dialogue['turns']]
|
432 |
+
|
433 |
+
# Compute semantic similarity turn-by-turn and average it
|
434 |
+
scores = []
|
435 |
+
for orig, aug in zip(original_texts, aug_texts):
|
436 |
+
# Simple semantic similarity for scoring
|
437 |
+
emb_orig = self._compute_embedding(orig)
|
438 |
+
emb_aug = self._compute_embedding(aug)
|
439 |
+
sim = (emb_orig @ emb_aug) / (np.linalg.norm(emb_orig)*np.linalg.norm(emb_aug))
|
440 |
+
scores.append(sim)
|
441 |
+
|
442 |
+
# Could also incorporate diversity checks, content overlap, etc.
|
443 |
+
return float(np.mean(scores)) if scores else 0.0
|
444 |
+
|
445 |
+
def _generate_dialogue_combinations(self, dialogue_id: str, turn_variations: List[List[Dict]], original_dialogue: Dict) -> List[Dict]:
|
446 |
+
"""
|
447 |
+
Generate dialogue combinations using a more controlled approach:
|
448 |
+
- Include the original turn as a fallback variation for each turn.
|
449 |
+
- Sort variations by a quality score.
|
450 |
+
- Ensure a balanced augmentation by requiring at least some turns to be augmented.
|
451 |
+
- Over-generate and then select top dialogues by quality.
|
452 |
+
"""
|
453 |
+
# Over-generate factor: create more candidates than needed
|
454 |
+
over_generate_factor = self.config.augmentation_factor * 2
|
455 |
+
|
456 |
+
# Add the original turn as a fallback variation for each turn if not present
|
457 |
+
for i, turn_variants in enumerate(turn_variations):
|
458 |
+
original_turn_text = None
|
459 |
+
# Check if we previously stored original turn text with a marker or just use the original dialogue
|
460 |
+
# If you previously used "|ORIGINAL|" marker, handle it here. Otherwise, just get from original_dialogue.
|
461 |
+
original_turn_text = original_dialogue['turns'][i]['text']
|
462 |
+
|
463 |
+
# Add the original turn as a variation if not already included
|
464 |
+
if not any(v['text'] == original_turn_text for v in turn_variants):
|
465 |
+
turn_variants.append({
|
466 |
+
'speaker': original_dialogue['turns'][i]['speaker'],
|
467 |
+
'text': original_turn_text
|
468 |
+
})
|
469 |
+
|
470 |
+
# Sort variations by score
|
471 |
+
original_text = original_dialogue['turns'][i]['text']
|
472 |
+
turn_variants.sort(key=lambda v: self._variation_score(original_text, v['text']), reverse=True)
|
473 |
+
|
474 |
augmented_dialogues = []
|
475 |
used_combinations = set()
|
476 |
+
|
477 |
+
def generate_candidates(current_turns=None, turn_index=0):
|
478 |
if current_turns is None:
|
479 |
current_turns = []
|
480 |
|
481 |
+
if len(augmented_dialogues) >= over_generate_factor:
|
482 |
return
|
483 |
|
484 |
if turn_index == len(turn_variations):
|
485 |
+
# Completed a candidate dialogue
|
486 |
dialogue_fingerprint = " | ".join(turn['text'] for turn in current_turns)
|
487 |
if dialogue_fingerprint not in used_combinations:
|
488 |
used_combinations.add(dialogue_fingerprint)
|
489 |
+
# Check if we have enough augmented turns
|
490 |
+
aug_count = sum(1 for orig, curr in zip(original_dialogue['turns'], current_turns)
|
491 |
+
if orig['text'] != curr['text'])
|
492 |
+
# Require at least half the turns to be augmented, for example
|
493 |
+
if aug_count >= max(1, len(turn_variations)//2):
|
494 |
+
augmented_dialogues.append({
|
495 |
+
'dialogue_id': f"{dialogue_id}_aug_{len(augmented_dialogues)}",
|
496 |
+
'turns': current_turns.copy()
|
497 |
+
})
|
498 |
return
|
499 |
+
|
500 |
+
turn_candidates = turn_variations[turn_index]
|
501 |
+
|
502 |
+
# If no variations are available for this turn, let's just return without error.
|
503 |
+
# Normally, this shouldn't happen since we always add the original turn above.
|
504 |
+
if not turn_candidates:
|
505 |
+
# If you want to at least have the original turn, add it now:
|
506 |
+
original_text = original_dialogue['turns'][turn_index]['text']
|
507 |
+
turn_candidates.append({
|
508 |
+
'speaker': original_dialogue['turns'][turn_index]['speaker'],
|
509 |
+
'text': original_text
|
510 |
+
})
|
511 |
+
|
512 |
+
# After the fallback, if still empty for some reason, just return.
|
513 |
+
if not turn_candidates:
|
514 |
+
return
|
515 |
+
|
516 |
+
# Example strategy:
|
517 |
+
# 1. Always try the top variation (most semantically similar).
|
518 |
+
# 2. If available and allowed, pick a mid-ranked variation for diversity.
|
519 |
+
# 3. Include the original turn if not selected yet.
|
520 |
+
|
521 |
+
num_vars = min(self.config.max_sampled_variations, len(turn_candidates))
|
522 |
+
|
523 |
+
# Always include top variation
|
524 |
+
candidates_to_pick = [turn_candidates[0]]
|
525 |
+
|
526 |
+
# If we have more than 2 variations and can pick more, add a middle variation for diversity
|
527 |
+
if len(turn_candidates) > 2 and num_vars > 1:
|
528 |
+
mid_index = len(turn_candidates)//2
|
529 |
+
candidates_to_pick.append(turn_candidates[mid_index])
|
530 |
+
|
531 |
+
# If we still have room for another variation, try adding the original turn if not included
|
532 |
+
if num_vars > len(candidates_to_pick):
|
533 |
+
original_turn_text = original_dialogue['turns'][turn_index]['text']
|
534 |
+
orig_candidate = next((v for v in turn_candidates if v['text'] == original_turn_text), None)
|
535 |
+
if orig_candidate and orig_candidate not in candidates_to_pick:
|
536 |
+
candidates_to_pick.append(orig_candidate)
|
537 |
+
|
538 |
+
# Shuffle candidates to produce different dialogues
|
539 |
+
np.random.shuffle(candidates_to_pick)
|
540 |
+
|
541 |
+
for variation in candidates_to_pick:
|
542 |
+
if len(augmented_dialogues) >= over_generate_factor:
|
543 |
return
|
544 |
current_turns.append(variation)
|
545 |
+
generate_candidates(current_turns, turn_index + 1)
|
546 |
current_turns.pop()
|
547 |
+
|
548 |
try:
|
549 |
+
generate_candidates()
|
550 |
except Exception as e:
|
551 |
print(f"Error in dialogue generation: {str(e)}")
|
552 |
return []
|
553 |
+
|
554 |
+
# Over-generated set of augmented dialogues is now available
|
555 |
+
# Let's score them and pick the top ones
|
556 |
+
scored_dialogues = []
|
557 |
+
for d in augmented_dialogues:
|
558 |
+
score = self._dialogue_quality_score(d, original_dialogue)
|
559 |
+
scored_dialogues.append((score, d))
|
560 |
+
|
561 |
+
scored_dialogues.sort(key=lambda x: x[0], reverse=True)
|
562 |
+
# Pick top `augmentation_factor` dialogues
|
563 |
+
final_dialogues = [d for _, d in scored_dialogues[:self.config.augmentation_factor]]
|
564 |
+
|
565 |
+
return final_dialogues
|
566 |
+
# def _generate_dialogue_combinations(self, dialogue_id: str, turn_variations: List[List[Dict]]) -> List[Dict]:
|
567 |
+
# """
|
568 |
+
# Generate dialogue combinations using sampling
|
569 |
+
# """
|
570 |
+
# augmented_dialogues = []
|
571 |
+
# used_combinations = set()
|
572 |
+
|
573 |
+
# def generate_dialogues(current_turns=None, turn_index=0):
|
574 |
+
# if current_turns is None:
|
575 |
+
# current_turns = []
|
576 |
+
|
577 |
+
# if len(augmented_dialogues) >= self.config.augmentation_factor:
|
578 |
+
# return
|
579 |
+
|
580 |
+
# if turn_index == len(turn_variations):
|
581 |
+
# dialogue_fingerprint = " | ".join(turn['text'] for turn in current_turns)
|
582 |
+
# if dialogue_fingerprint not in used_combinations:
|
583 |
+
# used_combinations.add(dialogue_fingerprint)
|
584 |
+
# augmented_dialogues.append({
|
585 |
+
# 'dialogue_id': f"{dialogue_id}_aug_{len(augmented_dialogues)}",
|
586 |
+
# 'turns': current_turns.copy()
|
587 |
+
# })
|
588 |
+
# return
|
589 |
+
|
590 |
+
# variations = list(turn_variations[turn_index])
|
591 |
+
# np.random.shuffle(variations)
|
592 |
+
|
593 |
+
# for variation in variations[:self.config.max_sampled_variations]:
|
594 |
+
# if len(augmented_dialogues) >= self.config.augmentation_factor:
|
595 |
+
# return
|
596 |
+
# current_turns.append(variation)
|
597 |
+
# generate_dialogues(current_turns, turn_index + 1)
|
598 |
+
# current_turns.pop()
|
599 |
+
|
600 |
+
# try:
|
601 |
+
# generate_dialogues()
|
602 |
+
# except Exception as e:
|
603 |
+
# print(f"Error in dialogue generation: {str(e)}")
|
604 |
+
# return []
|
605 |
+
|
606 |
+
# return augmented_dialogues
|
607 |
|
608 |
def _is_dialogue_duplicate(self, dialogue1: Dict, dialogue2: Dict) -> bool:
|
609 |
"""
|
|
|
616 |
def _augment_short_text(self, turn: Dict) -> List[Dict]:
|
617 |
"""
|
618 |
Special handling for very short texts with predefined variations.
|
619 |
+
If predefined variations are found, return them directly.
|
620 |
+
Otherwise, produce simple punctuation and capitalization variants.
|
621 |
+
Skip heavy quality checks for efficiency. These variations are safe and minimal.
|
|
|
|
|
622 |
"""
|
623 |
text = turn['text']
|
624 |
common_variations = {
|
|
|
652 |
'Fantastic!', 'Amazing!', 'Terrific!'
|
653 |
]
|
654 |
}
|
655 |
+
|
|
|
656 |
text_lower = text.lower().rstrip('!.,?')
|
657 |
+
# Check if text matches any predefined category
|
658 |
variations = []
|
|
|
|
|
659 |
for key, predefined_vars in common_variations.items():
|
660 |
if key in text_lower or text_lower in key:
|
661 |
variations.extend(predefined_vars)
|
662 |
|
|
|
663 |
if not variations:
|
664 |
+
# Generate simple punctuation and capitalization variations if no predefined match
|
665 |
+
base = text.rstrip('!.,?')
|
666 |
variations = [
|
667 |
+
base + '!',
|
668 |
+
base + '.',
|
669 |
+
base
|
670 |
]
|
671 |
|
672 |
# Add capitalization variations
|
673 |
+
capitalized = [v.capitalize() for v in variations if v.capitalize() not in variations]
|
674 |
+
variations.extend(capitalized)
|
|
|
|
|
675 |
|
676 |
+
# Ensure uniqueness
|
677 |
unique_variations = list(set(variations))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
678 |
|
679 |
+
# Directly return these variations, as they are minimal and trusted
|
680 |
+
# No further quality checks are needed
|
681 |
+
result_variations = unique_variations[:self.config.augmentation_factor]
|
682 |
+
return [{'speaker': turn['speaker'], 'text': v} for v in result_variations]
|
683 |
|
684 |
+
def process_batch(self, batch: List[Dict]) -> List[Dict]:
|
685 |
+
"""Process multiple dialogues at once to maximize GPU utilization"""
|
686 |
+
results = []
|
687 |
+
|
688 |
+
# Pre-compute embeddings for all texts in batch
|
689 |
+
all_texts = []
|
690 |
+
text_to_embedding = {}
|
691 |
+
|
692 |
+
for dialogue in batch:
|
693 |
+
for turn in dialogue['turns']:
|
694 |
+
all_texts.append(turn['text'])
|
695 |
+
|
696 |
+
# Batch compute embeddings
|
697 |
+
if all_texts:
|
698 |
+
embeddings = self._compute_batch_embeddings(all_texts)
|
699 |
+
for text, embedding in zip(all_texts, embeddings):
|
700 |
+
self.embedding_cache[text] = embedding
|
701 |
|
702 |
+
# Process each dialogue using cached embeddings
|
703 |
+
for dialogue in batch:
|
704 |
+
try:
|
705 |
+
augmented = self.augment_dialogue(dialogue)
|
706 |
+
results.extend(augmented)
|
707 |
+
except Exception as e:
|
708 |
+
print(f"Error processing dialogue {dialogue.get('dialogue_id', 'unknown')}: {e}")
|
709 |
+
continue
|
710 |
|
711 |
+
return results
|
main.py
CHANGED
@@ -59,13 +59,13 @@ def main():
|
|
59 |
min_length=1,
|
60 |
max_length=512,
|
61 |
batch_size=32 if tf.config.list_physical_devices('GPU') else 16,
|
62 |
-
max_turns_per_dialogue=
|
63 |
-
max_variations_per_turn=
|
64 |
max_sampled_variations=2,
|
65 |
context_window_size=4,
|
66 |
max_complexity_threshold=100,
|
67 |
use_cache=False,
|
68 |
-
debug=
|
69 |
allowed_speakers=['user', 'assistant'],
|
70 |
required_fields=['dialogue_id', 'turns']
|
71 |
)
|
|
|
59 |
min_length=1,
|
60 |
max_length=512,
|
61 |
batch_size=32 if tf.config.list_physical_devices('GPU') else 16,
|
62 |
+
max_turns_per_dialogue=12,
|
63 |
+
max_variations_per_turn=4,
|
64 |
max_sampled_variations=2,
|
65 |
context_window_size=4,
|
66 |
max_complexity_threshold=100,
|
67 |
use_cache=False,
|
68 |
+
debug=True,
|
69 |
allowed_speakers=['user', 'assistant'],
|
70 |
required_fields=['dialogue_id', 'turns']
|
71 |
)
|
paraphraser.py
CHANGED
@@ -9,11 +9,18 @@ class Paraphraser:
|
|
9 |
self.model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
|
10 |
self.model.eval()
|
11 |
|
12 |
-
def paraphrase(self, text, num_return_sequences=5, num_beams=
|
|
|
13 |
try:
|
14 |
input_text = "paraphrase: " + text + " </s>"
|
15 |
encoding = self.tokenizer.encode_plus(input_text, return_tensors="pt")
|
16 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
17 |
|
18 |
outputs = self.model.generate(
|
19 |
input_ids=input_ids,
|
@@ -24,7 +31,11 @@ class Paraphraser:
|
|
24 |
diversity_penalty=diversity_penalty,
|
25 |
early_stopping=True
|
26 |
)
|
27 |
-
|
|
|
|
|
|
|
|
|
28 |
return paraphrases
|
29 |
except Exception as e:
|
30 |
print(f"Error in paraphrasing: {e}")
|
|
|
9 |
self.model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
|
10 |
self.model.eval()
|
11 |
|
12 |
+
def paraphrase(self, text, num_return_sequences=5, num_beams=5,
|
13 |
+
num_beam_groups=1, diversity_penalty=0.0, device=None):
|
14 |
try:
|
15 |
input_text = "paraphrase: " + text + " </s>"
|
16 |
encoding = self.tokenizer.encode_plus(input_text, return_tensors="pt")
|
17 |
+
|
18 |
+
# Move input tensors to specified device if provided
|
19 |
+
if device is not None:
|
20 |
+
input_ids = encoding["input_ids"].to(device)
|
21 |
+
self.model = self.model.to(device)
|
22 |
+
else:
|
23 |
+
input_ids = encoding["input_ids"]
|
24 |
|
25 |
outputs = self.model.generate(
|
26 |
input_ids=input_ids,
|
|
|
31 |
diversity_penalty=diversity_penalty,
|
32 |
early_stopping=True
|
33 |
)
|
34 |
+
|
35 |
+
# Move outputs back to CPU for tokenizer decoding
|
36 |
+
outputs = outputs.cpu() if device is not None else outputs
|
37 |
+
paraphrases = [self.tokenizer.decode(output, skip_special_tokens=True)
|
38 |
+
for output in outputs]
|
39 |
return paraphrases
|
40 |
except Exception as e:
|
41 |
print(f"Error in paraphrasing: {e}")
|
pipeline_config.py
CHANGED
@@ -30,7 +30,6 @@ class PipelineConfig:
|
|
30 |
grammar_error_threshold: int = 2
|
31 |
rouge1_f1_threshold: float = 0.30
|
32 |
rouge2_f1_threshold: float = 0.15
|
33 |
-
perplexity_threshold: float = 50.0
|
34 |
|
35 |
# Response coherence thresholds
|
36 |
min_response_coherence: float = 0.3
|
|
|
30 |
grammar_error_threshold: int = 2
|
31 |
rouge1_f1_threshold: float = 0.30
|
32 |
rouge2_f1_threshold: float = 0.15
|
|
|
33 |
|
34 |
# Response coherence thresholds
|
35 |
min_response_coherence: float = 0.3
|
processing_pipeline.py
CHANGED
@@ -11,7 +11,6 @@ from pipeline_config import PipelineConfig
|
|
11 |
from dialogue_augmenter import DialogueAugmenter
|
12 |
from sklearn.feature_extraction.text import TfidfVectorizer
|
13 |
from sklearn.metrics.pairwise import cosine_similarity
|
14 |
-
from concurrent.futures import ProcessPoolExecutor
|
15 |
from typing import Set
|
16 |
|
17 |
class ProcessingPipeline:
|
@@ -33,7 +32,11 @@ class ProcessingPipeline:
|
|
33 |
self.use_gpu = torch.cuda.is_available()
|
34 |
self.batch_size = 32 if self.use_gpu else 8
|
35 |
self.use_multiprocessing = not self.use_gpu
|
36 |
-
|
|
|
|
|
|
|
|
|
37 |
if self.config.debug:
|
38 |
print(f"ProcessingPipeline initialized with:")
|
39 |
print(f"- GPU available: {self.use_gpu}")
|
@@ -75,7 +78,7 @@ class ProcessingPipeline:
|
|
75 |
text_to_dialogue_map[turn['text']] = dialogue['dialogue_id']
|
76 |
|
77 |
# Batch process embeddings
|
78 |
-
|
79 |
|
80 |
# Process dialogues with cached embeddings
|
81 |
for dialogue in batch:
|
@@ -89,16 +92,37 @@ class ProcessingPipeline:
|
|
89 |
print(f"Error processing batch: {str(e)}")
|
90 |
return results
|
91 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
92 |
def combine_results(self) -> Path:
|
93 |
-
"""Combine all
|
94 |
all_results = []
|
95 |
-
|
96 |
|
97 |
-
print(f"Combining {len(
|
98 |
-
for
|
99 |
-
with open(
|
100 |
-
|
101 |
-
all_results.extend(
|
102 |
|
103 |
# Save combined results
|
104 |
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
@@ -137,12 +161,13 @@ class ProcessingPipeline:
|
|
137 |
current_position = processed_count + batch_num + len(batch)
|
138 |
|
139 |
total_progress = (current_position / total_dialogues) * 100
|
140 |
-
batch_progress = (batch_num + 1) / ((len(remaining_dialogues) + self.batch_size - 1) // self.batch_size) * 100
|
141 |
|
142 |
-
print(
|
143 |
-
|
144 |
-
f"
|
145 |
-
|
|
|
|
|
146 |
|
147 |
# Process batch
|
148 |
batch_results = self._process_batch(batch)
|
@@ -152,20 +177,37 @@ class ProcessingPipeline:
|
|
152 |
batch_ids = {d['dialogue_id'] for d in batch}
|
153 |
processed_ids.update(batch_ids)
|
154 |
self._update_checkpoint(processed_ids)
|
155 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
156 |
print("\n" + "-" * 50)
|
157 |
print("Processing complete. Combining results...")
|
158 |
return self.combine_results()
|
159 |
|
160 |
def cleanup(self):
|
161 |
-
"""Clean up intermediate
|
|
|
162 |
batch_files = list(self.output_dir.glob("batch_*.json"))
|
163 |
for file in batch_files:
|
164 |
try:
|
165 |
file.unlink()
|
166 |
except Exception as e:
|
167 |
print(f"Error deleting {file}: {e}")
|
168 |
-
|
|
|
|
|
|
|
169 |
if self.checkpoint_file.exists():
|
170 |
try:
|
171 |
self.checkpoint_file.unlink()
|
@@ -276,4 +318,4 @@ class ProcessingPipeline:
|
|
276 |
"""
|
277 |
data_str = json.dumps(data, sort_keys=True)
|
278 |
hash_value = hashlib.md5(data_str.encode()).hexdigest()
|
279 |
-
return self.cache_dir / f"cache_{hash_value}.pkl"
|
|
|
11 |
from dialogue_augmenter import DialogueAugmenter
|
12 |
from sklearn.feature_extraction.text import TfidfVectorizer
|
13 |
from sklearn.metrics.pairwise import cosine_similarity
|
|
|
14 |
from typing import Set
|
15 |
|
16 |
class ProcessingPipeline:
|
|
|
32 |
self.use_gpu = torch.cuda.is_available()
|
33 |
self.batch_size = 32 if self.use_gpu else 8
|
34 |
self.use_multiprocessing = not self.use_gpu
|
35 |
+
|
36 |
+
# Counters for grouping batches
|
37 |
+
self.batch_counter = 0 # Count batches since last group combine
|
38 |
+
self.batch_group_number = 0 # How many groups have been created
|
39 |
+
|
40 |
if self.config.debug:
|
41 |
print(f"ProcessingPipeline initialized with:")
|
42 |
print(f"- GPU available: {self.use_gpu}")
|
|
|
78 |
text_to_dialogue_map[turn['text']] = dialogue['dialogue_id']
|
79 |
|
80 |
# Batch process embeddings
|
81 |
+
self.augmenter._compute_batch_embeddings(all_texts)
|
82 |
|
83 |
# Process dialogues with cached embeddings
|
84 |
for dialogue in batch:
|
|
|
92 |
print(f"Error processing batch: {str(e)}")
|
93 |
return results
|
94 |
|
95 |
+
def _combine_intermediate_batches(self):
|
96 |
+
"""
|
97 |
+
Combine all current batch_*.json files into a single batch_group_XXXX.json file,
|
98 |
+
then remove the batch_*.json files.
|
99 |
+
"""
|
100 |
+
batch_files = sorted(self.output_dir.glob("batch_*.json"))
|
101 |
+
if not batch_files:
|
102 |
+
return None # No files to combine
|
103 |
+
|
104 |
+
combined_data = []
|
105 |
+
for bf in batch_files:
|
106 |
+
with open(bf, 'r') as f:
|
107 |
+
combined_data.extend(json.load(f))
|
108 |
+
bf.unlink() # Remove the individual batch file after reading
|
109 |
+
|
110 |
+
self.batch_group_number += 1
|
111 |
+
group_file = self.output_dir / f"batch_group_{self.batch_group_number:04d}.json"
|
112 |
+
with open(group_file, 'w') as f:
|
113 |
+
json.dump(combined_data, f)
|
114 |
+
return group_file
|
115 |
+
|
116 |
def combine_results(self) -> Path:
|
117 |
+
"""Combine all batch_group_*.json files into final output"""
|
118 |
all_results = []
|
119 |
+
group_files = sorted(self.output_dir.glob("batch_group_*.json"))
|
120 |
|
121 |
+
print(f"Combining {len(group_files)} group files...")
|
122 |
+
for group_file in tqdm(group_files):
|
123 |
+
with open(group_file, 'r') as f:
|
124 |
+
group_data = json.load(f)
|
125 |
+
all_results.extend(group_data)
|
126 |
|
127 |
# Save combined results
|
128 |
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
|
|
161 |
current_position = processed_count + batch_num + len(batch)
|
162 |
|
163 |
total_progress = (current_position / total_dialogues) * 100
|
|
|
164 |
|
165 |
+
print('\033[K', end='')
|
166 |
+
print(f"Processing: {current_position}/{total_dialogues} dialogues "
|
167 |
+
f"({total_progress:.1f}% complete)")
|
168 |
+
print(f"Current batch: {batch_num//self.batch_size + 1} of "
|
169 |
+
f"{(len(remaining_dialogues) + self.batch_size - 1) // self.batch_size}")
|
170 |
+
print("-" * 50)
|
171 |
|
172 |
# Process batch
|
173 |
batch_results = self._process_batch(batch)
|
|
|
177 |
batch_ids = {d['dialogue_id'] for d in batch}
|
178 |
processed_ids.update(batch_ids)
|
179 |
self._update_checkpoint(processed_ids)
|
180 |
+
|
181 |
+
# Increment batch counter and combine if needed
|
182 |
+
self.batch_counter += 1
|
183 |
+
if self.batch_counter == 25:
|
184 |
+
# Combine these 25 batches into a group file
|
185 |
+
self._combine_intermediate_batches()
|
186 |
+
self.batch_counter = 0 # Reset counter after grouping
|
187 |
+
|
188 |
+
# If there are leftover batches less than 25
|
189 |
+
# combine them into one final group file
|
190 |
+
if self.batch_counter > 0:
|
191 |
+
self._combine_intermediate_batches()
|
192 |
+
self.batch_counter = 0
|
193 |
+
|
194 |
print("\n" + "-" * 50)
|
195 |
print("Processing complete. Combining results...")
|
196 |
return self.combine_results()
|
197 |
|
198 |
def cleanup(self):
|
199 |
+
"""Clean up intermediate files after successful processing"""
|
200 |
+
# Clean up any leftover batch files (should not exist if logic is correct)
|
201 |
batch_files = list(self.output_dir.glob("batch_*.json"))
|
202 |
for file in batch_files:
|
203 |
try:
|
204 |
file.unlink()
|
205 |
except Exception as e:
|
206 |
print(f"Error deleting {file}: {e}")
|
207 |
+
|
208 |
+
# We can also remove batch_group_*.json if desired after final combine
|
209 |
+
# but that might not be necessary if we want to keep them.
|
210 |
+
|
211 |
if self.checkpoint_file.exists():
|
212 |
try:
|
213 |
self.checkpoint_file.unlink()
|
|
|
318 |
"""
|
319 |
data_str = json.dumps(data, sort_keys=True)
|
320 |
hash_value = hashlib.md5(data_str.encode()).hexdigest()
|
321 |
+
return self.cache_dir / f"cache_{hash_value}.pkl"
|
quality_metrics.py
CHANGED
@@ -1,129 +1,47 @@
|
|
1 |
-
import torch
|
2 |
-
import tensorflow as tf
|
3 |
import tensorflow_hub as hub
|
4 |
-
from transformers import GPT2TokenizerFast, GPT2LMHeadModel
|
5 |
-
import language_tool_python
|
6 |
-
from rouge_score import rouge_scorer
|
7 |
import spacy
|
8 |
from sklearn.metrics.pairwise import cosine_similarity
|
9 |
-
import numpy as np
|
10 |
from typing import Dict
|
11 |
from pipeline_config import PipelineConfig
|
12 |
|
13 |
class QualityMetrics:
|
14 |
"""
|
15 |
-
|
16 |
"""
|
17 |
def __init__(self, config: PipelineConfig):
|
18 |
self.config = config
|
19 |
-
|
20 |
-
# Semantic similarity
|
21 |
self.use_model = hub.load('https://tfhub.dev/google/universal-sentence-encoder/4')
|
22 |
-
|
23 |
-
|
24 |
-
self.tokenizer = GPT2TokenizerFast.from_pretrained('gpt2')
|
25 |
-
self.model = GPT2LMHeadModel.from_pretrained('gpt2')
|
26 |
-
self.model.eval()
|
27 |
-
|
28 |
-
# Grammar
|
29 |
-
self.language_tool = language_tool_python.LanguageTool('en-US')
|
30 |
-
|
31 |
-
# Lexical similarity
|
32 |
-
self.rouge = rouge_scorer.RougeScorer(['rouge1', 'rouge2', 'rougeL'], use_stemmer=True)
|
33 |
-
|
34 |
-
# Diversity
|
35 |
-
self.nlp = spacy.load('en_core_web_sm')
|
36 |
-
|
37 |
-
def compute_perplexity(self, text):
|
38 |
-
try:
|
39 |
-
encodings = self.tokenizer(text, return_tensors='pt')
|
40 |
-
input_ids = encodings['input_ids']
|
41 |
-
with torch.no_grad():
|
42 |
-
outputs = self.model(input_ids, labels=input_ids)
|
43 |
-
loss = outputs.loss
|
44 |
-
perplexity = torch.exp(loss)
|
45 |
-
return perplexity.item()
|
46 |
-
except Exception as e:
|
47 |
-
print(f"Error computing perplexity for text '{text}': {e}")
|
48 |
-
return float('inf') # High perplexity value == poor quality
|
49 |
-
|
50 |
def compute_semantic_similarity(self, text1: str, text2: str) -> float:
|
51 |
-
"""
|
52 |
-
Compute semantic similarity between two texts using the Universal Sentence Encoder.
|
53 |
-
Args:
|
54 |
-
text1 (str): First text
|
55 |
-
text2 (str): Second text
|
56 |
-
Returns:
|
57 |
-
float: Cosine similarity score between the two texts (0-1)
|
58 |
-
"""
|
59 |
embeddings = self.use_model([text1, text2])
|
60 |
emb1, emb2 = embeddings[0].numpy(), embeddings[1].numpy()
|
61 |
return cosine_similarity([emb1], [emb2])[0][0]
|
62 |
|
63 |
def compute_metrics(self, original: str, augmented: str) -> Dict[str, float]:
|
64 |
-
"""
|
65 |
-
Compute quality metrics
|
66 |
-
"""
|
67 |
metrics = {}
|
68 |
-
|
69 |
-
# 1. Semantic Preservation
|
70 |
embeddings = self.use_model([original, augmented])
|
71 |
emb_orig, emb_aug = embeddings[0].numpy(), embeddings[1].numpy()
|
72 |
metrics['semantic_similarity'] = cosine_similarity([emb_orig], [emb_aug])[0][0]
|
73 |
|
74 |
-
#
|
75 |
-
metrics['perplexity'] = self.compute_perplexity(augmented)
|
76 |
-
metrics['grammar_errors'] = len(self.language_tool.check(augmented))
|
77 |
-
|
78 |
-
# 3. Lexical Diversity
|
79 |
doc_orig = self.nlp(original)
|
80 |
doc_aug = self.nlp(augmented)
|
81 |
-
|
82 |
-
# Type-token ratio with safety check
|
83 |
aug_tokens = [token.text.lower() for token in doc_aug]
|
84 |
metrics['type_token_ratio'] = len(set(aug_tokens)) / max(len(aug_tokens), 1)
|
85 |
|
86 |
-
|
87 |
-
|
88 |
-
aug_content = set([token.text.lower() for token in doc_aug if not token.is_stop])
|
89 |
-
|
90 |
-
# Safety check for empty content sets
|
91 |
if len(orig_content) == 0:
|
92 |
metrics['content_preservation'] = 1.0 if len(aug_content) == 0 else 0.0
|
93 |
else:
|
94 |
metrics['content_preservation'] = len(orig_content.intersection(aug_content)) / len(orig_content)
|
95 |
|
96 |
-
#
|
97 |
-
rouge_scores = self.rouge.score(original, augmented)
|
98 |
-
metrics['rouge1_f1'] = rouge_scores['rouge1'].fmeasure
|
99 |
-
metrics['rouge2_f1'] = rouge_scores['rouge2'].fmeasure
|
100 |
-
metrics['rougeL_f1'] = rouge_scores['rougeL'].fmeasure
|
101 |
-
|
102 |
-
# 5. Length Preservation with safety check
|
103 |
orig_words = len(original.split())
|
104 |
aug_words = len(augmented.split())
|
105 |
metrics['length_ratio'] = aug_words / max(orig_words, 1)
|
106 |
-
|
107 |
-
return metrics
|
108 |
|
109 |
-
|
110 |
-
"""
|
111 |
-
Enhanced quality threshold checking
|
112 |
-
"""
|
113 |
-
# Core quality checks
|
114 |
-
basic_quality = (
|
115 |
-
metrics['perplexity'] <= self.config.perplexity_threshold and
|
116 |
-
metrics['semantic_similarity'] >= self.config.semantic_similarity_threshold and
|
117 |
-
metrics['grammar_errors'] <= self.config.grammar_error_threshold
|
118 |
-
)
|
119 |
-
|
120 |
-
# Length preservation check
|
121 |
-
length_ok = 0.6 <= metrics['length_ratio'] <= 1.4
|
122 |
-
|
123 |
-
# Diversity check
|
124 |
-
diversity_ok = metrics['type_token_ratio'] >= 0.4
|
125 |
-
|
126 |
-
# Content preservation check
|
127 |
-
content_ok = metrics['content_preservation'] >= 0.6
|
128 |
-
|
129 |
-
return all([basic_quality, length_ok, diversity_ok, content_ok])
|
|
|
|
|
|
|
1 |
import tensorflow_hub as hub
|
|
|
|
|
|
|
2 |
import spacy
|
3 |
from sklearn.metrics.pairwise import cosine_similarity
|
|
|
4 |
from typing import Dict
|
5 |
from pipeline_config import PipelineConfig
|
6 |
|
7 |
class QualityMetrics:
|
8 |
"""
|
9 |
+
Quality metrics focusing on semantic similarity and basic lexical stats.
|
10 |
"""
|
11 |
def __init__(self, config: PipelineConfig):
|
12 |
self.config = config
|
|
|
|
|
13 |
self.use_model = hub.load('https://tfhub.dev/google/universal-sentence-encoder/4')
|
14 |
+
self.nlp = spacy.load('en_core_web_md')
|
15 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
16 |
def compute_semantic_similarity(self, text1: str, text2: str) -> float:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
17 |
embeddings = self.use_model([text1, text2])
|
18 |
emb1, emb2 = embeddings[0].numpy(), embeddings[1].numpy()
|
19 |
return cosine_similarity([emb1], [emb2])[0][0]
|
20 |
|
21 |
def compute_metrics(self, original: str, augmented: str) -> Dict[str, float]:
|
|
|
|
|
|
|
22 |
metrics = {}
|
23 |
+
# Semantic similarity
|
|
|
24 |
embeddings = self.use_model([original, augmented])
|
25 |
emb_orig, emb_aug = embeddings[0].numpy(), embeddings[1].numpy()
|
26 |
metrics['semantic_similarity'] = cosine_similarity([emb_orig], [emb_aug])[0][0]
|
27 |
|
28 |
+
# Lexical diversity & content preservation
|
|
|
|
|
|
|
|
|
29 |
doc_orig = self.nlp(original)
|
30 |
doc_aug = self.nlp(augmented)
|
31 |
+
|
|
|
32 |
aug_tokens = [token.text.lower() for token in doc_aug]
|
33 |
metrics['type_token_ratio'] = len(set(aug_tokens)) / max(len(aug_tokens), 1)
|
34 |
|
35 |
+
orig_content = {token.text.lower() for token in doc_orig if not token.is_stop}
|
36 |
+
aug_content = {token.text.lower() for token in doc_aug if not token.is_stop}
|
|
|
|
|
|
|
37 |
if len(orig_content) == 0:
|
38 |
metrics['content_preservation'] = 1.0 if len(aug_content) == 0 else 0.0
|
39 |
else:
|
40 |
metrics['content_preservation'] = len(orig_content.intersection(aug_content)) / len(orig_content)
|
41 |
|
42 |
+
# Length ratio
|
|
|
|
|
|
|
|
|
|
|
|
|
43 |
orig_words = len(original.split())
|
44 |
aug_words = len(augmented.split())
|
45 |
metrics['length_ratio'] = aug_words / max(orig_words, 1)
|
|
|
|
|
46 |
|
47 |
+
return metrics
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
readme.md
CHANGED
@@ -11,14 +11,16 @@ This package automatically downloads the following models during installation:
|
|
11 |
- Universal Sentence Encoder v4 (TensorFlow Hub)
|
12 |
- ChatGPT Paraphraser T5-base
|
13 |
- Helsinki-NLP translation models (en-de, de-es, es-en)
|
14 |
-
-
|
15 |
-
- spaCy en_core_web_sm
|
16 |
- nltk wordnet and averaged_perceptron_tagger_eng models
|
17 |
|
18 |
## Install package
|
19 |
|
20 |
pip install -e .
|
21 |
|
|
|
|
|
|
|
22 |
## Description
|
23 |
|
24 |
This Python script demonstrates a complete pipeline for dialogue augmentation, including validation, optimization, and data augmentation.
|
|
|
11 |
- Universal Sentence Encoder v4 (TensorFlow Hub)
|
12 |
- ChatGPT Paraphraser T5-base
|
13 |
- Helsinki-NLP translation models (en-de, de-es, es-en)
|
14 |
+
- spaCy en_core_web_sm, eng_core_web_md
|
|
|
15 |
- nltk wordnet and averaged_perceptron_tagger_eng models
|
16 |
|
17 |
## Install package
|
18 |
|
19 |
pip install -e .
|
20 |
|
21 |
+
On Linux with Cuda/GPU:
|
22 |
+
pip install faiss-gpu>=1.7.0
|
23 |
+
|
24 |
## Description
|
25 |
|
26 |
This Python script demonstrates a complete pipeline for dialogue augmentation, including validation, optimization, and data augmentation.
|
requirements.txt
CHANGED
@@ -1,12 +1,14 @@
|
|
1 |
-
|
|
|
2 |
numpy>=1.19.0 # General numerical computation
|
3 |
-
tqdm>=4.64.0 # Progress bar
|
4 |
-
torch>=1.10.0 # PyTorch, for deep learning
|
5 |
-
tensorflow>=2.6.0 # TensorFlow, for deep learning
|
6 |
-
tensorflow-hub>=0.12.0 # Pretrained model hub for TensorFlow
|
7 |
-
transformers>=4.21.0 # Hugging Face Transformers library
|
8 |
-
rouge-score>=0.1.2 # ROUGE metric for evaluation
|
9 |
-
language-tool-python>=2.7.1 # Grammar checking and text correction
|
10 |
scikit-learn>=1.0.0 # Machine learning tools
|
11 |
-
|
12 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
nlpaug>=1.1.0 # Data augmentation for NLP
|
2 |
+
nltk>=3.6.0 # Natural language toolkit
|
3 |
numpy>=1.19.0 # General numerical computation
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
4 |
scikit-learn>=1.0.0 # Machine learning tools
|
5 |
+
sacremoses>=0.0.53 # Required for some HuggingFace models
|
6 |
+
sentencepiece>=0.1.99 # Required for HuggingFace transformers
|
7 |
+
spacy>=3.0.0 # Text processing and tokenization
|
8 |
+
tensorflow>=2.13.0 # TensorFlow, for deep learning
|
9 |
+
tensorflow-hub>=0.12.0 # Pretrained model hub for TensorFlow
|
10 |
+
tokenizers>=0.13.0 # Required for HuggingFace transformers
|
11 |
+
torch>=2.0.0 # PyTorch, for deep learning
|
12 |
+
tqdm>=4.64.0 # Progress bar
|
13 |
+
transformers>=4.30.0 # Hugging Face Transformers library
|
14 |
+
faiss-cpu>=1.7.0 # Required for Facebook AI Similarity Search
|
response_quality_checker.py
ADDED
@@ -0,0 +1,164 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
from typing import List, Tuple, Dict, Any
|
3 |
+
from sklearn.metrics.pairwise import cosine_similarity
|
4 |
+
from chatbot4 import RetrievalChatbot
|
5 |
+
|
6 |
+
class ResponseQualityChecker:
|
7 |
+
"""Handles quality checking and confidence scoring for chatbot responses."""
|
8 |
+
|
9 |
+
def __init__(
|
10 |
+
self,
|
11 |
+
chatbot: RetrievalChatbot,
|
12 |
+
confidence_threshold: float = 0.5,
|
13 |
+
diversity_threshold: float = 0.1,
|
14 |
+
min_response_length: int = 3,
|
15 |
+
max_similarity_ratio: float = 0.9
|
16 |
+
):
|
17 |
+
self.confidence_threshold = confidence_threshold
|
18 |
+
self.diversity_threshold = diversity_threshold
|
19 |
+
self.min_response_length = min_response_length
|
20 |
+
self.max_similarity_ratio = max_similarity_ratio
|
21 |
+
self.chatbot = chatbot
|
22 |
+
|
23 |
+
def check_response_quality(
|
24 |
+
self,
|
25 |
+
query: str,
|
26 |
+
responses: List[Tuple[str, float]]
|
27 |
+
) -> Dict[str, Any]:
|
28 |
+
"""
|
29 |
+
Evaluate the quality of the responses based on various metrics.
|
30 |
+
"""
|
31 |
+
# Calculate diversity based on the responses themselves
|
32 |
+
diversity = self.calculate_diversity(responses)
|
33 |
+
|
34 |
+
# Calculate relevance based on some criteria
|
35 |
+
relevance = self.calculate_relevance(query, responses)
|
36 |
+
|
37 |
+
# Calculate length scores for each response
|
38 |
+
length_scores = [self._calculate_length_score(response) for response, _ in responses]
|
39 |
+
avg_length_score = np.mean(length_scores) if length_scores else 0.0
|
40 |
+
|
41 |
+
# Extract similarity scores
|
42 |
+
similarity_scores = [score for _, score in responses]
|
43 |
+
|
44 |
+
# Calculate score gap
|
45 |
+
score_gap = self._calculate_score_gap(similarity_scores, top_n=3)
|
46 |
+
|
47 |
+
# Aggregate metrics
|
48 |
+
metrics = {
|
49 |
+
'top_score': similarity_scores[0] if similarity_scores else 0.0,
|
50 |
+
'response_diversity': diversity,
|
51 |
+
'query_response_relevance': relevance,
|
52 |
+
'response_length_score': avg_length_score,
|
53 |
+
'top_3_score_gap': score_gap
|
54 |
+
}
|
55 |
+
|
56 |
+
# Determine overall confidence
|
57 |
+
is_confident = self._determine_confidence(metrics)
|
58 |
+
|
59 |
+
return {
|
60 |
+
'diversity': diversity,
|
61 |
+
'relevance': relevance,
|
62 |
+
'is_confident': is_confident,
|
63 |
+
'top_score': metrics['top_score'],
|
64 |
+
'response_diversity': metrics['response_diversity'],
|
65 |
+
'query_response_relevance': metrics['query_response_relevance'],
|
66 |
+
'response_length_score': metrics['response_length_score'],
|
67 |
+
'top_3_score_gap': metrics['top_3_score_gap']
|
68 |
+
}
|
69 |
+
|
70 |
+
def calculate_diversity(self, responses: List[Tuple[str, float]]) -> float:
|
71 |
+
"""
|
72 |
+
Calculate diversity as the average pairwise similarity between responses.
|
73 |
+
Lower similarity indicates higher diversity.
|
74 |
+
"""
|
75 |
+
if not responses:
|
76 |
+
return 0.0
|
77 |
+
|
78 |
+
# Encode responses
|
79 |
+
embeddings = [self.encode_text(response) for response, _ in responses]
|
80 |
+
if len(embeddings) < 2:
|
81 |
+
return 1.0 # Maximum diversity
|
82 |
+
|
83 |
+
# Compute pairwise cosine similarity
|
84 |
+
similarity_matrix = cosine_similarity(embeddings)
|
85 |
+
|
86 |
+
# Exclude diagonal
|
87 |
+
sum_similarities = np.sum(similarity_matrix) - len(responses)
|
88 |
+
num_pairs = len(responses) * (len(responses) - 1)
|
89 |
+
avg_similarity = sum_similarities / num_pairs if num_pairs > 0 else 0.0
|
90 |
+
diversity_score = 1 - avg_similarity # Higher value indicates more diversity
|
91 |
+
return diversity_score
|
92 |
+
|
93 |
+
def calculate_relevance(self, query: str, responses: List[Tuple[str, float]]) -> float:
|
94 |
+
"""
|
95 |
+
Calculate relevance as the average similarity between the query and each response.
|
96 |
+
"""
|
97 |
+
if not responses:
|
98 |
+
return 0.0
|
99 |
+
|
100 |
+
# Encode query
|
101 |
+
query_embedding = self.encode_query(query)
|
102 |
+
|
103 |
+
# Encode responses
|
104 |
+
response_embeddings = [self.encode_text(response) for response, _ in responses]
|
105 |
+
|
106 |
+
# Compute cosine similarity
|
107 |
+
similarities = cosine_similarity([query_embedding], response_embeddings)[0]
|
108 |
+
|
109 |
+
avg_relevance = np.mean(similarities) if similarities.size > 0 else 0.0
|
110 |
+
return avg_relevance
|
111 |
+
|
112 |
+
def _calculate_length_score(self, response: str) -> float:
|
113 |
+
"""Score based on response length appropriateness."""
|
114 |
+
length = len(response.split())
|
115 |
+
if length < self.min_response_length:
|
116 |
+
return length / self.min_response_length
|
117 |
+
return 1.0
|
118 |
+
|
119 |
+
def _calculate_score_gap(self, scores: List[float], top_n: int = 3) -> float:
|
120 |
+
"""
|
121 |
+
Calculate the average gap between the top N scores.
|
122 |
+
|
123 |
+
Args:
|
124 |
+
scores (List[float]): List of similarity scores.
|
125 |
+
top_n (int): Number of top scores to consider.
|
126 |
+
|
127 |
+
Returns:
|
128 |
+
float: Average score gap.
|
129 |
+
"""
|
130 |
+
if len(scores) < top_n + 1:
|
131 |
+
return 0.0
|
132 |
+
gaps = [scores[i] - scores[i + 1] for i in range(top_n)]
|
133 |
+
avg_gap = np.mean(gaps)
|
134 |
+
return avg_gap
|
135 |
+
|
136 |
+
def _determine_confidence(self, metrics: Dict[str, float]) -> bool:
|
137 |
+
"""
|
138 |
+
Determine if we're confident enough in the response.
|
139 |
+
|
140 |
+
Returns:
|
141 |
+
bool: True if we should use this response, False if we should abstain
|
142 |
+
"""
|
143 |
+
conditions = [
|
144 |
+
metrics['top_score'] >= self.confidence_threshold,
|
145 |
+
metrics['response_diversity'] >= self.diversity_threshold,
|
146 |
+
metrics['response_length_score'] >= 0.8,
|
147 |
+
metrics['query_response_relevance'] >= 0.3, # was 0.5
|
148 |
+
metrics['top_3_score_gap'] >= 0.05 # was 0.1
|
149 |
+
]
|
150 |
+
return all(conditions)
|
151 |
+
|
152 |
+
def encode_text(self, text: str) -> np.ndarray:
|
153 |
+
# 1) Turn text into a list if your encode_responses() expects a list.
|
154 |
+
# 2) Then call the method from the chatbot to get the embedding.
|
155 |
+
embedding_tensor = self.chatbot.encode_responses([text]) # returns tf.Tensor of shape (1, emb_dim)
|
156 |
+
embedding = embedding_tensor.numpy()[0].astype('float32') # shape: (emb_dim,)
|
157 |
+
embedding = embedding / np.linalg.norm(embedding) if np.linalg.norm(embedding) > 0 else embedding
|
158 |
+
return embedding
|
159 |
+
|
160 |
+
def encode_query(self, query: str) -> np.ndarray:
|
161 |
+
embedding_tensor = self.chatbot.encode_query(query) # returns tf.Tensor of shape (1, emb_dim)
|
162 |
+
embedding = embedding_tensor.numpy()[0].astype('float32') # shape: (emb_dim,)
|
163 |
+
embedding = embedding / np.linalg.norm(embedding) if np.linalg.norm(embedding) > 0 else embedding
|
164 |
+
return embedding
|
run_model.py
ADDED
@@ -0,0 +1,162 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
import glob
|
3 |
+
import os
|
4 |
+
from chatbot import RetrievalChatbot
|
5 |
+
import tensorflow as tf
|
6 |
+
from sklearn.model_selection import train_test_split
|
7 |
+
import matplotlib.pyplot as plt
|
8 |
+
|
9 |
+
def load_training_data(data_directory: str) -> list:
|
10 |
+
"""Load and combine dialogue data from multiple JSON files."""
|
11 |
+
all_dialogues = []
|
12 |
+
|
13 |
+
# Get all json files matching the pattern
|
14 |
+
pattern = os.path.join(data_directory, "batch_*.json")
|
15 |
+
json_files = sorted(glob.glob(pattern))
|
16 |
+
|
17 |
+
print(f"Found {len(json_files)} batch files")
|
18 |
+
|
19 |
+
for file_path in json_files:
|
20 |
+
try:
|
21 |
+
with open(file_path, 'r', encoding='utf-8') as f:
|
22 |
+
batch_dialogues = json.load(f)
|
23 |
+
all_dialogues.extend(batch_dialogues)
|
24 |
+
print(f"Loaded {len(batch_dialogues)} dialogues from {os.path.basename(file_path)}")
|
25 |
+
except Exception as e:
|
26 |
+
print(f"Error loading {file_path}: {str(e)}")
|
27 |
+
|
28 |
+
print(f"Total dialogues loaded: {len(all_dialogues)}")
|
29 |
+
return all_dialogues
|
30 |
+
|
31 |
+
def plot_training_history(train_losses, val_losses):
|
32 |
+
# Plot training and validation loss
|
33 |
+
plt.figure()
|
34 |
+
plt.plot(train_losses, label='Train Loss')
|
35 |
+
plt.plot(val_losses, label='Val Loss')
|
36 |
+
plt.xlabel('Epoch')
|
37 |
+
plt.ylabel('Triplet Loss')
|
38 |
+
plt.legend()
|
39 |
+
plt.show()
|
40 |
+
|
41 |
+
dialogues = load_training_data('processed_outputs')
|
42 |
+
|
43 |
+
# Initialize the chatbot
|
44 |
+
chatbot = RetrievalChatbot(
|
45 |
+
vocab_size=10000,
|
46 |
+
max_sequence_length=80,
|
47 |
+
embedding_dim=256,
|
48 |
+
lstm_units=256,
|
49 |
+
num_attention_heads=8,
|
50 |
+
margin=0.3
|
51 |
+
)
|
52 |
+
|
53 |
+
# Prepare the dataset for triplet training
|
54 |
+
q_pad, p_pad, n_pad = chatbot.prepare_dataset(dialogues, neg_samples_per_pos=3)
|
55 |
+
|
56 |
+
# Train with triplet loss
|
57 |
+
train_losses, val_losses = chatbot.train_with_triplet_loss(
|
58 |
+
q_pad, p_pad, n_pad,
|
59 |
+
epochs=1,
|
60 |
+
batch_size=32,
|
61 |
+
validation_split=0.2
|
62 |
+
)
|
63 |
+
|
64 |
+
plot_training_history(train_losses, val_losses)
|
65 |
+
|
66 |
+
# After training, test prediction
|
67 |
+
response_candidates = [turn['text'] for d in dialogues for turn in d['turns'] if turn['speaker'] == 'assistant']
|
68 |
+
|
69 |
+
# Test retrieval
|
70 |
+
test_query = "I'd like a recommendation for a Korean restaurant in NYC."
|
71 |
+
top_responses = chatbot.retrieve_top_n(test_query, response_candidates, top_n=5)
|
72 |
+
print("Top responses:")
|
73 |
+
for resp, score in top_responses:
|
74 |
+
print(f"Score: {score:.4f} - {resp}")
|
75 |
+
|
76 |
+
# Single-turn validation:
|
77 |
+
test_queries = [
|
78 |
+
"I want to book a Korean restaurant in NYC.",
|
79 |
+
"Can I get two tickets for 'What Men Want'?",
|
80 |
+
"What's the best time to watch the movie today?"
|
81 |
+
]
|
82 |
+
for query in test_queries:
|
83 |
+
top_responses = chatbot.retrieve_top_n(query, response_candidates, top_n=3)
|
84 |
+
print(f"\nQuery: {query}")
|
85 |
+
for resp, score in top_responses:
|
86 |
+
print(f"Score: {score:.4f} - {resp}")
|
87 |
+
|
88 |
+
# Multi-turn conversation:
|
89 |
+
multi_turn_history = []
|
90 |
+
|
91 |
+
def update_context(multi_turn_history, query, response, max_context_turns=3):
|
92 |
+
multi_turn_history.append((query, response))
|
93 |
+
if len(multi_turn_history) > max_context_turns:
|
94 |
+
multi_turn_history.pop(0)
|
95 |
+
|
96 |
+
def get_context_enhanced_query(multi_turn_history, query):
|
97 |
+
if not multi_turn_history:
|
98 |
+
return query
|
99 |
+
context = " ".join([f"User: {q} Assistant: {r}" for q, r in multi_turn_history])
|
100 |
+
return f"{context} User: {query}"
|
101 |
+
|
102 |
+
conversation_queries = [
|
103 |
+
"I'd like to watch a movie tonight.",
|
104 |
+
"Is there a showing of 'What Men Want'?",
|
105 |
+
"What time is the last show?",
|
106 |
+
"Can I get two tickets?"
|
107 |
+
]
|
108 |
+
|
109 |
+
for query in conversation_queries:
|
110 |
+
context_query = get_context_enhanced_query(multi_turn_history, query)
|
111 |
+
top_responses = chatbot.retrieve_top_n(context_query, response_candidates, top_n=3)
|
112 |
+
best_response = top_responses[0][0]
|
113 |
+
print(f"\nUser: {query}\nAssistant: {best_response}")
|
114 |
+
update_context(multi_turn_history, query, best_response)
|
115 |
+
|
116 |
+
|
117 |
+
|
118 |
+
|
119 |
+
|
120 |
+
|
121 |
+
|
122 |
+
|
123 |
+
#queries, responses, labels = chatbot.prepare_dataset(dialogues, neg_samples_per_pos=3)
|
124 |
+
|
125 |
+
#train_dialogues, val_dialogues = train_test_split(dialogues, test_size=0.2, random_state=20)
|
126 |
+
#query_train, query_val, response_train, response_val, labels_train, labels_val = train_test_split(queries, responses, labels, test_size=0.2, random_state=20)
|
127 |
+
|
128 |
+
# chatbot.model.compile(
|
129 |
+
# optimizer=tf.keras.optimizers.Adam(learning_rate=0.001, clipnorm=1.0),
|
130 |
+
# loss='binary_crossentropy',
|
131 |
+
# metrics=['accuracy']
|
132 |
+
# )
|
133 |
+
|
134 |
+
# # Train the model with early stopping to prevent overfitting
|
135 |
+
# callbacks = [
|
136 |
+
# tf.keras.callbacks.EarlyStopping(
|
137 |
+
# monitor='val_loss',
|
138 |
+
# patience=3,
|
139 |
+
# restore_best_weights=True
|
140 |
+
# ),
|
141 |
+
# tf.keras.callbacks.ReduceLROnPlateau(
|
142 |
+
# monitor='val_loss',
|
143 |
+
# factor=0.5,
|
144 |
+
# patience=2,
|
145 |
+
# min_lr=1e-6,
|
146 |
+
# verbose=1
|
147 |
+
# ),
|
148 |
+
# tf.keras.callbacks.ModelCheckpoint(
|
149 |
+
# 'chatbot_model.keras',
|
150 |
+
# monitor='val_loss',
|
151 |
+
# save_best_only=True
|
152 |
+
# )
|
153 |
+
# ]
|
154 |
+
|
155 |
+
# history = chatbot.model.fit(
|
156 |
+
# {'query_input': query_train, 'response_input': response_train},
|
157 |
+
# labels_train,
|
158 |
+
# validation_data=({'query_input': query_val, 'response_input': response_val}, labels_val),
|
159 |
+
# epochs=5,
|
160 |
+
# batch_size=32,
|
161 |
+
# callbacks=callbacks
|
162 |
+
# )
|
run_model2.py
ADDED
@@ -0,0 +1,340 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from chatbot2 import RetrievalChatbot, ChatbotConfig
|
2 |
+
import os
|
3 |
+
import json
|
4 |
+
import glob
|
5 |
+
import matplotlib.pyplot as plt
|
6 |
+
import logging
|
7 |
+
from pathlib import Path
|
8 |
+
from typing import List, Dict, Optional, Any, Tuple
|
9 |
+
import numpy as np
|
10 |
+
from datetime import datetime
|
11 |
+
from response_quality_checker import ResponseQualityChecker
|
12 |
+
|
13 |
+
# Configure logging
|
14 |
+
logging.basicConfig(
|
15 |
+
level=logging.INFO,
|
16 |
+
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
|
17 |
+
)
|
18 |
+
logger = logging.getLogger(__name__)
|
19 |
+
|
20 |
+
def load_training_data(data_directory: str, debug_samples: Optional[int] = None) -> list:
|
21 |
+
"""
|
22 |
+
Load and combine dialogue data from multiple JSON files.
|
23 |
+
|
24 |
+
Args:
|
25 |
+
data_directory: Directory containing the dialogue files
|
26 |
+
debug_samples: If set, only load this many dialogues for debugging
|
27 |
+
"""
|
28 |
+
all_dialogues = []
|
29 |
+
data_directory = Path(data_directory)
|
30 |
+
|
31 |
+
# Get all json files matching the pattern
|
32 |
+
pattern = "batch_*.json"
|
33 |
+
json_files = sorted(data_directory.glob(pattern))
|
34 |
+
|
35 |
+
logger.info(f"Found {len(json_files)} batch files")
|
36 |
+
|
37 |
+
if debug_samples:
|
38 |
+
logger.info(f"Debug mode: Will load up to {debug_samples} dialogues")
|
39 |
+
|
40 |
+
for file_path in json_files:
|
41 |
+
try:
|
42 |
+
with open(file_path, 'r', encoding='utf-8') as f:
|
43 |
+
batch_dialogues = json.load(f)
|
44 |
+
|
45 |
+
# If in debug mode, only take what we need from this batch
|
46 |
+
if debug_samples is not None:
|
47 |
+
remaining_samples = debug_samples - len(all_dialogues)
|
48 |
+
if remaining_samples <= 0:
|
49 |
+
break
|
50 |
+
batch_dialogues = batch_dialogues[:remaining_samples]
|
51 |
+
|
52 |
+
all_dialogues.extend(batch_dialogues)
|
53 |
+
logger.info(f"Loaded {len(batch_dialogues)} dialogues from {file_path.name}")
|
54 |
+
|
55 |
+
# If we've reached our debug sample limit, stop loading
|
56 |
+
if debug_samples is not None and len(all_dialogues) >= debug_samples:
|
57 |
+
logger.info(f"Debug mode: Reached {debug_samples} samples, stopping load")
|
58 |
+
break
|
59 |
+
|
60 |
+
except Exception as e:
|
61 |
+
logger.error(f"Error loading {file_path}: {str(e)}")
|
62 |
+
|
63 |
+
total_loaded = len(all_dialogues)
|
64 |
+
if debug_samples:
|
65 |
+
logger.info(f"Debug mode: Loaded {total_loaded}/{debug_samples} requested dialogues")
|
66 |
+
else:
|
67 |
+
logger.info(f"Total dialogues loaded: {total_loaded}")
|
68 |
+
|
69 |
+
return all_dialogues
|
70 |
+
|
71 |
+
def plot_training_history(history: Dict[str, List[float]], save_dir: Path = None):
|
72 |
+
"""Plot and optionally save training history."""
|
73 |
+
# Create figure with two subplots
|
74 |
+
fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(10, 12))
|
75 |
+
|
76 |
+
# Plot losses
|
77 |
+
ax1.plot(history['train_loss'], label='Train Loss')
|
78 |
+
ax1.plot(history['val_loss'], label='Validation Loss')
|
79 |
+
ax1.set_xlabel('Epoch')
|
80 |
+
ax1.set_ylabel('Triplet Loss')
|
81 |
+
ax1.set_title('Training and Validation Loss')
|
82 |
+
ax1.legend()
|
83 |
+
ax1.grid(True)
|
84 |
+
|
85 |
+
# Plot learning rate if available
|
86 |
+
if 'learning_rate' in history:
|
87 |
+
ax2.plot(history['learning_rate'], label='Learning Rate')
|
88 |
+
ax2.set_xlabel('Step')
|
89 |
+
ax2.set_ylabel('Learning Rate')
|
90 |
+
ax2.set_title('Learning Rate Schedule')
|
91 |
+
ax2.legend()
|
92 |
+
ax2.grid(True)
|
93 |
+
|
94 |
+
plt.tight_layout()
|
95 |
+
|
96 |
+
# Save if directory provided
|
97 |
+
if save_dir:
|
98 |
+
save_dir = Path(save_dir)
|
99 |
+
save_dir.mkdir(parents=True, exist_ok=True)
|
100 |
+
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
101 |
+
plt.savefig(save_dir / f'training_history_{timestamp}.png')
|
102 |
+
|
103 |
+
plt.show()
|
104 |
+
|
105 |
+
def setup_training_directories(base_dir: str = "chatbot_training") -> Dict[str, Path]:
|
106 |
+
"""Setup directory structure for training artifacts."""
|
107 |
+
base_dir = Path(base_dir)
|
108 |
+
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
109 |
+
train_dir = base_dir / f"training_run_{timestamp}"
|
110 |
+
|
111 |
+
directories = {
|
112 |
+
'base': train_dir,
|
113 |
+
'checkpoints': train_dir / 'checkpoints',
|
114 |
+
'plots': train_dir / 'plots',
|
115 |
+
'logs': train_dir / 'logs'
|
116 |
+
}
|
117 |
+
|
118 |
+
# Create directories
|
119 |
+
for dir_path in directories.values():
|
120 |
+
dir_path.mkdir(parents=True, exist_ok=True)
|
121 |
+
|
122 |
+
return directories
|
123 |
+
|
124 |
+
def run_automatic_validation(
|
125 |
+
chatbot,
|
126 |
+
response_pool: List[str],
|
127 |
+
quality_checker: ResponseQualityChecker,
|
128 |
+
num_examples: int = 5
|
129 |
+
) -> Dict[str, Any]:
|
130 |
+
"""
|
131 |
+
Run automatic validation with quality metrics.
|
132 |
+
"""
|
133 |
+
logger.info("\n=== Running Automatic Validation ===")
|
134 |
+
|
135 |
+
test_queries = [
|
136 |
+
"Hello, how are you today?",
|
137 |
+
"What's the weather like?",
|
138 |
+
"Can you help me with a problem?",
|
139 |
+
"Tell me a joke",
|
140 |
+
"What time is it?",
|
141 |
+
"I need help with my homework",
|
142 |
+
"Where's a good place to eat?",
|
143 |
+
"What movies are playing?",
|
144 |
+
"How do I reset my password?",
|
145 |
+
"Can you recommend a book?"
|
146 |
+
]
|
147 |
+
|
148 |
+
test_queries = test_queries[:num_examples]
|
149 |
+
metrics_history = []
|
150 |
+
|
151 |
+
for i, query in enumerate(test_queries, 1):
|
152 |
+
logger.info(f"\nTest Case {i}:")
|
153 |
+
logger.info(f"Query: {query}")
|
154 |
+
|
155 |
+
# Get responses and scores
|
156 |
+
responses = chatbot.retrieve_responses(
|
157 |
+
query,
|
158 |
+
response_pool,
|
159 |
+
context=None,
|
160 |
+
top_k=5
|
161 |
+
)
|
162 |
+
|
163 |
+
# Check quality
|
164 |
+
quality_metrics = quality_checker.check_response_quality(
|
165 |
+
query, responses, response_pool
|
166 |
+
)
|
167 |
+
metrics_history.append(quality_metrics)
|
168 |
+
|
169 |
+
# Log results
|
170 |
+
logger.info(f"Quality Metrics: {quality_metrics}")
|
171 |
+
logger.info("Top responses:")
|
172 |
+
for j, (response, score) in enumerate(responses[:3], 1):
|
173 |
+
logger.info(f"{j}. Score: {score:.4f}")
|
174 |
+
logger.info(f" Response: {response}")
|
175 |
+
if j == 1 and not quality_metrics['is_confident']:
|
176 |
+
logger.info(" [Low Confidence - Would abstain from answering]")
|
177 |
+
|
178 |
+
# Calculate aggregate metrics
|
179 |
+
aggregate_metrics = {
|
180 |
+
'num_queries_tested': len(test_queries),
|
181 |
+
'avg_top_response_score': np.mean([m['top_score'] for m in metrics_history]),
|
182 |
+
'avg_diversity': np.mean([m['response_diversity'] for m in metrics_history]),
|
183 |
+
'avg_relevance': np.mean([m['query_response_relevance'] for m in metrics_history]),
|
184 |
+
'confidence_rate': np.mean([m['is_confident'] for m in metrics_history]),
|
185 |
+
}
|
186 |
+
|
187 |
+
logger.info("\n=== Validation Summary ===")
|
188 |
+
for metric, value in aggregate_metrics.items():
|
189 |
+
logger.info(f"{metric}: {value:.4f}")
|
190 |
+
|
191 |
+
return aggregate_metrics
|
192 |
+
|
193 |
+
def chat_with_quality_check(
|
194 |
+
chatbot,
|
195 |
+
query: str,
|
196 |
+
response_pool: List[str],
|
197 |
+
conversation_history: List[Tuple[str, str]],
|
198 |
+
quality_checker: ResponseQualityChecker
|
199 |
+
) -> Tuple[Optional[str], List[Tuple[str, float]], Dict[str, Any]]:
|
200 |
+
"""
|
201 |
+
Enhanced chat function with quality checking.
|
202 |
+
"""
|
203 |
+
# Get responses and scores
|
204 |
+
responses = chatbot.retrieve_responses(
|
205 |
+
query,
|
206 |
+
response_pool,
|
207 |
+
conversation_history
|
208 |
+
)
|
209 |
+
|
210 |
+
# Check quality
|
211 |
+
quality_metrics = quality_checker.check_response_quality(
|
212 |
+
query, responses, response_pool
|
213 |
+
)
|
214 |
+
|
215 |
+
if quality_metrics['is_confident']:
|
216 |
+
return responses[0][0], responses, quality_metrics
|
217 |
+
else:
|
218 |
+
uncertainty_response = (
|
219 |
+
"I apologize, but I don't feel confident providing an answer to that "
|
220 |
+
"question at the moment. Could you please rephrase or ask something else?"
|
221 |
+
)
|
222 |
+
return uncertainty_response, responses, quality_metrics
|
223 |
+
|
224 |
+
def get_total_steps(dialogues: List[Dict[str, Any]], batch_size: int, epochs: int) -> int:
|
225 |
+
"""
|
226 |
+
Calculate total training steps based on dialogues and batch size.
|
227 |
+
Assume 80% of data for training due to validation split
|
228 |
+
"""
|
229 |
+
estimated_train_samples = int(len(dialogues) * 0.8)
|
230 |
+
steps_per_epoch = estimated_train_samples // batch_size
|
231 |
+
return steps_per_epoch * epochs
|
232 |
+
|
233 |
+
def main():
|
234 |
+
DEBUG_SAMPLES = 350
|
235 |
+
BATCH_SIZE = 32
|
236 |
+
EPOCHS = 5 if DEBUG_SAMPLES else 10
|
237 |
+
|
238 |
+
# Setup training directories
|
239 |
+
dirs = setup_training_directories()
|
240 |
+
|
241 |
+
# Load training data
|
242 |
+
dialogues = load_training_data('processed_outputs', debug_samples=DEBUG_SAMPLES)
|
243 |
+
total_steps = get_total_steps(dialogues, BATCH_SIZE, EPOCHS)
|
244 |
+
|
245 |
+
# Initialize configuration
|
246 |
+
config = ChatbotConfig(
|
247 |
+
embedding_dim=32, # TODO: 256
|
248 |
+
encoder_units=32, # TODO: 256
|
249 |
+
num_attention_heads=2, # TODO: 8
|
250 |
+
warmup_steps=int(total_steps * 0.1), # 10% of total steps for warmup
|
251 |
+
)
|
252 |
+
|
253 |
+
# Save configuration
|
254 |
+
with open(dirs['base'] / 'config.json', 'w') as f:
|
255 |
+
json.dump(config.to_dict(), f, indent=2)
|
256 |
+
|
257 |
+
# Initialize chatbot
|
258 |
+
chatbot = RetrievalChatbot(config)
|
259 |
+
|
260 |
+
# Prepare dataset
|
261 |
+
logger.info("Preparing dataset...")
|
262 |
+
|
263 |
+
# Prepare and train with debug samples
|
264 |
+
q_pad, p_pad, n_pad = chatbot.prepare_dataset(
|
265 |
+
dialogues,
|
266 |
+
neg_samples_per_pos=3,
|
267 |
+
debug_samples=DEBUG_SAMPLES
|
268 |
+
)
|
269 |
+
|
270 |
+
# Train model
|
271 |
+
logger.info("Starting training...")
|
272 |
+
chatbot.train(
|
273 |
+
q_pad, p_pad, n_pad,
|
274 |
+
epochs=EPOCHS,
|
275 |
+
batch_size=BATCH_SIZE,
|
276 |
+
checkpoint_dir=dirs['checkpoints']
|
277 |
+
)
|
278 |
+
|
279 |
+
# Plot and save training history
|
280 |
+
plot_training_history(chatbot.history, save_dir=dirs['plots'])
|
281 |
+
|
282 |
+
# Save final model
|
283 |
+
chatbot.save_models(dirs['base'] / 'final_model')
|
284 |
+
|
285 |
+
# Prepare response pool for chat
|
286 |
+
response_pool = [
|
287 |
+
turn['text'] for d in dialogues
|
288 |
+
for turn in d['turns'] if turn['speaker'] == 'assistant'
|
289 |
+
]
|
290 |
+
|
291 |
+
# Initialize quality checker with appropriate thresholds
|
292 |
+
quality_checker = ResponseQualityChecker(
|
293 |
+
confidence_threshold=0.6 if not DEBUG_SAMPLES else 0.4, # Lower threshold for debug
|
294 |
+
diversity_threshold=0.2,
|
295 |
+
min_response_length=10,
|
296 |
+
max_similarity_ratio=0.9
|
297 |
+
)
|
298 |
+
|
299 |
+
# Run automatic validation
|
300 |
+
validation_metrics = run_automatic_validation(
|
301 |
+
chatbot,
|
302 |
+
response_pool,
|
303 |
+
quality_checker,
|
304 |
+
num_examples=5 if DEBUG_SAMPLES else 10
|
305 |
+
)
|
306 |
+
|
307 |
+
# Log validation metrics
|
308 |
+
logger.info(f"Validation Metrics: {validation_metrics}")
|
309 |
+
|
310 |
+
# Now continue with interactive chat
|
311 |
+
logger.info("\nStarting interactive chat session...")
|
312 |
+
conversation_history = []
|
313 |
+
|
314 |
+
while True:
|
315 |
+
query = input("\nYou: ")
|
316 |
+
if query.lower() in ['quit', 'exit', 'bye']:
|
317 |
+
break
|
318 |
+
|
319 |
+
try:
|
320 |
+
response, candidates = chatbot.chat(
|
321 |
+
query,
|
322 |
+
response_pool,
|
323 |
+
conversation_history
|
324 |
+
)
|
325 |
+
print(f"\nAssistant: {response}")
|
326 |
+
|
327 |
+
# Print top alternative responses
|
328 |
+
print("\nAlternative responses:")
|
329 |
+
for resp, score in candidates[1:4]:
|
330 |
+
print(f"Score: {score:.4f} - {resp}")
|
331 |
+
|
332 |
+
# Update history
|
333 |
+
conversation_history.append((query, response))
|
334 |
+
|
335 |
+
except Exception as e:
|
336 |
+
logger.error(f"Error during chat: {str(e)}")
|
337 |
+
print("Sorry, I encountered an error. Please try again.")
|
338 |
+
|
339 |
+
if __name__ == "__main__":
|
340 |
+
main()
|
run_model3.py
ADDED
@@ -0,0 +1,434 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from chatbot3 import RetrievalChatbot, ChatbotConfig
|
2 |
+
import os
|
3 |
+
import json
|
4 |
+
import glob
|
5 |
+
import tensorflow as tf
|
6 |
+
import matplotlib.pyplot as plt
|
7 |
+
import logging
|
8 |
+
from pathlib import Path
|
9 |
+
from typing import List, Dict, Optional, Any, Tuple
|
10 |
+
import numpy as np
|
11 |
+
from datetime import datetime
|
12 |
+
from response_quality_checker import ResponseQualityChecker
|
13 |
+
import torch
|
14 |
+
from transformers import TFAutoModel, AutoTokenizer
|
15 |
+
|
16 |
+
|
17 |
+
policy = tf.keras.mixed_precision.Policy('mixed_float16')
|
18 |
+
tf.keras.mixed_precision.set_global_policy(policy)
|
19 |
+
|
20 |
+
# Configure logging
|
21 |
+
logging.basicConfig(
|
22 |
+
level=logging.INFO,
|
23 |
+
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
|
24 |
+
)
|
25 |
+
logger = logging.getLogger(__name__)
|
26 |
+
|
27 |
+
def setup_model_cache(cache_dir: Optional[Path] = None) -> Path:
|
28 |
+
"""Setup and manage model cache directory."""
|
29 |
+
if cache_dir is None:
|
30 |
+
cache_dir = Path.home() / '.chatbot_cache'
|
31 |
+
|
32 |
+
cache_dir.mkdir(parents=True, exist_ok=True)
|
33 |
+
|
34 |
+
# Set environment variables for various libraries
|
35 |
+
os.environ['TRANSFORMERS_CACHE'] = str(cache_dir / 'transformers')
|
36 |
+
os.environ['TORCH_HOME'] = str(cache_dir / 'torch')
|
37 |
+
os.environ['HF_HOME'] = str(cache_dir / 'huggingface')
|
38 |
+
|
39 |
+
logger.info(f"Using cache directory: {cache_dir}")
|
40 |
+
return cache_dir
|
41 |
+
|
42 |
+
def setup_gpu():
|
43 |
+
"""Configure GPU settings for optimal performance."""
|
44 |
+
logger.info("Checking GPU availability...")
|
45 |
+
|
46 |
+
gpus = tf.config.list_physical_devices('GPU')
|
47 |
+
if gpus:
|
48 |
+
try:
|
49 |
+
# Allow memory growth to prevent taking all GPU memory at once
|
50 |
+
for gpu in gpus:
|
51 |
+
tf.config.experimental.set_memory_growth(gpu, True)
|
52 |
+
logger.info(f"Found {len(gpus)} GPU(s). Memory growth enabled.")
|
53 |
+
|
54 |
+
# Log GPU info
|
55 |
+
for gpu in gpus:
|
56 |
+
logger.info(f"GPU Device: {gpu}")
|
57 |
+
return True
|
58 |
+
except Exception as e:
|
59 |
+
logger.error(f"Error configuring GPU: {str(e)}")
|
60 |
+
return False
|
61 |
+
|
62 |
+
else:
|
63 |
+
logger.info("No GPU found. Using CPU.")
|
64 |
+
return False
|
65 |
+
|
66 |
+
def preload_models(config: ChatbotConfig, cache_dir: Path):
|
67 |
+
"""Preload and cache models."""
|
68 |
+
logger.info("Preloading models...")
|
69 |
+
|
70 |
+
# Cache DistilBERT
|
71 |
+
model_name = config.pretrained_model
|
72 |
+
cache_path = cache_dir / 'transformers' / model_name
|
73 |
+
|
74 |
+
if not cache_path.exists():
|
75 |
+
logger.info(f"Downloading and caching {model_name}...")
|
76 |
+
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
77 |
+
model = TFAutoModel.from_pretrained(model_name)
|
78 |
+
|
79 |
+
# Save to cache
|
80 |
+
tokenizer.save_pretrained(cache_path)
|
81 |
+
model.save_pretrained(cache_path)
|
82 |
+
else:
|
83 |
+
logger.info(f"Using cached model from {cache_path}")
|
84 |
+
|
85 |
+
return cache_path
|
86 |
+
|
87 |
+
def load_training_data(data_directory: str, debug_samples: Optional[int] = None) -> list:
|
88 |
+
"""
|
89 |
+
Load and combine dialogue data from multiple JSON files.
|
90 |
+
|
91 |
+
Args:
|
92 |
+
data_directory: Directory containing the dialogue files
|
93 |
+
debug_samples: If set, only load this many dialogues for debugging
|
94 |
+
"""
|
95 |
+
all_dialogues = []
|
96 |
+
data_directory = Path(data_directory)
|
97 |
+
|
98 |
+
# Get all json files matching the pattern
|
99 |
+
pattern = "batch_*.json"
|
100 |
+
json_files = sorted(data_directory.glob(pattern))
|
101 |
+
|
102 |
+
logger.info(f"Found {len(json_files)} batch files")
|
103 |
+
|
104 |
+
if debug_samples:
|
105 |
+
logger.info(f"Debug mode: Will load up to {debug_samples} dialogues")
|
106 |
+
|
107 |
+
for file_path in json_files:
|
108 |
+
try:
|
109 |
+
with open(file_path, 'r', encoding='utf-8') as f:
|
110 |
+
batch_dialogues = json.load(f)
|
111 |
+
|
112 |
+
# If in debug mode, only take what we need from this batch
|
113 |
+
if debug_samples is not None:
|
114 |
+
remaining_samples = debug_samples - len(all_dialogues)
|
115 |
+
if remaining_samples <= 0:
|
116 |
+
break
|
117 |
+
batch_dialogues = batch_dialogues[:remaining_samples]
|
118 |
+
|
119 |
+
all_dialogues.extend(batch_dialogues)
|
120 |
+
logger.info(f"Loaded {len(batch_dialogues)} dialogues from {file_path.name}")
|
121 |
+
|
122 |
+
# If we've reached our debug sample limit, stop loading
|
123 |
+
if debug_samples is not None and len(all_dialogues) >= debug_samples:
|
124 |
+
logger.info(f"Debug mode: Reached {debug_samples} samples, stopping load")
|
125 |
+
break
|
126 |
+
|
127 |
+
except Exception as e:
|
128 |
+
logger.error(f"Error loading {file_path}: {str(e)}")
|
129 |
+
|
130 |
+
total_loaded = len(all_dialogues)
|
131 |
+
if debug_samples:
|
132 |
+
logger.info(f"Debug mode: Loaded {total_loaded}/{debug_samples} requested dialogues")
|
133 |
+
else:
|
134 |
+
logger.info(f"Total dialogues loaded: {total_loaded}")
|
135 |
+
|
136 |
+
return all_dialogues
|
137 |
+
|
138 |
+
def plot_training_history(history: Dict[str, List[float]], save_dir: Path = None):
|
139 |
+
"""Plot and optionally save training history."""
|
140 |
+
# Create figure with two subplots
|
141 |
+
fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(10, 12))
|
142 |
+
|
143 |
+
# Plot losses
|
144 |
+
ax1.plot(history['train_loss'], label='Train Loss')
|
145 |
+
ax1.plot(history['val_loss'], label='Validation Loss')
|
146 |
+
ax1.set_xlabel('Epoch')
|
147 |
+
ax1.set_ylabel('Triplet Loss')
|
148 |
+
ax1.set_title('Training and Validation Loss')
|
149 |
+
ax1.legend()
|
150 |
+
ax1.grid(True)
|
151 |
+
|
152 |
+
# Plot learning rate if available
|
153 |
+
if 'learning_rate' in history:
|
154 |
+
ax2.plot(history['learning_rate'], label='Learning Rate')
|
155 |
+
ax2.set_xlabel('Step')
|
156 |
+
ax2.set_ylabel('Learning Rate')
|
157 |
+
ax2.set_title('Learning Rate Schedule')
|
158 |
+
ax2.legend()
|
159 |
+
ax2.grid(True)
|
160 |
+
|
161 |
+
plt.tight_layout()
|
162 |
+
|
163 |
+
# Save if directory provided
|
164 |
+
if save_dir:
|
165 |
+
save_dir = Path(save_dir)
|
166 |
+
save_dir.mkdir(parents=True, exist_ok=True)
|
167 |
+
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
168 |
+
plt.savefig(save_dir / f'training_history_{timestamp}.png')
|
169 |
+
|
170 |
+
plt.show()
|
171 |
+
|
172 |
+
def setup_training_directories(base_dir: str = "chatbot_training") -> Dict[str, Path]:
|
173 |
+
"""Setup directory structure for training artifacts."""
|
174 |
+
base_dir = Path(base_dir)
|
175 |
+
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
176 |
+
train_dir = base_dir / f"training_run_{timestamp}"
|
177 |
+
|
178 |
+
directories = {
|
179 |
+
'base': train_dir,
|
180 |
+
'checkpoints': train_dir / 'checkpoints',
|
181 |
+
'plots': train_dir / 'plots',
|
182 |
+
'logs': train_dir / 'logs'
|
183 |
+
}
|
184 |
+
|
185 |
+
# Create directories
|
186 |
+
for dir_path in directories.values():
|
187 |
+
dir_path.mkdir(parents=True, exist_ok=True)
|
188 |
+
|
189 |
+
return directories
|
190 |
+
|
191 |
+
def run_automatic_validation(
|
192 |
+
chatbot,
|
193 |
+
response_pool: List[str],
|
194 |
+
quality_checker: ResponseQualityChecker,
|
195 |
+
num_examples: int = 5
|
196 |
+
) -> Dict[str, Any]:
|
197 |
+
"""
|
198 |
+
Run automatic validation with quality metrics.
|
199 |
+
"""
|
200 |
+
logger.info("\n=== Running Automatic Validation ===")
|
201 |
+
|
202 |
+
test_queries = [
|
203 |
+
"Hello, how are you today?",
|
204 |
+
"What's the weather like?",
|
205 |
+
"Can you help me with a problem?",
|
206 |
+
"Tell me a joke",
|
207 |
+
"What time is it?",
|
208 |
+
"I need help with my homework",
|
209 |
+
"Where's a good place to eat?",
|
210 |
+
"What movies are playing?",
|
211 |
+
"How do I reset my password?",
|
212 |
+
"Can you recommend a book?"
|
213 |
+
]
|
214 |
+
|
215 |
+
test_queries = test_queries[:num_examples]
|
216 |
+
metrics_history = []
|
217 |
+
|
218 |
+
for i, query in enumerate(test_queries, 1):
|
219 |
+
logger.info(f"\nTest Case {i}:")
|
220 |
+
logger.info(f"Query: {query}")
|
221 |
+
|
222 |
+
# Get responses and scores
|
223 |
+
responses = chatbot.retrieve_responses(
|
224 |
+
query,
|
225 |
+
response_pool,
|
226 |
+
context=None,
|
227 |
+
top_k=5
|
228 |
+
)
|
229 |
+
|
230 |
+
# Check quality
|
231 |
+
quality_metrics = quality_checker.check_response_quality(
|
232 |
+
query, responses, response_pool
|
233 |
+
)
|
234 |
+
metrics_history.append(quality_metrics)
|
235 |
+
|
236 |
+
# Log results
|
237 |
+
logger.info(f"Quality Metrics: {quality_metrics}")
|
238 |
+
logger.info("Top responses:")
|
239 |
+
for j, (response, score) in enumerate(responses[:3], 1):
|
240 |
+
logger.info(f"{j}. Score: {score:.4f}")
|
241 |
+
logger.info(f" Response: {response}")
|
242 |
+
if j == 1 and not quality_metrics['is_confident']:
|
243 |
+
logger.info(" [Low Confidence - Would abstain from answering]")
|
244 |
+
|
245 |
+
# Calculate aggregate metrics
|
246 |
+
aggregate_metrics = {
|
247 |
+
'num_queries_tested': len(test_queries),
|
248 |
+
'avg_top_response_score': np.mean([m['top_score'] for m in metrics_history]),
|
249 |
+
'avg_diversity': np.mean([m['response_diversity'] for m in metrics_history]),
|
250 |
+
'avg_relevance': np.mean([m['query_response_relevance'] for m in metrics_history]),
|
251 |
+
'confidence_rate': np.mean([m['is_confident'] for m in metrics_history]),
|
252 |
+
}
|
253 |
+
|
254 |
+
logger.info("\n=== Validation Summary ===")
|
255 |
+
for metric, value in aggregate_metrics.items():
|
256 |
+
logger.info(f"{metric}: {value:.4f}")
|
257 |
+
|
258 |
+
return aggregate_metrics
|
259 |
+
|
260 |
+
def chat_with_quality_check(
|
261 |
+
chatbot,
|
262 |
+
query: str,
|
263 |
+
response_pool: List[str],
|
264 |
+
conversation_history: List[Tuple[str, str]],
|
265 |
+
quality_checker: ResponseQualityChecker
|
266 |
+
) -> Tuple[Optional[str], List[Tuple[str, float]], Dict[str, Any]]:
|
267 |
+
"""
|
268 |
+
Enhanced chat function with quality checking.
|
269 |
+
"""
|
270 |
+
# Get responses and scores
|
271 |
+
responses = chatbot.retrieve_responses(
|
272 |
+
query,
|
273 |
+
response_pool,
|
274 |
+
conversation_history
|
275 |
+
)
|
276 |
+
|
277 |
+
# Check quality
|
278 |
+
quality_metrics = quality_checker.check_response_quality(
|
279 |
+
query, responses, response_pool
|
280 |
+
)
|
281 |
+
|
282 |
+
if quality_metrics['is_confident']:
|
283 |
+
return responses[0][0], responses, quality_metrics
|
284 |
+
else:
|
285 |
+
uncertainty_response = (
|
286 |
+
"I apologize, but I don't feel confident providing an answer to that "
|
287 |
+
"question at the moment. Could you please rephrase or ask something else?"
|
288 |
+
)
|
289 |
+
return uncertainty_response, responses, quality_metrics
|
290 |
+
|
291 |
+
def get_total_steps(dialogues: List[Dict[str, Any]], batch_size: int, epochs: int) -> int:
|
292 |
+
"""
|
293 |
+
Calculate total training steps based on dialogues and batch size.
|
294 |
+
Assume 80% of data for training due to validation split
|
295 |
+
"""
|
296 |
+
estimated_train_samples = int(len(dialogues) * 0.8)
|
297 |
+
steps_per_epoch = estimated_train_samples // batch_size
|
298 |
+
return steps_per_epoch * epochs
|
299 |
+
|
300 |
+
def main():
|
301 |
+
# Set up GPU
|
302 |
+
is_gpu = setup_gpu()
|
303 |
+
|
304 |
+
DEBUG_SAMPLES = 350
|
305 |
+
BATCH_SIZE = 64 if is_gpu else 32
|
306 |
+
EPOCHS = 5 if DEBUG_SAMPLES else 10
|
307 |
+
|
308 |
+
# Set device
|
309 |
+
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
310 |
+
logger.info(f"Using device: {device}")
|
311 |
+
|
312 |
+
# Set up caching
|
313 |
+
cache_dir = setup_model_cache()
|
314 |
+
|
315 |
+
# Set up training directories
|
316 |
+
dirs = setup_training_directories()
|
317 |
+
|
318 |
+
# Load training data
|
319 |
+
dialogues = load_training_data('processed_outputs', debug_samples=DEBUG_SAMPLES)
|
320 |
+
total_steps = get_total_steps(dialogues, BATCH_SIZE, EPOCHS)
|
321 |
+
|
322 |
+
# Initialize configuration
|
323 |
+
config = ChatbotConfig(
|
324 |
+
embedding_dim=768, # Match DistilBERT's dimension
|
325 |
+
encoder_units=256,
|
326 |
+
num_attention_heads=8,
|
327 |
+
warmup_steps=int(total_steps * 0.1),
|
328 |
+
learning_rate=0.0003,
|
329 |
+
margin=0.5,
|
330 |
+
pretrained_model='distilbert-base-uncased'
|
331 |
+
)
|
332 |
+
|
333 |
+
# Preload models
|
334 |
+
preload_models(config, cache_dir)
|
335 |
+
|
336 |
+
# Save configuration
|
337 |
+
with open(dirs['base'] / 'config.json', 'w') as f:
|
338 |
+
json.dump(config.to_dict(), f, indent=2)
|
339 |
+
|
340 |
+
# Initialize chatbot
|
341 |
+
chatbot = RetrievalChatbot(config)
|
342 |
+
|
343 |
+
# Prepare dataset
|
344 |
+
logger.info("Preparing dataset...")
|
345 |
+
|
346 |
+
# Prepare and train with debug samples
|
347 |
+
q_pad, p_pad, n_pad = chatbot.prepare_dataset(
|
348 |
+
dialogues,
|
349 |
+
neg_samples_per_pos=3,
|
350 |
+
debug_samples=DEBUG_SAMPLES
|
351 |
+
)
|
352 |
+
|
353 |
+
tensorboard_callback = tf.keras.callbacks.TensorBoard(log_dir='./logs')
|
354 |
+
|
355 |
+
# Train model
|
356 |
+
logger.info("Starting training...")
|
357 |
+
chatbot.train(
|
358 |
+
q_pad, p_pad, n_pad,
|
359 |
+
epochs=EPOCHS,
|
360 |
+
batch_size=BATCH_SIZE,
|
361 |
+
validation_split=0.2,
|
362 |
+
checkpoint_dir=dirs['checkpoints'],
|
363 |
+
callbacks=[tensorboard_callback]
|
364 |
+
)
|
365 |
+
|
366 |
+
# Plot and save training history
|
367 |
+
plot_training_history(chatbot.history, save_dir=dirs['plots'])
|
368 |
+
|
369 |
+
# Save final model
|
370 |
+
chatbot.save_models(dirs['base'] / 'final_model')
|
371 |
+
|
372 |
+
# Prepare response pool for chat
|
373 |
+
response_pool = [
|
374 |
+
turn['text'] for d in dialogues
|
375 |
+
for turn in d['turns'] if turn['speaker'] == 'assistant'
|
376 |
+
]
|
377 |
+
|
378 |
+
# Initialize quality checker with appropriate thresholds
|
379 |
+
quality_checker = ResponseQualityChecker(
|
380 |
+
confidence_threshold=0.6 if not DEBUG_SAMPLES else 0.4, # Lower threshold for debug
|
381 |
+
diversity_threshold=0.2,
|
382 |
+
min_response_length=10,
|
383 |
+
max_similarity_ratio=0.9
|
384 |
+
)
|
385 |
+
|
386 |
+
# Run automatic validation
|
387 |
+
validation_metrics = run_automatic_validation(
|
388 |
+
chatbot,
|
389 |
+
response_pool,
|
390 |
+
quality_checker,
|
391 |
+
num_examples=5 if DEBUG_SAMPLES else 10
|
392 |
+
)
|
393 |
+
|
394 |
+
# Log validation metrics
|
395 |
+
logger.info(f"Validation Metrics: {validation_metrics}")
|
396 |
+
|
397 |
+
# Now continue with interactive chat
|
398 |
+
logger.info("\nStarting interactive chat session...")
|
399 |
+
conversation_history = []
|
400 |
+
|
401 |
+
while True:
|
402 |
+
query = input("\nYou: ")
|
403 |
+
if query.lower() in ['quit', 'exit', 'bye']:
|
404 |
+
break
|
405 |
+
|
406 |
+
try:
|
407 |
+
response, candidates, quality_metrics = chat_with_quality_check(
|
408 |
+
chatbot,
|
409 |
+
query,
|
410 |
+
response_pool,
|
411 |
+
conversation_history,
|
412 |
+
quality_checker
|
413 |
+
)
|
414 |
+
print(f"\nAssistant: {response}")
|
415 |
+
|
416 |
+
# Print top alternative responses if confident
|
417 |
+
if quality_metrics['is_confident']:
|
418 |
+
print("\nAlternative responses:")
|
419 |
+
for resp, score in candidates[1:4]:
|
420 |
+
print(f"Score: {score:.4f} - {resp}")
|
421 |
+
|
422 |
+
# Update history only for confident responses
|
423 |
+
conversation_history.append((query, response))
|
424 |
+
else:
|
425 |
+
print("\nQuality metrics indicated low confidence:")
|
426 |
+
print(f"Confidence score: {quality_metrics['top_score']:.4f}")
|
427 |
+
print(f"Response relevance: {quality_metrics['query_response_relevance']:.4f}")
|
428 |
+
|
429 |
+
except Exception as e:
|
430 |
+
logger.error(f"Error during chat: {str(e)}")
|
431 |
+
print("Sorry, I encountered an error. Please try again.")
|
432 |
+
|
433 |
+
if __name__ == "__main__":
|
434 |
+
main()
|
run_model4.py
ADDED
@@ -0,0 +1,237 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from chatbot4 import RetrievalChatbot, ChatbotConfig
|
2 |
+
import os
|
3 |
+
import tensorflow as tf
|
4 |
+
import matplotlib.pyplot as plt
|
5 |
+
import logging
|
6 |
+
from pathlib import Path
|
7 |
+
from typing import List, Dict, Optional
|
8 |
+
from datetime import datetime
|
9 |
+
from response_quality_checker import ResponseQualityChecker
|
10 |
+
|
11 |
+
# Configure logging
|
12 |
+
logging.basicConfig(
|
13 |
+
level=logging.INFO,
|
14 |
+
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
|
15 |
+
)
|
16 |
+
logger = logging.getLogger(__name__)
|
17 |
+
|
18 |
+
def setup_model_cache(cache_dir: Optional[Path] = None) -> Path:
|
19 |
+
"""Setup and manage model cache directory."""
|
20 |
+
if cache_dir is None:
|
21 |
+
cache_dir = Path.home() / '.chatbot_cache'
|
22 |
+
|
23 |
+
cache_dir.mkdir(parents=True, exist_ok=True)
|
24 |
+
|
25 |
+
# Set environment variables for various libraries
|
26 |
+
os.environ['TRANSFORMERS_CACHE'] = str(cache_dir / 'transformers')
|
27 |
+
os.environ['TORCH_HOME'] = str(cache_dir / 'torch')
|
28 |
+
os.environ['HF_HOME'] = str(cache_dir / 'huggingface')
|
29 |
+
|
30 |
+
logger.info(f"Using cache directory: {cache_dir}")
|
31 |
+
return cache_dir
|
32 |
+
|
33 |
+
def setup_gpu():
|
34 |
+
"""Configure GPU settings for optimal performance."""
|
35 |
+
logger.info("Checking GPU availability...")
|
36 |
+
|
37 |
+
gpus = tf.config.list_physical_devices('GPU')
|
38 |
+
if gpus:
|
39 |
+
try:
|
40 |
+
# Allow memory growth to prevent taking all GPU memory at once
|
41 |
+
for gpu in gpus:
|
42 |
+
tf.config.experimental.set_memory_growth(gpu, True)
|
43 |
+
logger.info(f"Found {len(gpus)} GPU(s). Memory growth enabled.")
|
44 |
+
|
45 |
+
# Log GPU info
|
46 |
+
for gpu in gpus:
|
47 |
+
logger.info(f"GPU Device: {gpu}")
|
48 |
+
return True
|
49 |
+
except Exception as e:
|
50 |
+
logger.error(f"Error configuring GPU: {str(e)}")
|
51 |
+
return False
|
52 |
+
|
53 |
+
else:
|
54 |
+
logger.info("No GPU found. Using CPU.")
|
55 |
+
return False
|
56 |
+
|
57 |
+
# def preload_models(config: ChatbotConfig, cache_dir: Path):
|
58 |
+
# """Preload and cache models."""
|
59 |
+
# logger.info("Preloading models...")
|
60 |
+
|
61 |
+
# # Cache DistilBERT
|
62 |
+
# model_name = config.pretrained_model
|
63 |
+
# cache_path = cache_dir / 'transformers' / model_name
|
64 |
+
|
65 |
+
# if not cache_path.exists():
|
66 |
+
# logger.info(f"Downloading and caching {model_name}...")
|
67 |
+
# tokenizer = AutoTokenizer.from_pretrained(model_name)
|
68 |
+
# model = TFAutoModel.from_pretrained(model_name)
|
69 |
+
|
70 |
+
# # Save to cache
|
71 |
+
# tokenizer.save_pretrained(cache_path)
|
72 |
+
# model.save_pretrained(cache_path)
|
73 |
+
# else:
|
74 |
+
# logger.info(f"Using cached model from {cache_path}")
|
75 |
+
|
76 |
+
# return cache_path
|
77 |
+
|
78 |
+
def plot_training_history(history: Dict[str, List[float]], save_dir: Path = None):
|
79 |
+
"""Plot and optionally save training history."""
|
80 |
+
# Create figure with two subplots
|
81 |
+
fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(10, 12))
|
82 |
+
|
83 |
+
# Plot losses
|
84 |
+
ax1.plot(history['train_loss'], label='Train Loss')
|
85 |
+
ax1.plot(history['val_loss'], label='Validation Loss')
|
86 |
+
ax1.set_xlabel('Epoch')
|
87 |
+
ax1.set_ylabel('Triplet Loss')
|
88 |
+
ax1.set_title('Training and Validation Loss')
|
89 |
+
ax1.legend()
|
90 |
+
ax1.grid(True)
|
91 |
+
|
92 |
+
# Plot learning rate if available
|
93 |
+
if 'learning_rate' in history:
|
94 |
+
ax2.plot(history['learning_rate'], label='Learning Rate')
|
95 |
+
ax2.set_xlabel('Step')
|
96 |
+
ax2.set_ylabel('Learning Rate')
|
97 |
+
ax2.set_title('Learning Rate Schedule')
|
98 |
+
ax2.legend()
|
99 |
+
ax2.grid(True)
|
100 |
+
|
101 |
+
plt.tight_layout()
|
102 |
+
|
103 |
+
# Save if directory provided
|
104 |
+
if save_dir:
|
105 |
+
save_dir = Path(save_dir)
|
106 |
+
save_dir.mkdir(parents=True, exist_ok=True)
|
107 |
+
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
108 |
+
plt.savefig(save_dir / f'training_history_{timestamp}.png')
|
109 |
+
|
110 |
+
plt.show()
|
111 |
+
|
112 |
+
def setup_training_directories(base_dir: str = "chatbot_training") -> Dict[str, Path]:
|
113 |
+
"""Setup directory structure for training artifacts."""
|
114 |
+
base_dir = Path(base_dir)
|
115 |
+
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
116 |
+
train_dir = base_dir / f"training_run_{timestamp}"
|
117 |
+
|
118 |
+
directories = {
|
119 |
+
'base': train_dir,
|
120 |
+
'checkpoints': train_dir / 'checkpoints',
|
121 |
+
'plots': train_dir / 'plots',
|
122 |
+
'logs': train_dir / 'logs'
|
123 |
+
}
|
124 |
+
|
125 |
+
# Create directories
|
126 |
+
for dir_path in directories.values():
|
127 |
+
dir_path.mkdir(parents=True, exist_ok=True)
|
128 |
+
|
129 |
+
return directories
|
130 |
+
|
131 |
+
def main():
|
132 |
+
# Set up GPU
|
133 |
+
is_gpu = setup_gpu()
|
134 |
+
|
135 |
+
DEBUG_SAMPLES = 2000
|
136 |
+
BATCH_SIZE = 128 if is_gpu else 64
|
137 |
+
EPOCHS = 5 if DEBUG_SAMPLES else 10
|
138 |
+
|
139 |
+
# Set up caching
|
140 |
+
cache_dir = setup_model_cache()
|
141 |
+
|
142 |
+
# Set up training directories
|
143 |
+
dirs = setup_training_directories()
|
144 |
+
|
145 |
+
# Initialize configuration
|
146 |
+
config = ChatbotConfig(
|
147 |
+
embedding_dim=768, # Match DistilBERT's dimension
|
148 |
+
max_sequence_length=512,
|
149 |
+
freeze_embeddings=False
|
150 |
+
)
|
151 |
+
|
152 |
+
# Preload models
|
153 |
+
#preload_models(config, cache_dir)
|
154 |
+
|
155 |
+
# Save configuration
|
156 |
+
# with open(dirs['base'] / 'config.json', 'w') as f:
|
157 |
+
# json.dump(config.to_dict(), f, indent=4)
|
158 |
+
|
159 |
+
# Load training data
|
160 |
+
dialogues = RetrievalChatbot.load_training_data(data_path='processed_outputs/batch_group_0010.json', debug_samples=DEBUG_SAMPLES)
|
161 |
+
|
162 |
+
# Initialize chatbot
|
163 |
+
chatbot = RetrievalChatbot(config, dialogues)
|
164 |
+
|
165 |
+
# Check trainable variables
|
166 |
+
chatbot.check_trainable_variables()
|
167 |
+
|
168 |
+
# Verify FAISS
|
169 |
+
chatbot.verify_faiss_index()
|
170 |
+
|
171 |
+
# Prepare dataset
|
172 |
+
logger.info("Preparing dataset...")
|
173 |
+
|
174 |
+
# Prepare and train with debug samples
|
175 |
+
q_tensor, p_tensor = chatbot.prepare_dataset(dialogues)
|
176 |
+
|
177 |
+
quality_checker = ResponseQualityChecker(chatbot=chatbot)
|
178 |
+
|
179 |
+
# Train model
|
180 |
+
logger.info("Starting training...")
|
181 |
+
|
182 |
+
tf.config.optimizer.set_jit(True) # XLA
|
183 |
+
policy = tf.keras.mixed_precision.Policy('mixed_float16')
|
184 |
+
tf.keras.mixed_precision.set_global_policy(policy)
|
185 |
+
|
186 |
+
chatbot.train(
|
187 |
+
q_pad=q_tensor,
|
188 |
+
p_pad=p_tensor,
|
189 |
+
epochs=EPOCHS,
|
190 |
+
batch_size=BATCH_SIZE,
|
191 |
+
validation_split=0.2,
|
192 |
+
checkpoint_dir="checkpoints/",
|
193 |
+
use_lr_schedule=True, # Enable custom schedule
|
194 |
+
peak_lr=2e-5, # Peak learning rate
|
195 |
+
warmup_steps_ratio=0.1, # 10% warmup
|
196 |
+
early_stopping_patience=3 # Stop if no improvement for 3 epochs
|
197 |
+
)
|
198 |
+
|
199 |
+
# Plot and save training history
|
200 |
+
#plot_training_history(chatbot.history, save_dir=dirs['plots'])
|
201 |
+
|
202 |
+
# Save final model
|
203 |
+
chatbot.save_models(dirs['base'] / 'final_model')
|
204 |
+
|
205 |
+
# Run automatic validation
|
206 |
+
validation_metrics = chatbot.run_automatic_validation(quality_checker, num_examples=5)
|
207 |
+
|
208 |
+
# Log validation metrics
|
209 |
+
logger.info(f"Validation Metrics: {validation_metrics}")
|
210 |
+
|
211 |
+
# Now continue with interactive chat
|
212 |
+
logger.info("\nStarting interactive chat session...")
|
213 |
+
conversation_history = []
|
214 |
+
|
215 |
+
while True:
|
216 |
+
user_input = input("You: ")
|
217 |
+
if user_input.lower() in ['quit', 'exit', 'bye']:
|
218 |
+
print("Assistant: Goodbye!")
|
219 |
+
break
|
220 |
+
|
221 |
+
response, candidates, metrics = chatbot.chat(
|
222 |
+
query=user_input,
|
223 |
+
conversation_history=None, # Pass conversation history if available
|
224 |
+
quality_checker=quality_checker,
|
225 |
+
top_k=5
|
226 |
+
)
|
227 |
+
|
228 |
+
print(f"Assistant: {response}")
|
229 |
+
|
230 |
+
# Optionally, display alternative responses
|
231 |
+
if metrics.get('is_confident', False):
|
232 |
+
print("\nAlternative responses:")
|
233 |
+
for resp, score in candidates[1:4]:
|
234 |
+
print(f"Score: {score:.4f} - {resp}")
|
235 |
+
|
236 |
+
if __name__ == "__main__":
|
237 |
+
main()
|
setup.py
CHANGED
@@ -1,6 +1,7 @@
|
|
1 |
from setuptools import setup, find_packages
|
2 |
import subprocess
|
3 |
import sys
|
|
|
4 |
|
5 |
with open("README.md", "r", encoding="utf-8") as fh:
|
6 |
long_description = fh.read()
|
@@ -8,11 +9,21 @@ with open("README.md", "r", encoding="utf-8") as fh:
|
|
8 |
with open("requirements.txt", "r", encoding="utf-8") as fh:
|
9 |
requirements = [line.strip() for line in fh if line.strip() and not line.startswith("#")]
|
10 |
|
11 |
-
def
|
12 |
"""
|
13 |
-
Download spaCy model.
|
|
|
|
|
|
|
14 |
"""
|
15 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
16 |
|
17 |
def setup_models():
|
18 |
"""
|
@@ -22,10 +33,17 @@ def setup_models():
|
|
22 |
from sklearn.feature_extraction.text import TfidfVectorizer
|
23 |
from transformers import (
|
24 |
AutoTokenizer,
|
|
|
25 |
GPT2TokenizerFast,
|
26 |
-
MarianTokenizer
|
|
|
|
|
27 |
)
|
28 |
|
|
|
|
|
|
|
|
|
29 |
# Download Universal Sentence Encoder
|
30 |
_ = hub.load('https://tfhub.dev/google/universal-sentence-encoder/4')
|
31 |
|
@@ -63,12 +81,48 @@ def setup_nltk():
|
|
63 |
except Exception as e:
|
64 |
print(f"Warning: Could not download {package}: {str(e)}")
|
65 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
66 |
setup(
|
67 |
name="text-data-augmenter",
|
68 |
version="0.1.0",
|
69 |
author="Joe Armani",
|
70 |
author_email="[email protected]",
|
71 |
description="A tool for generating high-quality dialogue variations",
|
|
|
|
|
72 |
packages=find_packages(),
|
73 |
classifiers=[
|
74 |
"Development Status :: 3 - Alpha",
|
@@ -95,6 +149,7 @@ setup(
|
|
95 |
)
|
96 |
|
97 |
if __name__ == '__main__':
|
98 |
-
|
99 |
setup_models()
|
100 |
-
setup_nltk()
|
|
|
|
1 |
from setuptools import setup, find_packages
|
2 |
import subprocess
|
3 |
import sys
|
4 |
+
import platform
|
5 |
|
6 |
with open("README.md", "r", encoding="utf-8") as fh:
|
7 |
long_description = fh.read()
|
|
|
9 |
with open("requirements.txt", "r", encoding="utf-8") as fh:
|
10 |
requirements = [line.strip() for line in fh if line.strip() and not line.startswith("#")]
|
11 |
|
12 |
+
def setup_spacy_models(models=['en_core_web_sm', 'en_core_web_md']):
|
13 |
"""
|
14 |
+
Download the specified spaCy model.
|
15 |
+
|
16 |
+
Args:
|
17 |
+
models(List): List[str] of the names of the spaCy model to download.
|
18 |
"""
|
19 |
+
try:
|
20 |
+
for model in models:
|
21 |
+
print(f"Downloading spaCy model: {model}")
|
22 |
+
subprocess.check_call([sys.executable, "-m", "spacy", "download", model])
|
23 |
+
print(f"Successfully downloaded spaCy model: {model}")
|
24 |
+
except subprocess.CalledProcessError as e:
|
25 |
+
print(f"Error downloading spaCy model: {model}")
|
26 |
+
print(e)
|
27 |
|
28 |
def setup_models():
|
29 |
"""
|
|
|
33 |
from sklearn.feature_extraction.text import TfidfVectorizer
|
34 |
from transformers import (
|
35 |
AutoTokenizer,
|
36 |
+
AutoModel,
|
37 |
GPT2TokenizerFast,
|
38 |
+
MarianTokenizer,
|
39 |
+
DistilBertTokenizer,
|
40 |
+
DistilBertModel
|
41 |
)
|
42 |
|
43 |
+
# Download DistilBERT for chatbot
|
44 |
+
tokenizer = DistilBertTokenizer.from_pretrained('distilbert-base-uncased')
|
45 |
+
model = DistilBertModel.from_pretrained('distilbert-base-uncased')
|
46 |
+
|
47 |
# Download Universal Sentence Encoder
|
48 |
_ = hub.load('https://tfhub.dev/google/universal-sentence-encoder/4')
|
49 |
|
|
|
81 |
except Exception as e:
|
82 |
print(f"Warning: Could not download {package}: {str(e)}")
|
83 |
|
84 |
+
def setup_faiss():
|
85 |
+
"""
|
86 |
+
Download required faiss library.
|
87 |
+
"""
|
88 |
+
current_os = platform.system()
|
89 |
+
cuda_available = False
|
90 |
+
|
91 |
+
# Function to check CUDA availability
|
92 |
+
def check_cuda():
|
93 |
+
try:
|
94 |
+
import torch
|
95 |
+
return torch.cuda.is_available()
|
96 |
+
except:
|
97 |
+
return False
|
98 |
+
|
99 |
+
if current_os == "Linux" and check_cuda():
|
100 |
+
# Attempt to install faiss-gpu
|
101 |
+
try:
|
102 |
+
print("Attempting to install faiss-gpu...")
|
103 |
+
subprocess.check_call([sys.executable, "-m", "pip", "install", "faiss-gpu>=1.7.0"])
|
104 |
+
print("Successfully installed faiss-gpu")
|
105 |
+
return
|
106 |
+
except subprocess.CalledProcessError:
|
107 |
+
print("Failed to install faiss-gpu. Falling back to faiss-cpu.")
|
108 |
+
|
109 |
+
# Install faiss-cpu as the default
|
110 |
+
try:
|
111 |
+
print("Installing faiss-cpu...")
|
112 |
+
subprocess.check_call([sys.executable, "-m", "pip", "install", "faiss-cpu>=1.7.0"])
|
113 |
+
print("Successfully installed faiss-cpu")
|
114 |
+
except subprocess.CalledProcessError as e:
|
115 |
+
print("Error installing faiss-cpu")
|
116 |
+
print(e)
|
117 |
+
|
118 |
setup(
|
119 |
name="text-data-augmenter",
|
120 |
version="0.1.0",
|
121 |
author="Joe Armani",
|
122 |
author_email="[email protected]",
|
123 |
description="A tool for generating high-quality dialogue variations",
|
124 |
+
long_description=long_description,
|
125 |
+
long_description_content_type="text/markdown",
|
126 |
packages=find_packages(),
|
127 |
classifiers=[
|
128 |
"Development Status :: 3 - Alpha",
|
|
|
149 |
)
|
150 |
|
151 |
if __name__ == '__main__':
|
152 |
+
setup_spacy_models()
|
153 |
setup_models()
|
154 |
+
setup_nltk()
|
155 |
+
setup_faiss()
|
test_trained_model.py
ADDED
File without changes
|