JoeArmani commited on
Commit
300fe5d
·
1 Parent(s): d2df8da

updates through 4th iteration

Browse files
.gitignore CHANGED
@@ -159,3 +159,7 @@ datasets/*
159
 
160
  processed_outputs/*
161
  !processed_outputs/.gitkeep
 
 
 
 
 
159
 
160
  processed_outputs/*
161
  !processed_outputs/.gitkeep
162
+
163
+ chatbot_training/
164
+ checkpoints/
165
+ .DS_Store
augmented_combined_dataset.json DELETED
The diff for this file is too large to render. See raw diff
 
back_translator.py CHANGED
@@ -3,6 +3,8 @@ from transformers import (
3
  MarianTokenizer,
4
  )
5
 
 
 
6
  class BackTranslator:
7
  """
8
  Perform Back-translation with pivot language. English -> German -> Spanish -> English
@@ -20,7 +22,7 @@ class BackTranslator:
20
  self.tokenizer_pivot_forward = MarianTokenizer.from_pretrained(pivot_forward_model_name)
21
  self.model_pivot_forward = MarianMTModel.from_pretrained(pivot_forward_model_name)
22
 
23
- # Pivot translation model (German to Spanish)
24
  pivot_backward_model_name = f'Helsinki-NLP/opus-mt-{pivot_lang}-{target_lang}'
25
  self.tokenizer_pivot_backward = MarianTokenizer.from_pretrained(pivot_backward_model_name)
26
  self.model_pivot_backward = MarianMTModel.from_pretrained(pivot_backward_model_name)
@@ -29,28 +31,57 @@ class BackTranslator:
29
  backward_model_name = f'Helsinki-NLP/opus-mt-{target_lang}-{source_lang}'
30
  self.tokenizer_backward = MarianTokenizer.from_pretrained(backward_model_name)
31
  self.model_backward = MarianMTModel.from_pretrained(backward_model_name)
 
 
 
 
 
32
 
33
- def back_translate(self, text):
34
- """
35
- Perform back-translation through German and Spanish to generate text variations.
36
- Args:
37
- text (str): The input text to be back-translated
 
 
 
 
 
 
 
 
38
 
39
- Returns:
40
- str: The back-translated text
41
- """
42
- # 1. English to German
43
- encoded_pivot = self.tokenizer_pivot_forward([text], padding=True, truncation=True, return_tensors='pt')
44
- generated_pivot = self.model_pivot_forward.generate(**encoded_pivot)
45
- pivot_text = self.tokenizer_pivot_forward.batch_decode(generated_pivot, skip_special_tokens=True)[0]
46
 
47
- # 2. German to Spanish
48
- encoded_back_pivot = self.tokenizer_pivot_backward([pivot_text], padding=True, truncation=True, return_tensors='pt')
49
- retranslated_pivot = self.model_pivot_backward.generate(**encoded_back_pivot)
50
- tgt_text_back = self.tokenizer_pivot_backward.batch_decode(retranslated_pivot, skip_special_tokens=True)[0]
 
 
 
 
 
 
 
51
 
52
- # 3. Spanish to English
53
- encoded_back = self.tokenizer_backward([tgt_text_back], padding=True, truncation=True, return_tensors='pt')
54
- retranslated = self.model_backward.generate(**encoded_back)
55
- src_text = self.tokenizer_backward.batch_decode(retranslated, skip_special_tokens=True)[0]
56
- return src_text
 
 
 
 
 
 
 
 
 
 
 
 
3
  MarianTokenizer,
4
  )
5
 
6
+ # Retained for reference but removed from the final code.
7
+ # This method did not seem helpful for this retrieval-based chatbot.
8
  class BackTranslator:
9
  """
10
  Perform Back-translation with pivot language. English -> German -> Spanish -> English
 
22
  self.tokenizer_pivot_forward = MarianTokenizer.from_pretrained(pivot_forward_model_name)
23
  self.model_pivot_forward = MarianMTModel.from_pretrained(pivot_forward_model_name)
24
 
25
+ # Pivot translation (German to Spanish)
26
  pivot_backward_model_name = f'Helsinki-NLP/opus-mt-{pivot_lang}-{target_lang}'
27
  self.tokenizer_pivot_backward = MarianTokenizer.from_pretrained(pivot_backward_model_name)
28
  self.model_pivot_backward = MarianMTModel.from_pretrained(pivot_backward_model_name)
 
31
  backward_model_name = f'Helsinki-NLP/opus-mt-{target_lang}-{source_lang}'
32
  self.tokenizer_backward = MarianTokenizer.from_pretrained(backward_model_name)
33
  self.model_backward = MarianMTModel.from_pretrained(backward_model_name)
34
+
35
+ # Set models to eval mode
36
+ self.model_pivot_forward.eval()
37
+ self.model_pivot_backward.eval()
38
+ self.model_backward.eval()
39
 
40
+ def back_translate(self, text, device=None):
41
+ try:
42
+ # Move models to device if specified
43
+ if device is not None:
44
+ self.model_pivot_forward = self.model_pivot_forward.to(device)
45
+ self.model_pivot_backward = self.model_pivot_backward.to(device)
46
+ self.model_backward = self.model_backward.to(device)
47
+
48
+ # Forward translation (English to German)
49
+ encoded_pivot = self.tokenizer_pivot_forward([text], padding=True,
50
+ truncation=True, return_tensors='pt')
51
+ if device is not None:
52
+ encoded_pivot = {k: v.to(device) for k, v in encoded_pivot.items()}
53
 
54
+ generated_pivot = self.model_pivot_forward.generate(**encoded_pivot)
55
+ if device is not None:
56
+ generated_pivot = generated_pivot.cpu()
57
+ pivot_text = self.tokenizer_pivot_forward.batch_decode(generated_pivot,
58
+ skip_special_tokens=True)[0]
 
 
59
 
60
+ # Pivot translation (German to Spanish)
61
+ encoded_back_pivot = self.tokenizer_pivot_backward([pivot_text], padding=True,
62
+ truncation=True, return_tensors='pt')
63
+ if device is not None:
64
+ encoded_back_pivot = {k: v.to(device) for k, v in encoded_back_pivot.items()}
65
+
66
+ retranslated_pivot = self.model_pivot_backward.generate(**encoded_back_pivot)
67
+ if device is not None:
68
+ retranslated_pivot = retranslated_pivot.cpu()
69
+ tgt_text_back = self.tokenizer_pivot_backward.batch_decode(retranslated_pivot,
70
+ skip_special_tokens=True)[0]
71
 
72
+ # Backward translation (Spanish to English)
73
+ encoded_back = self.tokenizer_backward([tgt_text_back], padding=True,
74
+ truncation=True, return_tensors='pt')
75
+ if device is not None:
76
+ encoded_back = {k: v.to(device) for k, v in encoded_back.items()}
77
+
78
+ retranslated = self.model_backward.generate(**encoded_back)
79
+ if device is not None:
80
+ retranslated = retranslated.cpu()
81
+ src_text = self.tokenizer_backward.batch_decode(retranslated,
82
+ skip_special_tokens=True)[0]
83
+
84
+ return src_text
85
+ except Exception as e:
86
+ print(f"Error in back translation: {e}")
87
+ return text
chatbot.py ADDED
@@ -0,0 +1,261 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import tensorflow as tf
3
+ import keras
4
+ print(tf.__version__)
5
+ print(keras.__version__)
6
+ import spacy
7
+ import random
8
+ from tqdm import trange
9
+
10
+ class RetrievalChatbot:
11
+ def __init__(
12
+ self,
13
+ vocab_size: int = 10000,
14
+ max_sequence_length: int = 80,
15
+ embedding_dim: int = 256,
16
+ lstm_units: int = 256,
17
+ num_attention_heads: int = 8,
18
+ margin: float = 0.3
19
+ ):
20
+ self.vocab_size = vocab_size
21
+ self.max_sequence_length = max_sequence_length
22
+ self.embedding_dim = embedding_dim
23
+ self.lstm_units = lstm_units
24
+ self.num_attention_heads = num_attention_heads
25
+ self.margin = margin
26
+
27
+ self.nlp = spacy.load('en_core_web_md')
28
+ self.tokenizer = tf.keras.preprocessing.text.Tokenizer(
29
+ num_words=vocab_size,
30
+ oov_token="<OOV>"
31
+ )
32
+
33
+ self.query_encoder_model, self.response_encoder_model = self._build_encoders()
34
+
35
+ def _positional_encoding(self, position: int, d_model: int) -> tf.Tensor:
36
+ angles = np.arange(position)[:, np.newaxis] / np.power(
37
+ 10000,
38
+ (2 * (np.arange(d_model)[np.newaxis, :] // 2)) / d_model
39
+ )
40
+ sines = np.sin(angles[:, 0::2])
41
+ cosines = np.cos(angles[:, 1::2])
42
+ pos_encoding = np.concatenate([sines, cosines], axis=-1)
43
+ pos_encoding = pos_encoding[np.newaxis, ...]
44
+ return tf.cast(pos_encoding, dtype=tf.float32)
45
+
46
+ def _build_single_encoder(self, name_prefix: str):
47
+ input_layer = tf.keras.Input(shape=(self.max_sequence_length,), name=f"{name_prefix}_input")
48
+ embedding = tf.keras.layers.Embedding(
49
+ self.vocab_size,
50
+ self.embedding_dim,
51
+ mask_zero=True,
52
+ name=f"{name_prefix}_embedding"
53
+ )(input_layer)
54
+
55
+ pos_encoding = self._positional_encoding(self.max_sequence_length, self.embedding_dim)
56
+ x = embedding + pos_encoding
57
+
58
+ # # Multi-head attention
59
+ # attention_output = tf.keras.layers.MultiHeadAttention(
60
+ # num_heads=self.num_attention_heads,
61
+ # key_dim=self.embedding_dim // self.num_attention_heads
62
+ # )(x, x)
63
+ # x = tf.keras.layers.LayerNormalization()(x + attention_output)
64
+
65
+ for i in range(2):
66
+ lstm_out = tf.keras.layers.LSTM(
67
+ self.lstm_units,
68
+ return_sequences=True,
69
+ kernel_regularizer=tf.keras.regularizers.l2(0.01),
70
+ name=f"{name_prefix}_lstm_{i}"
71
+ )(x)
72
+ x = tf.keras.layers.LayerNormalization()(x + lstm_out)
73
+
74
+ encoder_output = tf.keras.layers.LSTM(
75
+ self.lstm_units,
76
+ name=f"{name_prefix}_final_lstm"
77
+ )(x)
78
+ encoder_output = tf.keras.layers.Dropout(0.2)(encoder_output)
79
+ encoder_output = tf.keras.layers.Lambda(lambda t: tf.nn.l2_normalize(t, axis=1))(encoder_output)
80
+
81
+ return tf.keras.Model(input_layer, encoder_output, name=f"{name_prefix}_encoder")
82
+
83
+ def _build_encoders(self):
84
+ query_encoder = self._build_single_encoder("query")
85
+ response_encoder = self._build_single_encoder("response")
86
+ return query_encoder, response_encoder
87
+
88
+ def _spacy_similarity(self, text1: str, text2: str) -> float:
89
+ doc1 = self.nlp(text1)
90
+ doc2 = self.nlp(text2)
91
+ print('doc1:', doc1)
92
+ print('doc2:', doc2)
93
+ print('doc1.similarity(doc2):', doc1.similarity(doc2))
94
+ return doc1.similarity(doc2)
95
+
96
+ def prepare_dataset(self, dialogues: list, neg_samples_per_pos=3):
97
+ # Create triplets: (query, positive, negative)
98
+ response_pool = [
99
+ turn['text'] for d in dialogues for turn in d['turns'] if turn['speaker'] == 'assistant'
100
+ ]
101
+ queries, positives, negatives = [], [], []
102
+
103
+ for dialogue in dialogues:
104
+ turns = dialogue['turns']
105
+ for i in range(0, len(turns)-1):
106
+ if turns[i]['speaker'] == 'user' and turns[i+1]['speaker'] == 'assistant':
107
+ q = turns[i]['text']
108
+ p = turns[i+1]['text']
109
+
110
+ # Find negatives using spaCy similarity
111
+ neg_candidates = []
112
+ attempts = 0
113
+ while len(neg_candidates) < neg_samples_per_pos and attempts < 200:
114
+ cand = random.choice(response_pool)
115
+ if cand != p:
116
+ sim = self._spacy_similarity(cand, p)
117
+ # Choose thresholds that produce hard negatives
118
+ if 0.4 < sim < 0.9:
119
+ neg_candidates.append(cand)
120
+ attempts += 1
121
+
122
+ if len(neg_candidates) == neg_samples_per_pos:
123
+ for neg in neg_candidates:
124
+ queries.append(q)
125
+ positives.append(p)
126
+ negatives.append(neg)
127
+
128
+ # Fit tokenizer
129
+ all_text = queries + positives + negatives
130
+ self.tokenizer.fit_on_texts(all_text)
131
+
132
+ def seq_pad(txts):
133
+ seq = self.tokenizer.texts_to_sequences(txts)
134
+ return tf.keras.preprocessing.sequence.pad_sequences(seq, maxlen=self.max_sequence_length, padding='post')
135
+
136
+ q_pad = seq_pad(queries)
137
+ p_pad = seq_pad(positives)
138
+ n_pad = seq_pad(negatives)
139
+
140
+ return q_pad, p_pad, n_pad
141
+
142
+ def triplet_loss(self, q_emb, p_emb, n_emb):
143
+ pos_dist = tf.reduce_sum(tf.square(q_emb - p_emb), axis=1)
144
+ neg_dist = tf.reduce_sum(tf.square(q_emb - n_emb), axis=1)
145
+ loss = tf.maximum(0.0, self.margin + pos_dist - neg_dist)
146
+ return tf.reduce_mean(loss)
147
+
148
+ def train_with_triplet_loss(
149
+ self, q_pad, p_pad, n_pad,
150
+ epochs=3,
151
+ batch_size=16,
152
+ validation_split=0.2,
153
+ early_stopping_patience=3,
154
+ use_tqdm=True
155
+ ):
156
+ train_losses = []
157
+ val_losses = []
158
+
159
+ total_samples = len(q_pad)
160
+ idxs = np.arange(total_samples)
161
+ np.random.shuffle(idxs)
162
+ train_size = int((1 - validation_split) * total_samples)
163
+
164
+ train_idxs = idxs[:train_size]
165
+ val_idxs = idxs[train_size:]
166
+
167
+ q_train, p_train, n_train = q_pad[train_idxs], p_pad[train_idxs], n_pad[train_idxs]
168
+ q_val, p_val, n_val = q_pad[val_idxs], p_pad[val_idxs], n_pad[val_idxs]
169
+
170
+ optimizer = tf.keras.optimizers.Adam(learning_rate=0.001)
171
+ best_val_loss = float('inf')
172
+ wait = 0
173
+
174
+ for epoch in range(epochs):
175
+ # Shuffle training data each epoch
176
+ perm = np.random.permutation(len(q_train))
177
+ q_train, p_train, n_train = q_train[perm], p_train[perm], n_train[perm]
178
+
179
+ num_batches = len(q_train) // batch_size
180
+ epoch_train_loss = 0.0
181
+
182
+ batch_iter = range(num_batches)
183
+ if use_tqdm:
184
+ batch_iter = trange(num_batches, desc=f"Epoch {epoch+1}/{epochs}")
185
+
186
+ for i in batch_iter:
187
+ q_batch = q_train[i*batch_size:(i+1)*batch_size]
188
+ p_batch = p_train[i*batch_size:(i+1)*batch_size]
189
+ n_batch = n_train[i*batch_size:(i+1)*batch_size]
190
+
191
+ with tf.GradientTape() as tape:
192
+ q_emb = self.query_encoder_model(q_batch, training=True)
193
+ p_emb = self.response_encoder_model(p_batch, training=True)
194
+ n_emb = self.response_encoder_model(n_batch, training=True)
195
+ loss = self.triplet_loss(q_emb, p_emb, n_emb)
196
+
197
+ grads = tape.gradient(
198
+ loss,
199
+ self.query_encoder_model.trainable_variables +
200
+ self.response_encoder_model.trainable_variables
201
+ )
202
+ optimizer.apply_gradients(zip(
203
+ grads,
204
+ self.query_encoder_model.trainable_variables +
205
+ self.response_encoder_model.trainable_variables
206
+ ))
207
+ epoch_train_loss += loss.numpy()
208
+
209
+ epoch_train_loss /= num_batches
210
+
211
+ # Validation loss
212
+ val_batches = len(q_val) // batch_size
213
+ epoch_val_loss = 0.0
214
+ for i in range(val_batches):
215
+ q_batch = q_val[i*batch_size:(i+1)*batch_size]
216
+ p_batch = p_val[i*batch_size:(i+1)*batch_size]
217
+ n_batch = n_val[i*batch_size:(i+1)*batch_size]
218
+
219
+ q_emb = self.query_encoder_model(q_batch, training=False)
220
+ p_emb = self.response_encoder_model(p_batch, training=False)
221
+ n_emb = self.response_encoder_model(n_batch, training=False)
222
+ v_loss = self.triplet_loss(q_emb, p_emb, n_emb)
223
+ epoch_val_loss += v_loss.numpy()
224
+
225
+ if val_batches > 0:
226
+ epoch_val_loss /= val_batches
227
+
228
+ train_losses.append(epoch_train_loss)
229
+ val_losses.append(epoch_val_loss)
230
+
231
+ print(f"Epoch {epoch+1}/{epochs}, Train Loss: {epoch_train_loss:.4f}, Val Loss: {epoch_val_loss:.4f}")
232
+
233
+ # Early Stopping logic
234
+ if epoch_val_loss < best_val_loss:
235
+ best_val_loss = epoch_val_loss
236
+ wait = 0
237
+ # (Optional) Save best weights
238
+ else:
239
+ wait += 1
240
+ if wait >= early_stopping_patience:
241
+ print("Early stopping triggered.")
242
+ break
243
+
244
+ return train_losses, val_losses
245
+
246
+ def encode_texts(self, texts, is_query=True):
247
+ seq = self.tokenizer.texts_to_sequences(texts)
248
+ pad_seq = tf.keras.preprocessing.sequence.pad_sequences(seq, maxlen=self.max_sequence_length, padding='post')
249
+ if is_query:
250
+ return self.query_encoder_model(pad_seq, training=False)
251
+ else:
252
+ return self.response_encoder_model(pad_seq, training=False)
253
+
254
+ def retrieve_top_n(self, query: str, candidates: list, top_n=5):
255
+ q_emb = self.encode_texts([query], is_query=True) # shape (1, d)
256
+ c_emb = self.encode_texts(candidates, is_query=False) # shape (num_cand, d)
257
+ sim = tf.matmul(q_emb, c_emb, transpose_b=True).numpy()[0] # dot product similarity
258
+ top_indices = np.argsort(sim)[::-1][:top_n]
259
+ return [(candidates[i], sim[i]) for i in top_indices]
260
+
261
+
chatbot2.py ADDED
@@ -0,0 +1,839 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import tensorflow as tf
3
+ import spacy
4
+ import random
5
+ from typing import List, Tuple, Dict, Optional, Union
6
+ from dataclasses import dataclass
7
+ from tqdm import tqdm
8
+ import logging
9
+ from pathlib import Path
10
+ import json
11
+
12
+ # Configure logging
13
+ logging.basicConfig(
14
+ level=logging.INFO,
15
+ format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
16
+ )
17
+ logger = logging.getLogger(__name__)
18
+
19
+ @dataclass
20
+ class ChatbotConfig:
21
+ """Configuration for the retrieval chatbot."""
22
+ vocab_size: int = 10000
23
+ max_sequence_length: int = 512
24
+ embedding_dim: int = 256
25
+ encoder_units: int = 256
26
+ num_attention_heads: int = 8
27
+ dropout_rate: float = 0.2
28
+ l2_reg_weight: float = 0.001
29
+ margin: float = 0.3
30
+ learning_rate: float = 0.001
31
+ min_text_length: int = 3 # Reduced from 10 to allow shorter responses
32
+ max_context_turns: int = 5
33
+ warmup_steps: int = 200
34
+ spacy_model: str = 'en_core_web_md'
35
+
36
+ def to_dict(self) -> dict:
37
+ """Convert config to dictionary."""
38
+ return {k: str(v) if isinstance(v, Path) else v
39
+ for k, v in self.__dict__.items()}
40
+
41
+ @classmethod
42
+ def from_dict(cls, config_dict: dict) -> 'ChatbotConfig':
43
+ """Create config from dictionary."""
44
+ return cls(**{k: v for k, v in config_dict.items()
45
+ if k in cls.__dataclass_fields__})
46
+
47
+ class TransformerBlock(tf.keras.layers.Layer):
48
+ """Custom Transformer block with pre-layer normalization."""
49
+ def __init__(
50
+ self,
51
+ embed_dim: int,
52
+ num_heads: int,
53
+ ff_dim: int,
54
+ dropout: float = 0.1,
55
+ **kwargs
56
+ ):
57
+ super().__init__(**kwargs)
58
+ self.embed_dim = embed_dim
59
+ self.num_heads = num_heads
60
+ self.ff_dim = ff_dim
61
+ self.dropout = dropout
62
+
63
+ self.attention = tf.keras.layers.MultiHeadAttention(
64
+ num_heads=num_heads,
65
+ key_dim=embed_dim // num_heads,
66
+ dropout=dropout
67
+ )
68
+ self.ffn = tf.keras.Sequential([
69
+ tf.keras.layers.Dense(ff_dim, activation="gelu"),
70
+ tf.keras.layers.Dense(embed_dim),
71
+ ])
72
+
73
+ self.layernorm1 = tf.keras.layers.LayerNormalization(epsilon=1e-6)
74
+ self.layernorm2 = tf.keras.layers.LayerNormalization(epsilon=1e-6)
75
+ self.dropout1 = tf.keras.layers.Dropout(dropout)
76
+ self.dropout2 = tf.keras.layers.Dropout(dropout)
77
+
78
+ def call(self, inputs: tf.Tensor, training: bool, mask: Optional[tf.Tensor] = None) -> tf.Tensor:
79
+ # Pre-layer normalization
80
+ norm_inputs = self.layernorm1(inputs)
81
+
82
+ # Self-attention
83
+ attention_output = self.attention(
84
+ query=norm_inputs,
85
+ value=norm_inputs,
86
+ key=norm_inputs,
87
+ attention_mask=mask,
88
+ training=training
89
+ )
90
+ attention_output = self.dropout1(attention_output, training=training)
91
+ attention_output = inputs + attention_output
92
+
93
+ # Feed-forward network
94
+ norm_attention = self.layernorm2(attention_output)
95
+ ffn_output = self.ffn(norm_attention)
96
+ ffn_output = self.dropout2(ffn_output, training=training)
97
+
98
+ return attention_output + ffn_output
99
+
100
+ def get_config(self) -> dict:
101
+ config = super().get_config()
102
+ config.update({
103
+ "embed_dim": self.embed_dim,
104
+ "num_heads": self.num_heads,
105
+ "ff_dim": self.ff_dim,
106
+ "dropout": self.dropout,
107
+ })
108
+ return config
109
+
110
+ class EncoderModel(tf.keras.Model):
111
+ """Dual encoder model with shared weights option."""
112
+ def __init__(
113
+ self,
114
+ config: ChatbotConfig,
115
+ name: str = "encoder",
116
+ shared_weights: bool = False,
117
+ **kwargs
118
+ ):
119
+ super().__init__(name=name, **kwargs)
120
+ self.config = config
121
+ self.shared_weights = shared_weights
122
+
123
+ # Input embedding layer
124
+ self.embedding = tf.keras.layers.Embedding(
125
+ config.vocab_size,
126
+ config.embedding_dim,
127
+ mask_zero=True,
128
+ name=f"{name}_embedding"
129
+ )
130
+
131
+ # Positional encoding
132
+ self.pos_encoding = self._get_positional_encoding()
133
+
134
+ # Transformer blocks
135
+ self.transformer_blocks = [
136
+ TransformerBlock(
137
+ config.embedding_dim,
138
+ config.num_attention_heads,
139
+ config.encoder_units * 4,
140
+ config.dropout_rate,
141
+ name=f"{name}_transformer_{i}"
142
+ ) for i in range(3)
143
+ ]
144
+
145
+ # Final LSTM layer
146
+ self.final_lstm = tf.keras.layers.LSTM(
147
+ config.encoder_units,
148
+ kernel_regularizer=tf.keras.regularizers.l2(config.l2_reg_weight),
149
+ name=f"{name}_final_lstm"
150
+ )
151
+
152
+ self.dropout = tf.keras.layers.Dropout(config.dropout_rate)
153
+ self.normalize = tf.keras.layers.Lambda(
154
+ lambda x: tf.nn.l2_normalize(x, axis=1)
155
+ )
156
+
157
+ def _get_positional_encoding(self) -> tf.Tensor:
158
+ """Generate positional encoding matrix."""
159
+ pos = np.arange(self.config.max_sequence_length)[:, np.newaxis]
160
+ i = np.arange(self.config.embedding_dim)[np.newaxis, :]
161
+ angle = pos / np.power(10000, (2 * (i // 2)) / self.config.embedding_dim)
162
+
163
+ pos_encoding = np.zeros_like(angle)
164
+ pos_encoding[:, 0::2] = np.sin(angle[:, 0::2])
165
+ pos_encoding[:, 1::2] = np.cos(angle[:, 1::2])
166
+
167
+ return tf.cast(pos_encoding[np.newaxis, ...], dtype=tf.float32)
168
+
169
+ def call(self, inputs: tf.Tensor, training: bool = False) -> tf.Tensor:
170
+ # Get input mask
171
+ mask = self.embedding.compute_mask(inputs)
172
+ mask = mask[:, tf.newaxis, tf.newaxis, :] # Add attention dims
173
+
174
+ # Embedding + positional encoding
175
+ x = self.embedding(inputs)
176
+ x = x + self.pos_encoding
177
+
178
+ # Apply transformer blocks
179
+ for transformer_block in self.transformer_blocks:
180
+ x = transformer_block(x, training=training, mask=mask)
181
+
182
+ # Final processing
183
+ x = self.final_lstm(x)
184
+ x = self.dropout(x, training=training)
185
+ return self.normalize(x)
186
+
187
+ class RetrievalChatbot:
188
+ """Professional implementation of a retrieval-based chatbot."""
189
+ def __init__(self, config: ChatbotConfig):
190
+ self.config = config
191
+ self.nlp = spacy.load(config.spacy_model)
192
+
193
+ # Initialize tokenizer
194
+ self.tokenizer = tf.keras.preprocessing.text.Tokenizer(
195
+ num_words=config.vocab_size,
196
+ oov_token="<OOV>"
197
+ )
198
+
199
+ # Special tokens
200
+ self.special_tokens = {
201
+ "user": "<USER>",
202
+ "assistant": "<ASSISTANT>",
203
+ "context": "<CONTEXT>",
204
+ "sep": "<SEP>"
205
+ }
206
+
207
+ # Build models
208
+ self._build_models()
209
+
210
+ # Training history
211
+ self.history = {
212
+ "train_loss": [],
213
+ "val_loss": [],
214
+ "train_metrics": {},
215
+ "val_metrics": {}
216
+ }
217
+
218
+ # Initialize similarity cache
219
+ self.similarity_cache = {}
220
+
221
+ def _build_models(self):
222
+ """Initialize the encoder models."""
223
+ # Query encoder
224
+ self.query_encoder = EncoderModel(
225
+ self.config,
226
+ name="query_encoder",
227
+ shared_weights=False
228
+ )
229
+
230
+ # Response encoder (can share weights with query encoder)
231
+ self.response_encoder = EncoderModel(
232
+ self.config,
233
+ name="response_encoder",
234
+ shared_weights=False
235
+ )
236
+
237
+ def save_models(self, save_dir: Union[str, Path]):
238
+ """Save models and configuration."""
239
+ save_dir = Path(save_dir)
240
+ save_dir.mkdir(parents=True, exist_ok=True)
241
+
242
+ # Save config
243
+ with open(save_dir / "config.json", "w") as f:
244
+ json.dump(self.config.to_dict(), f, indent=2)
245
+
246
+ # Save models with proper extension
247
+ self.query_encoder.save(save_dir / "query_encoder.keras")
248
+ self.response_encoder.save(save_dir / "response_encoder.keras")
249
+
250
+ # Save tokenizer config
251
+ tokenizer_config = {
252
+ "word_index": self.tokenizer.word_index,
253
+ "word_counts": self.tokenizer.word_counts,
254
+ "document_count": self.tokenizer.document_count,
255
+ "index_docs": self.tokenizer.index_docs,
256
+ "index_word": self.tokenizer.index_word
257
+ }
258
+ with open(save_dir / "tokenizer_config.json", "w") as f:
259
+ json.dump(tokenizer_config, f)
260
+
261
+ @classmethod
262
+ def load_models(cls, load_dir: Union[str, Path]) -> 'RetrievalChatbot':
263
+ """Load saved models and configuration."""
264
+ load_dir = Path(load_dir)
265
+
266
+ # Load config
267
+ with open(load_dir / "config.json", "r") as f:
268
+ config = ChatbotConfig.from_dict(json.load(f))
269
+
270
+ # Initialize chatbot
271
+ chatbot = cls(config)
272
+
273
+ # Load models with proper extension
274
+ chatbot.query_encoder = tf.keras.models.load_model(
275
+ load_dir / "query_encoder.keras",
276
+ custom_objects={"TransformerBlock": TransformerBlock}
277
+ )
278
+ chatbot.response_encoder = tf.keras.models.load_model(
279
+ load_dir / "response_encoder.keras",
280
+ custom_objects={"TransformerBlock": TransformerBlock}
281
+ )
282
+
283
+ # Load tokenizer config
284
+ with open(load_dir / "tokenizer_config.json", "r") as f:
285
+ tokenizer_config = json.load(f)
286
+
287
+ chatbot.tokenizer = tf.keras.preprocessing.text.Tokenizer(
288
+ num_words=config.vocab_size,
289
+ oov_token="<OOV>"
290
+ )
291
+ chatbot.tokenizer.word_index = tokenizer_config["word_index"]
292
+ chatbot.tokenizer.word_counts = tokenizer_config["word_counts"]
293
+ chatbot.tokenizer.document_count = tokenizer_config["document_count"]
294
+ chatbot.tokenizer.index_docs = tokenizer_config["index_docs"]
295
+ chatbot.tokenizer.index_word = tokenizer_config["index_word"]
296
+
297
+ return chatbot
298
+
299
+ def _improved_spacy_similarity(self, text1: str, text2: str) -> float:
300
+ """Calculate semantic similarity between texts with preprocessing."""
301
+ def preprocess(text: str) -> str:
302
+ # Basic cleaning
303
+ text = ' '.join(text.split())
304
+ return text if text.strip() else "empty_document"
305
+
306
+ # Get cache key
307
+ cache_key = f"{hash(text1)}_{hash(text2)}"
308
+ if cache_key in self.similarity_cache:
309
+ return self.similarity_cache[cache_key]
310
+
311
+ # Process texts
312
+ text1, text2 = preprocess(text1), preprocess(text2)
313
+ doc1, doc2 = self.nlp(text1), self.nlp(text2)
314
+
315
+ # Calculate similarity
316
+ if doc1.has_vector and doc2.has_vector:
317
+ sim = doc1.similarity(doc2)
318
+ else:
319
+ # Fallback to token overlap similarity
320
+ tokens1 = {t.lower_ for t in doc1 if not t.is_stop and not t.is_punct}
321
+ tokens2 = {t.lower_ for t in doc2 if not t.is_stop and not t.is_punct}
322
+ intersection = len(tokens1.intersection(tokens2))
323
+ union = len(tokens1.union(tokens2))
324
+ sim = intersection / union if union > 0 else 0.0
325
+
326
+ # Cache result
327
+ self.similarity_cache[cache_key] = sim
328
+ return sim
329
+
330
+ def _smart_negative_sampling(
331
+ self,
332
+ positive: str,
333
+ response_pool: List[str],
334
+ n_samples: int,
335
+ max_attempts: int = 200,
336
+ similarity_bounds: Tuple[float, float] = (0.3, 0.8),
337
+ batch_size: int = 10
338
+ ) -> List[str]:
339
+ """Smart negative sampling with similarity bounds and batching."""
340
+ candidates = []
341
+ seen = set()
342
+ attempts = 0
343
+
344
+ while len(candidates) < n_samples and attempts < max_attempts:
345
+ # Batch process candidates
346
+ batch = random.sample(
347
+ response_pool,
348
+ min(batch_size, max_attempts - attempts)
349
+ )
350
+
351
+ for candidate in batch:
352
+ if candidate != positive and candidate not in seen:
353
+ seen.add(candidate)
354
+ sim = self._improved_spacy_similarity(candidate, positive)
355
+
356
+ # Check similarity bounds
357
+ if similarity_bounds[0] < sim < similarity_bounds[1]:
358
+ candidates.append(candidate)
359
+ if len(candidates) == n_samples:
360
+ break
361
+
362
+ attempts += len(batch)
363
+
364
+ return candidates
365
+
366
+ def train(
367
+ self,
368
+ q_pad: tf.Tensor,
369
+ p_pad: tf.Tensor,
370
+ n_pad: tf.Tensor,
371
+ epochs: int = 3,
372
+ batch_size: int = 32,
373
+ validation_split: float = 0.2,
374
+ checkpoint_dir: Optional[Union[str, Path]] = None
375
+ ):
376
+ """Train the model with improved training loop."""
377
+ # Setup training
378
+ total_samples = len(q_pad)
379
+ train_size = int((1 - validation_split) * total_samples)
380
+
381
+ # Split data
382
+ indices = np.random.permutation(total_samples)
383
+ train_idx, val_idx = indices[:train_size], indices[train_size:]
384
+
385
+ train_data = (q_pad[train_idx], p_pad[train_idx], n_pad[train_idx])
386
+ val_data = (q_pad[val_idx], p_pad[val_idx], n_pad[val_idx])
387
+
388
+ # Setup optimizer with learning rate schedule
389
+ steps_per_epoch = train_size // batch_size
390
+ total_steps = steps_per_epoch * epochs
391
+
392
+ lr_schedule = self._get_lr_schedule(
393
+ total_steps,
394
+ self.config.learning_rate,
395
+ self.config.warmup_steps
396
+ )
397
+
398
+ optimizer = tf.keras.optimizers.Adam(lr_schedule)
399
+
400
+ # Setup checkpointing
401
+ if checkpoint_dir:
402
+ checkpoint_dir = Path(checkpoint_dir)
403
+ checkpoint_dir.mkdir(parents=True, exist_ok=True)
404
+
405
+ # Setup checkpoint callback with correct file format
406
+ checkpoint_template = str(checkpoint_dir / "model_epoch_{epoch:04d}.weights.h5")
407
+ checkpoint_callback = tf.keras.callbacks.ModelCheckpoint(
408
+ checkpoint_template,
409
+ save_weights_only=True,
410
+ save_best_only=True,
411
+ monitor='val_loss',
412
+ mode='min',
413
+ verbose=1
414
+ )
415
+
416
+ # Training loop
417
+ best_val_loss = float('inf')
418
+ patience = 5
419
+ wait = 0
420
+
421
+ for epoch in range(epochs):
422
+ # Training
423
+ train_loss = self._train_epoch(
424
+ train_data,
425
+ optimizer,
426
+ batch_size,
427
+ training=True
428
+ )
429
+
430
+ # Validation
431
+ val_loss = self._train_epoch(
432
+ val_data,
433
+ optimizer,
434
+ batch_size,
435
+ training=False
436
+ )
437
+
438
+ # Update history
439
+ self.history['train_loss'].append(train_loss)
440
+ self.history['val_loss'].append(val_loss)
441
+
442
+ logger.info(
443
+ f"Epoch {epoch + 1}/{epochs} - "
444
+ f"train_loss: {train_loss:.4f} - "
445
+ f"val_loss: {val_loss:.4f}"
446
+ )
447
+
448
+ # Early stopping
449
+ if val_loss < best_val_loss:
450
+ best_val_loss = val_loss
451
+ wait = 0
452
+ if checkpoint_dir:
453
+ self.save_models(checkpoint_dir / f"best_model")
454
+ else:
455
+ wait += 1
456
+ if wait >= patience:
457
+ logger.info("Early stopping triggered")
458
+ break
459
+
460
+ def _get_lr_schedule(
461
+ self,
462
+ total_steps: int,
463
+ peak_lr: float,
464
+ warmup_steps: int
465
+ ) -> tf.keras.optimizers.schedules.LearningRateSchedule:
466
+ """Enhanced learning rate schedule with better error handling and logging."""
467
+ class CustomSchedule(tf.keras.optimizers.schedules.LearningRateSchedule):
468
+ def __init__(
469
+ self,
470
+ total_steps: int,
471
+ peak_lr: float,
472
+ warmup_steps: int
473
+ ):
474
+ super().__init__()
475
+ self.total_steps = tf.cast(total_steps, tf.float32)
476
+ self.peak_lr = tf.cast(peak_lr, tf.float32)
477
+ self.warmup_steps = tf.cast(max(1, warmup_steps), tf.float32) # Prevent 0
478
+
479
+ # Calculate and store constants
480
+ self.initial_lr = self.peak_lr * 0.1 # Start at 10% of peak
481
+ self.min_lr = self.peak_lr * 0.01 # Minimum 1% of peak
482
+
483
+ logger.info(f"Learning rate schedule initialized:")
484
+ logger.info(f" Initial LR: {float(self.initial_lr):.6f}")
485
+ logger.info(f" Peak LR: {float(self.peak_lr):.6f}")
486
+ logger.info(f" Min LR: {float(self.min_lr):.6f}")
487
+ logger.info(f" Warmup steps: {int(self.warmup_steps)}")
488
+ logger.info(f" Total steps: {int(self.total_steps)}")
489
+
490
+ def __call__(self, step):
491
+ step = tf.cast(step, tf.float32)
492
+
493
+ # Warmup phase
494
+ warmup_factor = tf.minimum(1.0, step / self.warmup_steps)
495
+ warmup_lr = self.initial_lr + (self.peak_lr - self.initial_lr) * warmup_factor
496
+
497
+ # Decay phase
498
+ decay_steps = tf.maximum(1.0, self.total_steps - self.warmup_steps)
499
+ decay_factor = (step - self.warmup_steps) / decay_steps
500
+ decay_factor = tf.minimum(tf.maximum(0.0, decay_factor), 1.0) # Clip to [0,1]
501
+
502
+ cosine_decay = 0.5 * (1.0 + tf.cos(tf.constant(np.pi) * decay_factor))
503
+ decay_lr = self.min_lr + (self.peak_lr - self.min_lr) * cosine_decay
504
+
505
+ # Choose between warmup and decay
506
+ final_lr = tf.where(step < self.warmup_steps, warmup_lr, decay_lr)
507
+
508
+ # Ensure learning rate is valid
509
+ final_lr = tf.maximum(self.min_lr, final_lr)
510
+ final_lr = tf.where(tf.math.is_finite(final_lr), final_lr, self.min_lr)
511
+
512
+ return final_lr
513
+
514
+ def get_config(self):
515
+ return {
516
+ "total_steps": self.total_steps,
517
+ "peak_lr": self.peak_lr,
518
+ "warmup_steps": self.warmup_steps,
519
+ }
520
+
521
+ return CustomSchedule(total_steps, peak_lr, warmup_steps)
522
+
523
+ @tf.function
524
+ def _train_step(
525
+ self,
526
+ q_batch: tf.Tensor,
527
+ p_batch: tf.Tensor,
528
+ n_batch: tf.Tensor,
529
+ optimizer: tf.keras.optimizers.Optimizer,
530
+ training: bool = True
531
+ ) -> tf.Tensor:
532
+ """Single training step with triplet loss."""
533
+ with tf.GradientTape() as tape:
534
+ # Get embeddings
535
+ q_emb = self.query_encoder(q_batch, training=training)
536
+ p_emb = self.response_encoder(p_batch, training=training)
537
+ n_emb = self.response_encoder(n_batch, training=training)
538
+
539
+ # Calculate triplet loss
540
+ pos_dist = tf.reduce_sum(tf.square(q_emb - p_emb), axis=1)
541
+ neg_dist = tf.reduce_sum(tf.square(q_emb - n_emb), axis=1)
542
+
543
+ loss = tf.maximum(0.0, self.config.margin + pos_dist - neg_dist)
544
+ loss = tf.reduce_mean(loss)
545
+
546
+ if training:
547
+ # Apply gradients
548
+ gradients = tape.gradient(
549
+ loss,
550
+ self.query_encoder.trainable_variables +
551
+ self.response_encoder.trainable_variables
552
+ )
553
+ optimizer.apply_gradients(zip(
554
+ gradients,
555
+ self.query_encoder.trainable_variables +
556
+ self.response_encoder.trainable_variables
557
+ ))
558
+
559
+ return loss
560
+
561
+ def _train_epoch(
562
+ self,
563
+ data: Tuple[tf.Tensor, tf.Tensor, tf.Tensor],
564
+ optimizer: tf.keras.optimizers.Optimizer,
565
+ batch_size: int,
566
+ training: bool = True
567
+ ) -> float:
568
+ """Train for one epoch with enhanced logging and progress tracking."""
569
+ q_data, p_data, n_data = data
570
+ total_loss = 0
571
+ num_batches = len(q_data) // batch_size
572
+
573
+ # Log current learning rate at start of epoch
574
+ if training:
575
+ if hasattr(optimizer.learning_rate, '__call__'):
576
+ current_lr = optimizer.learning_rate(optimizer.iterations)
577
+ else:
578
+ current_lr = optimizer.learning_rate
579
+ logger.info(f"Current learning rate: {float(current_lr):.6f}")
580
+
581
+ # Shuffle data
582
+ indices = np.random.permutation(len(q_data))
583
+ q_data = q_data[indices]
584
+ p_data = p_data[indices]
585
+ n_data = n_data[indices]
586
+
587
+ # Create progress bar
588
+ mode = "Training" if training else "Validation"
589
+ pbar = tqdm(
590
+ total=num_batches,
591
+ desc=f"{mode} batches",
592
+ unit="batch",
593
+ dynamic_ncols=True # Automatically adjust width
594
+ )
595
+
596
+ # Process batches
597
+ for i in range(num_batches):
598
+ start_idx = i * batch_size
599
+ end_idx = start_idx + batch_size
600
+
601
+ batch_loss = self._train_step(
602
+ q_data[start_idx:end_idx],
603
+ p_data[start_idx:end_idx],
604
+ n_data[start_idx:end_idx],
605
+ optimizer,
606
+ training
607
+ )
608
+ total_loss += batch_loss
609
+
610
+ # Update progress bar with current loss
611
+ avg_loss = total_loss / (i + 1)
612
+ pbar.set_postfix({
613
+ 'loss': f'{avg_loss:.4f}',
614
+ 'lr': f'{float(current_lr):.6f}' if training else 'N/A'
615
+ })
616
+ pbar.update(1)
617
+
618
+ pbar.close()
619
+ return total_loss / num_batches if num_batches > 0 else 0
620
+
621
+ def _prepare_sequences(
622
+ self,
623
+ queries: List[str],
624
+ positives: List[str],
625
+ negatives: List[str]
626
+ ) -> Tuple[tf.Tensor, tf.Tensor, tf.Tensor]:
627
+ """Enhanced sequence preparation with logging and text preprocessing."""
628
+ logger.info("Preparing sequences...")
629
+
630
+ # Text cleaning function from old version
631
+ def clean_text(text: str) -> str:
632
+ # Remove excessive whitespace
633
+ text = ' '.join(text.split())
634
+ # Remove very long repetitive sequences
635
+ if len(text) > 500: # Add length limit
636
+ text = ' '.join(dict.fromkeys(text.split()))
637
+ return text
638
+
639
+ # Process texts with special tokens and cleaning
640
+ queries = [f"{self.special_tokens['user']} {clean_text(q)}" for q in queries]
641
+ positives = [f"{self.special_tokens['assistant']} {clean_text(p)}" for p in positives]
642
+ negatives = [f"{self.special_tokens['assistant']} {clean_text(n)}" for n in negatives]
643
+
644
+ # Fit tokenizer and log vocabulary statistics
645
+ all_texts = queries + positives + negatives
646
+ self.tokenizer.fit_on_texts(all_texts)
647
+
648
+ # Log vocabulary statistics
649
+ vocab_size = len(self.tokenizer.word_index)
650
+ logger.info(f"Vocabulary statistics:")
651
+ logger.info(f" Total unique tokens: {vocab_size}")
652
+ logger.info(f" Vocab limit: {self.config.vocab_size}")
653
+
654
+ # Log most common tokens
655
+ word_freq = sorted(
656
+ self.tokenizer.word_counts.items(),
657
+ key=lambda x: x[1],
658
+ reverse=True
659
+ )[:10]
660
+ logger.info("Most common tokens:")
661
+ for word, freq in word_freq:
662
+ logger.info(f" {word}: {freq}")
663
+
664
+ # Padding function from old version
665
+ def pad_sequences(texts: List[str]) -> tf.Tensor:
666
+ sequences = self.tokenizer.texts_to_sequences(texts)
667
+ return tf.keras.preprocessing.sequence.pad_sequences(
668
+ sequences,
669
+ maxlen=self.config.max_sequence_length,
670
+ padding='post',
671
+ truncating='post'
672
+ )
673
+
674
+ # Return padded sequences
675
+ return (
676
+ pad_sequences(queries),
677
+ pad_sequences(positives),
678
+ pad_sequences(negatives)
679
+ )
680
+
681
+ def prepare_dataset(
682
+ self,
683
+ dialogues: List[dict],
684
+ neg_samples_per_pos: int = 3,
685
+ debug_samples: Optional[int] = None
686
+ ) -> Tuple[tf.Tensor, tf.Tensor, tf.Tensor]:
687
+ """Prepare dataset with enhanced logging and statistics."""
688
+ logger.info("Preparing dataset...")
689
+
690
+ # Log dataset statistics
691
+ total_dialogues = len(dialogues)
692
+ total_turns = sum(len(d['turns']) for d in dialogues)
693
+ avg_turns = total_turns / total_dialogues
694
+
695
+ logger.info(f"Dataset statistics:")
696
+ logger.info(f" Total dialogues: {total_dialogues}")
697
+ logger.info(f" Total turns: {total_turns}")
698
+ logger.info(f" Average turns per dialogue: {avg_turns:.2f}")
699
+
700
+ # Extract and filter responses with logging
701
+ response_pool = []
702
+ skipped_short = 0
703
+ skipped_long = 0
704
+
705
+ for d in dialogues:
706
+ for turn in d['turns']:
707
+ if turn['speaker'] == 'assistant':
708
+ text = turn['text'].strip()
709
+ length = len(text.split())
710
+ if length < self.config.min_text_length:
711
+ skipped_short += 1
712
+ continue
713
+ if length > self.config.max_sequence_length:
714
+ skipped_long += 1
715
+ continue
716
+ response_pool.append(text)
717
+
718
+ logger.info(f"Response pool statistics:")
719
+ logger.info(f" Total responses: {len(response_pool)}")
720
+ logger.info(f" Skipped (too short): {skipped_short}")
721
+ logger.info(f" Skipped (too long): {skipped_long}")
722
+
723
+ # Process dialogues and create training examples
724
+ queries, positives, negatives = [], [], []
725
+
726
+ for dialogue in tqdm(dialogues, desc="Processing dialogues"):
727
+ turns = dialogue['turns']
728
+ for i in range(len(turns) - 1):
729
+ if turns[i]['speaker'] == 'user' and turns[i+1]['speaker'] == 'assistant':
730
+ query = turns[i]['text'].strip()
731
+ positive = turns[i+1]['text'].strip()
732
+
733
+ # Skip short texts
734
+ if (len(query.split()) < self.config.min_text_length or
735
+ len(positive.split()) < self.config.min_text_length): # Fixed
736
+ continue
737
+
738
+ # Get negative samples
739
+ neg_samples = self._smart_negative_sampling(
740
+ positive,
741
+ response_pool,
742
+ neg_samples_per_pos
743
+ )
744
+
745
+ if len(neg_samples) == neg_samples_per_pos:
746
+ for neg in neg_samples:
747
+ queries.append(query)
748
+ positives.append(positive)
749
+ negatives.append(neg)
750
+
751
+ # Log final dataset statistics
752
+ logger.info(f"Final dataset statistics:")
753
+ logger.info(f" Training examples: {len(queries)}")
754
+ logger.info(f" Unique queries: {len(set(queries))}")
755
+ logger.info(f" Unique responses: {len(set(positives))}")
756
+
757
+ return self._prepare_sequences(queries, positives, negatives)
758
+
759
+ def encode_query(self, query: str, context: Optional[List[Tuple[str, str]]] = None) -> tf.Tensor:
760
+ """Encode a query with optional conversation context."""
761
+ # Prepare query with context
762
+ if context:
763
+ context_str = ' '.join([
764
+ f"{self.special_tokens['user']} {q} "
765
+ f"{self.special_tokens['assistant']} {r}"
766
+ for q, r in context[-self.config.max_context_turns:]
767
+ ])
768
+ query = f"{context_str} {self.special_tokens['user']} {query}"
769
+ else:
770
+ query = f"{self.special_tokens['user']} {query}"
771
+
772
+ # Tokenize and pad
773
+ seq = self.tokenizer.texts_to_sequences([query])
774
+ padded_seq = tf.keras.preprocessing.sequence.pad_sequences(
775
+ seq,
776
+ maxlen=self.config.max_sequence_length,
777
+ padding='post',
778
+ truncating='post'
779
+ )
780
+
781
+ return self.query_encoder(padded_seq, training=False)
782
+
783
+ def encode_responses(self, responses: List[str]) -> tf.Tensor:
784
+ """Encode a batch of responses."""
785
+ # Prepare responses
786
+ responses = [
787
+ f"{self.special_tokens['assistant']} {r}"
788
+ for r in responses
789
+ ]
790
+
791
+ # Tokenize and pad
792
+ sequences = self.tokenizer.texts_to_sequences(responses)
793
+ padded_sequences = tf.keras.preprocessing.sequence.pad_sequences(
794
+ sequences,
795
+ maxlen=self.config.max_sequence_length,
796
+ padding='post',
797
+ truncating='post'
798
+ )
799
+
800
+ return self.response_encoder(padded_sequences, training=False)
801
+
802
+ def retrieve_responses(
803
+ self,
804
+ query: str,
805
+ candidates: List[str],
806
+ context: Optional[List[Tuple[str, str]]] = None,
807
+ top_k: int = 5
808
+ ) -> List[Tuple[str, float]]:
809
+ """Retrieve top-k responses for a query."""
810
+ # Encode query and candidates
811
+ q_emb = self.encode_query(query, context)
812
+ c_emb = self.encode_responses(candidates)
813
+
814
+ # Calculate similarities
815
+ similarities = tf.matmul(q_emb, c_emb, transpose_b=True).numpy()[0]
816
+
817
+ # Get top-k responses
818
+ top_indices = np.argsort(similarities)[::-1][:top_k]
819
+
820
+ return [(candidates[i], similarities[i]) for i in top_indices]
821
+
822
+ def chat(
823
+ self,
824
+ query: str,
825
+ response_pool: List[str],
826
+ conversation_history: Optional[List[Tuple[str, str]]] = None,
827
+ top_k: int = 5
828
+ ) -> Tuple[str, List[Tuple[str, float]]]:
829
+ """Interactive chat with response selection."""
830
+ # Get responses with scores
831
+ responses = self.retrieve_responses(
832
+ query,
833
+ response_pool,
834
+ conversation_history,
835
+ top_k
836
+ )
837
+
838
+ # Return best response and all candidates with scores
839
+ return responses[0][0], responses
chatbot3.py ADDED
@@ -0,0 +1,824 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import TFAutoModel, AutoTokenizer
2
+ import tensorflow as tf
3
+ import numpy as np
4
+ from typing import List, Tuple, Dict, Optional, Union
5
+ from dataclasses import dataclass
6
+ import logging
7
+ import spacy
8
+ import random
9
+ import json
10
+ from tqdm import tqdm
11
+ from pathlib import Path
12
+
13
+ # Configure logging
14
+ logging.basicConfig(
15
+ level=logging.INFO,
16
+ format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
17
+ )
18
+ logger = logging.getLogger(__name__)
19
+
20
+ @dataclass
21
+ class ChatbotConfig:
22
+ """Enhanced configuration with pretrained model settings."""
23
+ vocab_size: int = 10000
24
+ max_sequence_length: int = 512
25
+ embedding_dim: int = 768 # Match DistilBERT's dimension
26
+ encoder_units: int = 256
27
+ num_attention_heads: int = 8
28
+ dropout_rate: float = 0.2
29
+ l2_reg_weight: float = 0.001
30
+ margin: float = 0.3
31
+ learning_rate: float = 0.001
32
+ min_text_length: int = 3
33
+ max_context_turns: int = 5
34
+ warmup_steps: int = 200
35
+ pretrained_model: str = 'distilbert-base-uncased'
36
+ freeze_embeddings: bool = True
37
+ spacy_model: str = 'en_core_web_md'
38
+
39
+ def to_dict(self) -> dict:
40
+ """Convert config to dictionary."""
41
+ return {k: str(v) if isinstance(v, Path) else v
42
+ for k, v in self.__dict__.items()}
43
+
44
+ @classmethod
45
+ def from_dict(cls, config_dict: dict) -> 'ChatbotConfig':
46
+ """Create config from dictionary."""
47
+ return cls(**{k: v for k, v in config_dict.items()
48
+ if k in cls.__dataclass_fields__})
49
+
50
+ class TransformerBlock(tf.keras.layers.Layer):
51
+ """Custom Transformer block with pre-layer normalization."""
52
+ def __init__(
53
+ self,
54
+ embed_dim: int,
55
+ num_heads: int,
56
+ ff_dim: int,
57
+ dropout: float = 0.1,
58
+ **kwargs
59
+ ):
60
+ super().__init__(**kwargs)
61
+ self.embed_dim = embed_dim
62
+ self.num_heads = num_heads
63
+ self.ff_dim = ff_dim
64
+ self.dropout = dropout
65
+
66
+ self.attention = tf.keras.layers.MultiHeadAttention(
67
+ num_heads=num_heads,
68
+ key_dim=embed_dim // num_heads,
69
+ dropout=dropout
70
+ )
71
+ self.ffn = tf.keras.Sequential([
72
+ tf.keras.layers.Dense(ff_dim, activation="gelu"),
73
+ tf.keras.layers.Dense(embed_dim),
74
+ ])
75
+
76
+ self.layernorm1 = tf.keras.layers.LayerNormalization(epsilon=1e-6)
77
+ self.layernorm2 = tf.keras.layers.LayerNormalization(epsilon=1e-6)
78
+ self.dropout1 = tf.keras.layers.Dropout(dropout)
79
+ self.dropout2 = tf.keras.layers.Dropout(dropout)
80
+
81
+ def call(self, inputs: tf.Tensor, training: bool, mask: Optional[tf.Tensor] = None) -> tf.Tensor:
82
+ # Pre-layer normalization
83
+ norm_inputs = self.layernorm1(inputs)
84
+
85
+ # Self-attention
86
+ attention_output = self.attention(
87
+ query=norm_inputs,
88
+ value=norm_inputs,
89
+ key=norm_inputs,
90
+ attention_mask=mask,
91
+ training=training
92
+ )
93
+ attention_output = self.dropout1(attention_output, training=training)
94
+ attention_output = inputs + attention_output
95
+
96
+ # Feed-forward network
97
+ norm_attention = self.layernorm2(attention_output)
98
+ ffn_output = self.ffn(norm_attention)
99
+ ffn_output = self.dropout2(ffn_output, training=training)
100
+
101
+ return attention_output + ffn_output
102
+
103
+ def get_config(self) -> dict:
104
+ config = super().get_config()
105
+ config.update({
106
+ "embed_dim": self.embed_dim,
107
+ "num_heads": self.num_heads,
108
+ "ff_dim": self.ff_dim,
109
+ "dropout": self.dropout,
110
+ })
111
+ return config
112
+
113
+ class EncoderModel(tf.keras.Model):
114
+ """Dual encoder model with pretrained embeddings."""
115
+ def __init__(
116
+ self,
117
+ config: ChatbotConfig,
118
+ name: str = "encoder",
119
+ shared_weights: bool = False,
120
+ **kwargs
121
+ ):
122
+ super().__init__(name=name, **kwargs)
123
+ self.config = config
124
+ self.shared_weights = shared_weights
125
+
126
+ # Load pretrained model and tokenizer
127
+ self.pretrained = TFAutoModel.from_pretrained(config.pretrained_model)
128
+
129
+ # Freeze pretrained weights if specified
130
+ if config.freeze_embeddings:
131
+ self.pretrained.trainable = False
132
+
133
+ # Transformer blocks for additional processing
134
+ self.transformer_blocks = [
135
+ TransformerBlock(
136
+ config.embedding_dim,
137
+ config.num_attention_heads,
138
+ config.encoder_units * 4,
139
+ config.dropout_rate,
140
+ name=f"{name}_transformer_{i}"
141
+ ) for i in range(2) # Reduced number of blocks since we're using pretrained
142
+ ]
143
+
144
+ # Final LSTM layer
145
+ self.final_lstm = tf.keras.layers.LSTM(
146
+ config.encoder_units,
147
+ kernel_regularizer=tf.keras.regularizers.l2(config.l2_reg_weight),
148
+ name=f"{name}_final_lstm"
149
+ )
150
+
151
+ self.dropout = tf.keras.layers.Dropout(config.dropout_rate)
152
+ self.normalize = tf.keras.layers.Lambda(
153
+ lambda x: tf.nn.l2_normalize(x, axis=1)
154
+ )
155
+
156
+ def call(self, inputs: tf.Tensor, training: bool = False) -> tf.Tensor:
157
+ # Get pretrained embeddings
158
+ pretrained_outputs = self.pretrained(inputs, training=training)
159
+ x = pretrained_outputs.last_hidden_state
160
+
161
+ # Get attention mask from input
162
+ attention_mask = tf.cast(tf.not_equal(inputs, 0), tf.float32)
163
+ attention_mask = attention_mask[:, tf.newaxis, tf.newaxis, :]
164
+
165
+ # Apply transformer blocks
166
+ for transformer_block in self.transformer_blocks:
167
+ x = transformer_block(x, training=training, mask=attention_mask)
168
+
169
+ # Final processing
170
+ x = self.final_lstm(x)
171
+ x = self.dropout(x, training=training)
172
+ return self.normalize(x)
173
+
174
+ class RetrievalChatbot:
175
+ """Modified chatbot using pretrained embeddings with full functionality."""
176
+ def __init__(self, config: ChatbotConfig):
177
+ self.config = config
178
+ self.nlp = spacy.load(config.spacy_model)
179
+
180
+ # Use HuggingFace tokenizer instead of Keras
181
+ self.tokenizer = AutoTokenizer.from_pretrained(config.pretrained_model)
182
+
183
+ # Special tokens
184
+ self.special_tokens = {
185
+ "user": "<USER>",
186
+ "assistant": "<ASSISTANT>",
187
+ "context": "<CONTEXT>",
188
+ "sep": "<SEP>"
189
+ }
190
+
191
+ # Add special tokens to tokenizer
192
+ self.tokenizer.add_special_tokens(
193
+ {'additional_special_tokens': list(self.special_tokens.values())}
194
+ )
195
+
196
+ # Build models
197
+ self._build_models()
198
+
199
+ # Initialize training tracking
200
+ self.history = {
201
+ "train_loss": [],
202
+ "val_loss": [],
203
+ "train_metrics": {},
204
+ "val_metrics": {}
205
+ }
206
+
207
+ self.similarity_cache = {}
208
+
209
+ def _build_models(self):
210
+ """Initialize the encoder models."""
211
+ # Query encoder
212
+ self.query_encoder = EncoderModel(
213
+ self.config,
214
+ name="query_encoder",
215
+ shared_weights=False
216
+ )
217
+
218
+ # Response encoder (can share weights with query encoder)
219
+ self.response_encoder = EncoderModel(
220
+ self.config,
221
+ name="response_encoder",
222
+ shared_weights=False
223
+ )
224
+
225
+ # Resize token embeddings to match the tokenizer's vocab size
226
+ new_vocab_size = len(self.tokenizer)
227
+ self.query_encoder.pretrained.resize_token_embeddings(new_vocab_size)
228
+ self.response_encoder.pretrained.resize_token_embeddings(new_vocab_size)
229
+
230
+ def save_models(self, save_dir: Union[str, Path]):
231
+ """Save models and configuration."""
232
+ save_dir = Path(save_dir)
233
+ save_dir.mkdir(parents=True, exist_ok=True)
234
+
235
+ # Save config
236
+ with open(save_dir / "config.json", "w") as f:
237
+ json.dump(self.config.to_dict(), f, indent=2)
238
+
239
+ # Save models
240
+ self.query_encoder.pretrained.save_pretrained(save_dir / "query_encoder")
241
+ self.response_encoder.pretrained.save_pretrained(save_dir / "response_encoder")
242
+
243
+ # Save tokenizer
244
+ self.tokenizer.save_pretrained(save_dir / "tokenizer")
245
+
246
+ @classmethod
247
+ def load_models(cls, load_dir: Union[str, Path]) -> 'RetrievalChatbot':
248
+ """Load saved models and configuration."""
249
+ load_dir = Path(load_dir)
250
+
251
+ # Load config
252
+ with open(load_dir / "config.json", "r") as f:
253
+ config = ChatbotConfig.from_dict(json.load(f))
254
+
255
+ # Initialize chatbot
256
+ chatbot = cls(config)
257
+
258
+ # Load models
259
+ chatbot.query_encoder.pretrained = TFAutoModel.from_pretrained(
260
+ load_dir / "query_encoder",
261
+ config=config
262
+ )
263
+ chatbot.response_encoder.pretrained = TFAutoModel.from_pretrained(
264
+ load_dir / "response_encoder",
265
+ config=config
266
+ )
267
+
268
+ # Load tokenizer
269
+ chatbot.tokenizer = AutoTokenizer.from_pretrained(load_dir / "tokenizer")
270
+
271
+ return chatbot
272
+
273
+ def _improved_spacy_similarity(self, text1: str, text2: str) -> float:
274
+ """Calculate semantic similarity between texts with preprocessing."""
275
+ def preprocess(text: str) -> str:
276
+ # Basic cleaning
277
+ text = ' '.join(text.split())
278
+ return text if text.strip() else "empty_document"
279
+
280
+ # Get cache key
281
+ cache_key = f"{hash(text1)}_{hash(text2)}"
282
+ if cache_key in self.similarity_cache:
283
+ return self.similarity_cache[cache_key]
284
+
285
+ # Process texts
286
+ text1, text2 = preprocess(text1), preprocess(text2)
287
+ doc1, doc2 = self.nlp(text1), self.nlp(text2)
288
+
289
+ # Calculate similarity
290
+ if doc1.has_vector and doc2.has_vector:
291
+ sim = doc1.similarity(doc2)
292
+ else:
293
+ # Fallback to token overlap similarity
294
+ tokens1 = {t.lower_ for t in doc1 if not t.is_stop and not t.is_punct}
295
+ tokens2 = {t.lower_ for t in doc2 if not t.is_stop and not t.is_punct}
296
+ intersection = len(tokens1.intersection(tokens2))
297
+ union = len(tokens1.union(tokens2))
298
+ sim = intersection / union if union > 0 else 0.0
299
+
300
+ # Cache result
301
+ self.similarity_cache[cache_key] = sim
302
+ return sim
303
+
304
+ def prepare_dataset(
305
+ self,
306
+ dialogues: List[dict],
307
+ neg_samples_per_pos: int = 3,
308
+ debug_samples: Optional[int] = None
309
+ ) -> Tuple[tf.Tensor, tf.Tensor, tf.Tensor]:
310
+ """Prepare dataset with enhanced logging and statistics."""
311
+ logger.info("Preparing dataset...")
312
+
313
+ # Apply debug_samples limit if specified
314
+ if debug_samples is not None:
315
+ dialogues = dialogues[:debug_samples]
316
+ logger.info(f"Debug mode: Limited to {debug_samples} dialogues")
317
+
318
+ # Log dataset statistics
319
+ total_dialogues = len(dialogues)
320
+ total_turns = sum(len(d['turns']) for d in dialogues)
321
+ avg_turns = total_turns / total_dialogues if total_dialogues > 0 else 0
322
+
323
+ logger.info(f"Dataset statistics:")
324
+ logger.info(f" Total dialogues: {total_dialogues}")
325
+ logger.info(f" Total turns: {total_turns}")
326
+ logger.info(f" Average turns per dialogue: {avg_turns:.2f}")
327
+
328
+ # Extract and filter responses with logging
329
+ response_pool = []
330
+ skipped_short = 0
331
+ skipped_long = 0
332
+
333
+ for d in dialogues:
334
+ for turn in d['turns']:
335
+ if turn.get('speaker') == 'assistant' and 'text' in turn:
336
+ text = turn['text'].strip()
337
+ length = len(text.split())
338
+ if length < self.config.min_text_length:
339
+ skipped_short += 1
340
+ continue
341
+ if length > self.config.max_sequence_length:
342
+ skipped_long += 1
343
+ continue
344
+ response_pool.append(text)
345
+
346
+ logger.info(f"Response pool statistics:")
347
+ logger.info(f" Total responses: {len(response_pool)}")
348
+ logger.info(f" Skipped (too short): {skipped_short}")
349
+ logger.info(f" Skipped (too long): {skipped_long}")
350
+
351
+ # Process dialogues and create training examples
352
+ queries, positives, negatives = [], [], []
353
+
354
+ for dialogue in tqdm(dialogues, desc="Processing dialogues"):
355
+ turns = dialogue.get('turns', [])
356
+ for i in range(len(turns) - 1):
357
+ current_turn = turns[i]
358
+ next_turn = turns[i+1]
359
+
360
+ if (current_turn.get('speaker') == 'user' and
361
+ next_turn.get('speaker') == 'assistant' and
362
+ 'text' in current_turn and
363
+ 'text' in next_turn):
364
+
365
+ query = current_turn['text'].strip()
366
+ positive = next_turn['text'].strip()
367
+
368
+ # Skip short texts
369
+ if (len(query.split()) < self.config.min_text_length or
370
+ len(positive.split()) < self.config.min_text_length):
371
+ continue
372
+
373
+ # Get negative samples
374
+ neg_samples = self._smart_negative_sampling(
375
+ positive,
376
+ response_pool,
377
+ neg_samples_per_pos
378
+ )
379
+
380
+ if len(neg_samples) == neg_samples_per_pos:
381
+ for neg in neg_samples:
382
+ queries.append(query)
383
+ positives.append(positive)
384
+ negatives.append(neg)
385
+ else:
386
+ logger.warning(f"Insufficient negative samples for positive response: '{positive}'")
387
+
388
+ # Log final dataset statistics
389
+ logger.info(f"Final dataset statistics:")
390
+ logger.info(f" Training examples: {len(queries)}")
391
+ logger.info(f" Unique queries: {len(set(queries))}")
392
+ logger.info(f" Unique responses: {len(set(positives))}")
393
+
394
+ return self._prepare_sequences(queries, positives, negatives)
395
+
396
+ def _smart_negative_sampling(
397
+ self,
398
+ positive: str,
399
+ response_pool: List[str],
400
+ n_samples: int,
401
+ max_attempts: int = 200,
402
+ similarity_bounds: Tuple[float, float] = (0.2, 0.9),
403
+ batch_size: int = 10
404
+ ) -> List[str]:
405
+ """Smart negative sampling with similarity bounds and fallback strategies."""
406
+ candidates = []
407
+ seen = set()
408
+ attempts = 0
409
+
410
+ while len(candidates) < n_samples and attempts < max_attempts:
411
+ remaining = min(batch_size, len(response_pool) - len(seen), max_attempts - attempts)
412
+ if remaining <= 0:
413
+ break
414
+ batch = random.sample(
415
+ [r for r in response_pool if r not in seen and r != positive],
416
+ remaining
417
+ )
418
+
419
+ for candidate in batch:
420
+ seen.add(candidate)
421
+ sim = self._improved_spacy_similarity(candidate, positive)
422
+
423
+ if similarity_bounds[0] < sim < similarity_bounds[1]:
424
+ candidates.append(candidate)
425
+ if len(candidates) == n_samples:
426
+ break
427
+
428
+ attempts += len(batch)
429
+
430
+ if len(candidates) < n_samples:
431
+ logger.warning(f"Only found {len(candidates)} negative samples for positive response: '{positive}'")
432
+ # Fallback to random negatives without similarity constraints
433
+ fallback_needed = n_samples - len(candidates)
434
+ available_negatives = [r for r in response_pool if r != positive and r not in seen]
435
+ if available_negatives:
436
+ additional_negatives = random.sample(
437
+ available_negatives,
438
+ min(fallback_needed, len(available_negatives))
439
+ )
440
+ candidates.extend(additional_negatives)
441
+
442
+ return candidates
443
+
444
+ def _prepare_sequences(
445
+ self,
446
+ queries: List[str],
447
+ positives: List[str],
448
+ negatives: List[str]
449
+ ) -> Tuple[tf.Tensor, tf.Tensor, tf.Tensor]:
450
+ """Modified sequence preparation for pretrained tokenizer."""
451
+ logger.info("Preparing sequences...")
452
+
453
+ # Process texts with special tokens
454
+ queries = [f"{self.special_tokens['user']} {q}" for q in queries]
455
+ positives = [f"{self.special_tokens['assistant']} {p}" for p in positives]
456
+ negatives = [f"{self.special_tokens['assistant']} {n}" for n in negatives]
457
+
458
+ # Tokenize using HuggingFace tokenizer
459
+ def encode_batch(texts: List[str]) -> tf.Tensor:
460
+ # HuggingFace tokenizer returns TensorFlow tensors when return_tensors='tf'
461
+ encodings = self.tokenizer(
462
+ texts,
463
+ padding='max_length',
464
+ truncation=True,
465
+ max_length=self.config.max_sequence_length,
466
+ return_tensors='tf'
467
+ )
468
+ return encodings['input_ids']
469
+
470
+ # Encode all sequences
471
+ q_tensor = encode_batch(queries)
472
+ p_tensor = encode_batch(positives)
473
+ n_tensor = encode_batch(negatives)
474
+
475
+ # Log statistics about encoded sequences
476
+ logger.info("Sequence statistics:")
477
+ logger.info(f" Query sequence shape: {q_tensor.shape}")
478
+ logger.info(f" Positive response sequence shape: {p_tensor.shape}")
479
+ logger.info(f" Negative response sequence shape: {n_tensor.shape}")
480
+
481
+ return q_tensor, p_tensor, n_tensor
482
+
483
+ def train(
484
+ self,
485
+ q_pad: tf.Tensor,
486
+ p_pad: tf.Tensor,
487
+ n_pad: tf.Tensor,
488
+ epochs: int = 3,
489
+ batch_size: int = 32,
490
+ validation_split: float = 0.2,
491
+ checkpoint_dir: Optional[Union[str, Path]] = None
492
+ ):
493
+ """Train the model with improved training loop."""
494
+ # Setup training
495
+ total_samples = tf.shape(q_pad)[0]
496
+ train_size = int((1 - validation_split) * total_samples.numpy())
497
+
498
+ # Shuffle and split data
499
+ indices = tf.random.shuffle(tf.range(start=0, limit=total_samples, dtype=tf.int32))
500
+ train_idx = indices[:train_size]
501
+ val_idx = indices[train_size:]
502
+
503
+ # Split data using TF indexing
504
+ train_data = (
505
+ tf.gather(q_pad, train_idx),
506
+ tf.gather(p_pad, train_idx),
507
+ tf.gather(n_pad, train_idx)
508
+ )
509
+ val_data = (
510
+ tf.gather(q_pad, val_idx),
511
+ tf.gather(p_pad, val_idx),
512
+ tf.gather(n_pad, val_idx)
513
+ )
514
+
515
+ # Setup optimizer with learning rate schedule
516
+ steps_per_epoch = train_size // batch_size
517
+ total_steps = steps_per_epoch * epochs
518
+
519
+ lr_schedule = self._get_lr_schedule(
520
+ total_steps,
521
+ self.config.learning_rate,
522
+ self.config.warmup_steps
523
+ )
524
+
525
+ optimizer = tf.keras.optimizers.Adam(learning_rate=lr_schedule)
526
+
527
+ # Setup checkpointing
528
+ if checkpoint_dir:
529
+ checkpoint_dir = Path(checkpoint_dir)
530
+ checkpoint_dir.mkdir(parents=True, exist_ok=True)
531
+
532
+ # Setup checkpoint callback with correct file format
533
+ checkpoint_template = str(checkpoint_dir / "model_epoch_{epoch:04d}.weights.h5")
534
+ checkpoint_callback = tf.keras.callbacks.ModelCheckpoint(
535
+ checkpoint_template,
536
+ save_weights_only=True,
537
+ save_best_only=True,
538
+ monitor='val_loss',
539
+ mode='min',
540
+ verbose=1
541
+ )
542
+
543
+ # Training loop
544
+ best_val_loss = float('inf')
545
+ patience = 5
546
+ wait = 0
547
+
548
+ for epoch in range(epochs):
549
+ # Training
550
+ train_loss = self._train_epoch(
551
+ train_data,
552
+ optimizer,
553
+ batch_size,
554
+ training=True
555
+ )
556
+
557
+ # Validation
558
+ val_loss = self._train_epoch(
559
+ val_data,
560
+ optimizer,
561
+ batch_size,
562
+ training=False
563
+ )
564
+
565
+ # Update history
566
+ self.history['train_loss'].append(train_loss)
567
+ self.history['val_loss'].append(val_loss)
568
+
569
+ logger.info(
570
+ f"Epoch {epoch + 1}/{epochs} - "
571
+ f"train_loss: {train_loss:.4f} - "
572
+ f"val_loss: {val_loss:.4f}"
573
+ )
574
+
575
+ # Early stopping
576
+ if val_loss < best_val_loss:
577
+ best_val_loss = val_loss
578
+ wait = 0
579
+ if checkpoint_dir:
580
+ self.save_models(checkpoint_dir / f"best_model")
581
+ else:
582
+ wait += 1
583
+ if wait >= patience:
584
+ logger.info("Early stopping triggered")
585
+ break
586
+
587
+ def _train_epoch(
588
+ self,
589
+ data: Tuple[tf.Tensor, tf.Tensor, tf.Tensor],
590
+ optimizer: tf.keras.optimizers.Optimizer,
591
+ batch_size: int,
592
+ training: bool = True
593
+ ) -> float:
594
+ """Train for one epoch with enhanced logging and progress tracking."""
595
+ q_data, p_data, n_data = data
596
+ total_loss = 0.0
597
+ num_batches = tf.shape(q_data)[0] // batch_size
598
+
599
+ # Log current learning rate at start of epoch
600
+ if training:
601
+ if hasattr(optimizer.learning_rate, '__call__'):
602
+ current_lr = optimizer.learning_rate(optimizer.iterations)
603
+ else:
604
+ current_lr = optimizer.learning_rate
605
+ logger.info(f"Current learning rate: {float(current_lr):.6f}")
606
+
607
+ # Create progress bar
608
+ mode = "Training" if training else "Validation"
609
+ pbar = tqdm(
610
+ total=num_batches.numpy(),
611
+ desc=f"{mode} batches",
612
+ unit="batch",
613
+ dynamic_ncols=True
614
+ )
615
+
616
+ # Process batches
617
+ for i in range(num_batches):
618
+ start_idx = i * batch_size
619
+ end_idx = start_idx + batch_size
620
+
621
+ batch_loss = self._train_step(
622
+ q_data[start_idx:end_idx],
623
+ p_data[start_idx:end_idx],
624
+ n_data[start_idx:end_idx],
625
+ optimizer,
626
+ training
627
+ )
628
+ total_loss += batch_loss.numpy()
629
+
630
+ # Update progress bar with current loss
631
+ avg_loss = total_loss / (i + 1)
632
+ pbar.set_postfix({
633
+ 'loss': f'{avg_loss:.4f}',
634
+ 'lr': f'{float(current_lr):.6f}' if training else 'N/A'
635
+ })
636
+ pbar.update(1)
637
+
638
+ pbar.close()
639
+ return total_loss / num_batches.numpy() if num_batches > 0 else 0.0
640
+
641
+ @tf.function
642
+ def _train_step(
643
+ self,
644
+ q_batch: tf.Tensor,
645
+ p_batch: tf.Tensor,
646
+ n_batch: tf.Tensor,
647
+ optimizer: tf.keras.optimizers.Optimizer,
648
+ training: bool = True
649
+ ) -> tf.Tensor:
650
+ """Single training step with triplet loss."""
651
+ with tf.GradientTape() as tape:
652
+ # Get embeddings
653
+ q_emb = self.query_encoder(q_batch, training=training)
654
+ p_emb = self.response_encoder(p_batch, training=training)
655
+ n_emb = self.response_encoder(n_batch, training=training)
656
+
657
+ # Calculate triplet loss
658
+ pos_dist = tf.reduce_sum(tf.square(q_emb - p_emb), axis=1)
659
+ neg_dist = tf.reduce_sum(tf.square(q_emb - n_emb), axis=1)
660
+
661
+ loss = tf.maximum(0.0, self.config.margin + pos_dist - neg_dist)
662
+ loss = tf.reduce_mean(loss)
663
+
664
+ if training:
665
+ # Apply gradients
666
+ gradients = tape.gradient(
667
+ loss,
668
+ self.query_encoder.trainable_variables +
669
+ self.response_encoder.trainable_variables
670
+ )
671
+ optimizer.apply_gradients(zip(
672
+ gradients,
673
+ self.query_encoder.trainable_variables +
674
+ self.response_encoder.trainable_variables
675
+ ))
676
+
677
+ return loss
678
+
679
+ def _get_lr_schedule(
680
+ self,
681
+ total_steps: int,
682
+ peak_lr: float,
683
+ warmup_steps: int
684
+ ) -> tf.keras.optimizers.schedules.LearningRateSchedule:
685
+ """Enhanced learning rate schedule with better error handling and logging."""
686
+ class CustomSchedule(tf.keras.optimizers.schedules.LearningRateSchedule):
687
+ def __init__(
688
+ self,
689
+ total_steps: int,
690
+ peak_lr: float,
691
+ warmup_steps: int
692
+ ):
693
+ super().__init__()
694
+ self.total_steps = tf.cast(total_steps, tf.float32)
695
+ self.peak_lr = tf.cast(peak_lr, tf.float32)
696
+ self.warmup_steps = tf.cast(max(1, warmup_steps), tf.float32) # Prevent 0
697
+
698
+ # Calculate and store constants
699
+ self.initial_lr = self.peak_lr * 0.1 # Start at 10% of peak
700
+ self.min_lr = self.peak_lr * 0.01 # Minimum 1% of peak
701
+
702
+ logger.info(f"Learning rate schedule initialized:")
703
+ logger.info(f" Initial LR: {float(self.initial_lr):.6f}")
704
+ logger.info(f" Peak LR: {float(self.peak_lr):.6f}")
705
+ logger.info(f" Min LR: {float(self.min_lr):.6f}")
706
+ logger.info(f" Warmup steps: {int(self.warmup_steps)}")
707
+ logger.info(f" Total steps: {int(self.total_steps)}")
708
+
709
+ def __call__(self, step):
710
+ step = tf.cast(step, tf.float32)
711
+
712
+ # Warmup phase
713
+ warmup_factor = tf.minimum(1.0, step / self.warmup_steps)
714
+ warmup_lr = self.initial_lr + (self.peak_lr - self.initial_lr) * warmup_factor
715
+
716
+ # Decay phase
717
+ decay_steps = tf.maximum(1.0, self.total_steps - self.warmup_steps)
718
+ decay_factor = (step - self.warmup_steps) / decay_steps
719
+ decay_factor = tf.minimum(tf.maximum(0.0, decay_factor), 1.0) # Clip to [0,1]
720
+
721
+ cosine_decay = 0.5 * (1.0 + tf.cos(np.pi * decay_factor))
722
+ decay_lr = self.min_lr + (self.peak_lr - self.min_lr) * cosine_decay
723
+
724
+ # Choose between warmup and decay
725
+ final_lr = tf.where(step < self.warmup_steps, warmup_lr, decay_lr)
726
+
727
+ # Ensure learning rate is valid
728
+ final_lr = tf.maximum(self.min_lr, final_lr)
729
+ final_lr = tf.where(tf.math.is_finite(final_lr), final_lr, self.min_lr)
730
+
731
+ return final_lr
732
+
733
+ def get_config(self):
734
+ return {
735
+ "total_steps": self.total_steps,
736
+ "peak_lr": self.peak_lr,
737
+ "warmup_steps": self.warmup_steps,
738
+ }
739
+
740
+ return CustomSchedule(total_steps, peak_lr, warmup_steps)
741
+
742
+ def encode_query(self, query: str, context: Optional[List[Tuple[str, str]]] = None) -> tf.Tensor:
743
+ """Encode a query with optional conversation context."""
744
+ # Prepare query with context
745
+ if context:
746
+ context_str = ' '.join([
747
+ f"{self.special_tokens['user']} {q} "
748
+ f"{self.special_tokens['assistant']} {r}"
749
+ for q, r in context[-self.config.max_context_turns:]
750
+ ])
751
+ query = f"{context_str} {self.special_tokens['user']} {query}"
752
+ else:
753
+ query = f"{self.special_tokens['user']} {query}"
754
+
755
+ # Tokenize and pad using TensorFlow tensors
756
+ encodings = self.tokenizer(
757
+ [query],
758
+ padding='max_length',
759
+ truncation=True,
760
+ max_length=self.config.max_sequence_length,
761
+ return_tensors='tf'
762
+ )
763
+ input_ids = encodings['input_ids']
764
+
765
+ return self.query_encoder(input_ids, training=False)
766
+
767
+ def encode_responses(self, responses: List[str]) -> tf.Tensor:
768
+ """Encode a batch of responses."""
769
+ # Prepare responses
770
+ responses = [
771
+ f"{self.special_tokens['assistant']} {r}"
772
+ for r in responses
773
+ ]
774
+
775
+ # Tokenize and pad using TensorFlow tensors
776
+ encodings = self.tokenizer(
777
+ responses,
778
+ padding='max_length',
779
+ truncation=True,
780
+ max_length=self.config.max_sequence_length,
781
+ return_tensors='tf'
782
+ )
783
+ input_ids = encodings['input_ids']
784
+
785
+ return self.response_encoder(input_ids, training=False)
786
+
787
+ def retrieve_responses(
788
+ self,
789
+ query: str,
790
+ candidates: List[str],
791
+ context: Optional[List[Tuple[str, str]]] = None,
792
+ top_k: int = 5
793
+ ) -> List[Tuple[str, float]]:
794
+ """Retrieve top-k responses for a query."""
795
+ # Encode query and candidates
796
+ q_emb = self.encode_query(query, context)
797
+ c_emb = self.encode_responses(candidates)
798
+
799
+ # Calculate similarities
800
+ similarities = tf.matmul(q_emb, c_emb, transpose_b=True).numpy()[0]
801
+
802
+ # Get top-k responses
803
+ top_indices = np.argsort(similarities)[::-1][:top_k]
804
+
805
+ return [(candidates[i], similarities[i]) for i in top_indices]
806
+
807
+ def chat(
808
+ self,
809
+ query: str,
810
+ response_pool: List[str],
811
+ conversation_history: Optional[List[Tuple[str, str]]] = None,
812
+ top_k: int = 5
813
+ ) -> Tuple[str, List[Tuple[str, float]]]:
814
+ """Interactive chat with response selection."""
815
+ # Get responses with scores
816
+ responses = self.retrieve_responses(
817
+ query,
818
+ response_pool,
819
+ conversation_history,
820
+ top_k
821
+ )
822
+
823
+ # Return best response and all candidates with scores
824
+ return responses[0][0], responses
chatbot4.py ADDED
@@ -0,0 +1,1291 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import TFAutoModel, AutoTokenizer
2
+ import tensorflow as tf
3
+ import numpy as np
4
+ from typing import List, Tuple, Dict, Optional, Union, Any
5
+ from dataclasses import dataclass
6
+ import logging
7
+ import json
8
+ from tqdm import tqdm
9
+ from pathlib import Path
10
+ import faiss
11
+ from response_quality_checker import ResponseQualityChecker
12
+
13
+ policy = tf.keras.mixed_precision.Policy('mixed_float16')
14
+ tf.keras.mixed_precision.set_global_policy(policy)
15
+
16
+ # Configure logging
17
+ logging.basicConfig(
18
+ level=logging.DEBUG,
19
+ format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
20
+ )
21
+ logger = logging.getLogger(__name__)
22
+
23
+ @dataclass
24
+ class ChatbotConfig:
25
+ """Configuration for the RetrievalChatbot."""
26
+ vocab_size: int = 30526 # DistilBERT vocab size
27
+ max_sequence_length: int = 512
28
+ embedding_dim: int = 768 # Match DistilBERT's dimension
29
+ encoder_units: int = 256
30
+ num_attention_heads: int = 8
31
+ dropout_rate: float = 0.2
32
+ l2_reg_weight: float = 0.001
33
+ margin: float = 0.3
34
+ learning_rate: float = 0.001
35
+ min_text_length: int = 3
36
+ max_context_turns: int = 5
37
+ warmup_steps: int = 200
38
+ pretrained_model: str = 'distilbert-base-uncased'
39
+ dtype: str = 'float32'
40
+ freeze_embeddings: bool = False
41
+ # Additional configurations can be added here
42
+
43
+ def to_dict(self) -> dict:
44
+ """Convert config to dictionary."""
45
+ return {k: str(v) if isinstance(v, Path) else v
46
+ for k, v in self.__dict__.items()}
47
+
48
+ @classmethod
49
+ def from_dict(cls, config_dict: dict) -> 'ChatbotConfig':
50
+ """Create config from dictionary."""
51
+ return cls(**{k: v for k, v in config_dict.items()
52
+ if k in cls.__dataclass_fields__})
53
+
54
+ class EncoderModel(tf.keras.Model):
55
+ """Dual encoder model with pretrained embeddings."""
56
+ def __init__(
57
+ self,
58
+ config: ChatbotConfig,
59
+ name: str = "encoder",
60
+ shared_weights: bool = False,
61
+ **kwargs
62
+ ):
63
+ super().__init__(name=name, **kwargs)
64
+ self.config = config
65
+ self.shared_weights = shared_weights
66
+
67
+ # Load pretrained model
68
+ self.pretrained = TFAutoModel.from_pretrained(config.pretrained_model)
69
+
70
+ # Freeze pretrained weights if specified
71
+ self.pretrained.distilbert.embeddings.trainable = False
72
+ for i, layer_module in enumerate(self.pretrained.distilbert.transformer.layer):
73
+ if i < 3: # freeze first 2 layers
74
+ layer_module.trainable = False
75
+ else:
76
+ layer_module.trainable = True
77
+
78
+ # Pooling layer (Global Average Pooling)
79
+ self.pooler = tf.keras.layers.GlobalAveragePooling1D()
80
+
81
+ # Dropout and normalization
82
+ self.dropout = tf.keras.layers.Dropout(config.dropout_rate)
83
+ self.normalize = tf.keras.layers.Lambda(
84
+ lambda x: tf.nn.l2_normalize(x, axis=1)
85
+ )
86
+
87
+ def call(self, inputs: tf.Tensor, training: bool = False) -> tf.Tensor:
88
+ """Forward pass."""
89
+ # Get pretrained embeddings
90
+ pretrained_outputs = self.pretrained(inputs, training=training)
91
+ x = pretrained_outputs.last_hidden_state # Shape: [batch_size, seq_len, embedding_dim]
92
+
93
+ # Apply pooling
94
+ x = self.pooler(x) # Shape: [batch_size, embedding_dim]
95
+
96
+ # Apply dropout
97
+ x = self.dropout(x, training=training)
98
+
99
+ # L2 normalization
100
+ x = self.normalize(x) # Shape: [batch_size, embedding_dim]
101
+
102
+ return x
103
+
104
+ def get_config(self) -> dict:
105
+ """Return the config of the model."""
106
+ config = super().get_config()
107
+ config.update({
108
+ "config": self.config.to_dict(),
109
+ "shared_weights": self.shared_weights,
110
+ "name": self.name
111
+ })
112
+ return config
113
+
114
+ # class CustomLRSchedule(tf.keras.optimizers.schedules.LearningRateSchedule):
115
+ # def __init__(self, initial_lr, peak_lr, min_lr, warmup_steps, total_steps):
116
+ # super().__init__()
117
+ # self.initial_lr = initial_lr
118
+ # self.peak_lr = peak_lr
119
+ # self.min_lr = min_lr
120
+ # self.warmup_steps = min(warmup_steps, total_steps // 2) # Ensure warmup_steps <= total_steps
121
+ # self.total_steps = total_steps
122
+
123
+ # def __call__(self, step):
124
+ # if step < self.warmup_steps:
125
+ # # Linear warmup
126
+ # lr = self.initial_lr + (self.peak_lr - self.initial_lr) * (step / self.warmup_steps)
127
+ # else:
128
+ # # Linear decay
129
+ # decay_steps = self.total_steps - self.warmup_steps
130
+ # if decay_steps > 0:
131
+ # lr = self.peak_lr - (self.peak_lr - self.min_lr) * ((step - self.warmup_steps) / decay_steps)
132
+ # else:
133
+ # lr = self.peak_lr
134
+ # return lr
135
+
136
+ # def get_config(self):
137
+ # return {
138
+ # "initial_lr": self.initial_lr,
139
+ # "peak_lr": self.peak_lr,
140
+ # "min_lr": self.min_lr,
141
+ # "warmup_steps": self.warmup_steps,
142
+ # "total_steps": self.total_steps,
143
+ # }
144
+
145
+ class RetrievalChatbot:
146
+ """Retrieval-based chatbot using pretrained embeddings and FAISS for similarity search."""
147
+ def __init__(self, config: ChatbotConfig, dialogues: List[dict] = []):
148
+ self.config = config
149
+
150
+ # Special tokens
151
+ self.special_tokens = {
152
+ "user": "<USER>",
153
+ "assistant": "<ASSISTANT>",
154
+ "context": "<CONTEXT>",
155
+ "sep": "<SEP>"
156
+ }
157
+
158
+ # Initialize tokenizer and add special tokens
159
+ self.tokenizer = AutoTokenizer.from_pretrained(config.pretrained_model)
160
+ self.tokenizer.add_special_tokens(
161
+ {'additional_special_tokens': list(self.special_tokens.values())}
162
+ )
163
+
164
+ # Build encoders
165
+ self._build_models()
166
+
167
+ # Initialize FAISS index
168
+ self._initialize_faiss()
169
+
170
+ # Precompute and index response embeddings
171
+ self._precompute_and_index_responses(dialogues)
172
+
173
+ # Initialize training history
174
+ self.history = {
175
+ "train_loss": [],
176
+ "val_loss": [],
177
+ "train_metrics": {},
178
+ "val_metrics": {}
179
+ }
180
+
181
+ def _build_models(self):
182
+ """Initialize the shared encoder."""
183
+ logger.info("Building encoder model...")
184
+
185
+ # Shared encoder for both queries and responses
186
+ self.encoder = EncoderModel(
187
+ self.config,
188
+ name="shared_encoder",
189
+ )
190
+
191
+ # Resize token embeddings after adding special tokens
192
+ new_vocab_size = len(self.tokenizer)
193
+ self.encoder.pretrained.resize_token_embeddings(new_vocab_size)
194
+ logger.info(f"Token embeddings resized to: {new_vocab_size}")
195
+
196
+ # Inspect embeddings attributes for debugging
197
+ logger.info("Inspecting embeddings attributes:")
198
+ for attr in dir(self.encoder.pretrained.distilbert.embeddings):
199
+ if not attr.startswith('_'):
200
+ logger.info(f" {attr}")
201
+
202
+ # Verify embedding layers without accessing word_embeddings directly
203
+ embedding_dim = getattr(self.encoder.pretrained.distilbert.embeddings, 'embedding_dim', 'Unknown')
204
+ vocab_size = getattr(self.encoder.pretrained.distilbert.embeddings, 'input_dim', len(self.tokenizer))
205
+ logger.info(f"Encoder Embedding Dimension: {embedding_dim}")
206
+ logger.info(f"Encoder Embedding Vocabulary Size: {vocab_size}")
207
+
208
+ logger.info("Encoder model built and embeddings resized successfully.")
209
+ for var in self.encoder.pretrained.trainable_variables:
210
+ logger.info(f"{var.name}, {var.shape}")
211
+
212
+ def check_trainable_variables(self):
213
+ """Logs the trainable variables in both encoders."""
214
+ logger.info("Checking trainable variables in shared_encoder:")
215
+ for var in self.encoder.pretrained.trainable_variables:
216
+ logger.info(f" {var.name}, shape: {var.shape}")
217
+
218
+ # logger.info("Checking trainable variables in response_encoder:")
219
+ # for var in self.response_encoder.pretrained.trainable_variables:
220
+ # logger.info(f" {var.name}, shape: {var.shape}")
221
+
222
+ def _initialize_faiss(self):
223
+ """Initialize FAISS index based on available resources."""
224
+ logger.info("Initializing FAISS index...")
225
+ # Determine if GPU FAISS is available
226
+ try:
227
+ res = faiss.StandardGpuResources()
228
+ self.faiss_gpu = True
229
+ logger.info("FAISS GPU resources initialized.")
230
+ except Exception as e:
231
+ self.faiss_gpu = False
232
+ logger.info("FAISS GPU resources not available. Using FAISS CPU.")
233
+
234
+ # Initialize FAISS index for Inner Product (for cosine similarity)
235
+ if self.faiss_gpu:
236
+ self.index = faiss.IndexFlatIP(self.config.embedding_dim)
237
+ self.index = faiss.index_cpu_to_gpu(res, 0, self.index)
238
+ else:
239
+ self.index = faiss.IndexFlatIP(self.config.embedding_dim)
240
+ logger.info("FAISS index initialized.")
241
+
242
+ def verify_faiss_index(chatbot):
243
+ """Verify that FAISS index matches the response pool."""
244
+ indexed_size = chatbot.index.ntotal
245
+ pool_size = len(chatbot.response_pool)
246
+ logger.info(f"FAISS index size: {indexed_size}")
247
+ logger.info(f"Response pool size: {pool_size}")
248
+ if indexed_size != pool_size:
249
+ logger.warning("Mismatch between FAISS index size and response pool size.")
250
+ else:
251
+ logger.info("FAISS index correctly matches the response pool.")
252
+
253
+
254
+ def _precompute_and_index_responses(self, dialogues: List[dict]):
255
+ """Precompute embeddings for all responses and index them using FAISS."""
256
+ logger.info("Precomputing response embeddings and indexing with FAISS...")
257
+
258
+ # Use tqdm for collecting responses
259
+ responses = []
260
+ for dialogue in tqdm(dialogues, desc="Collecting assistant responses"):
261
+ turns = dialogue.get('turns', [])
262
+ for turn in turns:
263
+ if turn.get('speaker') == 'assistant' and 'text' in turn:
264
+ responses.append(turn['text'].strip())
265
+
266
+ # Remove duplicates
267
+ unique_responses = list(set(responses))
268
+ logger.info(f"Found {len(unique_responses)} unique responses.")
269
+
270
+ # Encode responses
271
+ response_embeddings = self.encode_responses(unique_responses)
272
+ response_embeddings = response_embeddings.numpy()
273
+
274
+ # Ensure float32
275
+ if response_embeddings.dtype != np.float32:
276
+ logger.info(f"Converting embeddings from {response_embeddings.dtype} to float32.")
277
+ response_embeddings = response_embeddings.astype('float32')
278
+
279
+ # Ensure the array is contiguous in memory
280
+ if not response_embeddings.flags['C_CONTIGUOUS']:
281
+ logger.info("Making embeddings contiguous in memory.")
282
+ response_embeddings = np.ascontiguousarray(response_embeddings)
283
+
284
+ # Normalize embeddings for cosine similarity
285
+ logger.info("Normalizing embeddings with FAISS.")
286
+ faiss.normalize_L2(response_embeddings)
287
+
288
+ # Add to FAISS index
289
+ logger.info("Adding embeddings to FAISS index...")
290
+ self.index.add(response_embeddings)
291
+ logger.info(f"Indexed {self.index.ntotal} responses.")
292
+
293
+ # Store responses and embeddings
294
+ self.response_pool = unique_responses
295
+ self.response_embeddings = response_embeddings
296
+ logger.info("Precomputation and indexing completed.")
297
+
298
+ def encode_responses(
299
+ self,
300
+ responses: List[str],
301
+ batch_size: int = 64
302
+ ) -> tf.Tensor:
303
+ """
304
+ Encodes a list of responses into embeddings, using chunked/batched processing
305
+ to avoid running out of memory when there are many responses.
306
+
307
+ Args:
308
+ responses (List[str]): The list of response texts to encode.
309
+ batch_size (int): How many responses to encode per chunk.
310
+ Adjust based on available GPU/CPU memory.
311
+
312
+ Returns:
313
+ tf.Tensor: Tensor of shape (N, emb_dim) with all response embeddings.
314
+ """
315
+ logger.info(f"Encoding {len(responses)} responses in batches of size {batch_size}...")
316
+
317
+ # We'll accumulate embeddings in a list and concatenate at the end
318
+ all_embeddings = []
319
+
320
+ # Set up a progress bar
321
+ from tqdm import tqdm
322
+ pbar = tqdm(total=len(responses), desc="Encoding responses")
323
+
324
+ # Process the responses in chunks of 'batch_size'
325
+ for start_idx in range(0, len(responses), batch_size):
326
+ end_idx = start_idx + batch_size
327
+ batch_texts = responses[start_idx:end_idx]
328
+
329
+ # Tokenize the current batch
330
+ encodings = self.tokenizer(
331
+ batch_texts,
332
+ padding='max_length',
333
+ truncation=True,
334
+ max_length=self.config.max_sequence_length,
335
+ return_tensors='tf',
336
+ )
337
+
338
+ # Run the encoder forward pass
339
+ input_ids = encodings['input_ids']
340
+ embeddings_batch = self.encoder(input_ids, training=False)
341
+
342
+ # Cast to float32 if needed
343
+ if embeddings_batch.dtype != tf.float32:
344
+ embeddings_batch = tf.cast(embeddings_batch, tf.float32)
345
+
346
+ # Collect
347
+ all_embeddings.append(embeddings_batch)
348
+
349
+ # Update progress bar
350
+ pbar.update(len(batch_texts))
351
+
352
+ pbar.close()
353
+
354
+ # Concatenate all batch embeddings along axis=0
355
+ if len(all_embeddings) == 1:
356
+ # Only one batch
357
+ final_embeddings = all_embeddings[0]
358
+ else:
359
+ # Multiple batches, concatenate
360
+ final_embeddings = tf.concat(all_embeddings, axis=0)
361
+
362
+ logger.info(
363
+ f"Finished encoding {len(responses)} responses. "
364
+ f"Final shape: {final_embeddings.shape}"
365
+ )
366
+ return final_embeddings
367
+
368
+ def encode_query(self, query: str, context: Optional[List[Tuple[str, str]]] = None) -> tf.Tensor:
369
+ """Encode a query with optional conversation context."""
370
+ # Prepare query with context
371
+ if context:
372
+ context_str = ' '.join([
373
+ f"{self.special_tokens['user']} {q} "
374
+ f"{self.special_tokens['assistant']} {r}"
375
+ for q, r in context[-self.config.max_context_turns:]
376
+ ])
377
+ query = f"{context_str} {self.special_tokens['user']} {query}"
378
+ else:
379
+ query = f"{self.special_tokens['user']} {query}"
380
+
381
+ # Tokenize and encode
382
+ encodings = self.tokenizer(
383
+ [query],
384
+ padding='max_length',
385
+ truncation=True,
386
+ max_length=self.config.max_sequence_length,
387
+ return_tensors='tf'
388
+ )
389
+ input_ids = encodings['input_ids']
390
+
391
+ # Verify token IDs
392
+ max_id = tf.reduce_max(input_ids).numpy()
393
+ new_vocab_size = len(self.tokenizer)
394
+ logger.info(f"Maximum input_id: {max_id}, Vocab Size: {new_vocab_size}")
395
+
396
+ if max_id >= new_vocab_size:
397
+ logger.error(f"Token ID {max_id} exceeds the vocabulary size {new_vocab_size}.")
398
+ raise ValueError("Token ID exceeds vocabulary size.")
399
+
400
+ # Get embeddings from the shared encoder
401
+ return self.encoder(input_ids, training=False)
402
+
403
+ def retrieve_responses_faiss(self, query: str, top_k: int = 5) -> List[Tuple[str, float]]:
404
+ """Retrieve top-k responses using FAISS."""
405
+ # Encode the query
406
+ q_emb = self.encode_query(query) # Shape: [1, embedding_dim]
407
+ q_emb_np = q_emb.numpy().astype('float32') # Ensure type matches FAISS requirements
408
+
409
+ # Normalize the query embedding for cosine similarity
410
+ faiss.normalize_L2(q_emb_np)
411
+
412
+ # Search the FAISS index
413
+ distances, indices = self.index.search(q_emb_np, top_k)
414
+
415
+ # Map indices to responses and distances to similarities
416
+ top_responses = []
417
+ for i, idx in enumerate(indices[0]):
418
+ if idx < len(self.response_pool):
419
+ top_responses.append((self.response_pool[idx], float(distances[0][i])))
420
+ else:
421
+ logger.warning(f"FAISS returned invalid index {idx}. Skipping.")
422
+
423
+ return top_responses
424
+
425
+ def save_models(self, save_dir: Union[str, Path]):
426
+ """Save models and configuration."""
427
+ save_dir = Path(save_dir)
428
+ save_dir.mkdir(parents=True, exist_ok=True)
429
+
430
+ # Save config
431
+ with open(save_dir / "config.json", "w") as f:
432
+ json.dump(self.config.to_dict(), f, indent=2)
433
+
434
+ # Save models
435
+ self.encoder.pretrained.save_pretrained(save_dir / "shared_encoder")
436
+
437
+ # Save tokenizer
438
+ self.tokenizer.save_pretrained(save_dir / "tokenizer")
439
+
440
+ logger.info(f"Models and tokenizer saved to {save_dir}.")
441
+
442
+ @classmethod
443
+ def load_models(cls, load_dir: Union[str, Path]) -> 'RetrievalChatbot':
444
+ """Load saved models and configuration."""
445
+ load_dir = Path(load_dir)
446
+
447
+ # Load config
448
+ with open(load_dir / "config.json", "r") as f:
449
+ config = ChatbotConfig.from_dict(json.load(f))
450
+
451
+ # Initialize chatbot
452
+ chatbot = cls(config)
453
+
454
+ # Load models
455
+ chatbot.encoder.pretrained = TFAutoModel.from_pretrained(
456
+ load_dir / "shared_encoder",
457
+ config=config
458
+ )
459
+ # chatbot.response_encoder.pretrained = TFAutoModel.from_pretrained(
460
+ # load_dir / "response_encoder",
461
+ # config=config
462
+ # )
463
+
464
+ # Load tokenizer
465
+ chatbot.tokenizer = AutoTokenizer.from_pretrained(load_dir / "tokenizer")
466
+
467
+ logger.info(f"Models and tokenizer loaded from {load_dir}.")
468
+ return chatbot
469
+
470
+ @staticmethod
471
+ def load_training_data(data_path: Union[str, Path], debug_samples: Optional[int] = None) -> List[dict]:
472
+ """
473
+ Load training data from a JSON file.
474
+
475
+ Args:
476
+ data_path (Union[str, Path]): Path to the JSON file containing dialogues.
477
+ debug_samples (Optional[int]): Number of samples to load for debugging.
478
+
479
+ Returns:
480
+ List[dict]: List of dialogue dictionaries.
481
+ """
482
+ logger.info(f"Loading training data from {data_path}...")
483
+ data_path = Path(data_path)
484
+ if not data_path.exists():
485
+ logger.error(f"Data file {data_path} does not exist.")
486
+ return []
487
+
488
+ with open(data_path, 'r', encoding='utf-8') as f:
489
+ dialogues = json.load(f)
490
+
491
+ if debug_samples is not None:
492
+ dialogues = dialogues[:debug_samples]
493
+ logger.info(f"Debug mode: Limited to {debug_samples} dialogues")
494
+
495
+ logger.info(f"Loaded {len(dialogues)} dialogues.")
496
+ return dialogues
497
+
498
+ def prepare_dataset(
499
+ self,
500
+ dialogues: List[dict],
501
+ debug_samples: int = None
502
+ ) -> Tuple[tf.Tensor, tf.Tensor]:
503
+ """
504
+ Prepares dataset for in-batch negatives:
505
+ Only returns (query, positive) pairs.
506
+ """
507
+ logger.info("Preparing in-batch dataset...")
508
+
509
+ queries, positives = [], []
510
+
511
+ for dialogue in dialogues:
512
+ turns = dialogue.get('turns', [])
513
+ for i in range(len(turns) - 1):
514
+ current_turn = turns[i]
515
+ next_turn = turns[i+1]
516
+
517
+ if (current_turn.get('speaker') == 'user' and
518
+ next_turn.get('speaker') == 'assistant' and
519
+ 'text' in current_turn and
520
+ 'text' in next_turn):
521
+
522
+ query = current_turn['text'].strip()
523
+ positive = next_turn['text'].strip()
524
+
525
+ queries.append(query)
526
+ positives.append(positive)
527
+
528
+ # Optional debug slicing
529
+ if debug_samples is not None:
530
+ queries = queries[:debug_samples]
531
+ positives = positives[:debug_samples]
532
+ logger.info(f"Debug mode: limited to {debug_samples} pairs.")
533
+
534
+ logger.info(f"Prepared {len(queries)} (query, positive) pairs.")
535
+
536
+ # Tokenize queries
537
+ encoded_queries = self.tokenizer(
538
+ queries,
539
+ padding='max_length',
540
+ truncation=True,
541
+ max_length=self.config.max_sequence_length,
542
+ return_tensors='tf'
543
+ )
544
+ # Tokenize positives
545
+ encoded_positives = self.tokenizer(
546
+ positives,
547
+ padding='max_length',
548
+ truncation=True,
549
+ max_length=self.config.max_sequence_length,
550
+ return_tensors='tf'
551
+ )
552
+
553
+ q_tensor = encoded_queries['input_ids']
554
+ p_tensor = encoded_positives['input_ids']
555
+
556
+ logger.info("Tokenized and padded sequences for in-batch training.")
557
+ return q_tensor, p_tensor
558
+
559
+ def train(
560
+ self,
561
+ q_pad: tf.Tensor,
562
+ p_pad: tf.Tensor,
563
+ epochs: int,
564
+ batch_size: int,
565
+ validation_split: float,
566
+ checkpoint_dir: str,
567
+ use_lr_schedule: bool = True,
568
+ peak_lr: float = 2e-5,
569
+ warmup_steps_ratio: float = 0.1,
570
+ early_stopping_patience: int = 3,
571
+ min_delta: float = 1e-4
572
+ ):
573
+ dataset_size = tf.shape(q_pad)[0].numpy()
574
+ val_size = int(dataset_size * validation_split)
575
+ train_size = dataset_size - val_size
576
+
577
+ logger.info(f"Total samples: {dataset_size}")
578
+ logger.info(f"Training samples: {train_size}")
579
+ logger.info(f"Validation samples: {val_size}")
580
+
581
+ steps_per_epoch = train_size // batch_size
582
+ if train_size % batch_size != 0:
583
+ steps_per_epoch += 1
584
+ total_steps = steps_per_epoch * epochs
585
+ logger.info(f"Total training steps (approx): {total_steps}")
586
+
587
+ # 1) Set up LR schedule or fixed LR
588
+ if use_lr_schedule:
589
+ warmup_steps = int(total_steps * warmup_steps_ratio)
590
+ lr_schedule = self._get_lr_schedule(
591
+ total_steps=total_steps,
592
+ peak_lr=peak_lr,
593
+ warmup_steps=warmup_steps
594
+ )
595
+ self.optimizer = tf.keras.optimizers.Adam(learning_rate=lr_schedule)
596
+ logger.info("Using custom learning rate schedule.")
597
+ else:
598
+ self.optimizer = tf.keras.optimizers.Adam(learning_rate=peak_lr)
599
+ logger.info("Using fixed learning rate.")
600
+
601
+ # 2) Prepare data splits
602
+ train_q = q_pad[:train_size]
603
+ train_p = p_pad[:train_size]
604
+ val_q = q_pad[train_size:]
605
+ val_p = p_pad[train_size:]
606
+
607
+ train_dataset = tf.data.Dataset.from_tensor_slices((train_q, train_p))
608
+ train_dataset = train_dataset.shuffle(buffer_size=4096).batch(batch_size)
609
+
610
+ val_dataset = tf.data.Dataset.from_tensor_slices((val_q, val_p))
611
+ val_dataset = val_dataset.batch(batch_size)
612
+
613
+ # 3) Checkpoint + manager
614
+ checkpoint = tf.train.Checkpoint(optimizer=self.optimizer, model=self.encoder)
615
+ manager = tf.train.CheckpointManager(checkpoint, checkpoint_dir, max_to_keep=3)
616
+
617
+ # 4) TensorBoard setup
618
+ import datetime
619
+ import os
620
+ from pathlib import Path
621
+
622
+ log_dir = Path(checkpoint_dir) / "tensorboard_logs"
623
+ log_dir.mkdir(parents=True, exist_ok=True)
624
+
625
+ current_time = datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
626
+ train_log_dir = str(log_dir / f"train_{current_time}")
627
+ val_log_dir = str(log_dir / f"val_{current_time}")
628
+
629
+ train_summary_writer = tf.summary.create_file_writer(train_log_dir)
630
+ val_summary_writer = tf.summary.create_file_writer(val_log_dir)
631
+
632
+ logger.info(f"TensorBoard logs will be saved in {log_dir}")
633
+
634
+ # 5) Early stopping
635
+ best_val_loss = float("inf")
636
+ epochs_no_improve = 0
637
+
638
+ logger.info("Beginning training loop...")
639
+ global_step = 0
640
+
641
+ from tqdm import tqdm
642
+ for epoch in range(1, epochs + 1):
643
+ logger.info(f"\n=== Epoch {epoch}/{epochs} ===")
644
+ epoch_loss_avg = tf.keras.metrics.Mean()
645
+
646
+ # Training loop
647
+ with tqdm(total=steps_per_epoch, desc=f"Training Epoch {epoch}") as pbar:
648
+ for (q_batch, p_batch) in train_dataset:
649
+ global_step += 1
650
+
651
+ # Train step
652
+ batch_loss = self._train_step(q_batch, p_batch)
653
+ epoch_loss_avg(batch_loss)
654
+
655
+ # Get current LR
656
+ if use_lr_schedule:
657
+ lr = self.optimizer.learning_rate
658
+ if isinstance(lr, tf.keras.optimizers.schedules.LearningRateSchedule):
659
+ # Get the current step
660
+ current_step = tf.cast(self.optimizer.iterations, tf.float32)
661
+ # Compute the current learning rate
662
+ current_lr = lr(current_step)
663
+ else:
664
+ # If learning_rate is not a schedule, use it directly
665
+ current_lr = lr
666
+ # Convert to float for logging
667
+ current_lr_value = float(current_lr.numpy())
668
+ else:
669
+ # If using fixed learning rate
670
+ current_lr_value = float(self.optimizer.learning_rate.numpy())
671
+
672
+ # Update tqdm
673
+ pbar.update(1)
674
+ pbar.set_postfix({
675
+ "loss": f"{batch_loss.numpy():.4f}",
676
+ "lr": f"{current_lr_value:.2e}"
677
+ })
678
+
679
+ # TensorBoard: log train metrics per step
680
+ with train_summary_writer.as_default():
681
+ tf.summary.scalar("loss", batch_loss, step=global_step)
682
+ tf.summary.scalar("learning_rate", current_lr_value, step=global_step)
683
+
684
+ # Validation
685
+ val_loss_avg = tf.keras.metrics.Mean()
686
+ for q_val, p_val in val_dataset:
687
+ q_enc = self.encoder(q_val, training=False)
688
+ p_enc = self.encoder(p_val, training=False)
689
+ sim_matrix = tf.matmul(q_enc, p_enc, transpose_b=True)
690
+ bs_val = tf.shape(q_enc)[0]
691
+ labels_val = tf.range(bs_val, dtype=tf.int32)
692
+ loss_val = tf.nn.sparse_softmax_cross_entropy_with_logits(
693
+ labels=labels_val,
694
+ logits=sim_matrix
695
+ )
696
+ val_loss_avg(tf.reduce_mean(loss_val))
697
+
698
+ train_loss = epoch_loss_avg.result().numpy()
699
+ val_loss = val_loss_avg.result().numpy()
700
+
701
+ logger.info(f"Epoch {epoch} Complete: Train Loss={train_loss:.4f}, Val Loss={val_loss:.4f}")
702
+
703
+ # TensorBoard: validation loss
704
+ with val_summary_writer.as_default():
705
+ tf.summary.scalar("val_loss", val_loss, step=epoch)
706
+
707
+ # Save checkpoint
708
+ manager.save()
709
+
710
+ # Update history
711
+ self.history['train_loss'].append(train_loss)
712
+ self.history['val_loss'].append(val_loss)
713
+ self.history.setdefault('learning_rate', []).append(float(current_lr_value))
714
+
715
+ # Early stopping
716
+ if val_loss < best_val_loss - min_delta:
717
+ best_val_loss = val_loss
718
+ epochs_no_improve = 0
719
+ logger.info(f"Validation loss improved to {val_loss:.4f}. Reset patience.")
720
+ else:
721
+ epochs_no_improve += 1
722
+ logger.info(f"No improvement this epoch. Patience: {epochs_no_improve}/{early_stopping_patience}")
723
+ if epochs_no_improve >= early_stopping_patience:
724
+ logger.info("Early stopping triggered.")
725
+ break
726
+
727
+ logger.info("In-batch training completed!")
728
+
729
+ @tf.function
730
+ def _train_step(self, q_batch, p_batch):
731
+ """
732
+ Single training step using in-batch negatives.
733
+ q_batch: (batch_size, seq_len) int32 input_ids for queries
734
+ p_batch: (batch_size, seq_len) int32 input_ids for positives
735
+ """
736
+ with tf.GradientTape() as tape:
737
+ # Encode queries and positives
738
+ q_enc = self.encoder(q_batch, training=True) # [B, emb_dim]
739
+ p_enc = self.encoder(p_batch, training=True) # [B, emb_dim]
740
+
741
+ # Compute similarity matrix: (B, B) = q_enc * p_enc^T
742
+ # If embeddings are L2-normalized, this is cosine similarity
743
+ sim_matrix = tf.matmul(q_enc, p_enc, transpose_b=True) # [B, B]
744
+
745
+ # Labels are just the diagonal indices
746
+ batch_size = tf.shape(q_enc)[0]
747
+ labels = tf.range(batch_size, dtype=tf.int32) # [0..B-1]
748
+
749
+ # Softmax cross-entropy
750
+ loss = tf.nn.sparse_softmax_cross_entropy_with_logits(
751
+ labels=labels,
752
+ logits=sim_matrix
753
+ )
754
+ loss = tf.reduce_mean(loss)
755
+
756
+ # Compute gradients for the pretrained DistilBERT variables only
757
+ train_vars = self.encoder.pretrained.trainable_variables
758
+ gradients = tape.gradient(loss, train_vars)
759
+
760
+ # Remove any None grads (in case some layers are frozen)
761
+ grads_and_vars = [(g, v) for g, v in zip(gradients, train_vars) if g is not None]
762
+ if grads_and_vars:
763
+ self.optimizer.apply_gradients(grads_and_vars)
764
+
765
+ return loss
766
+
767
+ def _prepare_sequences(
768
+ self,
769
+ queries: List[str],
770
+ positives: List[str],
771
+ negatives: List[str]
772
+ ) -> Tuple[tf.Tensor, tf.Tensor, tf.Tensor]:
773
+ """Prepare and tokenize sequences for training."""
774
+ logger.info("Preparing sequences for training...")
775
+
776
+ # Handle empty lists
777
+ if not queries:
778
+ logger.error("No queries to encode. Skipping sequence preparation.")
779
+ return tf.constant([]), tf.constant([]), tf.constant([])
780
+
781
+ # Process texts with special tokens
782
+ queries = [f"{self.special_tokens['user']} {q}" for q in queries]
783
+ positives = [f"{self.special_tokens['assistant']} {p}" for p in positives]
784
+ negatives = [f"{self.special_tokens['assistant']} {n}" for n in negatives]
785
+
786
+ # Tokenize using HuggingFace tokenizer
787
+ def encode_batch(texts: List[str]) -> tf.Tensor:
788
+ if not texts:
789
+ logger.error("Empty text list provided to tokenizer.")
790
+ return tf.constant([])
791
+ encodings = self.tokenizer(
792
+ texts,
793
+ padding='max_length',
794
+ truncation=True,
795
+ max_length=self.config.max_sequence_length,
796
+ return_tensors='tf'
797
+ )
798
+ return encodings['input_ids']
799
+
800
+ # Encode all sequences
801
+ q_tensor = encode_batch(queries)
802
+ p_tensor = encode_batch(positives)
803
+ n_tensor = encode_batch(negatives)
804
+
805
+ # Log statistics about encoded sequences
806
+ logger.info("Sequence statistics:")
807
+ logger.info(f" Query sequence shape: {q_tensor.shape}")
808
+ logger.info(f" Positive response sequence shape: {p_tensor.shape}")
809
+ logger.info(f" Negative response sequence shape: {n_tensor.shape}")
810
+
811
+ return q_tensor, p_tensor, n_tensor
812
+
813
+ def _get_lr_schedule(
814
+ self,
815
+ total_steps: int,
816
+ peak_lr: float,
817
+ warmup_steps: int
818
+ ) -> tf.keras.optimizers.schedules.LearningRateSchedule:
819
+ """Create a custom learning rate schedule with warmup and cosine decay."""
820
+ class CustomSchedule(tf.keras.optimizers.schedules.LearningRateSchedule):
821
+ def __init__(
822
+ self,
823
+ total_steps: int,
824
+ peak_lr: float,
825
+ warmup_steps: int
826
+ ):
827
+ super().__init__()
828
+ self.total_steps = tf.cast(total_steps, tf.float32)
829
+ self.peak_lr = tf.cast(peak_lr, tf.float32)
830
+
831
+ # Adjust warmup_steps to not exceed half of total_steps
832
+ adjusted_warmup_steps = min(warmup_steps, max(1, total_steps // 10))
833
+ self.warmup_steps = tf.cast(adjusted_warmup_steps, tf.float32)
834
+
835
+ # Calculate and store constants
836
+ self.initial_lr = self.peak_lr * 0.1 # Start at 10% of peak
837
+ self.min_lr = self.peak_lr * 0.01 # Minimum 1% of peak
838
+
839
+ logger.info(f"Learning rate schedule initialized:")
840
+ logger.info(f" Initial LR: {float(self.initial_lr):.6f}")
841
+ logger.info(f" Peak LR: {float(self.peak_lr):.6f}")
842
+ logger.info(f" Min LR: {float(self.min_lr):.6f}")
843
+ logger.info(f" Warmup steps: {int(self.warmup_steps)}")
844
+ logger.info(f" Total steps: {int(self.total_steps)}")
845
+
846
+ def __call__(self, step):
847
+ step = tf.cast(step, tf.float32)
848
+
849
+ # Warmup phase
850
+ warmup_factor = tf.minimum(1.0, step / self.warmup_steps)
851
+ warmup_lr = self.initial_lr + (self.peak_lr - self.initial_lr) * warmup_factor
852
+
853
+ # Decay phase
854
+ decay_steps = tf.maximum(1.0, self.total_steps - self.warmup_steps)
855
+ decay_factor = (step - self.warmup_steps) / decay_steps
856
+ decay_factor = tf.minimum(tf.maximum(0.0, decay_factor), 1.0) # Clip to [0,1]
857
+
858
+ cosine_decay = 0.5 * (1.0 + tf.cos(np.pi * decay_factor))
859
+ decay_lr = self.min_lr + (self.peak_lr - self.min_lr) * cosine_decay
860
+
861
+ # Choose between warmup and decay
862
+ final_lr = tf.where(step < self.warmup_steps, warmup_lr, decay_lr)
863
+
864
+ # Ensure learning rate is valid
865
+ final_lr = tf.maximum(self.min_lr, final_lr)
866
+ final_lr = tf.where(tf.math.is_finite(final_lr), final_lr, self.min_lr)
867
+
868
+ return final_lr
869
+
870
+ def get_config(self):
871
+ return {
872
+ "total_steps": self.total_steps,
873
+ "peak_lr": self.peak_lr,
874
+ "warmup_steps": self.warmup_steps,
875
+ }
876
+
877
+ return CustomSchedule(total_steps, peak_lr, warmup_steps)
878
+
879
+ def _cosine_similarity(self, emb1: np.ndarray, emb2: np.ndarray) -> np.ndarray:
880
+ """Compute cosine similarity between two numpy arrays."""
881
+ normalized_emb1 = emb1 / np.linalg.norm(emb1, axis=1, keepdims=True)
882
+ normalized_emb2 = emb2 / np.linalg.norm(emb2, axis=1, keepdims=True)
883
+ return np.dot(normalized_emb1, normalized_emb2.T)
884
+
885
+ def run_automatic_validation(
886
+ self,
887
+ quality_checker: 'ResponseQualityChecker',
888
+ num_examples: int = 5
889
+ ) -> Dict[str, Any]:
890
+ """
891
+ Run automatic validation with quality metrics using FAISS-based retrieval.
892
+ """
893
+ logger.info("\n=== Running Automatic Validation ===")
894
+
895
+ test_queries = [
896
+ "Hello, how are you today?",
897
+ "What's the weather like?",
898
+ "Can you help me with a problem?",
899
+ "Tell me a joke",
900
+ "What time is it?",
901
+ "I need help with my homework",
902
+ "Where's a good place to eat?",
903
+ "What movies are playing?",
904
+ "How do I reset my password?",
905
+ "Can you recommend a book?"
906
+ ]
907
+
908
+ test_queries = test_queries[:num_examples]
909
+ metrics_history = []
910
+
911
+ for i, query in enumerate(test_queries, 1):
912
+ logger.info(f"\nTest Case {i}:")
913
+ logger.info(f"Query: {query}")
914
+
915
+ # Get responses and scores using FAISS
916
+ responses = self.retrieve_responses_faiss(query, top_k=5)
917
+
918
+ # Check quality
919
+ quality_metrics = quality_checker.check_response_quality(query, responses)
920
+ metrics_history.append(quality_metrics)
921
+
922
+ # Log results
923
+ logger.info(f"Quality Metrics: {quality_metrics}")
924
+ logger.info("Top responses:")
925
+ for j, (response, score) in enumerate(responses[:3], 1):
926
+ logger.info(f"{j}. Score: {score:.4f}")
927
+ logger.info(f" Response: {response}")
928
+ if j == 1 and not quality_metrics.get('is_confident', False):
929
+ logger.info(" [Low Confidence - Would abstain from answering]")
930
+
931
+ # Calculate aggregate metrics
932
+ aggregate_metrics = {
933
+ 'num_queries_tested': len(test_queries),
934
+ 'avg_top_response_score': np.mean([m.get('top_score', 0) for m in metrics_history]),
935
+ 'avg_diversity': np.mean([m.get('response_diversity', 0) for m in metrics_history]),
936
+ 'avg_relevance': np.mean([m.get('query_response_relevance', 0) for m in metrics_history]),
937
+ 'avg_length_score': np.mean([m.get('response_length_score', 0) for m in metrics_history]),
938
+ 'avg_score_gap': np.mean([m.get('top_3_score_gap', 0) for m in metrics_history]),
939
+ 'confidence_rate': np.mean([m.get('is_confident', False) for m in metrics_history]),
940
+ }
941
+
942
+ logger.info("\n=== Validation Summary ===")
943
+ for metric, value in aggregate_metrics.items():
944
+ logger.info(f"{metric}: {value:.4f}")
945
+
946
+ return aggregate_metrics
947
+
948
+ def chat(
949
+ self,
950
+ query: str,
951
+ conversation_history: Optional[List[Tuple[str, str]]] = None,
952
+ quality_checker: Optional['ResponseQualityChecker'] = None,
953
+ top_k: int = 5
954
+ ) -> Tuple[str, List[Tuple[str, float]], Dict[str, Any]]:
955
+ """
956
+ Interactive chat function with quality checking using FAISS-based retrieval.
957
+
958
+ Args:
959
+ query (str): The user's input query.
960
+ conversation_history (Optional[List[Tuple[str, str]]]): List of past (user, assistant) exchanges.
961
+ quality_checker (Optional['ResponseQualityChecker']): Quality checker instance.
962
+ top_k (int): Number of top responses to retrieve.
963
+
964
+ Returns:
965
+ Tuple[str, List[Tuple[str, float]], Dict[str, Any]]: (Response, Candidates, Quality Metrics)
966
+ """
967
+ # Retrieve responses using FAISS
968
+ responses = self.retrieve_responses_faiss(query, top_k)
969
+
970
+ # If no quality checker provided, return the top response
971
+ if quality_checker is None:
972
+ return responses[0][0] if responses else "I'm sorry, I don't have an answer for that.", responses, {}
973
+
974
+ # Check quality
975
+ quality_metrics = quality_checker.check_response_quality(query, responses)
976
+
977
+ if quality_metrics.get('is_confident', False):
978
+ return responses[0][0], responses, quality_metrics
979
+ else:
980
+ uncertainty_response = (
981
+ "I apologize, but I don't feel confident providing an answer to that "
982
+ "question at the moment. Could you please rephrase or ask something else?"
983
+ )
984
+ return uncertainty_response, responses, quality_metrics
985
+
986
+ # TODO: consider removal
987
+ # def prepare_dataset(self, dialogues: List[dict], neg_samples_per_pos: int = 1, debug_samples: int = None) -> Tuple[tf.Tensor, tf.Tensor, tf.Tensor]:
988
+ # """Prepares the dataset for training."""
989
+ # logger.info("Preparing dataset...")
990
+
991
+ # # Extract (query, positive, negative) triples
992
+ # queries, positives, negatives = [], [], []
993
+
994
+ # for dialogue in dialogues:
995
+ # turns = dialogue.get('turns', [])
996
+ # for i in range(len(turns) - 1):
997
+ # current_turn = turns[i]
998
+ # next_turn = turns[i+1]
999
+
1000
+ # if (current_turn.get('speaker') == 'user' and
1001
+ # next_turn.get('speaker') == 'assistant' and
1002
+ # 'text' in current_turn and
1003
+ # 'text' in next_turn):
1004
+
1005
+ # query = current_turn['text'].strip()
1006
+ # positive = next_turn['text'].strip()
1007
+
1008
+ # # Generate hard negative samples
1009
+ # hard_negatives = self.hard_negative_sampling(positive, n_samples=neg_samples_per_pos)
1010
+ # for negative in hard_negatives:
1011
+ # negatives.append(negative)
1012
+ # queries.append(query)
1013
+ # positives.append(positive)
1014
+
1015
+ # logger.info(f"Prepared {len(queries)} training examples.")
1016
+
1017
+ # # Tokenize and pad sequences
1018
+ # encoded_queries = self.tokenizer(
1019
+ # queries,
1020
+ # padding='max_length',
1021
+ # truncation=True,
1022
+ # max_length=self.config.max_sequence_length,
1023
+ # return_tensors='tf'
1024
+ # )
1025
+ # encoded_positives = self.tokenizer(
1026
+ # positives,
1027
+ # padding='max_length',
1028
+ # truncation=True,
1029
+ # max_length=self.config.max_sequence_length,
1030
+ # return_tensors='tf'
1031
+ # )
1032
+ # encoded_negatives = self.tokenizer(
1033
+ # negatives,
1034
+ # padding='max_length',
1035
+ # truncation=True,
1036
+ # max_length=self.config.max_sequence_length,
1037
+ # return_tensors='tf'
1038
+ # )
1039
+
1040
+ # q_tensor = encoded_queries['input_ids']
1041
+ # p_tensor = encoded_positives['input_ids']
1042
+ # n_tensor = encoded_negatives['input_ids']
1043
+
1044
+ # logger.info(f"Tokenized and padded sequences.")
1045
+
1046
+ # return q_tensor, p_tensor, n_tensor
1047
+
1048
+
1049
+ # # TODO: consider removal
1050
+ # def hard_negative_sampling(self, positive_response, n_samples=1):
1051
+ # """Select hard negatives based on cosine similarity."""
1052
+ # try:
1053
+ # # Ensure we don't request more negatives than available
1054
+ # max_neg_samples = len(self.response_pool) - 1 # Exclude the positive response
1055
+ # n_samples = min(n_samples, max_neg_samples)
1056
+
1057
+ # if n_samples <= 0:
1058
+ # logger.error("Not enough responses to sample negatives.")
1059
+ # return []
1060
+
1061
+ # # Encode the positive response using the chatbot's encode_responses method
1062
+ # pos_emb = self.encode_responses([positive_response]).numpy()
1063
+ # faiss.normalize_L2(pos_emb)
1064
+ # #logger.info(f"Normalized positive embedding for response: {positive_response}")
1065
+
1066
+ # # Search for the top n_samples + 1 most similar responses (including the positive itself)
1067
+ # D, I = self.index.search(pos_emb, n_samples + 1)
1068
+ # #logger.info(f"FAISS search results: {I}")
1069
+
1070
+ # # Exclude the positive response itself (assuming it's indexed)
1071
+ # negatives = []
1072
+ # for i in range(n_samples):
1073
+ # idx = I[0][i + 1] # Skip the first one as it's the positive
1074
+ # if idx < len(self.response_pool):
1075
+ # negative_response = self.response_pool[idx]
1076
+ # negatives.append(negative_response)
1077
+ # logger.info(f"Selected negative: {negative_response}")
1078
+ # else:
1079
+ # logger.warning(f"Index {idx} out of range for response_pool with size {len(self.response_pool)}.")
1080
+
1081
+ # return negatives
1082
+ # except Exception as e:
1083
+ # logger.error(f"An error occurred during hard negative sampling: {e}")
1084
+ # return []
1085
+
1086
+ # def train(
1087
+ # self,
1088
+ # q_pad: tf.Tensor,
1089
+ # p_pad: tf.Tensor,
1090
+ # n_pad: tf.Tensor,
1091
+ # epochs: int,
1092
+ # batch_size: int,
1093
+ # validation_split: float,
1094
+ # checkpoint_dir: str,
1095
+ # callbacks: Optional[List[tf.keras.callbacks.Callback]] = None
1096
+ # ):
1097
+ # """
1098
+ # Train the chatbot model.
1099
+
1100
+ # Args:
1101
+ # q_pad (tf.Tensor): Padded query input_ids.
1102
+ # p_pad (tf.Tensor): Padded positive response input_ids.
1103
+ # n_pad (tf.Tensor): Padded negative response input_ids.
1104
+ # epochs (int): Number of training epochs.
1105
+ # batch_size (int): Training batch size.
1106
+ # validation_split (float): Fraction of data to use for validation.
1107
+ # checkpoint_dir (str): Directory to save model checkpoints.
1108
+ # callbacks (list, optional): List of Keras callbacks.
1109
+ # """
1110
+ # dataset_size = tf.shape(q_pad)[0].numpy()
1111
+ # val_size = int(dataset_size * validation_split)
1112
+ # train_size = dataset_size - val_size
1113
+
1114
+ # logger.info(f"Total samples: {dataset_size}")
1115
+ # logger.info(f"Training samples: {train_size}")
1116
+ # logger.info(f"Validation samples: {val_size}")
1117
+
1118
+ # # Calculate steps_per_epoch
1119
+ # steps_per_epoch = train_size // batch_size
1120
+ # if train_size % batch_size != 0:
1121
+ # steps_per_epoch += 1
1122
+ # total_steps = steps_per_epoch * epochs
1123
+
1124
+ # logger.info(f"Total training steps: {total_steps}")
1125
+
1126
+ # # Initialize learning rate schedule with adjusted warmup_steps
1127
+ # lr_schedule = self._get_lr_schedule(
1128
+ # total_steps=total_steps,
1129
+ # peak_lr=self.config.learning_rate,
1130
+ # warmup_steps=self.config.warmup_steps
1131
+ # )
1132
+
1133
+ # # callbacks = []
1134
+ # # if checkpoint_dir:
1135
+ # # checkpoint_dir = Path(checkpoint_dir)
1136
+ # # checkpoint_dir.mkdir(parents=True, exist_ok=True)
1137
+
1138
+ # # # Setup checkpoint callback with correct file format
1139
+ # # checkpoint_template = str(checkpoint_dir / "model_epoch_{epoch:04d}.weights.h5")
1140
+ # # checkpoint_callback = tf.keras.callbacks.ModelCheckpoint(
1141
+ # # checkpoint_template,
1142
+ # # save_weights_only=True,
1143
+ # # save_best_only=True,
1144
+ # # monitor='val_loss',
1145
+ # # mode='min',
1146
+ # # verbose=1
1147
+ # # )
1148
+ # # callbacks.append(checkpoint_callback)
1149
+
1150
+ # # # Early stopping callback
1151
+ # # early_stopping = tf.keras.callbacks.EarlyStopping(
1152
+ # # monitor='val_loss',
1153
+ # # patience=5,
1154
+ # # restore_best_weights=True,
1155
+ # # verbose=1
1156
+ # # )
1157
+ # # callbacks.append(early_stopping)
1158
+
1159
+ # # # TensorBoard callback
1160
+ # # tensorboard_callback = tf.keras.callbacks.TensorBoard(log_dir='./logs')
1161
+ # # callbacks.append(tensorboard_callback)
1162
+
1163
+ # # Update optimizer with the new learning rate schedule
1164
+ # self.optimizer = tf.keras.optimizers.Adam(learning_rate=lr_schedule)
1165
+
1166
+ # # Split the data
1167
+ # train_q = q_pad[:train_size]
1168
+ # train_p = p_pad[:train_size]
1169
+ # train_n = n_pad[:train_size]
1170
+
1171
+ # val_q = q_pad[train_size:]
1172
+ # val_p = p_pad[train_size:]
1173
+ # val_n = n_pad[train_size:]
1174
+
1175
+ # # Create TensorFlow datasets
1176
+ # train_dataset = tf.data.Dataset.from_tensor_slices((train_q, train_p, train_n))
1177
+ # train_dataset = train_dataset.shuffle(buffer_size=1000).batch(batch_size)
1178
+
1179
+ # val_dataset = tf.data.Dataset.from_tensor_slices((val_q, val_p, val_n))
1180
+ # val_dataset = val_dataset.batch(batch_size)
1181
+
1182
+ # # Log dataset sizes
1183
+ # logger.info(f"Training dataset batches: {len(list(train_dataset))}")
1184
+ # logger.info(f"Validation dataset batches: {len(list(val_dataset))}")
1185
+
1186
+ # # Create checkpoint manager
1187
+ # checkpoint = tf.train.Checkpoint(optimizer=self.optimizer, model=self.encoder)
1188
+ # manager = tf.train.CheckpointManager(checkpoint, checkpoint_dir, max_to_keep=3)
1189
+
1190
+ # for epoch in range(1, epochs + 1):
1191
+ # logger.info(f"Epoch {epoch}/{epochs}")
1192
+ # epoch_loss_avg = tf.keras.metrics.Mean()
1193
+
1194
+ # # Training loop
1195
+ # for q_batch, p_batch, n_batch in train_dataset:
1196
+ # batch_loss = self._train_step(q_batch, p_batch, n_batch)
1197
+ # epoch_loss_avg(batch_loss)
1198
+
1199
+ # # Validation loop
1200
+ # val_loss_avg = tf.keras.metrics.Mean()
1201
+ # try:
1202
+ # for q_val, p_val, n_val in val_dataset:
1203
+ # # Encode queries, positives, and negatives without training
1204
+ # q_enc = self.encoder(q_val, training=False)
1205
+ # p_enc = self.encoder(p_val, training=False)
1206
+ # n_enc = self.encoder(n_val, training=False)
1207
+
1208
+ # # Compute cosine similarities
1209
+ # pos_sim = tf.reduce_sum(tf.multiply(q_enc, p_enc), axis=1)
1210
+ # neg_sim = tf.reduce_sum(tf.multiply(q_enc, n_enc), axis=1)
1211
+
1212
+ # # Ensure similarities are float32
1213
+ # pos_sim = tf.cast(pos_sim, tf.float32)
1214
+ # neg_sim = tf.cast(neg_sim, tf.float32)
1215
+
1216
+ # # Compute loss with margin
1217
+ # margin = tf.cast(self.config.margin, tf.float32)
1218
+ # loss = tf.maximum(0.0, margin - pos_sim + neg_sim)
1219
+
1220
+ # val_loss_avg(tf.reduce_mean(loss))
1221
+
1222
+ # # Optional: Log individual batch validation loss
1223
+ # logger.debug(f"Batch Validation Loss: {tf.reduce_mean(loss).numpy():.6f}")
1224
+
1225
+ # train_loss = epoch_loss_avg.result().numpy()
1226
+ # val_loss = val_loss_avg.result().numpy()
1227
+ # logger.info(f"Epoch {epoch}: Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}")
1228
+
1229
+ # # Save checkpoint
1230
+ # manager.save()
1231
+
1232
+ # # Update history
1233
+ # self.history['train_loss'].append(train_loss)
1234
+ # self.history['val_loss'].append(val_loss)
1235
+
1236
+ # # Invoke callbacks if any
1237
+ # if callbacks:
1238
+ # for callback in callbacks:
1239
+ # callback.on_epoch_end(epoch, logs={'loss': train_loss, 'val_loss': val_loss})
1240
+
1241
+ # except tf.errors.OutOfRangeError:
1242
+ # logger.warning("Validation dataset is exhausted before expected.")
1243
+ # self.history['val_loss'].append(val_loss_avg.result().numpy())
1244
+
1245
+ # logger.info("Training completed.")
1246
+
1247
+ # @tf.function
1248
+ # def _train_step(self, q_batch, p_batch, n_batch):
1249
+ # """
1250
+ # Performs a single training step with query, positive, and negative batches.
1251
+
1252
+ # Args:
1253
+ # q_batch (tf.Tensor): Batch of query input_ids.
1254
+ # p_batch (tf.Tensor): Batch of positive response input_ids.
1255
+ # n_batch (tf.Tensor): Batch of negative response input_ids.
1256
+
1257
+ # Returns:
1258
+ # tf.Tensor: Mean loss for the batch.
1259
+ # """
1260
+ # with tf.GradientTape() as tape:
1261
+ # # Encode queries, positives, and negatives using the shared encoder
1262
+ # q_enc = self.encoder(q_batch, training=True) # Shape: (batch_size, embedding_dim)
1263
+ # p_enc = self.encoder(p_batch, training=True) # Shape: (batch_size, embedding_dim)
1264
+ # n_enc = self.encoder(n_batch, training=True) # Shape: (batch_size, embedding_dim)
1265
+
1266
+ # # Compute cosine similarities
1267
+ # pos_sim = tf.reduce_sum(tf.multiply(q_enc, p_enc), axis=1) # Shape: (batch_size,)
1268
+ # neg_sim = tf.reduce_sum(tf.multiply(q_enc, n_enc), axis=1) # Shape: (batch_size,)
1269
+
1270
+ # # Ensure similarities are float32
1271
+ # pos_sim = tf.cast(pos_sim, tf.float32)
1272
+ # neg_sim = tf.cast(neg_sim, tf.float32)
1273
+
1274
+ # # Compute loss with margin
1275
+ # margin = tf.cast(self.config.margin, tf.float32)
1276
+ # loss = tf.maximum(0.0, margin - pos_sim + neg_sim)
1277
+
1278
+ # # Compute gradients and update encoder weights
1279
+ # gradients = tape.gradient(loss, self.encoder.pretrained.trainable_variables)
1280
+
1281
+ # # Filter out None gradients (if any)
1282
+ # grads_and_vars = [
1283
+ # (g, v) for g, v in zip(gradients, self.encoder.pretrained.trainable_variables)
1284
+ # if g is not None
1285
+ # ]
1286
+
1287
+ # if grads_and_vars:
1288
+ # self.optimizer.apply_gradients(grads_and_vars)
1289
+
1290
+ # # Return mean loss
1291
+ # return tf.reduce_mean(loss)
dialogue_augmenter.py CHANGED
@@ -3,11 +3,9 @@ import numpy as np
3
  import torch
4
  import tensorflow as tf
5
  import tensorflow_hub as hub
6
- import re
7
  from pipeline_config import PipelineConfig
8
  from quality_metrics import QualityMetrics
9
  from paraphraser import Paraphraser
10
- from back_translator import BackTranslator
11
  import nlpaug.augmenter.word as naw
12
  from concurrent.futures import ThreadPoolExecutor
13
  from functools import lru_cache
@@ -29,9 +27,12 @@ class DialogueAugmenter:
29
  print(f"Using device: {self.device}")
30
  if self.use_gpu:
31
  print(f"GPU Device: {torch.cuda.get_device_name(0)}")
 
32
 
33
- # Load base models
34
  self.quality_metrics = QualityMetrics(config)
 
 
 
35
  self.use_model = hub.load('https://tfhub.dev/google/universal-sentence-encoder/4')
36
 
37
  # Initialize augmentation models based on hardware
@@ -39,10 +40,6 @@ class DialogueAugmenter:
39
 
40
  # Initialize caches
41
  self.embedding_cache = {}
42
- self.perplexity_cache = {}
43
-
44
- # Compile regex patterns
45
- self.spelling_pattern = re.compile(r'[a-zA-Z]{3,}')
46
 
47
  # GPU memory management if available
48
  if self.use_gpu:
@@ -57,25 +54,20 @@ class DialogueAugmenter:
57
  def _initialize_augmentation_models(self):
58
  """Initialize augmentation models with appropriate device settings"""
59
  # Advanced augmentation techniques
60
- self.paraphraser = Paraphraser()
61
- self.back_translator = BackTranslator()
62
-
63
  if self.use_gpu:
64
- # Move models to GPU if available
65
  self.paraphraser.model = self.paraphraser.model.to(self.device)
66
- self.back_translator.model_pivot_forward = self.back_translator.model_pivot_forward.to(self.device)
67
- self.back_translator.model_pivot_backward = self.back_translator.model_pivot_backward.to(self.device)
68
- self.back_translator.model_backward = self.back_translator.model_backward.to(self.device)
69
 
70
  # Basic augmentation techniques
71
  self.word_augmenter = naw.SynonymAug(aug_src='wordnet')
72
- self.spelling_augmenter = naw.SpellingAug()
73
 
74
  self.augmenters = {
75
- 'advanced': [self.paraphraser, self.back_translator],
 
 
76
  'basic': [
77
  ('synonym', self.word_augmenter),
78
- ('spelling', self.spelling_augmenter)
79
  ]
80
  }
81
 
@@ -103,52 +95,46 @@ class DialogueAugmenter:
103
 
104
  def _quick_quality_check(self, variation: str, original: str) -> bool:
105
  """
106
- Stricter preliminary quality check while maintaining reasonable pass rates
107
  """
108
  if self.config.debug:
109
  print(f"\nQuick check for variation: {variation}")
110
 
111
- # Stricter length check
112
  orig_len = len(original.split())
113
  var_len = len(variation.split())
114
-
115
- # For very short texts (1-3 words), still allow more variation
116
  if orig_len <= 3:
117
- if var_len > orig_len * 3: # Reduced from 4x to 3x
118
  if self.config.debug:
119
  print(f"Failed length check (short text): {var_len} vs {orig_len}")
120
  return False
121
  else:
122
- if var_len > orig_len * 2: # Reduced from 3x to 2x
123
  if self.config.debug:
124
  print(f"Failed length check (long text): {var_len} vs {orig_len}")
125
  return False
126
-
127
- # Enhanced content check - more words in common
128
  stop_words = {'a', 'an', 'the', 'and', 'or', 'but', 'in', 'on', 'at', 'to', 'for', 'is', 'are', 'that', 'this', 'will', 'can'}
129
  orig_words = set(w.lower() for w in original.split() if w.lower() not in stop_words)
130
  var_words = set(w.lower() for w in variation.split() if w.lower() not in stop_words)
131
-
132
- # Require more content word overlap
133
- content_overlap = len(orig_words.intersection(var_words)) / len(orig_words) if orig_words else 0
134
- if content_overlap < 0.3: # Increased from no minimum to 30% overlap
 
 
 
 
 
135
  if self.config.debug:
136
- print(f"Failed content check: overlap {content_overlap:.2f}")
137
- return False
138
-
139
  if self.config.debug:
140
  print("Passed all quick checks")
141
  return True
142
 
143
- def _compute_metrics_parallel(self, original: str, candidates: List[str]) -> List[Dict[str, float]]:
144
- """Compute quality metrics for multiple candidates in parallel"""
145
- with ThreadPoolExecutor(max_workers=4) as executor:
146
- futures = [
147
- executor.submit(self.quality_metrics.compute_metrics, original, candidate)
148
- for candidate in candidates
149
- ]
150
- return [future.result() for future in futures]
151
-
152
  def _filter_variations_batch(self, variations: List[str], context: List[str], original_turn: str) -> List[str]:
153
  """
154
  Filter variations using batched computations with detailed logging
@@ -162,12 +148,17 @@ class DialogueAugmenter:
162
  print(f"Original turn: {original_turn}")
163
 
164
  words = original_turn.split()
 
 
 
 
 
165
  if len(words) < 3:
166
  if self.config.debug:
167
  print("Short text detected, using predefined variations")
168
  short_text_variations = self._augment_short_text({'text': original_turn, 'speaker': ''})
169
  return [var['text'] for var in short_text_variations]
170
-
171
  # If this is the first turn (no context), be more lenient
172
  if not context:
173
  preliminary_filtered = variations
@@ -183,57 +174,85 @@ class DialogueAugmenter:
183
  print(f"Passed quick check: {passed}")
184
  if passed:
185
  preliminary_filtered.append(var)
186
-
187
  if self.config.debug:
188
  print(f"Variations after quick check: {len(preliminary_filtered)}")
189
-
190
  if not preliminary_filtered:
191
  return []
192
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
193
  # Only use last turn for coherence
194
  recent_context = [context[-1]] if context else []
195
  context_text = ' '.join(recent_context) if recent_context else ''
196
-
197
- # Even more lenient thresholds
198
- min_similarity = 0.1 # Further reduced
199
- min_coherence = 0.05 # Further reduced
200
-
201
  if context_text:
202
  if self.config.debug:
203
  print(f"\nContext text: {context_text}")
204
-
205
- all_texts = [context_text] + preliminary_filtered
206
  all_embeddings = self._compute_batch_embeddings(all_texts)
207
-
208
  context_embedding = all_embeddings[0]
209
  variation_embeddings = all_embeddings[1:]
210
-
211
  # Vectorized similarity computation
212
  context_similarities = cosine_similarity([context_embedding], variation_embeddings)[0]
213
-
214
  # Response coherence check
215
  if recent_context:
216
  prev_embedding = self._compute_embedding(recent_context[-1])
217
  response_coherence = cosine_similarity([prev_embedding], variation_embeddings)[0]
218
  else:
219
  response_coherence = np.ones_like(context_similarities)
220
-
221
- # Combined scoring with detailed logging
222
  filtered_variations = []
223
  for i, (variation, sim, coh) in enumerate(zip(
224
- preliminary_filtered, context_similarities, response_coherence)):
225
- # Use absolute values for scoring
226
  combined_score = (
227
  self.config.context_similarity_weight * abs(sim) +
228
  self.config.response_coherence_weight * abs(coh)
229
  )
230
-
231
  if self.config.debug:
232
  print(f"\nVariation: {variation}")
233
  print(f"Context similarity: {sim:.3f}")
234
  print(f"Response coherence: {coh:.3f}")
235
  print(f"Combined score: {combined_score:.3f}")
236
-
237
  # Accept if EITHER score is good enough
238
  if (combined_score >= min_similarity or abs(coh) >= min_coherence):
239
  filtered_variations.append(variation)
@@ -242,74 +261,71 @@ class DialogueAugmenter:
242
  else:
243
  if self.config.debug:
244
  print("REJECTED")
245
-
246
  # If we have enough variations, stop
247
  if len(filtered_variations) >= self.config.max_variations_per_turn:
248
  break
249
  else:
250
- filtered_variations = preliminary_filtered[:self.config.max_variations_per_turn]
251
-
252
  if self.config.debug:
253
  print(f"\nFinal filtered variations: {len(filtered_variations)}")
254
-
255
  return filtered_variations
256
 
257
  def _generate_variations_progressive(self, text: str, needed: int) -> List[str]:
258
  """
259
- Generate variations progressively until we have enough good ones
 
260
  """
261
  variations = set()
262
-
263
  if self.config.debug:
264
  print(f"\nAttempting to generate {needed} variations for text: {text}")
265
-
266
- # Try advanced augmenters first
267
  for augmenter in self.augmenters['advanced']:
268
  if len(variations) >= needed:
269
  break
270
-
271
  try:
272
  if isinstance(augmenter, Paraphraser):
273
  if self.config.debug:
274
  print("Trying paraphrase augmentation...")
275
- new_vars = augmenter.paraphrase(text, num_return_sequences=needed-len(variations))
 
 
 
 
 
 
 
276
  if self.config.debug:
277
  print(f"Paraphraser generated {len(new_vars)} variations")
278
- else:
279
- if self.config.debug:
280
- print("Trying back translation...")
281
- new_vars = [augmenter.back_translate(text)]
282
- if self.config.debug:
283
- print(f"Back translator generated {len(new_vars)} variations")
284
-
285
  valid_vars = [v for v in new_vars if v.strip() and v != text]
286
  variations.update(valid_vars)
287
-
288
  if self.config.debug:
289
  print(f"Current unique variations: {len(variations)}")
290
-
291
  except Exception as e:
292
  print(f"Error in advanced augmentation: {str(e)}")
293
  continue
294
-
295
  # Try basic augmenters if needed
296
  if len(variations) < needed:
297
  if self.config.debug:
298
  print("Not enough variations, trying basic augmenters...")
299
-
300
  for aug_type, augmenter in self.augmenters['basic']:
301
  if len(variations) >= needed:
302
  break
303
-
304
  try:
305
- if aug_type == 'spelling' and self._is_technical_or_formal_text(text):
306
- if self.config.debug:
307
- print("Skipping spelling augmentation for technical text")
308
- continue
309
-
310
  if self.config.debug:
311
  print(f"Trying {aug_type} augmentation...")
312
-
313
  new_vars = augmenter.augment(text, n=2)
314
  if isinstance(new_vars, list):
315
  valid_vars = [v for v in new_vars if v.strip() and v != text]
@@ -317,21 +333,21 @@ class DialogueAugmenter:
317
  else:
318
  if new_vars.strip() and new_vars != text:
319
  variations.add(new_vars)
320
-
321
  if self.config.debug:
322
  print(f"After {aug_type}, total variations: {len(variations)}")
323
-
324
  except Exception as e:
325
  print(f"Error in {aug_type} augmentation: {str(e)}")
326
  continue
327
-
328
  variations_list = list(variations)
329
-
330
  if self.config.debug:
331
  print(f"Final number of variations generated: {len(variations_list)}")
332
  if not variations_list:
333
  print("WARNING: No variations were generated!")
334
-
335
  return variations_list
336
 
337
  def augment_dialogue(self, dialogue: Dict) -> List[Dict]:
@@ -375,7 +391,8 @@ class DialogueAugmenter:
375
  # Generate combinations with sampling
376
  augmented_dialogues = self._generate_dialogue_combinations(
377
  dialogue['dialogue_id'],
378
- turn_variations
 
379
  )
380
 
381
  # Add original dialogue
@@ -392,47 +409,201 @@ class DialogueAugmenter:
392
 
393
  return result
394
 
395
- def _generate_dialogue_combinations(self, dialogue_id: str, turn_variations: List[List[Dict]]) -> List[Dict]:
 
 
 
 
 
 
 
 
 
 
 
 
396
  """
397
- Generate dialogue combinations using sampling
 
 
398
  """
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
399
  augmented_dialogues = []
400
  used_combinations = set()
401
-
402
- def generate_dialogues(current_turns=None, turn_index=0):
403
  if current_turns is None:
404
  current_turns = []
405
 
406
- if len(augmented_dialogues) >= self.config.augmentation_factor:
407
  return
408
 
409
  if turn_index == len(turn_variations):
 
410
  dialogue_fingerprint = " | ".join(turn['text'] for turn in current_turns)
411
  if dialogue_fingerprint not in used_combinations:
412
  used_combinations.add(dialogue_fingerprint)
413
- augmented_dialogues.append({
414
- 'dialogue_id': f"{dialogue_id}_aug_{len(augmented_dialogues)}",
415
- 'turns': current_turns.copy()
416
- })
 
 
 
 
 
417
  return
418
-
419
- variations = list(turn_variations[turn_index])
420
- np.random.shuffle(variations)
421
-
422
- for variation in variations[:self.config.max_sampled_variations]:
423
- if len(augmented_dialogues) >= self.config.augmentation_factor:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
424
  return
425
  current_turns.append(variation)
426
- generate_dialogues(current_turns, turn_index + 1)
427
  current_turns.pop()
428
-
429
  try:
430
- generate_dialogues()
431
  except Exception as e:
432
  print(f"Error in dialogue generation: {str(e)}")
433
  return []
434
-
435
- return augmented_dialogues
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
436
 
437
  def _is_dialogue_duplicate(self, dialogue1: Dict, dialogue2: Dict) -> bool:
438
  """
@@ -445,11 +616,9 @@ class DialogueAugmenter:
445
  def _augment_short_text(self, turn: Dict) -> List[Dict]:
446
  """
447
  Special handling for very short texts with predefined variations.
448
- Args:
449
- turn (Dict): Original dialogue turn
450
-
451
- Returns:
452
- List[Dict]: List of variations for the short text
453
  """
454
  text = turn['text']
455
  common_variations = {
@@ -483,71 +652,60 @@ class DialogueAugmenter:
483
  'Fantastic!', 'Amazing!', 'Terrific!'
484
  ]
485
  }
486
-
487
- # Try to find matching variations
488
  text_lower = text.lower().rstrip('!.,?')
 
489
  variations = []
490
-
491
- # Check if text matches any of our predefined categories
492
  for key, predefined_vars in common_variations.items():
493
  if key in text_lower or text_lower in key:
494
  variations.extend(predefined_vars)
495
 
496
- # If no predefined variations found, generate simple variants
497
  if not variations:
498
- # Add punctuation variations
 
499
  variations = [
500
- text.rstrip('!.,?') + '!',
501
- text.rstrip('!.,?') + '.',
502
- text.rstrip('!.,?')
503
  ]
504
 
505
  # Add capitalization variations
506
- variations.extend([
507
- v.capitalize() for v in variations
508
- if v.capitalize() not in variations
509
- ])
510
 
511
- # Filter variations for uniqueness and quality
512
  unique_variations = list(set(variations))
513
- quality_variations = []
514
-
515
- for var in unique_variations:
516
- metrics = self.quality_metrics.compute_metrics(text, var)
517
- quality_score = (
518
- 0.35 * metrics['semantic_similarity'] +
519
- 0.30 * (1.0 - metrics['perplexity'] / 100) +
520
- 0.15 * (1.0 - metrics['grammar_errors'] / 10) +
521
- 0.15 * metrics['content_preservation'] +
522
- 0.10 * metrics['type_token_ratio']
523
- )
524
-
525
- # More lenient quality threshold for short texts
526
- if quality_score >= 0.5: # Lower threshold for short texts
527
- quality_variations.append(var)
528
-
529
- # Ensure we have at least some variations
530
- if not quality_variations:
531
- quality_variations = [text]
532
 
533
- # Return the variations with original speaker
534
- return [{'speaker': turn['speaker'], 'text': v} for v in quality_variations[:self.config.augmentation_factor]]
 
 
535
 
536
- def _is_technical_or_formal_text(self, text: str) -> bool:
537
- """
538
- Check if text is formal/technical and shouldn't have spelling variations.
539
- """
540
- formal_indicators = {
541
- 'technical_terms': {'api', 'config', 'database', 'server', 'system'},
542
- 'formal_phrases': {'please advise', 'regarding', 'furthermore', 'moreover'},
543
- 'professional_context': {'meeting', 'conference', 'project', 'deadline'}
544
- }
545
-
546
- text_lower = text.lower()
547
- words = set(text_lower.split())
 
 
 
 
 
548
 
549
- for category in formal_indicators.values():
550
- if words.intersection(category):
551
- return True
 
 
 
 
 
552
 
553
- return False
 
3
  import torch
4
  import tensorflow as tf
5
  import tensorflow_hub as hub
 
6
  from pipeline_config import PipelineConfig
7
  from quality_metrics import QualityMetrics
8
  from paraphraser import Paraphraser
 
9
  import nlpaug.augmenter.word as naw
10
  from concurrent.futures import ThreadPoolExecutor
11
  from functools import lru_cache
 
27
  print(f"Using device: {self.device}")
28
  if self.use_gpu:
29
  print(f"GPU Device: {torch.cuda.get_device_name(0)}")
30
+
31
 
 
32
  self.quality_metrics = QualityMetrics(config)
33
+ self.semantic_similarity_threshold = 0.75
34
+
35
+ # Load model
36
  self.use_model = hub.load('https://tfhub.dev/google/universal-sentence-encoder/4')
37
 
38
  # Initialize augmentation models based on hardware
 
40
 
41
  # Initialize caches
42
  self.embedding_cache = {}
 
 
 
 
43
 
44
  # GPU memory management if available
45
  if self.use_gpu:
 
54
  def _initialize_augmentation_models(self):
55
  """Initialize augmentation models with appropriate device settings"""
56
  # Advanced augmentation techniques
57
+ self.paraphraser = Paraphraser()
 
 
58
  if self.use_gpu:
59
+ # Move model to GPU if available
60
  self.paraphraser.model = self.paraphraser.model.to(self.device)
 
 
 
61
 
62
  # Basic augmentation techniques
63
  self.word_augmenter = naw.SynonymAug(aug_src='wordnet')
 
64
 
65
  self.augmenters = {
66
+ 'advanced': [
67
+ self.paraphraser,
68
+ ],
69
  'basic': [
70
  ('synonym', self.word_augmenter),
 
71
  ]
72
  }
73
 
 
95
 
96
  def _quick_quality_check(self, variation: str, original: str) -> bool:
97
  """
98
+ Preliminary quality check while maintaining reasonable pass rates
99
  """
100
  if self.config.debug:
101
  print(f"\nQuick check for variation: {variation}")
102
 
 
103
  orig_len = len(original.split())
104
  var_len = len(variation.split())
105
+
106
+ # For very short texts (<= 3 words), still allow more variation
107
  if orig_len <= 3:
108
+ if var_len > orig_len * 3:
109
  if self.config.debug:
110
  print(f"Failed length check (short text): {var_len} vs {orig_len}")
111
  return False
112
  else:
113
+ if var_len > orig_len * 2:
114
  if self.config.debug:
115
  print(f"Failed length check (long text): {var_len} vs {orig_len}")
116
  return False
117
+
118
+ # Adjust content overlap check based on length
119
  stop_words = {'a', 'an', 'the', 'and', 'or', 'but', 'in', 'on', 'at', 'to', 'for', 'is', 'are', 'that', 'this', 'will', 'can'}
120
  orig_words = set(w.lower() for w in original.split() if w.lower() not in stop_words)
121
  var_words = set(w.lower() for w in variation.split() if w.lower() not in stop_words)
122
+
123
+ # If very short turn (less than 5 words), skip the content overlap check
124
+ if orig_len >= 5:
125
+ content_overlap = len(orig_words.intersection(var_words)) / len(orig_words) if orig_words else 0
126
+ if content_overlap < 0.2:
127
+ if self.config.debug:
128
+ print(f"Failed content check: overlap {content_overlap:.2f}")
129
+ return False
130
+ else:
131
  if self.config.debug:
132
+ print("Short turn detected (<5 words), skipping content overlap check")
133
+
 
134
  if self.config.debug:
135
  print("Passed all quick checks")
136
  return True
137
 
 
 
 
 
 
 
 
 
 
138
  def _filter_variations_batch(self, variations: List[str], context: List[str], original_turn: str) -> List[str]:
139
  """
140
  Filter variations using batched computations with detailed logging
 
148
  print(f"Original turn: {original_turn}")
149
 
150
  words = original_turn.split()
151
+ orig_len = len(words)
152
+
153
+ # If very short text, consider adjusting thresholds
154
+ is_very_short = orig_len < 5
155
+
156
  if len(words) < 3:
157
  if self.config.debug:
158
  print("Short text detected, using predefined variations")
159
  short_text_variations = self._augment_short_text({'text': original_turn, 'speaker': ''})
160
  return [var['text'] for var in short_text_variations]
161
+
162
  # If this is the first turn (no context), be more lenient
163
  if not context:
164
  preliminary_filtered = variations
 
174
  print(f"Passed quick check: {passed}")
175
  if passed:
176
  preliminary_filtered.append(var)
177
+
178
  if self.config.debug:
179
  print(f"Variations after quick check: {len(preliminary_filtered)}")
180
+
181
  if not preliminary_filtered:
182
  return []
183
+
184
+ # Compute embeddings for original and variations
185
+ original_embedding = self._compute_embedding(original_turn)
186
+ variation_embeddings = self._compute_batch_embeddings(preliminary_filtered)
187
+
188
+ # Compute similarities
189
+ sims = cosine_similarity([original_embedding], variation_embeddings)[0]
190
+
191
+ # If very short turn, slightly lower the semantic similarity threshold
192
+ dynamic_sem_threshold = self.semantic_similarity_threshold
193
+ if is_very_short:
194
+ dynamic_sem_threshold = max(0.7, self.semantic_similarity_threshold - 0.05)
195
+
196
+ # Filter by semantic similarity threshold
197
+ refined_filtered = []
198
+ for var, sim in zip(preliminary_filtered, sims):
199
+ if sim >= dynamic_sem_threshold:
200
+ refined_filtered.append(var)
201
+ else:
202
+ if self.config.debug:
203
+ print(f"Variation '{var}' discarded due to low semantic similarity: {sim:.3f}")
204
+
205
+ if not refined_filtered:
206
+ return []
207
+
208
+ # Relax context coherence thresholds further if desired
209
+ # We already have min_similarity = 0.1, min_coherence = 0.05
210
+ # Let's lower them slightly more if the turn is very short:
211
+ if is_very_short:
212
+ min_similarity = 0.05
213
+ min_coherence = 0.02
214
+ else:
215
+ min_similarity = 0.1
216
+ min_coherence = 0.05
217
+
218
  # Only use last turn for coherence
219
  recent_context = [context[-1]] if context else []
220
  context_text = ' '.join(recent_context) if recent_context else ''
221
+
 
 
 
 
222
  if context_text:
223
  if self.config.debug:
224
  print(f"\nContext text: {context_text}")
225
+
226
+ all_texts = [context_text] + refined_filtered
227
  all_embeddings = self._compute_batch_embeddings(all_texts)
228
+
229
  context_embedding = all_embeddings[0]
230
  variation_embeddings = all_embeddings[1:]
231
+
232
  # Vectorized similarity computation
233
  context_similarities = cosine_similarity([context_embedding], variation_embeddings)[0]
234
+
235
  # Response coherence check
236
  if recent_context:
237
  prev_embedding = self._compute_embedding(recent_context[-1])
238
  response_coherence = cosine_similarity([prev_embedding], variation_embeddings)[0]
239
  else:
240
  response_coherence = np.ones_like(context_similarities)
241
+
 
242
  filtered_variations = []
243
  for i, (variation, sim, coh) in enumerate(zip(
244
+ refined_filtered, context_similarities, response_coherence)):
 
245
  combined_score = (
246
  self.config.context_similarity_weight * abs(sim) +
247
  self.config.response_coherence_weight * abs(coh)
248
  )
249
+
250
  if self.config.debug:
251
  print(f"\nVariation: {variation}")
252
  print(f"Context similarity: {sim:.3f}")
253
  print(f"Response coherence: {coh:.3f}")
254
  print(f"Combined score: {combined_score:.3f}")
255
+
256
  # Accept if EITHER score is good enough
257
  if (combined_score >= min_similarity or abs(coh) >= min_coherence):
258
  filtered_variations.append(variation)
 
261
  else:
262
  if self.config.debug:
263
  print("REJECTED")
264
+
265
  # If we have enough variations, stop
266
  if len(filtered_variations) >= self.config.max_variations_per_turn:
267
  break
268
  else:
269
+ filtered_variations = refined_filtered[:self.config.max_variations_per_turn]
270
+
271
  if self.config.debug:
272
  print(f"\nFinal filtered variations: {len(filtered_variations)}")
273
+
274
  return filtered_variations
275
 
276
  def _generate_variations_progressive(self, text: str, needed: int) -> List[str]:
277
  """
278
+ Generate variations progressively until we have enough good ones.
279
+ Adjust paraphraser parameters for closer paraphrases as needed.
280
  """
281
  variations = set()
282
+
283
  if self.config.debug:
284
  print(f"\nAttempting to generate {needed} variations for text: {text}")
285
+
286
+ # Fine-tune paraphraser here if needed: fewer beams, less diversity already done
287
  for augmenter in self.augmenters['advanced']:
288
  if len(variations) >= needed:
289
  break
290
+
291
  try:
292
  if isinstance(augmenter, Paraphraser):
293
  if self.config.debug:
294
  print("Trying paraphrase augmentation...")
295
+ new_vars = augmenter.paraphrase(
296
+ text,
297
+ num_return_sequences=needed-len(variations),
298
+ device=self.device if self.use_gpu else None,
299
+ num_beams=4, # even fewer beams for more faithful paraphrases
300
+ num_beam_groups=1,
301
+ diversity_penalty=0.0
302
+ )
303
  if self.config.debug:
304
  print(f"Paraphraser generated {len(new_vars)} variations")
305
+
 
 
 
 
 
 
306
  valid_vars = [v for v in new_vars if v.strip() and v != text]
307
  variations.update(valid_vars)
308
+
309
  if self.config.debug:
310
  print(f"Current unique variations: {len(variations)}")
311
+
312
  except Exception as e:
313
  print(f"Error in advanced augmentation: {str(e)}")
314
  continue
315
+
316
  # Try basic augmenters if needed
317
  if len(variations) < needed:
318
  if self.config.debug:
319
  print("Not enough variations, trying basic augmenters...")
320
+
321
  for aug_type, augmenter in self.augmenters['basic']:
322
  if len(variations) >= needed:
323
  break
324
+
325
  try:
 
 
 
 
 
326
  if self.config.debug:
327
  print(f"Trying {aug_type} augmentation...")
328
+
329
  new_vars = augmenter.augment(text, n=2)
330
  if isinstance(new_vars, list):
331
  valid_vars = [v for v in new_vars if v.strip() and v != text]
 
333
  else:
334
  if new_vars.strip() and new_vars != text:
335
  variations.add(new_vars)
336
+
337
  if self.config.debug:
338
  print(f"After {aug_type}, total variations: {len(variations)}")
339
+
340
  except Exception as e:
341
  print(f"Error in {aug_type} augmentation: {str(e)}")
342
  continue
343
+
344
  variations_list = list(variations)
345
+
346
  if self.config.debug:
347
  print(f"Final number of variations generated: {len(variations_list)}")
348
  if not variations_list:
349
  print("WARNING: No variations were generated!")
350
+
351
  return variations_list
352
 
353
  def augment_dialogue(self, dialogue: Dict) -> List[Dict]:
 
391
  # Generate combinations with sampling
392
  augmented_dialogues = self._generate_dialogue_combinations(
393
  dialogue['dialogue_id'],
394
+ turn_variations,
395
+ dialogue
396
  )
397
 
398
  # Add original dialogue
 
409
 
410
  return result
411
 
412
+ def _variation_score(self, original: str, variation: str) -> float:
413
+ """
414
+ Compute a single numeric score for a variation to guide selection.
415
+ You could use semantic similarity, content preservation, etc.
416
+ Higher is better.
417
+ """
418
+ metrics = self.quality_metrics.compute_metrics(original, variation)
419
+ # Example: Primarily semantic similarity, with a slight boost for content preservation
420
+ # Adjust as needed.
421
+ score = metrics['semantic_similarity'] * 0.7 + metrics['content_preservation'] * 0.3
422
+ return score
423
+
424
+ def _dialogue_quality_score(self, dialogue: Dict, original_dialogue: Dict) -> float:
425
  """
426
+ Compute a quality score for the entire augmented dialogue.
427
+ For example, average semantic similarity of turns to the original turns.
428
+ This is done after the dialogue is formed.
429
  """
430
+ original_texts = [t['text'] for t in original_dialogue['turns']]
431
+ aug_texts = [t['text'] for t in dialogue['turns']]
432
+
433
+ # Compute semantic similarity turn-by-turn and average it
434
+ scores = []
435
+ for orig, aug in zip(original_texts, aug_texts):
436
+ # Simple semantic similarity for scoring
437
+ emb_orig = self._compute_embedding(orig)
438
+ emb_aug = self._compute_embedding(aug)
439
+ sim = (emb_orig @ emb_aug) / (np.linalg.norm(emb_orig)*np.linalg.norm(emb_aug))
440
+ scores.append(sim)
441
+
442
+ # Could also incorporate diversity checks, content overlap, etc.
443
+ return float(np.mean(scores)) if scores else 0.0
444
+
445
+ def _generate_dialogue_combinations(self, dialogue_id: str, turn_variations: List[List[Dict]], original_dialogue: Dict) -> List[Dict]:
446
+ """
447
+ Generate dialogue combinations using a more controlled approach:
448
+ - Include the original turn as a fallback variation for each turn.
449
+ - Sort variations by a quality score.
450
+ - Ensure a balanced augmentation by requiring at least some turns to be augmented.
451
+ - Over-generate and then select top dialogues by quality.
452
+ """
453
+ # Over-generate factor: create more candidates than needed
454
+ over_generate_factor = self.config.augmentation_factor * 2
455
+
456
+ # Add the original turn as a fallback variation for each turn if not present
457
+ for i, turn_variants in enumerate(turn_variations):
458
+ original_turn_text = None
459
+ # Check if we previously stored original turn text with a marker or just use the original dialogue
460
+ # If you previously used "|ORIGINAL|" marker, handle it here. Otherwise, just get from original_dialogue.
461
+ original_turn_text = original_dialogue['turns'][i]['text']
462
+
463
+ # Add the original turn as a variation if not already included
464
+ if not any(v['text'] == original_turn_text for v in turn_variants):
465
+ turn_variants.append({
466
+ 'speaker': original_dialogue['turns'][i]['speaker'],
467
+ 'text': original_turn_text
468
+ })
469
+
470
+ # Sort variations by score
471
+ original_text = original_dialogue['turns'][i]['text']
472
+ turn_variants.sort(key=lambda v: self._variation_score(original_text, v['text']), reverse=True)
473
+
474
  augmented_dialogues = []
475
  used_combinations = set()
476
+
477
+ def generate_candidates(current_turns=None, turn_index=0):
478
  if current_turns is None:
479
  current_turns = []
480
 
481
+ if len(augmented_dialogues) >= over_generate_factor:
482
  return
483
 
484
  if turn_index == len(turn_variations):
485
+ # Completed a candidate dialogue
486
  dialogue_fingerprint = " | ".join(turn['text'] for turn in current_turns)
487
  if dialogue_fingerprint not in used_combinations:
488
  used_combinations.add(dialogue_fingerprint)
489
+ # Check if we have enough augmented turns
490
+ aug_count = sum(1 for orig, curr in zip(original_dialogue['turns'], current_turns)
491
+ if orig['text'] != curr['text'])
492
+ # Require at least half the turns to be augmented, for example
493
+ if aug_count >= max(1, len(turn_variations)//2):
494
+ augmented_dialogues.append({
495
+ 'dialogue_id': f"{dialogue_id}_aug_{len(augmented_dialogues)}",
496
+ 'turns': current_turns.copy()
497
+ })
498
  return
499
+
500
+ turn_candidates = turn_variations[turn_index]
501
+
502
+ # If no variations are available for this turn, let's just return without error.
503
+ # Normally, this shouldn't happen since we always add the original turn above.
504
+ if not turn_candidates:
505
+ # If you want to at least have the original turn, add it now:
506
+ original_text = original_dialogue['turns'][turn_index]['text']
507
+ turn_candidates.append({
508
+ 'speaker': original_dialogue['turns'][turn_index]['speaker'],
509
+ 'text': original_text
510
+ })
511
+
512
+ # After the fallback, if still empty for some reason, just return.
513
+ if not turn_candidates:
514
+ return
515
+
516
+ # Example strategy:
517
+ # 1. Always try the top variation (most semantically similar).
518
+ # 2. If available and allowed, pick a mid-ranked variation for diversity.
519
+ # 3. Include the original turn if not selected yet.
520
+
521
+ num_vars = min(self.config.max_sampled_variations, len(turn_candidates))
522
+
523
+ # Always include top variation
524
+ candidates_to_pick = [turn_candidates[0]]
525
+
526
+ # If we have more than 2 variations and can pick more, add a middle variation for diversity
527
+ if len(turn_candidates) > 2 and num_vars > 1:
528
+ mid_index = len(turn_candidates)//2
529
+ candidates_to_pick.append(turn_candidates[mid_index])
530
+
531
+ # If we still have room for another variation, try adding the original turn if not included
532
+ if num_vars > len(candidates_to_pick):
533
+ original_turn_text = original_dialogue['turns'][turn_index]['text']
534
+ orig_candidate = next((v for v in turn_candidates if v['text'] == original_turn_text), None)
535
+ if orig_candidate and orig_candidate not in candidates_to_pick:
536
+ candidates_to_pick.append(orig_candidate)
537
+
538
+ # Shuffle candidates to produce different dialogues
539
+ np.random.shuffle(candidates_to_pick)
540
+
541
+ for variation in candidates_to_pick:
542
+ if len(augmented_dialogues) >= over_generate_factor:
543
  return
544
  current_turns.append(variation)
545
+ generate_candidates(current_turns, turn_index + 1)
546
  current_turns.pop()
547
+
548
  try:
549
+ generate_candidates()
550
  except Exception as e:
551
  print(f"Error in dialogue generation: {str(e)}")
552
  return []
553
+
554
+ # Over-generated set of augmented dialogues is now available
555
+ # Let's score them and pick the top ones
556
+ scored_dialogues = []
557
+ for d in augmented_dialogues:
558
+ score = self._dialogue_quality_score(d, original_dialogue)
559
+ scored_dialogues.append((score, d))
560
+
561
+ scored_dialogues.sort(key=lambda x: x[0], reverse=True)
562
+ # Pick top `augmentation_factor` dialogues
563
+ final_dialogues = [d for _, d in scored_dialogues[:self.config.augmentation_factor]]
564
+
565
+ return final_dialogues
566
+ # def _generate_dialogue_combinations(self, dialogue_id: str, turn_variations: List[List[Dict]]) -> List[Dict]:
567
+ # """
568
+ # Generate dialogue combinations using sampling
569
+ # """
570
+ # augmented_dialogues = []
571
+ # used_combinations = set()
572
+
573
+ # def generate_dialogues(current_turns=None, turn_index=0):
574
+ # if current_turns is None:
575
+ # current_turns = []
576
+
577
+ # if len(augmented_dialogues) >= self.config.augmentation_factor:
578
+ # return
579
+
580
+ # if turn_index == len(turn_variations):
581
+ # dialogue_fingerprint = " | ".join(turn['text'] for turn in current_turns)
582
+ # if dialogue_fingerprint not in used_combinations:
583
+ # used_combinations.add(dialogue_fingerprint)
584
+ # augmented_dialogues.append({
585
+ # 'dialogue_id': f"{dialogue_id}_aug_{len(augmented_dialogues)}",
586
+ # 'turns': current_turns.copy()
587
+ # })
588
+ # return
589
+
590
+ # variations = list(turn_variations[turn_index])
591
+ # np.random.shuffle(variations)
592
+
593
+ # for variation in variations[:self.config.max_sampled_variations]:
594
+ # if len(augmented_dialogues) >= self.config.augmentation_factor:
595
+ # return
596
+ # current_turns.append(variation)
597
+ # generate_dialogues(current_turns, turn_index + 1)
598
+ # current_turns.pop()
599
+
600
+ # try:
601
+ # generate_dialogues()
602
+ # except Exception as e:
603
+ # print(f"Error in dialogue generation: {str(e)}")
604
+ # return []
605
+
606
+ # return augmented_dialogues
607
 
608
  def _is_dialogue_duplicate(self, dialogue1: Dict, dialogue2: Dict) -> bool:
609
  """
 
616
  def _augment_short_text(self, turn: Dict) -> List[Dict]:
617
  """
618
  Special handling for very short texts with predefined variations.
619
+ If predefined variations are found, return them directly.
620
+ Otherwise, produce simple punctuation and capitalization variants.
621
+ Skip heavy quality checks for efficiency. These variations are safe and minimal.
 
 
622
  """
623
  text = turn['text']
624
  common_variations = {
 
652
  'Fantastic!', 'Amazing!', 'Terrific!'
653
  ]
654
  }
655
+
 
656
  text_lower = text.lower().rstrip('!.,?')
657
+ # Check if text matches any predefined category
658
  variations = []
 
 
659
  for key, predefined_vars in common_variations.items():
660
  if key in text_lower or text_lower in key:
661
  variations.extend(predefined_vars)
662
 
 
663
  if not variations:
664
+ # Generate simple punctuation and capitalization variations if no predefined match
665
+ base = text.rstrip('!.,?')
666
  variations = [
667
+ base + '!',
668
+ base + '.',
669
+ base
670
  ]
671
 
672
  # Add capitalization variations
673
+ capitalized = [v.capitalize() for v in variations if v.capitalize() not in variations]
674
+ variations.extend(capitalized)
 
 
675
 
676
+ # Ensure uniqueness
677
  unique_variations = list(set(variations))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
678
 
679
+ # Directly return these variations, as they are minimal and trusted
680
+ # No further quality checks are needed
681
+ result_variations = unique_variations[:self.config.augmentation_factor]
682
+ return [{'speaker': turn['speaker'], 'text': v} for v in result_variations]
683
 
684
+ def process_batch(self, batch: List[Dict]) -> List[Dict]:
685
+ """Process multiple dialogues at once to maximize GPU utilization"""
686
+ results = []
687
+
688
+ # Pre-compute embeddings for all texts in batch
689
+ all_texts = []
690
+ text_to_embedding = {}
691
+
692
+ for dialogue in batch:
693
+ for turn in dialogue['turns']:
694
+ all_texts.append(turn['text'])
695
+
696
+ # Batch compute embeddings
697
+ if all_texts:
698
+ embeddings = self._compute_batch_embeddings(all_texts)
699
+ for text, embedding in zip(all_texts, embeddings):
700
+ self.embedding_cache[text] = embedding
701
 
702
+ # Process each dialogue using cached embeddings
703
+ for dialogue in batch:
704
+ try:
705
+ augmented = self.augment_dialogue(dialogue)
706
+ results.extend(augmented)
707
+ except Exception as e:
708
+ print(f"Error processing dialogue {dialogue.get('dialogue_id', 'unknown')}: {e}")
709
+ continue
710
 
711
+ return results
main.py CHANGED
@@ -59,13 +59,13 @@ def main():
59
  min_length=1,
60
  max_length=512,
61
  batch_size=32 if tf.config.list_physical_devices('GPU') else 16,
62
- max_turns_per_dialogue=6,
63
- max_variations_per_turn=3,
64
  max_sampled_variations=2,
65
  context_window_size=4,
66
  max_complexity_threshold=100,
67
  use_cache=False,
68
- debug=False,
69
  allowed_speakers=['user', 'assistant'],
70
  required_fields=['dialogue_id', 'turns']
71
  )
 
59
  min_length=1,
60
  max_length=512,
61
  batch_size=32 if tf.config.list_physical_devices('GPU') else 16,
62
+ max_turns_per_dialogue=12,
63
+ max_variations_per_turn=4,
64
  max_sampled_variations=2,
65
  context_window_size=4,
66
  max_complexity_threshold=100,
67
  use_cache=False,
68
+ debug=True,
69
  allowed_speakers=['user', 'assistant'],
70
  required_fields=['dialogue_id', 'turns']
71
  )
paraphraser.py CHANGED
@@ -9,11 +9,18 @@ class Paraphraser:
9
  self.model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
10
  self.model.eval()
11
 
12
- def paraphrase(self, text, num_return_sequences=5, num_beams=10, num_beam_groups=5, diversity_penalty=0.8):
 
13
  try:
14
  input_text = "paraphrase: " + text + " </s>"
15
  encoding = self.tokenizer.encode_plus(input_text, return_tensors="pt")
16
- input_ids = encoding["input_ids"]
 
 
 
 
 
 
17
 
18
  outputs = self.model.generate(
19
  input_ids=input_ids,
@@ -24,7 +31,11 @@ class Paraphraser:
24
  diversity_penalty=diversity_penalty,
25
  early_stopping=True
26
  )
27
- paraphrases = [self.tokenizer.decode(output, skip_special_tokens=True) for output in outputs]
 
 
 
 
28
  return paraphrases
29
  except Exception as e:
30
  print(f"Error in paraphrasing: {e}")
 
9
  self.model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
10
  self.model.eval()
11
 
12
+ def paraphrase(self, text, num_return_sequences=5, num_beams=5,
13
+ num_beam_groups=1, diversity_penalty=0.0, device=None):
14
  try:
15
  input_text = "paraphrase: " + text + " </s>"
16
  encoding = self.tokenizer.encode_plus(input_text, return_tensors="pt")
17
+
18
+ # Move input tensors to specified device if provided
19
+ if device is not None:
20
+ input_ids = encoding["input_ids"].to(device)
21
+ self.model = self.model.to(device)
22
+ else:
23
+ input_ids = encoding["input_ids"]
24
 
25
  outputs = self.model.generate(
26
  input_ids=input_ids,
 
31
  diversity_penalty=diversity_penalty,
32
  early_stopping=True
33
  )
34
+
35
+ # Move outputs back to CPU for tokenizer decoding
36
+ outputs = outputs.cpu() if device is not None else outputs
37
+ paraphrases = [self.tokenizer.decode(output, skip_special_tokens=True)
38
+ for output in outputs]
39
  return paraphrases
40
  except Exception as e:
41
  print(f"Error in paraphrasing: {e}")
pipeline_config.py CHANGED
@@ -30,7 +30,6 @@ class PipelineConfig:
30
  grammar_error_threshold: int = 2
31
  rouge1_f1_threshold: float = 0.30
32
  rouge2_f1_threshold: float = 0.15
33
- perplexity_threshold: float = 50.0
34
 
35
  # Response coherence thresholds
36
  min_response_coherence: float = 0.3
 
30
  grammar_error_threshold: int = 2
31
  rouge1_f1_threshold: float = 0.30
32
  rouge2_f1_threshold: float = 0.15
 
33
 
34
  # Response coherence thresholds
35
  min_response_coherence: float = 0.3
processing_pipeline.py CHANGED
@@ -11,7 +11,6 @@ from pipeline_config import PipelineConfig
11
  from dialogue_augmenter import DialogueAugmenter
12
  from sklearn.feature_extraction.text import TfidfVectorizer
13
  from sklearn.metrics.pairwise import cosine_similarity
14
- from concurrent.futures import ProcessPoolExecutor
15
  from typing import Set
16
 
17
  class ProcessingPipeline:
@@ -33,7 +32,11 @@ class ProcessingPipeline:
33
  self.use_gpu = torch.cuda.is_available()
34
  self.batch_size = 32 if self.use_gpu else 8
35
  self.use_multiprocessing = not self.use_gpu
36
-
 
 
 
 
37
  if self.config.debug:
38
  print(f"ProcessingPipeline initialized with:")
39
  print(f"- GPU available: {self.use_gpu}")
@@ -75,7 +78,7 @@ class ProcessingPipeline:
75
  text_to_dialogue_map[turn['text']] = dialogue['dialogue_id']
76
 
77
  # Batch process embeddings
78
- embeddings = self.augmenter._compute_batch_embeddings(all_texts)
79
 
80
  # Process dialogues with cached embeddings
81
  for dialogue in batch:
@@ -89,16 +92,37 @@ class ProcessingPipeline:
89
  print(f"Error processing batch: {str(e)}")
90
  return results
91
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
92
  def combine_results(self) -> Path:
93
- """Combine all batch files into final output"""
94
  all_results = []
95
- batch_files = sorted(self.output_dir.glob("batch_*.json"))
96
 
97
- print(f"Combining {len(batch_files)} batch files...")
98
- for batch_file in tqdm(batch_files):
99
- with open(batch_file, 'r') as f:
100
- batch_data = json.load(f)
101
- all_results.extend(batch_data)
102
 
103
  # Save combined results
104
  timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
@@ -137,12 +161,13 @@ class ProcessingPipeline:
137
  current_position = processed_count + batch_num + len(batch)
138
 
139
  total_progress = (current_position / total_dialogues) * 100
140
- batch_progress = (batch_num + 1) / ((len(remaining_dialogues) + self.batch_size - 1) // self.batch_size) * 100
141
 
142
- print(f"\rProgress: {current_position}/{total_dialogues} dialogues "
143
- f"({total_progress:.1f}% complete) - "
144
- f"Batch {batch_num//self.batch_size + 1} of "
145
- f"{(len(remaining_dialogues) + self.batch_size - 1) // self.batch_size}", end="")
 
 
146
 
147
  # Process batch
148
  batch_results = self._process_batch(batch)
@@ -152,20 +177,37 @@ class ProcessingPipeline:
152
  batch_ids = {d['dialogue_id'] for d in batch}
153
  processed_ids.update(batch_ids)
154
  self._update_checkpoint(processed_ids)
155
-
 
 
 
 
 
 
 
 
 
 
 
 
 
156
  print("\n" + "-" * 50)
157
  print("Processing complete. Combining results...")
158
  return self.combine_results()
159
 
160
  def cleanup(self):
161
- """Clean up intermediate batch files after successful processing"""
 
162
  batch_files = list(self.output_dir.glob("batch_*.json"))
163
  for file in batch_files:
164
  try:
165
  file.unlink()
166
  except Exception as e:
167
  print(f"Error deleting {file}: {e}")
168
-
 
 
 
169
  if self.checkpoint_file.exists():
170
  try:
171
  self.checkpoint_file.unlink()
@@ -276,4 +318,4 @@ class ProcessingPipeline:
276
  """
277
  data_str = json.dumps(data, sort_keys=True)
278
  hash_value = hashlib.md5(data_str.encode()).hexdigest()
279
- return self.cache_dir / f"cache_{hash_value}.pkl"
 
11
  from dialogue_augmenter import DialogueAugmenter
12
  from sklearn.feature_extraction.text import TfidfVectorizer
13
  from sklearn.metrics.pairwise import cosine_similarity
 
14
  from typing import Set
15
 
16
  class ProcessingPipeline:
 
32
  self.use_gpu = torch.cuda.is_available()
33
  self.batch_size = 32 if self.use_gpu else 8
34
  self.use_multiprocessing = not self.use_gpu
35
+
36
+ # Counters for grouping batches
37
+ self.batch_counter = 0 # Count batches since last group combine
38
+ self.batch_group_number = 0 # How many groups have been created
39
+
40
  if self.config.debug:
41
  print(f"ProcessingPipeline initialized with:")
42
  print(f"- GPU available: {self.use_gpu}")
 
78
  text_to_dialogue_map[turn['text']] = dialogue['dialogue_id']
79
 
80
  # Batch process embeddings
81
+ self.augmenter._compute_batch_embeddings(all_texts)
82
 
83
  # Process dialogues with cached embeddings
84
  for dialogue in batch:
 
92
  print(f"Error processing batch: {str(e)}")
93
  return results
94
 
95
+ def _combine_intermediate_batches(self):
96
+ """
97
+ Combine all current batch_*.json files into a single batch_group_XXXX.json file,
98
+ then remove the batch_*.json files.
99
+ """
100
+ batch_files = sorted(self.output_dir.glob("batch_*.json"))
101
+ if not batch_files:
102
+ return None # No files to combine
103
+
104
+ combined_data = []
105
+ for bf in batch_files:
106
+ with open(bf, 'r') as f:
107
+ combined_data.extend(json.load(f))
108
+ bf.unlink() # Remove the individual batch file after reading
109
+
110
+ self.batch_group_number += 1
111
+ group_file = self.output_dir / f"batch_group_{self.batch_group_number:04d}.json"
112
+ with open(group_file, 'w') as f:
113
+ json.dump(combined_data, f)
114
+ return group_file
115
+
116
  def combine_results(self) -> Path:
117
+ """Combine all batch_group_*.json files into final output"""
118
  all_results = []
119
+ group_files = sorted(self.output_dir.glob("batch_group_*.json"))
120
 
121
+ print(f"Combining {len(group_files)} group files...")
122
+ for group_file in tqdm(group_files):
123
+ with open(group_file, 'r') as f:
124
+ group_data = json.load(f)
125
+ all_results.extend(group_data)
126
 
127
  # Save combined results
128
  timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
 
161
  current_position = processed_count + batch_num + len(batch)
162
 
163
  total_progress = (current_position / total_dialogues) * 100
 
164
 
165
+ print('\033[K', end='')
166
+ print(f"Processing: {current_position}/{total_dialogues} dialogues "
167
+ f"({total_progress:.1f}% complete)")
168
+ print(f"Current batch: {batch_num//self.batch_size + 1} of "
169
+ f"{(len(remaining_dialogues) + self.batch_size - 1) // self.batch_size}")
170
+ print("-" * 50)
171
 
172
  # Process batch
173
  batch_results = self._process_batch(batch)
 
177
  batch_ids = {d['dialogue_id'] for d in batch}
178
  processed_ids.update(batch_ids)
179
  self._update_checkpoint(processed_ids)
180
+
181
+ # Increment batch counter and combine if needed
182
+ self.batch_counter += 1
183
+ if self.batch_counter == 25:
184
+ # Combine these 25 batches into a group file
185
+ self._combine_intermediate_batches()
186
+ self.batch_counter = 0 # Reset counter after grouping
187
+
188
+ # If there are leftover batches less than 25
189
+ # combine them into one final group file
190
+ if self.batch_counter > 0:
191
+ self._combine_intermediate_batches()
192
+ self.batch_counter = 0
193
+
194
  print("\n" + "-" * 50)
195
  print("Processing complete. Combining results...")
196
  return self.combine_results()
197
 
198
  def cleanup(self):
199
+ """Clean up intermediate files after successful processing"""
200
+ # Clean up any leftover batch files (should not exist if logic is correct)
201
  batch_files = list(self.output_dir.glob("batch_*.json"))
202
  for file in batch_files:
203
  try:
204
  file.unlink()
205
  except Exception as e:
206
  print(f"Error deleting {file}: {e}")
207
+
208
+ # We can also remove batch_group_*.json if desired after final combine
209
+ # but that might not be necessary if we want to keep them.
210
+
211
  if self.checkpoint_file.exists():
212
  try:
213
  self.checkpoint_file.unlink()
 
318
  """
319
  data_str = json.dumps(data, sort_keys=True)
320
  hash_value = hashlib.md5(data_str.encode()).hexdigest()
321
+ return self.cache_dir / f"cache_{hash_value}.pkl"
quality_metrics.py CHANGED
@@ -1,129 +1,47 @@
1
- import torch
2
- import tensorflow as tf
3
  import tensorflow_hub as hub
4
- from transformers import GPT2TokenizerFast, GPT2LMHeadModel
5
- import language_tool_python
6
- from rouge_score import rouge_scorer
7
  import spacy
8
  from sklearn.metrics.pairwise import cosine_similarity
9
- import numpy as np
10
  from typing import Dict
11
  from pipeline_config import PipelineConfig
12
 
13
  class QualityMetrics:
14
  """
15
- Measure augmented text quality
16
  """
17
  def __init__(self, config: PipelineConfig):
18
  self.config = config
19
-
20
- # Semantic similarity
21
  self.use_model = hub.load('https://tfhub.dev/google/universal-sentence-encoder/4')
22
-
23
- # Fluency metrics
24
- self.tokenizer = GPT2TokenizerFast.from_pretrained('gpt2')
25
- self.model = GPT2LMHeadModel.from_pretrained('gpt2')
26
- self.model.eval()
27
-
28
- # Grammar
29
- self.language_tool = language_tool_python.LanguageTool('en-US')
30
-
31
- # Lexical similarity
32
- self.rouge = rouge_scorer.RougeScorer(['rouge1', 'rouge2', 'rougeL'], use_stemmer=True)
33
-
34
- # Diversity
35
- self.nlp = spacy.load('en_core_web_sm')
36
-
37
- def compute_perplexity(self, text):
38
- try:
39
- encodings = self.tokenizer(text, return_tensors='pt')
40
- input_ids = encodings['input_ids']
41
- with torch.no_grad():
42
- outputs = self.model(input_ids, labels=input_ids)
43
- loss = outputs.loss
44
- perplexity = torch.exp(loss)
45
- return perplexity.item()
46
- except Exception as e:
47
- print(f"Error computing perplexity for text '{text}': {e}")
48
- return float('inf') # High perplexity value == poor quality
49
-
50
  def compute_semantic_similarity(self, text1: str, text2: str) -> float:
51
- """
52
- Compute semantic similarity between two texts using the Universal Sentence Encoder.
53
- Args:
54
- text1 (str): First text
55
- text2 (str): Second text
56
- Returns:
57
- float: Cosine similarity score between the two texts (0-1)
58
- """
59
  embeddings = self.use_model([text1, text2])
60
  emb1, emb2 = embeddings[0].numpy(), embeddings[1].numpy()
61
  return cosine_similarity([emb1], [emb2])[0][0]
62
 
63
  def compute_metrics(self, original: str, augmented: str) -> Dict[str, float]:
64
- """
65
- Compute quality metrics
66
- """
67
  metrics = {}
68
-
69
- # 1. Semantic Preservation
70
  embeddings = self.use_model([original, augmented])
71
  emb_orig, emb_aug = embeddings[0].numpy(), embeddings[1].numpy()
72
  metrics['semantic_similarity'] = cosine_similarity([emb_orig], [emb_aug])[0][0]
73
 
74
- # 2. Fluency & Naturalness
75
- metrics['perplexity'] = self.compute_perplexity(augmented)
76
- metrics['grammar_errors'] = len(self.language_tool.check(augmented))
77
-
78
- # 3. Lexical Diversity
79
  doc_orig = self.nlp(original)
80
  doc_aug = self.nlp(augmented)
81
-
82
- # Type-token ratio with safety check
83
  aug_tokens = [token.text.lower() for token in doc_aug]
84
  metrics['type_token_ratio'] = len(set(aug_tokens)) / max(len(aug_tokens), 1)
85
 
86
- # Content word overlap with safety checks
87
- orig_content = set([token.text.lower() for token in doc_orig if not token.is_stop])
88
- aug_content = set([token.text.lower() for token in doc_aug if not token.is_stop])
89
-
90
- # Safety check for empty content sets
91
  if len(orig_content) == 0:
92
  metrics['content_preservation'] = 1.0 if len(aug_content) == 0 else 0.0
93
  else:
94
  metrics['content_preservation'] = len(orig_content.intersection(aug_content)) / len(orig_content)
95
 
96
- # 4. Structural Preservation
97
- rouge_scores = self.rouge.score(original, augmented)
98
- metrics['rouge1_f1'] = rouge_scores['rouge1'].fmeasure
99
- metrics['rouge2_f1'] = rouge_scores['rouge2'].fmeasure
100
- metrics['rougeL_f1'] = rouge_scores['rougeL'].fmeasure
101
-
102
- # 5. Length Preservation with safety check
103
  orig_words = len(original.split())
104
  aug_words = len(augmented.split())
105
  metrics['length_ratio'] = aug_words / max(orig_words, 1)
106
-
107
- return metrics
108
 
109
- def meets_quality_threshold(self, metrics: Dict[str, float]) -> bool:
110
- """
111
- Enhanced quality threshold checking
112
- """
113
- # Core quality checks
114
- basic_quality = (
115
- metrics['perplexity'] <= self.config.perplexity_threshold and
116
- metrics['semantic_similarity'] >= self.config.semantic_similarity_threshold and
117
- metrics['grammar_errors'] <= self.config.grammar_error_threshold
118
- )
119
-
120
- # Length preservation check
121
- length_ok = 0.6 <= metrics['length_ratio'] <= 1.4
122
-
123
- # Diversity check
124
- diversity_ok = metrics['type_token_ratio'] >= 0.4
125
-
126
- # Content preservation check
127
- content_ok = metrics['content_preservation'] >= 0.6
128
-
129
- return all([basic_quality, length_ok, diversity_ok, content_ok])
 
 
 
1
  import tensorflow_hub as hub
 
 
 
2
  import spacy
3
  from sklearn.metrics.pairwise import cosine_similarity
 
4
  from typing import Dict
5
  from pipeline_config import PipelineConfig
6
 
7
  class QualityMetrics:
8
  """
9
+ Quality metrics focusing on semantic similarity and basic lexical stats.
10
  """
11
  def __init__(self, config: PipelineConfig):
12
  self.config = config
 
 
13
  self.use_model = hub.load('https://tfhub.dev/google/universal-sentence-encoder/4')
14
+ self.nlp = spacy.load('en_core_web_md')
15
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16
  def compute_semantic_similarity(self, text1: str, text2: str) -> float:
 
 
 
 
 
 
 
 
17
  embeddings = self.use_model([text1, text2])
18
  emb1, emb2 = embeddings[0].numpy(), embeddings[1].numpy()
19
  return cosine_similarity([emb1], [emb2])[0][0]
20
 
21
  def compute_metrics(self, original: str, augmented: str) -> Dict[str, float]:
 
 
 
22
  metrics = {}
23
+ # Semantic similarity
 
24
  embeddings = self.use_model([original, augmented])
25
  emb_orig, emb_aug = embeddings[0].numpy(), embeddings[1].numpy()
26
  metrics['semantic_similarity'] = cosine_similarity([emb_orig], [emb_aug])[0][0]
27
 
28
+ # Lexical diversity & content preservation
 
 
 
 
29
  doc_orig = self.nlp(original)
30
  doc_aug = self.nlp(augmented)
31
+
 
32
  aug_tokens = [token.text.lower() for token in doc_aug]
33
  metrics['type_token_ratio'] = len(set(aug_tokens)) / max(len(aug_tokens), 1)
34
 
35
+ orig_content = {token.text.lower() for token in doc_orig if not token.is_stop}
36
+ aug_content = {token.text.lower() for token in doc_aug if not token.is_stop}
 
 
 
37
  if len(orig_content) == 0:
38
  metrics['content_preservation'] = 1.0 if len(aug_content) == 0 else 0.0
39
  else:
40
  metrics['content_preservation'] = len(orig_content.intersection(aug_content)) / len(orig_content)
41
 
42
+ # Length ratio
 
 
 
 
 
 
43
  orig_words = len(original.split())
44
  aug_words = len(augmented.split())
45
  metrics['length_ratio'] = aug_words / max(orig_words, 1)
 
 
46
 
47
+ return metrics
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
readme.md CHANGED
@@ -11,14 +11,16 @@ This package automatically downloads the following models during installation:
11
  - Universal Sentence Encoder v4 (TensorFlow Hub)
12
  - ChatGPT Paraphraser T5-base
13
  - Helsinki-NLP translation models (en-de, de-es, es-en)
14
- - GPT-2 (for perplexity scoring)
15
- - spaCy en_core_web_sm
16
  - nltk wordnet and averaged_perceptron_tagger_eng models
17
 
18
  ## Install package
19
 
20
  pip install -e .
21
 
 
 
 
22
  ## Description
23
 
24
  This Python script demonstrates a complete pipeline for dialogue augmentation, including validation, optimization, and data augmentation.
 
11
  - Universal Sentence Encoder v4 (TensorFlow Hub)
12
  - ChatGPT Paraphraser T5-base
13
  - Helsinki-NLP translation models (en-de, de-es, es-en)
14
+ - spaCy en_core_web_sm, eng_core_web_md
 
15
  - nltk wordnet and averaged_perceptron_tagger_eng models
16
 
17
  ## Install package
18
 
19
  pip install -e .
20
 
21
+ On Linux with Cuda/GPU:
22
+ pip install faiss-gpu>=1.7.0
23
+
24
  ## Description
25
 
26
  This Python script demonstrates a complete pipeline for dialogue augmentation, including validation, optimization, and data augmentation.
requirements.txt CHANGED
@@ -1,12 +1,14 @@
1
- spacy>=3.0.0 # Text processing and tokenization
 
2
  numpy>=1.19.0 # General numerical computation
3
- tqdm>=4.64.0 # Progress bar
4
- torch>=1.10.0 # PyTorch, for deep learning
5
- tensorflow>=2.6.0 # TensorFlow, for deep learning
6
- tensorflow-hub>=0.12.0 # Pretrained model hub for TensorFlow
7
- transformers>=4.21.0 # Hugging Face Transformers library
8
- rouge-score>=0.1.2 # ROUGE metric for evaluation
9
- language-tool-python>=2.7.1 # Grammar checking and text correction
10
  scikit-learn>=1.0.0 # Machine learning tools
11
- nlpaug>=1.1.0 # Data augmentation for NLP
12
- nltk>=3.6.0 # Natural language toolkit
 
 
 
 
 
 
 
 
 
1
+ nlpaug>=1.1.0 # Data augmentation for NLP
2
+ nltk>=3.6.0 # Natural language toolkit
3
  numpy>=1.19.0 # General numerical computation
 
 
 
 
 
 
 
4
  scikit-learn>=1.0.0 # Machine learning tools
5
+ sacremoses>=0.0.53 # Required for some HuggingFace models
6
+ sentencepiece>=0.1.99 # Required for HuggingFace transformers
7
+ spacy>=3.0.0 # Text processing and tokenization
8
+ tensorflow>=2.13.0 # TensorFlow, for deep learning
9
+ tensorflow-hub>=0.12.0 # Pretrained model hub for TensorFlow
10
+ tokenizers>=0.13.0 # Required for HuggingFace transformers
11
+ torch>=2.0.0 # PyTorch, for deep learning
12
+ tqdm>=4.64.0 # Progress bar
13
+ transformers>=4.30.0 # Hugging Face Transformers library
14
+ faiss-cpu>=1.7.0 # Required for Facebook AI Similarity Search
response_quality_checker.py ADDED
@@ -0,0 +1,164 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ from typing import List, Tuple, Dict, Any
3
+ from sklearn.metrics.pairwise import cosine_similarity
4
+ from chatbot4 import RetrievalChatbot
5
+
6
+ class ResponseQualityChecker:
7
+ """Handles quality checking and confidence scoring for chatbot responses."""
8
+
9
+ def __init__(
10
+ self,
11
+ chatbot: RetrievalChatbot,
12
+ confidence_threshold: float = 0.5,
13
+ diversity_threshold: float = 0.1,
14
+ min_response_length: int = 3,
15
+ max_similarity_ratio: float = 0.9
16
+ ):
17
+ self.confidence_threshold = confidence_threshold
18
+ self.diversity_threshold = diversity_threshold
19
+ self.min_response_length = min_response_length
20
+ self.max_similarity_ratio = max_similarity_ratio
21
+ self.chatbot = chatbot
22
+
23
+ def check_response_quality(
24
+ self,
25
+ query: str,
26
+ responses: List[Tuple[str, float]]
27
+ ) -> Dict[str, Any]:
28
+ """
29
+ Evaluate the quality of the responses based on various metrics.
30
+ """
31
+ # Calculate diversity based on the responses themselves
32
+ diversity = self.calculate_diversity(responses)
33
+
34
+ # Calculate relevance based on some criteria
35
+ relevance = self.calculate_relevance(query, responses)
36
+
37
+ # Calculate length scores for each response
38
+ length_scores = [self._calculate_length_score(response) for response, _ in responses]
39
+ avg_length_score = np.mean(length_scores) if length_scores else 0.0
40
+
41
+ # Extract similarity scores
42
+ similarity_scores = [score for _, score in responses]
43
+
44
+ # Calculate score gap
45
+ score_gap = self._calculate_score_gap(similarity_scores, top_n=3)
46
+
47
+ # Aggregate metrics
48
+ metrics = {
49
+ 'top_score': similarity_scores[0] if similarity_scores else 0.0,
50
+ 'response_diversity': diversity,
51
+ 'query_response_relevance': relevance,
52
+ 'response_length_score': avg_length_score,
53
+ 'top_3_score_gap': score_gap
54
+ }
55
+
56
+ # Determine overall confidence
57
+ is_confident = self._determine_confidence(metrics)
58
+
59
+ return {
60
+ 'diversity': diversity,
61
+ 'relevance': relevance,
62
+ 'is_confident': is_confident,
63
+ 'top_score': metrics['top_score'],
64
+ 'response_diversity': metrics['response_diversity'],
65
+ 'query_response_relevance': metrics['query_response_relevance'],
66
+ 'response_length_score': metrics['response_length_score'],
67
+ 'top_3_score_gap': metrics['top_3_score_gap']
68
+ }
69
+
70
+ def calculate_diversity(self, responses: List[Tuple[str, float]]) -> float:
71
+ """
72
+ Calculate diversity as the average pairwise similarity between responses.
73
+ Lower similarity indicates higher diversity.
74
+ """
75
+ if not responses:
76
+ return 0.0
77
+
78
+ # Encode responses
79
+ embeddings = [self.encode_text(response) for response, _ in responses]
80
+ if len(embeddings) < 2:
81
+ return 1.0 # Maximum diversity
82
+
83
+ # Compute pairwise cosine similarity
84
+ similarity_matrix = cosine_similarity(embeddings)
85
+
86
+ # Exclude diagonal
87
+ sum_similarities = np.sum(similarity_matrix) - len(responses)
88
+ num_pairs = len(responses) * (len(responses) - 1)
89
+ avg_similarity = sum_similarities / num_pairs if num_pairs > 0 else 0.0
90
+ diversity_score = 1 - avg_similarity # Higher value indicates more diversity
91
+ return diversity_score
92
+
93
+ def calculate_relevance(self, query: str, responses: List[Tuple[str, float]]) -> float:
94
+ """
95
+ Calculate relevance as the average similarity between the query and each response.
96
+ """
97
+ if not responses:
98
+ return 0.0
99
+
100
+ # Encode query
101
+ query_embedding = self.encode_query(query)
102
+
103
+ # Encode responses
104
+ response_embeddings = [self.encode_text(response) for response, _ in responses]
105
+
106
+ # Compute cosine similarity
107
+ similarities = cosine_similarity([query_embedding], response_embeddings)[0]
108
+
109
+ avg_relevance = np.mean(similarities) if similarities.size > 0 else 0.0
110
+ return avg_relevance
111
+
112
+ def _calculate_length_score(self, response: str) -> float:
113
+ """Score based on response length appropriateness."""
114
+ length = len(response.split())
115
+ if length < self.min_response_length:
116
+ return length / self.min_response_length
117
+ return 1.0
118
+
119
+ def _calculate_score_gap(self, scores: List[float], top_n: int = 3) -> float:
120
+ """
121
+ Calculate the average gap between the top N scores.
122
+
123
+ Args:
124
+ scores (List[float]): List of similarity scores.
125
+ top_n (int): Number of top scores to consider.
126
+
127
+ Returns:
128
+ float: Average score gap.
129
+ """
130
+ if len(scores) < top_n + 1:
131
+ return 0.0
132
+ gaps = [scores[i] - scores[i + 1] for i in range(top_n)]
133
+ avg_gap = np.mean(gaps)
134
+ return avg_gap
135
+
136
+ def _determine_confidence(self, metrics: Dict[str, float]) -> bool:
137
+ """
138
+ Determine if we're confident enough in the response.
139
+
140
+ Returns:
141
+ bool: True if we should use this response, False if we should abstain
142
+ """
143
+ conditions = [
144
+ metrics['top_score'] >= self.confidence_threshold,
145
+ metrics['response_diversity'] >= self.diversity_threshold,
146
+ metrics['response_length_score'] >= 0.8,
147
+ metrics['query_response_relevance'] >= 0.3, # was 0.5
148
+ metrics['top_3_score_gap'] >= 0.05 # was 0.1
149
+ ]
150
+ return all(conditions)
151
+
152
+ def encode_text(self, text: str) -> np.ndarray:
153
+ # 1) Turn text into a list if your encode_responses() expects a list.
154
+ # 2) Then call the method from the chatbot to get the embedding.
155
+ embedding_tensor = self.chatbot.encode_responses([text]) # returns tf.Tensor of shape (1, emb_dim)
156
+ embedding = embedding_tensor.numpy()[0].astype('float32') # shape: (emb_dim,)
157
+ embedding = embedding / np.linalg.norm(embedding) if np.linalg.norm(embedding) > 0 else embedding
158
+ return embedding
159
+
160
+ def encode_query(self, query: str) -> np.ndarray:
161
+ embedding_tensor = self.chatbot.encode_query(query) # returns tf.Tensor of shape (1, emb_dim)
162
+ embedding = embedding_tensor.numpy()[0].astype('float32') # shape: (emb_dim,)
163
+ embedding = embedding / np.linalg.norm(embedding) if np.linalg.norm(embedding) > 0 else embedding
164
+ return embedding
run_model.py ADDED
@@ -0,0 +1,162 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import glob
3
+ import os
4
+ from chatbot import RetrievalChatbot
5
+ import tensorflow as tf
6
+ from sklearn.model_selection import train_test_split
7
+ import matplotlib.pyplot as plt
8
+
9
+ def load_training_data(data_directory: str) -> list:
10
+ """Load and combine dialogue data from multiple JSON files."""
11
+ all_dialogues = []
12
+
13
+ # Get all json files matching the pattern
14
+ pattern = os.path.join(data_directory, "batch_*.json")
15
+ json_files = sorted(glob.glob(pattern))
16
+
17
+ print(f"Found {len(json_files)} batch files")
18
+
19
+ for file_path in json_files:
20
+ try:
21
+ with open(file_path, 'r', encoding='utf-8') as f:
22
+ batch_dialogues = json.load(f)
23
+ all_dialogues.extend(batch_dialogues)
24
+ print(f"Loaded {len(batch_dialogues)} dialogues from {os.path.basename(file_path)}")
25
+ except Exception as e:
26
+ print(f"Error loading {file_path}: {str(e)}")
27
+
28
+ print(f"Total dialogues loaded: {len(all_dialogues)}")
29
+ return all_dialogues
30
+
31
+ def plot_training_history(train_losses, val_losses):
32
+ # Plot training and validation loss
33
+ plt.figure()
34
+ plt.plot(train_losses, label='Train Loss')
35
+ plt.plot(val_losses, label='Val Loss')
36
+ plt.xlabel('Epoch')
37
+ plt.ylabel('Triplet Loss')
38
+ plt.legend()
39
+ plt.show()
40
+
41
+ dialogues = load_training_data('processed_outputs')
42
+
43
+ # Initialize the chatbot
44
+ chatbot = RetrievalChatbot(
45
+ vocab_size=10000,
46
+ max_sequence_length=80,
47
+ embedding_dim=256,
48
+ lstm_units=256,
49
+ num_attention_heads=8,
50
+ margin=0.3
51
+ )
52
+
53
+ # Prepare the dataset for triplet training
54
+ q_pad, p_pad, n_pad = chatbot.prepare_dataset(dialogues, neg_samples_per_pos=3)
55
+
56
+ # Train with triplet loss
57
+ train_losses, val_losses = chatbot.train_with_triplet_loss(
58
+ q_pad, p_pad, n_pad,
59
+ epochs=1,
60
+ batch_size=32,
61
+ validation_split=0.2
62
+ )
63
+
64
+ plot_training_history(train_losses, val_losses)
65
+
66
+ # After training, test prediction
67
+ response_candidates = [turn['text'] for d in dialogues for turn in d['turns'] if turn['speaker'] == 'assistant']
68
+
69
+ # Test retrieval
70
+ test_query = "I'd like a recommendation for a Korean restaurant in NYC."
71
+ top_responses = chatbot.retrieve_top_n(test_query, response_candidates, top_n=5)
72
+ print("Top responses:")
73
+ for resp, score in top_responses:
74
+ print(f"Score: {score:.4f} - {resp}")
75
+
76
+ # Single-turn validation:
77
+ test_queries = [
78
+ "I want to book a Korean restaurant in NYC.",
79
+ "Can I get two tickets for 'What Men Want'?",
80
+ "What's the best time to watch the movie today?"
81
+ ]
82
+ for query in test_queries:
83
+ top_responses = chatbot.retrieve_top_n(query, response_candidates, top_n=3)
84
+ print(f"\nQuery: {query}")
85
+ for resp, score in top_responses:
86
+ print(f"Score: {score:.4f} - {resp}")
87
+
88
+ # Multi-turn conversation:
89
+ multi_turn_history = []
90
+
91
+ def update_context(multi_turn_history, query, response, max_context_turns=3):
92
+ multi_turn_history.append((query, response))
93
+ if len(multi_turn_history) > max_context_turns:
94
+ multi_turn_history.pop(0)
95
+
96
+ def get_context_enhanced_query(multi_turn_history, query):
97
+ if not multi_turn_history:
98
+ return query
99
+ context = " ".join([f"User: {q} Assistant: {r}" for q, r in multi_turn_history])
100
+ return f"{context} User: {query}"
101
+
102
+ conversation_queries = [
103
+ "I'd like to watch a movie tonight.",
104
+ "Is there a showing of 'What Men Want'?",
105
+ "What time is the last show?",
106
+ "Can I get two tickets?"
107
+ ]
108
+
109
+ for query in conversation_queries:
110
+ context_query = get_context_enhanced_query(multi_turn_history, query)
111
+ top_responses = chatbot.retrieve_top_n(context_query, response_candidates, top_n=3)
112
+ best_response = top_responses[0][0]
113
+ print(f"\nUser: {query}\nAssistant: {best_response}")
114
+ update_context(multi_turn_history, query, best_response)
115
+
116
+
117
+
118
+
119
+
120
+
121
+
122
+
123
+ #queries, responses, labels = chatbot.prepare_dataset(dialogues, neg_samples_per_pos=3)
124
+
125
+ #train_dialogues, val_dialogues = train_test_split(dialogues, test_size=0.2, random_state=20)
126
+ #query_train, query_val, response_train, response_val, labels_train, labels_val = train_test_split(queries, responses, labels, test_size=0.2, random_state=20)
127
+
128
+ # chatbot.model.compile(
129
+ # optimizer=tf.keras.optimizers.Adam(learning_rate=0.001, clipnorm=1.0),
130
+ # loss='binary_crossentropy',
131
+ # metrics=['accuracy']
132
+ # )
133
+
134
+ # # Train the model with early stopping to prevent overfitting
135
+ # callbacks = [
136
+ # tf.keras.callbacks.EarlyStopping(
137
+ # monitor='val_loss',
138
+ # patience=3,
139
+ # restore_best_weights=True
140
+ # ),
141
+ # tf.keras.callbacks.ReduceLROnPlateau(
142
+ # monitor='val_loss',
143
+ # factor=0.5,
144
+ # patience=2,
145
+ # min_lr=1e-6,
146
+ # verbose=1
147
+ # ),
148
+ # tf.keras.callbacks.ModelCheckpoint(
149
+ # 'chatbot_model.keras',
150
+ # monitor='val_loss',
151
+ # save_best_only=True
152
+ # )
153
+ # ]
154
+
155
+ # history = chatbot.model.fit(
156
+ # {'query_input': query_train, 'response_input': response_train},
157
+ # labels_train,
158
+ # validation_data=({'query_input': query_val, 'response_input': response_val}, labels_val),
159
+ # epochs=5,
160
+ # batch_size=32,
161
+ # callbacks=callbacks
162
+ # )
run_model2.py ADDED
@@ -0,0 +1,340 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from chatbot2 import RetrievalChatbot, ChatbotConfig
2
+ import os
3
+ import json
4
+ import glob
5
+ import matplotlib.pyplot as plt
6
+ import logging
7
+ from pathlib import Path
8
+ from typing import List, Dict, Optional, Any, Tuple
9
+ import numpy as np
10
+ from datetime import datetime
11
+ from response_quality_checker import ResponseQualityChecker
12
+
13
+ # Configure logging
14
+ logging.basicConfig(
15
+ level=logging.INFO,
16
+ format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
17
+ )
18
+ logger = logging.getLogger(__name__)
19
+
20
+ def load_training_data(data_directory: str, debug_samples: Optional[int] = None) -> list:
21
+ """
22
+ Load and combine dialogue data from multiple JSON files.
23
+
24
+ Args:
25
+ data_directory: Directory containing the dialogue files
26
+ debug_samples: If set, only load this many dialogues for debugging
27
+ """
28
+ all_dialogues = []
29
+ data_directory = Path(data_directory)
30
+
31
+ # Get all json files matching the pattern
32
+ pattern = "batch_*.json"
33
+ json_files = sorted(data_directory.glob(pattern))
34
+
35
+ logger.info(f"Found {len(json_files)} batch files")
36
+
37
+ if debug_samples:
38
+ logger.info(f"Debug mode: Will load up to {debug_samples} dialogues")
39
+
40
+ for file_path in json_files:
41
+ try:
42
+ with open(file_path, 'r', encoding='utf-8') as f:
43
+ batch_dialogues = json.load(f)
44
+
45
+ # If in debug mode, only take what we need from this batch
46
+ if debug_samples is not None:
47
+ remaining_samples = debug_samples - len(all_dialogues)
48
+ if remaining_samples <= 0:
49
+ break
50
+ batch_dialogues = batch_dialogues[:remaining_samples]
51
+
52
+ all_dialogues.extend(batch_dialogues)
53
+ logger.info(f"Loaded {len(batch_dialogues)} dialogues from {file_path.name}")
54
+
55
+ # If we've reached our debug sample limit, stop loading
56
+ if debug_samples is not None and len(all_dialogues) >= debug_samples:
57
+ logger.info(f"Debug mode: Reached {debug_samples} samples, stopping load")
58
+ break
59
+
60
+ except Exception as e:
61
+ logger.error(f"Error loading {file_path}: {str(e)}")
62
+
63
+ total_loaded = len(all_dialogues)
64
+ if debug_samples:
65
+ logger.info(f"Debug mode: Loaded {total_loaded}/{debug_samples} requested dialogues")
66
+ else:
67
+ logger.info(f"Total dialogues loaded: {total_loaded}")
68
+
69
+ return all_dialogues
70
+
71
+ def plot_training_history(history: Dict[str, List[float]], save_dir: Path = None):
72
+ """Plot and optionally save training history."""
73
+ # Create figure with two subplots
74
+ fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(10, 12))
75
+
76
+ # Plot losses
77
+ ax1.plot(history['train_loss'], label='Train Loss')
78
+ ax1.plot(history['val_loss'], label='Validation Loss')
79
+ ax1.set_xlabel('Epoch')
80
+ ax1.set_ylabel('Triplet Loss')
81
+ ax1.set_title('Training and Validation Loss')
82
+ ax1.legend()
83
+ ax1.grid(True)
84
+
85
+ # Plot learning rate if available
86
+ if 'learning_rate' in history:
87
+ ax2.plot(history['learning_rate'], label='Learning Rate')
88
+ ax2.set_xlabel('Step')
89
+ ax2.set_ylabel('Learning Rate')
90
+ ax2.set_title('Learning Rate Schedule')
91
+ ax2.legend()
92
+ ax2.grid(True)
93
+
94
+ plt.tight_layout()
95
+
96
+ # Save if directory provided
97
+ if save_dir:
98
+ save_dir = Path(save_dir)
99
+ save_dir.mkdir(parents=True, exist_ok=True)
100
+ timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
101
+ plt.savefig(save_dir / f'training_history_{timestamp}.png')
102
+
103
+ plt.show()
104
+
105
+ def setup_training_directories(base_dir: str = "chatbot_training") -> Dict[str, Path]:
106
+ """Setup directory structure for training artifacts."""
107
+ base_dir = Path(base_dir)
108
+ timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
109
+ train_dir = base_dir / f"training_run_{timestamp}"
110
+
111
+ directories = {
112
+ 'base': train_dir,
113
+ 'checkpoints': train_dir / 'checkpoints',
114
+ 'plots': train_dir / 'plots',
115
+ 'logs': train_dir / 'logs'
116
+ }
117
+
118
+ # Create directories
119
+ for dir_path in directories.values():
120
+ dir_path.mkdir(parents=True, exist_ok=True)
121
+
122
+ return directories
123
+
124
+ def run_automatic_validation(
125
+ chatbot,
126
+ response_pool: List[str],
127
+ quality_checker: ResponseQualityChecker,
128
+ num_examples: int = 5
129
+ ) -> Dict[str, Any]:
130
+ """
131
+ Run automatic validation with quality metrics.
132
+ """
133
+ logger.info("\n=== Running Automatic Validation ===")
134
+
135
+ test_queries = [
136
+ "Hello, how are you today?",
137
+ "What's the weather like?",
138
+ "Can you help me with a problem?",
139
+ "Tell me a joke",
140
+ "What time is it?",
141
+ "I need help with my homework",
142
+ "Where's a good place to eat?",
143
+ "What movies are playing?",
144
+ "How do I reset my password?",
145
+ "Can you recommend a book?"
146
+ ]
147
+
148
+ test_queries = test_queries[:num_examples]
149
+ metrics_history = []
150
+
151
+ for i, query in enumerate(test_queries, 1):
152
+ logger.info(f"\nTest Case {i}:")
153
+ logger.info(f"Query: {query}")
154
+
155
+ # Get responses and scores
156
+ responses = chatbot.retrieve_responses(
157
+ query,
158
+ response_pool,
159
+ context=None,
160
+ top_k=5
161
+ )
162
+
163
+ # Check quality
164
+ quality_metrics = quality_checker.check_response_quality(
165
+ query, responses, response_pool
166
+ )
167
+ metrics_history.append(quality_metrics)
168
+
169
+ # Log results
170
+ logger.info(f"Quality Metrics: {quality_metrics}")
171
+ logger.info("Top responses:")
172
+ for j, (response, score) in enumerate(responses[:3], 1):
173
+ logger.info(f"{j}. Score: {score:.4f}")
174
+ logger.info(f" Response: {response}")
175
+ if j == 1 and not quality_metrics['is_confident']:
176
+ logger.info(" [Low Confidence - Would abstain from answering]")
177
+
178
+ # Calculate aggregate metrics
179
+ aggregate_metrics = {
180
+ 'num_queries_tested': len(test_queries),
181
+ 'avg_top_response_score': np.mean([m['top_score'] for m in metrics_history]),
182
+ 'avg_diversity': np.mean([m['response_diversity'] for m in metrics_history]),
183
+ 'avg_relevance': np.mean([m['query_response_relevance'] for m in metrics_history]),
184
+ 'confidence_rate': np.mean([m['is_confident'] for m in metrics_history]),
185
+ }
186
+
187
+ logger.info("\n=== Validation Summary ===")
188
+ for metric, value in aggregate_metrics.items():
189
+ logger.info(f"{metric}: {value:.4f}")
190
+
191
+ return aggregate_metrics
192
+
193
+ def chat_with_quality_check(
194
+ chatbot,
195
+ query: str,
196
+ response_pool: List[str],
197
+ conversation_history: List[Tuple[str, str]],
198
+ quality_checker: ResponseQualityChecker
199
+ ) -> Tuple[Optional[str], List[Tuple[str, float]], Dict[str, Any]]:
200
+ """
201
+ Enhanced chat function with quality checking.
202
+ """
203
+ # Get responses and scores
204
+ responses = chatbot.retrieve_responses(
205
+ query,
206
+ response_pool,
207
+ conversation_history
208
+ )
209
+
210
+ # Check quality
211
+ quality_metrics = quality_checker.check_response_quality(
212
+ query, responses, response_pool
213
+ )
214
+
215
+ if quality_metrics['is_confident']:
216
+ return responses[0][0], responses, quality_metrics
217
+ else:
218
+ uncertainty_response = (
219
+ "I apologize, but I don't feel confident providing an answer to that "
220
+ "question at the moment. Could you please rephrase or ask something else?"
221
+ )
222
+ return uncertainty_response, responses, quality_metrics
223
+
224
+ def get_total_steps(dialogues: List[Dict[str, Any]], batch_size: int, epochs: int) -> int:
225
+ """
226
+ Calculate total training steps based on dialogues and batch size.
227
+ Assume 80% of data for training due to validation split
228
+ """
229
+ estimated_train_samples = int(len(dialogues) * 0.8)
230
+ steps_per_epoch = estimated_train_samples // batch_size
231
+ return steps_per_epoch * epochs
232
+
233
+ def main():
234
+ DEBUG_SAMPLES = 350
235
+ BATCH_SIZE = 32
236
+ EPOCHS = 5 if DEBUG_SAMPLES else 10
237
+
238
+ # Setup training directories
239
+ dirs = setup_training_directories()
240
+
241
+ # Load training data
242
+ dialogues = load_training_data('processed_outputs', debug_samples=DEBUG_SAMPLES)
243
+ total_steps = get_total_steps(dialogues, BATCH_SIZE, EPOCHS)
244
+
245
+ # Initialize configuration
246
+ config = ChatbotConfig(
247
+ embedding_dim=32, # TODO: 256
248
+ encoder_units=32, # TODO: 256
249
+ num_attention_heads=2, # TODO: 8
250
+ warmup_steps=int(total_steps * 0.1), # 10% of total steps for warmup
251
+ )
252
+
253
+ # Save configuration
254
+ with open(dirs['base'] / 'config.json', 'w') as f:
255
+ json.dump(config.to_dict(), f, indent=2)
256
+
257
+ # Initialize chatbot
258
+ chatbot = RetrievalChatbot(config)
259
+
260
+ # Prepare dataset
261
+ logger.info("Preparing dataset...")
262
+
263
+ # Prepare and train with debug samples
264
+ q_pad, p_pad, n_pad = chatbot.prepare_dataset(
265
+ dialogues,
266
+ neg_samples_per_pos=3,
267
+ debug_samples=DEBUG_SAMPLES
268
+ )
269
+
270
+ # Train model
271
+ logger.info("Starting training...")
272
+ chatbot.train(
273
+ q_pad, p_pad, n_pad,
274
+ epochs=EPOCHS,
275
+ batch_size=BATCH_SIZE,
276
+ checkpoint_dir=dirs['checkpoints']
277
+ )
278
+
279
+ # Plot and save training history
280
+ plot_training_history(chatbot.history, save_dir=dirs['plots'])
281
+
282
+ # Save final model
283
+ chatbot.save_models(dirs['base'] / 'final_model')
284
+
285
+ # Prepare response pool for chat
286
+ response_pool = [
287
+ turn['text'] for d in dialogues
288
+ for turn in d['turns'] if turn['speaker'] == 'assistant'
289
+ ]
290
+
291
+ # Initialize quality checker with appropriate thresholds
292
+ quality_checker = ResponseQualityChecker(
293
+ confidence_threshold=0.6 if not DEBUG_SAMPLES else 0.4, # Lower threshold for debug
294
+ diversity_threshold=0.2,
295
+ min_response_length=10,
296
+ max_similarity_ratio=0.9
297
+ )
298
+
299
+ # Run automatic validation
300
+ validation_metrics = run_automatic_validation(
301
+ chatbot,
302
+ response_pool,
303
+ quality_checker,
304
+ num_examples=5 if DEBUG_SAMPLES else 10
305
+ )
306
+
307
+ # Log validation metrics
308
+ logger.info(f"Validation Metrics: {validation_metrics}")
309
+
310
+ # Now continue with interactive chat
311
+ logger.info("\nStarting interactive chat session...")
312
+ conversation_history = []
313
+
314
+ while True:
315
+ query = input("\nYou: ")
316
+ if query.lower() in ['quit', 'exit', 'bye']:
317
+ break
318
+
319
+ try:
320
+ response, candidates = chatbot.chat(
321
+ query,
322
+ response_pool,
323
+ conversation_history
324
+ )
325
+ print(f"\nAssistant: {response}")
326
+
327
+ # Print top alternative responses
328
+ print("\nAlternative responses:")
329
+ for resp, score in candidates[1:4]:
330
+ print(f"Score: {score:.4f} - {resp}")
331
+
332
+ # Update history
333
+ conversation_history.append((query, response))
334
+
335
+ except Exception as e:
336
+ logger.error(f"Error during chat: {str(e)}")
337
+ print("Sorry, I encountered an error. Please try again.")
338
+
339
+ if __name__ == "__main__":
340
+ main()
run_model3.py ADDED
@@ -0,0 +1,434 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from chatbot3 import RetrievalChatbot, ChatbotConfig
2
+ import os
3
+ import json
4
+ import glob
5
+ import tensorflow as tf
6
+ import matplotlib.pyplot as plt
7
+ import logging
8
+ from pathlib import Path
9
+ from typing import List, Dict, Optional, Any, Tuple
10
+ import numpy as np
11
+ from datetime import datetime
12
+ from response_quality_checker import ResponseQualityChecker
13
+ import torch
14
+ from transformers import TFAutoModel, AutoTokenizer
15
+
16
+
17
+ policy = tf.keras.mixed_precision.Policy('mixed_float16')
18
+ tf.keras.mixed_precision.set_global_policy(policy)
19
+
20
+ # Configure logging
21
+ logging.basicConfig(
22
+ level=logging.INFO,
23
+ format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
24
+ )
25
+ logger = logging.getLogger(__name__)
26
+
27
+ def setup_model_cache(cache_dir: Optional[Path] = None) -> Path:
28
+ """Setup and manage model cache directory."""
29
+ if cache_dir is None:
30
+ cache_dir = Path.home() / '.chatbot_cache'
31
+
32
+ cache_dir.mkdir(parents=True, exist_ok=True)
33
+
34
+ # Set environment variables for various libraries
35
+ os.environ['TRANSFORMERS_CACHE'] = str(cache_dir / 'transformers')
36
+ os.environ['TORCH_HOME'] = str(cache_dir / 'torch')
37
+ os.environ['HF_HOME'] = str(cache_dir / 'huggingface')
38
+
39
+ logger.info(f"Using cache directory: {cache_dir}")
40
+ return cache_dir
41
+
42
+ def setup_gpu():
43
+ """Configure GPU settings for optimal performance."""
44
+ logger.info("Checking GPU availability...")
45
+
46
+ gpus = tf.config.list_physical_devices('GPU')
47
+ if gpus:
48
+ try:
49
+ # Allow memory growth to prevent taking all GPU memory at once
50
+ for gpu in gpus:
51
+ tf.config.experimental.set_memory_growth(gpu, True)
52
+ logger.info(f"Found {len(gpus)} GPU(s). Memory growth enabled.")
53
+
54
+ # Log GPU info
55
+ for gpu in gpus:
56
+ logger.info(f"GPU Device: {gpu}")
57
+ return True
58
+ except Exception as e:
59
+ logger.error(f"Error configuring GPU: {str(e)}")
60
+ return False
61
+
62
+ else:
63
+ logger.info("No GPU found. Using CPU.")
64
+ return False
65
+
66
+ def preload_models(config: ChatbotConfig, cache_dir: Path):
67
+ """Preload and cache models."""
68
+ logger.info("Preloading models...")
69
+
70
+ # Cache DistilBERT
71
+ model_name = config.pretrained_model
72
+ cache_path = cache_dir / 'transformers' / model_name
73
+
74
+ if not cache_path.exists():
75
+ logger.info(f"Downloading and caching {model_name}...")
76
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
77
+ model = TFAutoModel.from_pretrained(model_name)
78
+
79
+ # Save to cache
80
+ tokenizer.save_pretrained(cache_path)
81
+ model.save_pretrained(cache_path)
82
+ else:
83
+ logger.info(f"Using cached model from {cache_path}")
84
+
85
+ return cache_path
86
+
87
+ def load_training_data(data_directory: str, debug_samples: Optional[int] = None) -> list:
88
+ """
89
+ Load and combine dialogue data from multiple JSON files.
90
+
91
+ Args:
92
+ data_directory: Directory containing the dialogue files
93
+ debug_samples: If set, only load this many dialogues for debugging
94
+ """
95
+ all_dialogues = []
96
+ data_directory = Path(data_directory)
97
+
98
+ # Get all json files matching the pattern
99
+ pattern = "batch_*.json"
100
+ json_files = sorted(data_directory.glob(pattern))
101
+
102
+ logger.info(f"Found {len(json_files)} batch files")
103
+
104
+ if debug_samples:
105
+ logger.info(f"Debug mode: Will load up to {debug_samples} dialogues")
106
+
107
+ for file_path in json_files:
108
+ try:
109
+ with open(file_path, 'r', encoding='utf-8') as f:
110
+ batch_dialogues = json.load(f)
111
+
112
+ # If in debug mode, only take what we need from this batch
113
+ if debug_samples is not None:
114
+ remaining_samples = debug_samples - len(all_dialogues)
115
+ if remaining_samples <= 0:
116
+ break
117
+ batch_dialogues = batch_dialogues[:remaining_samples]
118
+
119
+ all_dialogues.extend(batch_dialogues)
120
+ logger.info(f"Loaded {len(batch_dialogues)} dialogues from {file_path.name}")
121
+
122
+ # If we've reached our debug sample limit, stop loading
123
+ if debug_samples is not None and len(all_dialogues) >= debug_samples:
124
+ logger.info(f"Debug mode: Reached {debug_samples} samples, stopping load")
125
+ break
126
+
127
+ except Exception as e:
128
+ logger.error(f"Error loading {file_path}: {str(e)}")
129
+
130
+ total_loaded = len(all_dialogues)
131
+ if debug_samples:
132
+ logger.info(f"Debug mode: Loaded {total_loaded}/{debug_samples} requested dialogues")
133
+ else:
134
+ logger.info(f"Total dialogues loaded: {total_loaded}")
135
+
136
+ return all_dialogues
137
+
138
+ def plot_training_history(history: Dict[str, List[float]], save_dir: Path = None):
139
+ """Plot and optionally save training history."""
140
+ # Create figure with two subplots
141
+ fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(10, 12))
142
+
143
+ # Plot losses
144
+ ax1.plot(history['train_loss'], label='Train Loss')
145
+ ax1.plot(history['val_loss'], label='Validation Loss')
146
+ ax1.set_xlabel('Epoch')
147
+ ax1.set_ylabel('Triplet Loss')
148
+ ax1.set_title('Training and Validation Loss')
149
+ ax1.legend()
150
+ ax1.grid(True)
151
+
152
+ # Plot learning rate if available
153
+ if 'learning_rate' in history:
154
+ ax2.plot(history['learning_rate'], label='Learning Rate')
155
+ ax2.set_xlabel('Step')
156
+ ax2.set_ylabel('Learning Rate')
157
+ ax2.set_title('Learning Rate Schedule')
158
+ ax2.legend()
159
+ ax2.grid(True)
160
+
161
+ plt.tight_layout()
162
+
163
+ # Save if directory provided
164
+ if save_dir:
165
+ save_dir = Path(save_dir)
166
+ save_dir.mkdir(parents=True, exist_ok=True)
167
+ timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
168
+ plt.savefig(save_dir / f'training_history_{timestamp}.png')
169
+
170
+ plt.show()
171
+
172
+ def setup_training_directories(base_dir: str = "chatbot_training") -> Dict[str, Path]:
173
+ """Setup directory structure for training artifacts."""
174
+ base_dir = Path(base_dir)
175
+ timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
176
+ train_dir = base_dir / f"training_run_{timestamp}"
177
+
178
+ directories = {
179
+ 'base': train_dir,
180
+ 'checkpoints': train_dir / 'checkpoints',
181
+ 'plots': train_dir / 'plots',
182
+ 'logs': train_dir / 'logs'
183
+ }
184
+
185
+ # Create directories
186
+ for dir_path in directories.values():
187
+ dir_path.mkdir(parents=True, exist_ok=True)
188
+
189
+ return directories
190
+
191
+ def run_automatic_validation(
192
+ chatbot,
193
+ response_pool: List[str],
194
+ quality_checker: ResponseQualityChecker,
195
+ num_examples: int = 5
196
+ ) -> Dict[str, Any]:
197
+ """
198
+ Run automatic validation with quality metrics.
199
+ """
200
+ logger.info("\n=== Running Automatic Validation ===")
201
+
202
+ test_queries = [
203
+ "Hello, how are you today?",
204
+ "What's the weather like?",
205
+ "Can you help me with a problem?",
206
+ "Tell me a joke",
207
+ "What time is it?",
208
+ "I need help with my homework",
209
+ "Where's a good place to eat?",
210
+ "What movies are playing?",
211
+ "How do I reset my password?",
212
+ "Can you recommend a book?"
213
+ ]
214
+
215
+ test_queries = test_queries[:num_examples]
216
+ metrics_history = []
217
+
218
+ for i, query in enumerate(test_queries, 1):
219
+ logger.info(f"\nTest Case {i}:")
220
+ logger.info(f"Query: {query}")
221
+
222
+ # Get responses and scores
223
+ responses = chatbot.retrieve_responses(
224
+ query,
225
+ response_pool,
226
+ context=None,
227
+ top_k=5
228
+ )
229
+
230
+ # Check quality
231
+ quality_metrics = quality_checker.check_response_quality(
232
+ query, responses, response_pool
233
+ )
234
+ metrics_history.append(quality_metrics)
235
+
236
+ # Log results
237
+ logger.info(f"Quality Metrics: {quality_metrics}")
238
+ logger.info("Top responses:")
239
+ for j, (response, score) in enumerate(responses[:3], 1):
240
+ logger.info(f"{j}. Score: {score:.4f}")
241
+ logger.info(f" Response: {response}")
242
+ if j == 1 and not quality_metrics['is_confident']:
243
+ logger.info(" [Low Confidence - Would abstain from answering]")
244
+
245
+ # Calculate aggregate metrics
246
+ aggregate_metrics = {
247
+ 'num_queries_tested': len(test_queries),
248
+ 'avg_top_response_score': np.mean([m['top_score'] for m in metrics_history]),
249
+ 'avg_diversity': np.mean([m['response_diversity'] for m in metrics_history]),
250
+ 'avg_relevance': np.mean([m['query_response_relevance'] for m in metrics_history]),
251
+ 'confidence_rate': np.mean([m['is_confident'] for m in metrics_history]),
252
+ }
253
+
254
+ logger.info("\n=== Validation Summary ===")
255
+ for metric, value in aggregate_metrics.items():
256
+ logger.info(f"{metric}: {value:.4f}")
257
+
258
+ return aggregate_metrics
259
+
260
+ def chat_with_quality_check(
261
+ chatbot,
262
+ query: str,
263
+ response_pool: List[str],
264
+ conversation_history: List[Tuple[str, str]],
265
+ quality_checker: ResponseQualityChecker
266
+ ) -> Tuple[Optional[str], List[Tuple[str, float]], Dict[str, Any]]:
267
+ """
268
+ Enhanced chat function with quality checking.
269
+ """
270
+ # Get responses and scores
271
+ responses = chatbot.retrieve_responses(
272
+ query,
273
+ response_pool,
274
+ conversation_history
275
+ )
276
+
277
+ # Check quality
278
+ quality_metrics = quality_checker.check_response_quality(
279
+ query, responses, response_pool
280
+ )
281
+
282
+ if quality_metrics['is_confident']:
283
+ return responses[0][0], responses, quality_metrics
284
+ else:
285
+ uncertainty_response = (
286
+ "I apologize, but I don't feel confident providing an answer to that "
287
+ "question at the moment. Could you please rephrase or ask something else?"
288
+ )
289
+ return uncertainty_response, responses, quality_metrics
290
+
291
+ def get_total_steps(dialogues: List[Dict[str, Any]], batch_size: int, epochs: int) -> int:
292
+ """
293
+ Calculate total training steps based on dialogues and batch size.
294
+ Assume 80% of data for training due to validation split
295
+ """
296
+ estimated_train_samples = int(len(dialogues) * 0.8)
297
+ steps_per_epoch = estimated_train_samples // batch_size
298
+ return steps_per_epoch * epochs
299
+
300
+ def main():
301
+ # Set up GPU
302
+ is_gpu = setup_gpu()
303
+
304
+ DEBUG_SAMPLES = 350
305
+ BATCH_SIZE = 64 if is_gpu else 32
306
+ EPOCHS = 5 if DEBUG_SAMPLES else 10
307
+
308
+ # Set device
309
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
310
+ logger.info(f"Using device: {device}")
311
+
312
+ # Set up caching
313
+ cache_dir = setup_model_cache()
314
+
315
+ # Set up training directories
316
+ dirs = setup_training_directories()
317
+
318
+ # Load training data
319
+ dialogues = load_training_data('processed_outputs', debug_samples=DEBUG_SAMPLES)
320
+ total_steps = get_total_steps(dialogues, BATCH_SIZE, EPOCHS)
321
+
322
+ # Initialize configuration
323
+ config = ChatbotConfig(
324
+ embedding_dim=768, # Match DistilBERT's dimension
325
+ encoder_units=256,
326
+ num_attention_heads=8,
327
+ warmup_steps=int(total_steps * 0.1),
328
+ learning_rate=0.0003,
329
+ margin=0.5,
330
+ pretrained_model='distilbert-base-uncased'
331
+ )
332
+
333
+ # Preload models
334
+ preload_models(config, cache_dir)
335
+
336
+ # Save configuration
337
+ with open(dirs['base'] / 'config.json', 'w') as f:
338
+ json.dump(config.to_dict(), f, indent=2)
339
+
340
+ # Initialize chatbot
341
+ chatbot = RetrievalChatbot(config)
342
+
343
+ # Prepare dataset
344
+ logger.info("Preparing dataset...")
345
+
346
+ # Prepare and train with debug samples
347
+ q_pad, p_pad, n_pad = chatbot.prepare_dataset(
348
+ dialogues,
349
+ neg_samples_per_pos=3,
350
+ debug_samples=DEBUG_SAMPLES
351
+ )
352
+
353
+ tensorboard_callback = tf.keras.callbacks.TensorBoard(log_dir='./logs')
354
+
355
+ # Train model
356
+ logger.info("Starting training...")
357
+ chatbot.train(
358
+ q_pad, p_pad, n_pad,
359
+ epochs=EPOCHS,
360
+ batch_size=BATCH_SIZE,
361
+ validation_split=0.2,
362
+ checkpoint_dir=dirs['checkpoints'],
363
+ callbacks=[tensorboard_callback]
364
+ )
365
+
366
+ # Plot and save training history
367
+ plot_training_history(chatbot.history, save_dir=dirs['plots'])
368
+
369
+ # Save final model
370
+ chatbot.save_models(dirs['base'] / 'final_model')
371
+
372
+ # Prepare response pool for chat
373
+ response_pool = [
374
+ turn['text'] for d in dialogues
375
+ for turn in d['turns'] if turn['speaker'] == 'assistant'
376
+ ]
377
+
378
+ # Initialize quality checker with appropriate thresholds
379
+ quality_checker = ResponseQualityChecker(
380
+ confidence_threshold=0.6 if not DEBUG_SAMPLES else 0.4, # Lower threshold for debug
381
+ diversity_threshold=0.2,
382
+ min_response_length=10,
383
+ max_similarity_ratio=0.9
384
+ )
385
+
386
+ # Run automatic validation
387
+ validation_metrics = run_automatic_validation(
388
+ chatbot,
389
+ response_pool,
390
+ quality_checker,
391
+ num_examples=5 if DEBUG_SAMPLES else 10
392
+ )
393
+
394
+ # Log validation metrics
395
+ logger.info(f"Validation Metrics: {validation_metrics}")
396
+
397
+ # Now continue with interactive chat
398
+ logger.info("\nStarting interactive chat session...")
399
+ conversation_history = []
400
+
401
+ while True:
402
+ query = input("\nYou: ")
403
+ if query.lower() in ['quit', 'exit', 'bye']:
404
+ break
405
+
406
+ try:
407
+ response, candidates, quality_metrics = chat_with_quality_check(
408
+ chatbot,
409
+ query,
410
+ response_pool,
411
+ conversation_history,
412
+ quality_checker
413
+ )
414
+ print(f"\nAssistant: {response}")
415
+
416
+ # Print top alternative responses if confident
417
+ if quality_metrics['is_confident']:
418
+ print("\nAlternative responses:")
419
+ for resp, score in candidates[1:4]:
420
+ print(f"Score: {score:.4f} - {resp}")
421
+
422
+ # Update history only for confident responses
423
+ conversation_history.append((query, response))
424
+ else:
425
+ print("\nQuality metrics indicated low confidence:")
426
+ print(f"Confidence score: {quality_metrics['top_score']:.4f}")
427
+ print(f"Response relevance: {quality_metrics['query_response_relevance']:.4f}")
428
+
429
+ except Exception as e:
430
+ logger.error(f"Error during chat: {str(e)}")
431
+ print("Sorry, I encountered an error. Please try again.")
432
+
433
+ if __name__ == "__main__":
434
+ main()
run_model4.py ADDED
@@ -0,0 +1,237 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from chatbot4 import RetrievalChatbot, ChatbotConfig
2
+ import os
3
+ import tensorflow as tf
4
+ import matplotlib.pyplot as plt
5
+ import logging
6
+ from pathlib import Path
7
+ from typing import List, Dict, Optional
8
+ from datetime import datetime
9
+ from response_quality_checker import ResponseQualityChecker
10
+
11
+ # Configure logging
12
+ logging.basicConfig(
13
+ level=logging.INFO,
14
+ format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
15
+ )
16
+ logger = logging.getLogger(__name__)
17
+
18
+ def setup_model_cache(cache_dir: Optional[Path] = None) -> Path:
19
+ """Setup and manage model cache directory."""
20
+ if cache_dir is None:
21
+ cache_dir = Path.home() / '.chatbot_cache'
22
+
23
+ cache_dir.mkdir(parents=True, exist_ok=True)
24
+
25
+ # Set environment variables for various libraries
26
+ os.environ['TRANSFORMERS_CACHE'] = str(cache_dir / 'transformers')
27
+ os.environ['TORCH_HOME'] = str(cache_dir / 'torch')
28
+ os.environ['HF_HOME'] = str(cache_dir / 'huggingface')
29
+
30
+ logger.info(f"Using cache directory: {cache_dir}")
31
+ return cache_dir
32
+
33
+ def setup_gpu():
34
+ """Configure GPU settings for optimal performance."""
35
+ logger.info("Checking GPU availability...")
36
+
37
+ gpus = tf.config.list_physical_devices('GPU')
38
+ if gpus:
39
+ try:
40
+ # Allow memory growth to prevent taking all GPU memory at once
41
+ for gpu in gpus:
42
+ tf.config.experimental.set_memory_growth(gpu, True)
43
+ logger.info(f"Found {len(gpus)} GPU(s). Memory growth enabled.")
44
+
45
+ # Log GPU info
46
+ for gpu in gpus:
47
+ logger.info(f"GPU Device: {gpu}")
48
+ return True
49
+ except Exception as e:
50
+ logger.error(f"Error configuring GPU: {str(e)}")
51
+ return False
52
+
53
+ else:
54
+ logger.info("No GPU found. Using CPU.")
55
+ return False
56
+
57
+ # def preload_models(config: ChatbotConfig, cache_dir: Path):
58
+ # """Preload and cache models."""
59
+ # logger.info("Preloading models...")
60
+
61
+ # # Cache DistilBERT
62
+ # model_name = config.pretrained_model
63
+ # cache_path = cache_dir / 'transformers' / model_name
64
+
65
+ # if not cache_path.exists():
66
+ # logger.info(f"Downloading and caching {model_name}...")
67
+ # tokenizer = AutoTokenizer.from_pretrained(model_name)
68
+ # model = TFAutoModel.from_pretrained(model_name)
69
+
70
+ # # Save to cache
71
+ # tokenizer.save_pretrained(cache_path)
72
+ # model.save_pretrained(cache_path)
73
+ # else:
74
+ # logger.info(f"Using cached model from {cache_path}")
75
+
76
+ # return cache_path
77
+
78
+ def plot_training_history(history: Dict[str, List[float]], save_dir: Path = None):
79
+ """Plot and optionally save training history."""
80
+ # Create figure with two subplots
81
+ fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(10, 12))
82
+
83
+ # Plot losses
84
+ ax1.plot(history['train_loss'], label='Train Loss')
85
+ ax1.plot(history['val_loss'], label='Validation Loss')
86
+ ax1.set_xlabel('Epoch')
87
+ ax1.set_ylabel('Triplet Loss')
88
+ ax1.set_title('Training and Validation Loss')
89
+ ax1.legend()
90
+ ax1.grid(True)
91
+
92
+ # Plot learning rate if available
93
+ if 'learning_rate' in history:
94
+ ax2.plot(history['learning_rate'], label='Learning Rate')
95
+ ax2.set_xlabel('Step')
96
+ ax2.set_ylabel('Learning Rate')
97
+ ax2.set_title('Learning Rate Schedule')
98
+ ax2.legend()
99
+ ax2.grid(True)
100
+
101
+ plt.tight_layout()
102
+
103
+ # Save if directory provided
104
+ if save_dir:
105
+ save_dir = Path(save_dir)
106
+ save_dir.mkdir(parents=True, exist_ok=True)
107
+ timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
108
+ plt.savefig(save_dir / f'training_history_{timestamp}.png')
109
+
110
+ plt.show()
111
+
112
+ def setup_training_directories(base_dir: str = "chatbot_training") -> Dict[str, Path]:
113
+ """Setup directory structure for training artifacts."""
114
+ base_dir = Path(base_dir)
115
+ timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
116
+ train_dir = base_dir / f"training_run_{timestamp}"
117
+
118
+ directories = {
119
+ 'base': train_dir,
120
+ 'checkpoints': train_dir / 'checkpoints',
121
+ 'plots': train_dir / 'plots',
122
+ 'logs': train_dir / 'logs'
123
+ }
124
+
125
+ # Create directories
126
+ for dir_path in directories.values():
127
+ dir_path.mkdir(parents=True, exist_ok=True)
128
+
129
+ return directories
130
+
131
+ def main():
132
+ # Set up GPU
133
+ is_gpu = setup_gpu()
134
+
135
+ DEBUG_SAMPLES = 2000
136
+ BATCH_SIZE = 128 if is_gpu else 64
137
+ EPOCHS = 5 if DEBUG_SAMPLES else 10
138
+
139
+ # Set up caching
140
+ cache_dir = setup_model_cache()
141
+
142
+ # Set up training directories
143
+ dirs = setup_training_directories()
144
+
145
+ # Initialize configuration
146
+ config = ChatbotConfig(
147
+ embedding_dim=768, # Match DistilBERT's dimension
148
+ max_sequence_length=512,
149
+ freeze_embeddings=False
150
+ )
151
+
152
+ # Preload models
153
+ #preload_models(config, cache_dir)
154
+
155
+ # Save configuration
156
+ # with open(dirs['base'] / 'config.json', 'w') as f:
157
+ # json.dump(config.to_dict(), f, indent=4)
158
+
159
+ # Load training data
160
+ dialogues = RetrievalChatbot.load_training_data(data_path='processed_outputs/batch_group_0010.json', debug_samples=DEBUG_SAMPLES)
161
+
162
+ # Initialize chatbot
163
+ chatbot = RetrievalChatbot(config, dialogues)
164
+
165
+ # Check trainable variables
166
+ chatbot.check_trainable_variables()
167
+
168
+ # Verify FAISS
169
+ chatbot.verify_faiss_index()
170
+
171
+ # Prepare dataset
172
+ logger.info("Preparing dataset...")
173
+
174
+ # Prepare and train with debug samples
175
+ q_tensor, p_tensor = chatbot.prepare_dataset(dialogues)
176
+
177
+ quality_checker = ResponseQualityChecker(chatbot=chatbot)
178
+
179
+ # Train model
180
+ logger.info("Starting training...")
181
+
182
+ tf.config.optimizer.set_jit(True) # XLA
183
+ policy = tf.keras.mixed_precision.Policy('mixed_float16')
184
+ tf.keras.mixed_precision.set_global_policy(policy)
185
+
186
+ chatbot.train(
187
+ q_pad=q_tensor,
188
+ p_pad=p_tensor,
189
+ epochs=EPOCHS,
190
+ batch_size=BATCH_SIZE,
191
+ validation_split=0.2,
192
+ checkpoint_dir="checkpoints/",
193
+ use_lr_schedule=True, # Enable custom schedule
194
+ peak_lr=2e-5, # Peak learning rate
195
+ warmup_steps_ratio=0.1, # 10% warmup
196
+ early_stopping_patience=3 # Stop if no improvement for 3 epochs
197
+ )
198
+
199
+ # Plot and save training history
200
+ #plot_training_history(chatbot.history, save_dir=dirs['plots'])
201
+
202
+ # Save final model
203
+ chatbot.save_models(dirs['base'] / 'final_model')
204
+
205
+ # Run automatic validation
206
+ validation_metrics = chatbot.run_automatic_validation(quality_checker, num_examples=5)
207
+
208
+ # Log validation metrics
209
+ logger.info(f"Validation Metrics: {validation_metrics}")
210
+
211
+ # Now continue with interactive chat
212
+ logger.info("\nStarting interactive chat session...")
213
+ conversation_history = []
214
+
215
+ while True:
216
+ user_input = input("You: ")
217
+ if user_input.lower() in ['quit', 'exit', 'bye']:
218
+ print("Assistant: Goodbye!")
219
+ break
220
+
221
+ response, candidates, metrics = chatbot.chat(
222
+ query=user_input,
223
+ conversation_history=None, # Pass conversation history if available
224
+ quality_checker=quality_checker,
225
+ top_k=5
226
+ )
227
+
228
+ print(f"Assistant: {response}")
229
+
230
+ # Optionally, display alternative responses
231
+ if metrics.get('is_confident', False):
232
+ print("\nAlternative responses:")
233
+ for resp, score in candidates[1:4]:
234
+ print(f"Score: {score:.4f} - {resp}")
235
+
236
+ if __name__ == "__main__":
237
+ main()
setup.py CHANGED
@@ -1,6 +1,7 @@
1
  from setuptools import setup, find_packages
2
  import subprocess
3
  import sys
 
4
 
5
  with open("README.md", "r", encoding="utf-8") as fh:
6
  long_description = fh.read()
@@ -8,11 +9,21 @@ with open("README.md", "r", encoding="utf-8") as fh:
8
  with open("requirements.txt", "r", encoding="utf-8") as fh:
9
  requirements = [line.strip() for line in fh if line.strip() and not line.startswith("#")]
10
 
11
- def setup_spacy_model():
12
  """
13
- Download spaCy model.
 
 
 
14
  """
15
- subprocess.check_call([sys.executable, "-m", "spacy", "download", "en_core_web_sm"])
 
 
 
 
 
 
 
16
 
17
  def setup_models():
18
  """
@@ -22,10 +33,17 @@ def setup_models():
22
  from sklearn.feature_extraction.text import TfidfVectorizer
23
  from transformers import (
24
  AutoTokenizer,
 
25
  GPT2TokenizerFast,
26
- MarianTokenizer
 
 
27
  )
28
 
 
 
 
 
29
  # Download Universal Sentence Encoder
30
  _ = hub.load('https://tfhub.dev/google/universal-sentence-encoder/4')
31
 
@@ -63,12 +81,48 @@ def setup_nltk():
63
  except Exception as e:
64
  print(f"Warning: Could not download {package}: {str(e)}")
65
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
66
  setup(
67
  name="text-data-augmenter",
68
  version="0.1.0",
69
  author="Joe Armani",
70
  author_email="[email protected]",
71
  description="A tool for generating high-quality dialogue variations",
 
 
72
  packages=find_packages(),
73
  classifiers=[
74
  "Development Status :: 3 - Alpha",
@@ -95,6 +149,7 @@ setup(
95
  )
96
 
97
  if __name__ == '__main__':
98
- setup_spacy_model()
99
  setup_models()
100
- setup_nltk()
 
 
1
  from setuptools import setup, find_packages
2
  import subprocess
3
  import sys
4
+ import platform
5
 
6
  with open("README.md", "r", encoding="utf-8") as fh:
7
  long_description = fh.read()
 
9
  with open("requirements.txt", "r", encoding="utf-8") as fh:
10
  requirements = [line.strip() for line in fh if line.strip() and not line.startswith("#")]
11
 
12
+ def setup_spacy_models(models=['en_core_web_sm', 'en_core_web_md']):
13
  """
14
+ Download the specified spaCy model.
15
+
16
+ Args:
17
+ models(List): List[str] of the names of the spaCy model to download.
18
  """
19
+ try:
20
+ for model in models:
21
+ print(f"Downloading spaCy model: {model}")
22
+ subprocess.check_call([sys.executable, "-m", "spacy", "download", model])
23
+ print(f"Successfully downloaded spaCy model: {model}")
24
+ except subprocess.CalledProcessError as e:
25
+ print(f"Error downloading spaCy model: {model}")
26
+ print(e)
27
 
28
  def setup_models():
29
  """
 
33
  from sklearn.feature_extraction.text import TfidfVectorizer
34
  from transformers import (
35
  AutoTokenizer,
36
+ AutoModel,
37
  GPT2TokenizerFast,
38
+ MarianTokenizer,
39
+ DistilBertTokenizer,
40
+ DistilBertModel
41
  )
42
 
43
+ # Download DistilBERT for chatbot
44
+ tokenizer = DistilBertTokenizer.from_pretrained('distilbert-base-uncased')
45
+ model = DistilBertModel.from_pretrained('distilbert-base-uncased')
46
+
47
  # Download Universal Sentence Encoder
48
  _ = hub.load('https://tfhub.dev/google/universal-sentence-encoder/4')
49
 
 
81
  except Exception as e:
82
  print(f"Warning: Could not download {package}: {str(e)}")
83
 
84
+ def setup_faiss():
85
+ """
86
+ Download required faiss library.
87
+ """
88
+ current_os = platform.system()
89
+ cuda_available = False
90
+
91
+ # Function to check CUDA availability
92
+ def check_cuda():
93
+ try:
94
+ import torch
95
+ return torch.cuda.is_available()
96
+ except:
97
+ return False
98
+
99
+ if current_os == "Linux" and check_cuda():
100
+ # Attempt to install faiss-gpu
101
+ try:
102
+ print("Attempting to install faiss-gpu...")
103
+ subprocess.check_call([sys.executable, "-m", "pip", "install", "faiss-gpu>=1.7.0"])
104
+ print("Successfully installed faiss-gpu")
105
+ return
106
+ except subprocess.CalledProcessError:
107
+ print("Failed to install faiss-gpu. Falling back to faiss-cpu.")
108
+
109
+ # Install faiss-cpu as the default
110
+ try:
111
+ print("Installing faiss-cpu...")
112
+ subprocess.check_call([sys.executable, "-m", "pip", "install", "faiss-cpu>=1.7.0"])
113
+ print("Successfully installed faiss-cpu")
114
+ except subprocess.CalledProcessError as e:
115
+ print("Error installing faiss-cpu")
116
+ print(e)
117
+
118
  setup(
119
  name="text-data-augmenter",
120
  version="0.1.0",
121
  author="Joe Armani",
122
  author_email="[email protected]",
123
  description="A tool for generating high-quality dialogue variations",
124
+ long_description=long_description,
125
+ long_description_content_type="text/markdown",
126
  packages=find_packages(),
127
  classifiers=[
128
  "Development Status :: 3 - Alpha",
 
149
  )
150
 
151
  if __name__ == '__main__':
152
+ setup_spacy_models()
153
  setup_models()
154
+ setup_nltk()
155
+ setup_faiss()
test_trained_model.py ADDED
File without changes