JoeArmani
commited on
Commit
·
f7b283c
1
Parent(s):
300fe5d
summarization, reranker, environment setup, and response quality checker
Browse files- chatbot.py +0 -261
- chatbot2.py +0 -839
- chatbot3.py +0 -824
- chatbot4.py → chatbot_model.py +595 -629
- chatbot_validator.py +207 -0
- conversation_summarizer.py +147 -0
- cross_encoder_reranker.py +51 -0
- dialogue_augmenter.py +0 -1
- environment_setup.py +207 -0
- logger_config.py +10 -0
- requirements.txt +13 -1
- response_quality_checker.py +125 -119
- run_model.py +0 -162
- run_model2.py +0 -340
- run_model3.py +0 -434
- run_model4.py +0 -237
- run_model_train.py +96 -0
- setup.py +40 -4
- training_plotter.py +127 -0
chatbot.py
DELETED
@@ -1,261 +0,0 @@
|
|
1 |
-
import numpy as np
|
2 |
-
import tensorflow as tf
|
3 |
-
import keras
|
4 |
-
print(tf.__version__)
|
5 |
-
print(keras.__version__)
|
6 |
-
import spacy
|
7 |
-
import random
|
8 |
-
from tqdm import trange
|
9 |
-
|
10 |
-
class RetrievalChatbot:
|
11 |
-
def __init__(
|
12 |
-
self,
|
13 |
-
vocab_size: int = 10000,
|
14 |
-
max_sequence_length: int = 80,
|
15 |
-
embedding_dim: int = 256,
|
16 |
-
lstm_units: int = 256,
|
17 |
-
num_attention_heads: int = 8,
|
18 |
-
margin: float = 0.3
|
19 |
-
):
|
20 |
-
self.vocab_size = vocab_size
|
21 |
-
self.max_sequence_length = max_sequence_length
|
22 |
-
self.embedding_dim = embedding_dim
|
23 |
-
self.lstm_units = lstm_units
|
24 |
-
self.num_attention_heads = num_attention_heads
|
25 |
-
self.margin = margin
|
26 |
-
|
27 |
-
self.nlp = spacy.load('en_core_web_md')
|
28 |
-
self.tokenizer = tf.keras.preprocessing.text.Tokenizer(
|
29 |
-
num_words=vocab_size,
|
30 |
-
oov_token="<OOV>"
|
31 |
-
)
|
32 |
-
|
33 |
-
self.query_encoder_model, self.response_encoder_model = self._build_encoders()
|
34 |
-
|
35 |
-
def _positional_encoding(self, position: int, d_model: int) -> tf.Tensor:
|
36 |
-
angles = np.arange(position)[:, np.newaxis] / np.power(
|
37 |
-
10000,
|
38 |
-
(2 * (np.arange(d_model)[np.newaxis, :] // 2)) / d_model
|
39 |
-
)
|
40 |
-
sines = np.sin(angles[:, 0::2])
|
41 |
-
cosines = np.cos(angles[:, 1::2])
|
42 |
-
pos_encoding = np.concatenate([sines, cosines], axis=-1)
|
43 |
-
pos_encoding = pos_encoding[np.newaxis, ...]
|
44 |
-
return tf.cast(pos_encoding, dtype=tf.float32)
|
45 |
-
|
46 |
-
def _build_single_encoder(self, name_prefix: str):
|
47 |
-
input_layer = tf.keras.Input(shape=(self.max_sequence_length,), name=f"{name_prefix}_input")
|
48 |
-
embedding = tf.keras.layers.Embedding(
|
49 |
-
self.vocab_size,
|
50 |
-
self.embedding_dim,
|
51 |
-
mask_zero=True,
|
52 |
-
name=f"{name_prefix}_embedding"
|
53 |
-
)(input_layer)
|
54 |
-
|
55 |
-
pos_encoding = self._positional_encoding(self.max_sequence_length, self.embedding_dim)
|
56 |
-
x = embedding + pos_encoding
|
57 |
-
|
58 |
-
# # Multi-head attention
|
59 |
-
# attention_output = tf.keras.layers.MultiHeadAttention(
|
60 |
-
# num_heads=self.num_attention_heads,
|
61 |
-
# key_dim=self.embedding_dim // self.num_attention_heads
|
62 |
-
# )(x, x)
|
63 |
-
# x = tf.keras.layers.LayerNormalization()(x + attention_output)
|
64 |
-
|
65 |
-
for i in range(2):
|
66 |
-
lstm_out = tf.keras.layers.LSTM(
|
67 |
-
self.lstm_units,
|
68 |
-
return_sequences=True,
|
69 |
-
kernel_regularizer=tf.keras.regularizers.l2(0.01),
|
70 |
-
name=f"{name_prefix}_lstm_{i}"
|
71 |
-
)(x)
|
72 |
-
x = tf.keras.layers.LayerNormalization()(x + lstm_out)
|
73 |
-
|
74 |
-
encoder_output = tf.keras.layers.LSTM(
|
75 |
-
self.lstm_units,
|
76 |
-
name=f"{name_prefix}_final_lstm"
|
77 |
-
)(x)
|
78 |
-
encoder_output = tf.keras.layers.Dropout(0.2)(encoder_output)
|
79 |
-
encoder_output = tf.keras.layers.Lambda(lambda t: tf.nn.l2_normalize(t, axis=1))(encoder_output)
|
80 |
-
|
81 |
-
return tf.keras.Model(input_layer, encoder_output, name=f"{name_prefix}_encoder")
|
82 |
-
|
83 |
-
def _build_encoders(self):
|
84 |
-
query_encoder = self._build_single_encoder("query")
|
85 |
-
response_encoder = self._build_single_encoder("response")
|
86 |
-
return query_encoder, response_encoder
|
87 |
-
|
88 |
-
def _spacy_similarity(self, text1: str, text2: str) -> float:
|
89 |
-
doc1 = self.nlp(text1)
|
90 |
-
doc2 = self.nlp(text2)
|
91 |
-
print('doc1:', doc1)
|
92 |
-
print('doc2:', doc2)
|
93 |
-
print('doc1.similarity(doc2):', doc1.similarity(doc2))
|
94 |
-
return doc1.similarity(doc2)
|
95 |
-
|
96 |
-
def prepare_dataset(self, dialogues: list, neg_samples_per_pos=3):
|
97 |
-
# Create triplets: (query, positive, negative)
|
98 |
-
response_pool = [
|
99 |
-
turn['text'] for d in dialogues for turn in d['turns'] if turn['speaker'] == 'assistant'
|
100 |
-
]
|
101 |
-
queries, positives, negatives = [], [], []
|
102 |
-
|
103 |
-
for dialogue in dialogues:
|
104 |
-
turns = dialogue['turns']
|
105 |
-
for i in range(0, len(turns)-1):
|
106 |
-
if turns[i]['speaker'] == 'user' and turns[i+1]['speaker'] == 'assistant':
|
107 |
-
q = turns[i]['text']
|
108 |
-
p = turns[i+1]['text']
|
109 |
-
|
110 |
-
# Find negatives using spaCy similarity
|
111 |
-
neg_candidates = []
|
112 |
-
attempts = 0
|
113 |
-
while len(neg_candidates) < neg_samples_per_pos and attempts < 200:
|
114 |
-
cand = random.choice(response_pool)
|
115 |
-
if cand != p:
|
116 |
-
sim = self._spacy_similarity(cand, p)
|
117 |
-
# Choose thresholds that produce hard negatives
|
118 |
-
if 0.4 < sim < 0.9:
|
119 |
-
neg_candidates.append(cand)
|
120 |
-
attempts += 1
|
121 |
-
|
122 |
-
if len(neg_candidates) == neg_samples_per_pos:
|
123 |
-
for neg in neg_candidates:
|
124 |
-
queries.append(q)
|
125 |
-
positives.append(p)
|
126 |
-
negatives.append(neg)
|
127 |
-
|
128 |
-
# Fit tokenizer
|
129 |
-
all_text = queries + positives + negatives
|
130 |
-
self.tokenizer.fit_on_texts(all_text)
|
131 |
-
|
132 |
-
def seq_pad(txts):
|
133 |
-
seq = self.tokenizer.texts_to_sequences(txts)
|
134 |
-
return tf.keras.preprocessing.sequence.pad_sequences(seq, maxlen=self.max_sequence_length, padding='post')
|
135 |
-
|
136 |
-
q_pad = seq_pad(queries)
|
137 |
-
p_pad = seq_pad(positives)
|
138 |
-
n_pad = seq_pad(negatives)
|
139 |
-
|
140 |
-
return q_pad, p_pad, n_pad
|
141 |
-
|
142 |
-
def triplet_loss(self, q_emb, p_emb, n_emb):
|
143 |
-
pos_dist = tf.reduce_sum(tf.square(q_emb - p_emb), axis=1)
|
144 |
-
neg_dist = tf.reduce_sum(tf.square(q_emb - n_emb), axis=1)
|
145 |
-
loss = tf.maximum(0.0, self.margin + pos_dist - neg_dist)
|
146 |
-
return tf.reduce_mean(loss)
|
147 |
-
|
148 |
-
def train_with_triplet_loss(
|
149 |
-
self, q_pad, p_pad, n_pad,
|
150 |
-
epochs=3,
|
151 |
-
batch_size=16,
|
152 |
-
validation_split=0.2,
|
153 |
-
early_stopping_patience=3,
|
154 |
-
use_tqdm=True
|
155 |
-
):
|
156 |
-
train_losses = []
|
157 |
-
val_losses = []
|
158 |
-
|
159 |
-
total_samples = len(q_pad)
|
160 |
-
idxs = np.arange(total_samples)
|
161 |
-
np.random.shuffle(idxs)
|
162 |
-
train_size = int((1 - validation_split) * total_samples)
|
163 |
-
|
164 |
-
train_idxs = idxs[:train_size]
|
165 |
-
val_idxs = idxs[train_size:]
|
166 |
-
|
167 |
-
q_train, p_train, n_train = q_pad[train_idxs], p_pad[train_idxs], n_pad[train_idxs]
|
168 |
-
q_val, p_val, n_val = q_pad[val_idxs], p_pad[val_idxs], n_pad[val_idxs]
|
169 |
-
|
170 |
-
optimizer = tf.keras.optimizers.Adam(learning_rate=0.001)
|
171 |
-
best_val_loss = float('inf')
|
172 |
-
wait = 0
|
173 |
-
|
174 |
-
for epoch in range(epochs):
|
175 |
-
# Shuffle training data each epoch
|
176 |
-
perm = np.random.permutation(len(q_train))
|
177 |
-
q_train, p_train, n_train = q_train[perm], p_train[perm], n_train[perm]
|
178 |
-
|
179 |
-
num_batches = len(q_train) // batch_size
|
180 |
-
epoch_train_loss = 0.0
|
181 |
-
|
182 |
-
batch_iter = range(num_batches)
|
183 |
-
if use_tqdm:
|
184 |
-
batch_iter = trange(num_batches, desc=f"Epoch {epoch+1}/{epochs}")
|
185 |
-
|
186 |
-
for i in batch_iter:
|
187 |
-
q_batch = q_train[i*batch_size:(i+1)*batch_size]
|
188 |
-
p_batch = p_train[i*batch_size:(i+1)*batch_size]
|
189 |
-
n_batch = n_train[i*batch_size:(i+1)*batch_size]
|
190 |
-
|
191 |
-
with tf.GradientTape() as tape:
|
192 |
-
q_emb = self.query_encoder_model(q_batch, training=True)
|
193 |
-
p_emb = self.response_encoder_model(p_batch, training=True)
|
194 |
-
n_emb = self.response_encoder_model(n_batch, training=True)
|
195 |
-
loss = self.triplet_loss(q_emb, p_emb, n_emb)
|
196 |
-
|
197 |
-
grads = tape.gradient(
|
198 |
-
loss,
|
199 |
-
self.query_encoder_model.trainable_variables +
|
200 |
-
self.response_encoder_model.trainable_variables
|
201 |
-
)
|
202 |
-
optimizer.apply_gradients(zip(
|
203 |
-
grads,
|
204 |
-
self.query_encoder_model.trainable_variables +
|
205 |
-
self.response_encoder_model.trainable_variables
|
206 |
-
))
|
207 |
-
epoch_train_loss += loss.numpy()
|
208 |
-
|
209 |
-
epoch_train_loss /= num_batches
|
210 |
-
|
211 |
-
# Validation loss
|
212 |
-
val_batches = len(q_val) // batch_size
|
213 |
-
epoch_val_loss = 0.0
|
214 |
-
for i in range(val_batches):
|
215 |
-
q_batch = q_val[i*batch_size:(i+1)*batch_size]
|
216 |
-
p_batch = p_val[i*batch_size:(i+1)*batch_size]
|
217 |
-
n_batch = n_val[i*batch_size:(i+1)*batch_size]
|
218 |
-
|
219 |
-
q_emb = self.query_encoder_model(q_batch, training=False)
|
220 |
-
p_emb = self.response_encoder_model(p_batch, training=False)
|
221 |
-
n_emb = self.response_encoder_model(n_batch, training=False)
|
222 |
-
v_loss = self.triplet_loss(q_emb, p_emb, n_emb)
|
223 |
-
epoch_val_loss += v_loss.numpy()
|
224 |
-
|
225 |
-
if val_batches > 0:
|
226 |
-
epoch_val_loss /= val_batches
|
227 |
-
|
228 |
-
train_losses.append(epoch_train_loss)
|
229 |
-
val_losses.append(epoch_val_loss)
|
230 |
-
|
231 |
-
print(f"Epoch {epoch+1}/{epochs}, Train Loss: {epoch_train_loss:.4f}, Val Loss: {epoch_val_loss:.4f}")
|
232 |
-
|
233 |
-
# Early Stopping logic
|
234 |
-
if epoch_val_loss < best_val_loss:
|
235 |
-
best_val_loss = epoch_val_loss
|
236 |
-
wait = 0
|
237 |
-
# (Optional) Save best weights
|
238 |
-
else:
|
239 |
-
wait += 1
|
240 |
-
if wait >= early_stopping_patience:
|
241 |
-
print("Early stopping triggered.")
|
242 |
-
break
|
243 |
-
|
244 |
-
return train_losses, val_losses
|
245 |
-
|
246 |
-
def encode_texts(self, texts, is_query=True):
|
247 |
-
seq = self.tokenizer.texts_to_sequences(texts)
|
248 |
-
pad_seq = tf.keras.preprocessing.sequence.pad_sequences(seq, maxlen=self.max_sequence_length, padding='post')
|
249 |
-
if is_query:
|
250 |
-
return self.query_encoder_model(pad_seq, training=False)
|
251 |
-
else:
|
252 |
-
return self.response_encoder_model(pad_seq, training=False)
|
253 |
-
|
254 |
-
def retrieve_top_n(self, query: str, candidates: list, top_n=5):
|
255 |
-
q_emb = self.encode_texts([query], is_query=True) # shape (1, d)
|
256 |
-
c_emb = self.encode_texts(candidates, is_query=False) # shape (num_cand, d)
|
257 |
-
sim = tf.matmul(q_emb, c_emb, transpose_b=True).numpy()[0] # dot product similarity
|
258 |
-
top_indices = np.argsort(sim)[::-1][:top_n]
|
259 |
-
return [(candidates[i], sim[i]) for i in top_indices]
|
260 |
-
|
261 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
chatbot2.py
DELETED
@@ -1,839 +0,0 @@
|
|
1 |
-
import numpy as np
|
2 |
-
import tensorflow as tf
|
3 |
-
import spacy
|
4 |
-
import random
|
5 |
-
from typing import List, Tuple, Dict, Optional, Union
|
6 |
-
from dataclasses import dataclass
|
7 |
-
from tqdm import tqdm
|
8 |
-
import logging
|
9 |
-
from pathlib import Path
|
10 |
-
import json
|
11 |
-
|
12 |
-
# Configure logging
|
13 |
-
logging.basicConfig(
|
14 |
-
level=logging.INFO,
|
15 |
-
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
|
16 |
-
)
|
17 |
-
logger = logging.getLogger(__name__)
|
18 |
-
|
19 |
-
@dataclass
|
20 |
-
class ChatbotConfig:
|
21 |
-
"""Configuration for the retrieval chatbot."""
|
22 |
-
vocab_size: int = 10000
|
23 |
-
max_sequence_length: int = 512
|
24 |
-
embedding_dim: int = 256
|
25 |
-
encoder_units: int = 256
|
26 |
-
num_attention_heads: int = 8
|
27 |
-
dropout_rate: float = 0.2
|
28 |
-
l2_reg_weight: float = 0.001
|
29 |
-
margin: float = 0.3
|
30 |
-
learning_rate: float = 0.001
|
31 |
-
min_text_length: int = 3 # Reduced from 10 to allow shorter responses
|
32 |
-
max_context_turns: int = 5
|
33 |
-
warmup_steps: int = 200
|
34 |
-
spacy_model: str = 'en_core_web_md'
|
35 |
-
|
36 |
-
def to_dict(self) -> dict:
|
37 |
-
"""Convert config to dictionary."""
|
38 |
-
return {k: str(v) if isinstance(v, Path) else v
|
39 |
-
for k, v in self.__dict__.items()}
|
40 |
-
|
41 |
-
@classmethod
|
42 |
-
def from_dict(cls, config_dict: dict) -> 'ChatbotConfig':
|
43 |
-
"""Create config from dictionary."""
|
44 |
-
return cls(**{k: v for k, v in config_dict.items()
|
45 |
-
if k in cls.__dataclass_fields__})
|
46 |
-
|
47 |
-
class TransformerBlock(tf.keras.layers.Layer):
|
48 |
-
"""Custom Transformer block with pre-layer normalization."""
|
49 |
-
def __init__(
|
50 |
-
self,
|
51 |
-
embed_dim: int,
|
52 |
-
num_heads: int,
|
53 |
-
ff_dim: int,
|
54 |
-
dropout: float = 0.1,
|
55 |
-
**kwargs
|
56 |
-
):
|
57 |
-
super().__init__(**kwargs)
|
58 |
-
self.embed_dim = embed_dim
|
59 |
-
self.num_heads = num_heads
|
60 |
-
self.ff_dim = ff_dim
|
61 |
-
self.dropout = dropout
|
62 |
-
|
63 |
-
self.attention = tf.keras.layers.MultiHeadAttention(
|
64 |
-
num_heads=num_heads,
|
65 |
-
key_dim=embed_dim // num_heads,
|
66 |
-
dropout=dropout
|
67 |
-
)
|
68 |
-
self.ffn = tf.keras.Sequential([
|
69 |
-
tf.keras.layers.Dense(ff_dim, activation="gelu"),
|
70 |
-
tf.keras.layers.Dense(embed_dim),
|
71 |
-
])
|
72 |
-
|
73 |
-
self.layernorm1 = tf.keras.layers.LayerNormalization(epsilon=1e-6)
|
74 |
-
self.layernorm2 = tf.keras.layers.LayerNormalization(epsilon=1e-6)
|
75 |
-
self.dropout1 = tf.keras.layers.Dropout(dropout)
|
76 |
-
self.dropout2 = tf.keras.layers.Dropout(dropout)
|
77 |
-
|
78 |
-
def call(self, inputs: tf.Tensor, training: bool, mask: Optional[tf.Tensor] = None) -> tf.Tensor:
|
79 |
-
# Pre-layer normalization
|
80 |
-
norm_inputs = self.layernorm1(inputs)
|
81 |
-
|
82 |
-
# Self-attention
|
83 |
-
attention_output = self.attention(
|
84 |
-
query=norm_inputs,
|
85 |
-
value=norm_inputs,
|
86 |
-
key=norm_inputs,
|
87 |
-
attention_mask=mask,
|
88 |
-
training=training
|
89 |
-
)
|
90 |
-
attention_output = self.dropout1(attention_output, training=training)
|
91 |
-
attention_output = inputs + attention_output
|
92 |
-
|
93 |
-
# Feed-forward network
|
94 |
-
norm_attention = self.layernorm2(attention_output)
|
95 |
-
ffn_output = self.ffn(norm_attention)
|
96 |
-
ffn_output = self.dropout2(ffn_output, training=training)
|
97 |
-
|
98 |
-
return attention_output + ffn_output
|
99 |
-
|
100 |
-
def get_config(self) -> dict:
|
101 |
-
config = super().get_config()
|
102 |
-
config.update({
|
103 |
-
"embed_dim": self.embed_dim,
|
104 |
-
"num_heads": self.num_heads,
|
105 |
-
"ff_dim": self.ff_dim,
|
106 |
-
"dropout": self.dropout,
|
107 |
-
})
|
108 |
-
return config
|
109 |
-
|
110 |
-
class EncoderModel(tf.keras.Model):
|
111 |
-
"""Dual encoder model with shared weights option."""
|
112 |
-
def __init__(
|
113 |
-
self,
|
114 |
-
config: ChatbotConfig,
|
115 |
-
name: str = "encoder",
|
116 |
-
shared_weights: bool = False,
|
117 |
-
**kwargs
|
118 |
-
):
|
119 |
-
super().__init__(name=name, **kwargs)
|
120 |
-
self.config = config
|
121 |
-
self.shared_weights = shared_weights
|
122 |
-
|
123 |
-
# Input embedding layer
|
124 |
-
self.embedding = tf.keras.layers.Embedding(
|
125 |
-
config.vocab_size,
|
126 |
-
config.embedding_dim,
|
127 |
-
mask_zero=True,
|
128 |
-
name=f"{name}_embedding"
|
129 |
-
)
|
130 |
-
|
131 |
-
# Positional encoding
|
132 |
-
self.pos_encoding = self._get_positional_encoding()
|
133 |
-
|
134 |
-
# Transformer blocks
|
135 |
-
self.transformer_blocks = [
|
136 |
-
TransformerBlock(
|
137 |
-
config.embedding_dim,
|
138 |
-
config.num_attention_heads,
|
139 |
-
config.encoder_units * 4,
|
140 |
-
config.dropout_rate,
|
141 |
-
name=f"{name}_transformer_{i}"
|
142 |
-
) for i in range(3)
|
143 |
-
]
|
144 |
-
|
145 |
-
# Final LSTM layer
|
146 |
-
self.final_lstm = tf.keras.layers.LSTM(
|
147 |
-
config.encoder_units,
|
148 |
-
kernel_regularizer=tf.keras.regularizers.l2(config.l2_reg_weight),
|
149 |
-
name=f"{name}_final_lstm"
|
150 |
-
)
|
151 |
-
|
152 |
-
self.dropout = tf.keras.layers.Dropout(config.dropout_rate)
|
153 |
-
self.normalize = tf.keras.layers.Lambda(
|
154 |
-
lambda x: tf.nn.l2_normalize(x, axis=1)
|
155 |
-
)
|
156 |
-
|
157 |
-
def _get_positional_encoding(self) -> tf.Tensor:
|
158 |
-
"""Generate positional encoding matrix."""
|
159 |
-
pos = np.arange(self.config.max_sequence_length)[:, np.newaxis]
|
160 |
-
i = np.arange(self.config.embedding_dim)[np.newaxis, :]
|
161 |
-
angle = pos / np.power(10000, (2 * (i // 2)) / self.config.embedding_dim)
|
162 |
-
|
163 |
-
pos_encoding = np.zeros_like(angle)
|
164 |
-
pos_encoding[:, 0::2] = np.sin(angle[:, 0::2])
|
165 |
-
pos_encoding[:, 1::2] = np.cos(angle[:, 1::2])
|
166 |
-
|
167 |
-
return tf.cast(pos_encoding[np.newaxis, ...], dtype=tf.float32)
|
168 |
-
|
169 |
-
def call(self, inputs: tf.Tensor, training: bool = False) -> tf.Tensor:
|
170 |
-
# Get input mask
|
171 |
-
mask = self.embedding.compute_mask(inputs)
|
172 |
-
mask = mask[:, tf.newaxis, tf.newaxis, :] # Add attention dims
|
173 |
-
|
174 |
-
# Embedding + positional encoding
|
175 |
-
x = self.embedding(inputs)
|
176 |
-
x = x + self.pos_encoding
|
177 |
-
|
178 |
-
# Apply transformer blocks
|
179 |
-
for transformer_block in self.transformer_blocks:
|
180 |
-
x = transformer_block(x, training=training, mask=mask)
|
181 |
-
|
182 |
-
# Final processing
|
183 |
-
x = self.final_lstm(x)
|
184 |
-
x = self.dropout(x, training=training)
|
185 |
-
return self.normalize(x)
|
186 |
-
|
187 |
-
class RetrievalChatbot:
|
188 |
-
"""Professional implementation of a retrieval-based chatbot."""
|
189 |
-
def __init__(self, config: ChatbotConfig):
|
190 |
-
self.config = config
|
191 |
-
self.nlp = spacy.load(config.spacy_model)
|
192 |
-
|
193 |
-
# Initialize tokenizer
|
194 |
-
self.tokenizer = tf.keras.preprocessing.text.Tokenizer(
|
195 |
-
num_words=config.vocab_size,
|
196 |
-
oov_token="<OOV>"
|
197 |
-
)
|
198 |
-
|
199 |
-
# Special tokens
|
200 |
-
self.special_tokens = {
|
201 |
-
"user": "<USER>",
|
202 |
-
"assistant": "<ASSISTANT>",
|
203 |
-
"context": "<CONTEXT>",
|
204 |
-
"sep": "<SEP>"
|
205 |
-
}
|
206 |
-
|
207 |
-
# Build models
|
208 |
-
self._build_models()
|
209 |
-
|
210 |
-
# Training history
|
211 |
-
self.history = {
|
212 |
-
"train_loss": [],
|
213 |
-
"val_loss": [],
|
214 |
-
"train_metrics": {},
|
215 |
-
"val_metrics": {}
|
216 |
-
}
|
217 |
-
|
218 |
-
# Initialize similarity cache
|
219 |
-
self.similarity_cache = {}
|
220 |
-
|
221 |
-
def _build_models(self):
|
222 |
-
"""Initialize the encoder models."""
|
223 |
-
# Query encoder
|
224 |
-
self.query_encoder = EncoderModel(
|
225 |
-
self.config,
|
226 |
-
name="query_encoder",
|
227 |
-
shared_weights=False
|
228 |
-
)
|
229 |
-
|
230 |
-
# Response encoder (can share weights with query encoder)
|
231 |
-
self.response_encoder = EncoderModel(
|
232 |
-
self.config,
|
233 |
-
name="response_encoder",
|
234 |
-
shared_weights=False
|
235 |
-
)
|
236 |
-
|
237 |
-
def save_models(self, save_dir: Union[str, Path]):
|
238 |
-
"""Save models and configuration."""
|
239 |
-
save_dir = Path(save_dir)
|
240 |
-
save_dir.mkdir(parents=True, exist_ok=True)
|
241 |
-
|
242 |
-
# Save config
|
243 |
-
with open(save_dir / "config.json", "w") as f:
|
244 |
-
json.dump(self.config.to_dict(), f, indent=2)
|
245 |
-
|
246 |
-
# Save models with proper extension
|
247 |
-
self.query_encoder.save(save_dir / "query_encoder.keras")
|
248 |
-
self.response_encoder.save(save_dir / "response_encoder.keras")
|
249 |
-
|
250 |
-
# Save tokenizer config
|
251 |
-
tokenizer_config = {
|
252 |
-
"word_index": self.tokenizer.word_index,
|
253 |
-
"word_counts": self.tokenizer.word_counts,
|
254 |
-
"document_count": self.tokenizer.document_count,
|
255 |
-
"index_docs": self.tokenizer.index_docs,
|
256 |
-
"index_word": self.tokenizer.index_word
|
257 |
-
}
|
258 |
-
with open(save_dir / "tokenizer_config.json", "w") as f:
|
259 |
-
json.dump(tokenizer_config, f)
|
260 |
-
|
261 |
-
@classmethod
|
262 |
-
def load_models(cls, load_dir: Union[str, Path]) -> 'RetrievalChatbot':
|
263 |
-
"""Load saved models and configuration."""
|
264 |
-
load_dir = Path(load_dir)
|
265 |
-
|
266 |
-
# Load config
|
267 |
-
with open(load_dir / "config.json", "r") as f:
|
268 |
-
config = ChatbotConfig.from_dict(json.load(f))
|
269 |
-
|
270 |
-
# Initialize chatbot
|
271 |
-
chatbot = cls(config)
|
272 |
-
|
273 |
-
# Load models with proper extension
|
274 |
-
chatbot.query_encoder = tf.keras.models.load_model(
|
275 |
-
load_dir / "query_encoder.keras",
|
276 |
-
custom_objects={"TransformerBlock": TransformerBlock}
|
277 |
-
)
|
278 |
-
chatbot.response_encoder = tf.keras.models.load_model(
|
279 |
-
load_dir / "response_encoder.keras",
|
280 |
-
custom_objects={"TransformerBlock": TransformerBlock}
|
281 |
-
)
|
282 |
-
|
283 |
-
# Load tokenizer config
|
284 |
-
with open(load_dir / "tokenizer_config.json", "r") as f:
|
285 |
-
tokenizer_config = json.load(f)
|
286 |
-
|
287 |
-
chatbot.tokenizer = tf.keras.preprocessing.text.Tokenizer(
|
288 |
-
num_words=config.vocab_size,
|
289 |
-
oov_token="<OOV>"
|
290 |
-
)
|
291 |
-
chatbot.tokenizer.word_index = tokenizer_config["word_index"]
|
292 |
-
chatbot.tokenizer.word_counts = tokenizer_config["word_counts"]
|
293 |
-
chatbot.tokenizer.document_count = tokenizer_config["document_count"]
|
294 |
-
chatbot.tokenizer.index_docs = tokenizer_config["index_docs"]
|
295 |
-
chatbot.tokenizer.index_word = tokenizer_config["index_word"]
|
296 |
-
|
297 |
-
return chatbot
|
298 |
-
|
299 |
-
def _improved_spacy_similarity(self, text1: str, text2: str) -> float:
|
300 |
-
"""Calculate semantic similarity between texts with preprocessing."""
|
301 |
-
def preprocess(text: str) -> str:
|
302 |
-
# Basic cleaning
|
303 |
-
text = ' '.join(text.split())
|
304 |
-
return text if text.strip() else "empty_document"
|
305 |
-
|
306 |
-
# Get cache key
|
307 |
-
cache_key = f"{hash(text1)}_{hash(text2)}"
|
308 |
-
if cache_key in self.similarity_cache:
|
309 |
-
return self.similarity_cache[cache_key]
|
310 |
-
|
311 |
-
# Process texts
|
312 |
-
text1, text2 = preprocess(text1), preprocess(text2)
|
313 |
-
doc1, doc2 = self.nlp(text1), self.nlp(text2)
|
314 |
-
|
315 |
-
# Calculate similarity
|
316 |
-
if doc1.has_vector and doc2.has_vector:
|
317 |
-
sim = doc1.similarity(doc2)
|
318 |
-
else:
|
319 |
-
# Fallback to token overlap similarity
|
320 |
-
tokens1 = {t.lower_ for t in doc1 if not t.is_stop and not t.is_punct}
|
321 |
-
tokens2 = {t.lower_ for t in doc2 if not t.is_stop and not t.is_punct}
|
322 |
-
intersection = len(tokens1.intersection(tokens2))
|
323 |
-
union = len(tokens1.union(tokens2))
|
324 |
-
sim = intersection / union if union > 0 else 0.0
|
325 |
-
|
326 |
-
# Cache result
|
327 |
-
self.similarity_cache[cache_key] = sim
|
328 |
-
return sim
|
329 |
-
|
330 |
-
def _smart_negative_sampling(
|
331 |
-
self,
|
332 |
-
positive: str,
|
333 |
-
response_pool: List[str],
|
334 |
-
n_samples: int,
|
335 |
-
max_attempts: int = 200,
|
336 |
-
similarity_bounds: Tuple[float, float] = (0.3, 0.8),
|
337 |
-
batch_size: int = 10
|
338 |
-
) -> List[str]:
|
339 |
-
"""Smart negative sampling with similarity bounds and batching."""
|
340 |
-
candidates = []
|
341 |
-
seen = set()
|
342 |
-
attempts = 0
|
343 |
-
|
344 |
-
while len(candidates) < n_samples and attempts < max_attempts:
|
345 |
-
# Batch process candidates
|
346 |
-
batch = random.sample(
|
347 |
-
response_pool,
|
348 |
-
min(batch_size, max_attempts - attempts)
|
349 |
-
)
|
350 |
-
|
351 |
-
for candidate in batch:
|
352 |
-
if candidate != positive and candidate not in seen:
|
353 |
-
seen.add(candidate)
|
354 |
-
sim = self._improved_spacy_similarity(candidate, positive)
|
355 |
-
|
356 |
-
# Check similarity bounds
|
357 |
-
if similarity_bounds[0] < sim < similarity_bounds[1]:
|
358 |
-
candidates.append(candidate)
|
359 |
-
if len(candidates) == n_samples:
|
360 |
-
break
|
361 |
-
|
362 |
-
attempts += len(batch)
|
363 |
-
|
364 |
-
return candidates
|
365 |
-
|
366 |
-
def train(
|
367 |
-
self,
|
368 |
-
q_pad: tf.Tensor,
|
369 |
-
p_pad: tf.Tensor,
|
370 |
-
n_pad: tf.Tensor,
|
371 |
-
epochs: int = 3,
|
372 |
-
batch_size: int = 32,
|
373 |
-
validation_split: float = 0.2,
|
374 |
-
checkpoint_dir: Optional[Union[str, Path]] = None
|
375 |
-
):
|
376 |
-
"""Train the model with improved training loop."""
|
377 |
-
# Setup training
|
378 |
-
total_samples = len(q_pad)
|
379 |
-
train_size = int((1 - validation_split) * total_samples)
|
380 |
-
|
381 |
-
# Split data
|
382 |
-
indices = np.random.permutation(total_samples)
|
383 |
-
train_idx, val_idx = indices[:train_size], indices[train_size:]
|
384 |
-
|
385 |
-
train_data = (q_pad[train_idx], p_pad[train_idx], n_pad[train_idx])
|
386 |
-
val_data = (q_pad[val_idx], p_pad[val_idx], n_pad[val_idx])
|
387 |
-
|
388 |
-
# Setup optimizer with learning rate schedule
|
389 |
-
steps_per_epoch = train_size // batch_size
|
390 |
-
total_steps = steps_per_epoch * epochs
|
391 |
-
|
392 |
-
lr_schedule = self._get_lr_schedule(
|
393 |
-
total_steps,
|
394 |
-
self.config.learning_rate,
|
395 |
-
self.config.warmup_steps
|
396 |
-
)
|
397 |
-
|
398 |
-
optimizer = tf.keras.optimizers.Adam(lr_schedule)
|
399 |
-
|
400 |
-
# Setup checkpointing
|
401 |
-
if checkpoint_dir:
|
402 |
-
checkpoint_dir = Path(checkpoint_dir)
|
403 |
-
checkpoint_dir.mkdir(parents=True, exist_ok=True)
|
404 |
-
|
405 |
-
# Setup checkpoint callback with correct file format
|
406 |
-
checkpoint_template = str(checkpoint_dir / "model_epoch_{epoch:04d}.weights.h5")
|
407 |
-
checkpoint_callback = tf.keras.callbacks.ModelCheckpoint(
|
408 |
-
checkpoint_template,
|
409 |
-
save_weights_only=True,
|
410 |
-
save_best_only=True,
|
411 |
-
monitor='val_loss',
|
412 |
-
mode='min',
|
413 |
-
verbose=1
|
414 |
-
)
|
415 |
-
|
416 |
-
# Training loop
|
417 |
-
best_val_loss = float('inf')
|
418 |
-
patience = 5
|
419 |
-
wait = 0
|
420 |
-
|
421 |
-
for epoch in range(epochs):
|
422 |
-
# Training
|
423 |
-
train_loss = self._train_epoch(
|
424 |
-
train_data,
|
425 |
-
optimizer,
|
426 |
-
batch_size,
|
427 |
-
training=True
|
428 |
-
)
|
429 |
-
|
430 |
-
# Validation
|
431 |
-
val_loss = self._train_epoch(
|
432 |
-
val_data,
|
433 |
-
optimizer,
|
434 |
-
batch_size,
|
435 |
-
training=False
|
436 |
-
)
|
437 |
-
|
438 |
-
# Update history
|
439 |
-
self.history['train_loss'].append(train_loss)
|
440 |
-
self.history['val_loss'].append(val_loss)
|
441 |
-
|
442 |
-
logger.info(
|
443 |
-
f"Epoch {epoch + 1}/{epochs} - "
|
444 |
-
f"train_loss: {train_loss:.4f} - "
|
445 |
-
f"val_loss: {val_loss:.4f}"
|
446 |
-
)
|
447 |
-
|
448 |
-
# Early stopping
|
449 |
-
if val_loss < best_val_loss:
|
450 |
-
best_val_loss = val_loss
|
451 |
-
wait = 0
|
452 |
-
if checkpoint_dir:
|
453 |
-
self.save_models(checkpoint_dir / f"best_model")
|
454 |
-
else:
|
455 |
-
wait += 1
|
456 |
-
if wait >= patience:
|
457 |
-
logger.info("Early stopping triggered")
|
458 |
-
break
|
459 |
-
|
460 |
-
def _get_lr_schedule(
|
461 |
-
self,
|
462 |
-
total_steps: int,
|
463 |
-
peak_lr: float,
|
464 |
-
warmup_steps: int
|
465 |
-
) -> tf.keras.optimizers.schedules.LearningRateSchedule:
|
466 |
-
"""Enhanced learning rate schedule with better error handling and logging."""
|
467 |
-
class CustomSchedule(tf.keras.optimizers.schedules.LearningRateSchedule):
|
468 |
-
def __init__(
|
469 |
-
self,
|
470 |
-
total_steps: int,
|
471 |
-
peak_lr: float,
|
472 |
-
warmup_steps: int
|
473 |
-
):
|
474 |
-
super().__init__()
|
475 |
-
self.total_steps = tf.cast(total_steps, tf.float32)
|
476 |
-
self.peak_lr = tf.cast(peak_lr, tf.float32)
|
477 |
-
self.warmup_steps = tf.cast(max(1, warmup_steps), tf.float32) # Prevent 0
|
478 |
-
|
479 |
-
# Calculate and store constants
|
480 |
-
self.initial_lr = self.peak_lr * 0.1 # Start at 10% of peak
|
481 |
-
self.min_lr = self.peak_lr * 0.01 # Minimum 1% of peak
|
482 |
-
|
483 |
-
logger.info(f"Learning rate schedule initialized:")
|
484 |
-
logger.info(f" Initial LR: {float(self.initial_lr):.6f}")
|
485 |
-
logger.info(f" Peak LR: {float(self.peak_lr):.6f}")
|
486 |
-
logger.info(f" Min LR: {float(self.min_lr):.6f}")
|
487 |
-
logger.info(f" Warmup steps: {int(self.warmup_steps)}")
|
488 |
-
logger.info(f" Total steps: {int(self.total_steps)}")
|
489 |
-
|
490 |
-
def __call__(self, step):
|
491 |
-
step = tf.cast(step, tf.float32)
|
492 |
-
|
493 |
-
# Warmup phase
|
494 |
-
warmup_factor = tf.minimum(1.0, step / self.warmup_steps)
|
495 |
-
warmup_lr = self.initial_lr + (self.peak_lr - self.initial_lr) * warmup_factor
|
496 |
-
|
497 |
-
# Decay phase
|
498 |
-
decay_steps = tf.maximum(1.0, self.total_steps - self.warmup_steps)
|
499 |
-
decay_factor = (step - self.warmup_steps) / decay_steps
|
500 |
-
decay_factor = tf.minimum(tf.maximum(0.0, decay_factor), 1.0) # Clip to [0,1]
|
501 |
-
|
502 |
-
cosine_decay = 0.5 * (1.0 + tf.cos(tf.constant(np.pi) * decay_factor))
|
503 |
-
decay_lr = self.min_lr + (self.peak_lr - self.min_lr) * cosine_decay
|
504 |
-
|
505 |
-
# Choose between warmup and decay
|
506 |
-
final_lr = tf.where(step < self.warmup_steps, warmup_lr, decay_lr)
|
507 |
-
|
508 |
-
# Ensure learning rate is valid
|
509 |
-
final_lr = tf.maximum(self.min_lr, final_lr)
|
510 |
-
final_lr = tf.where(tf.math.is_finite(final_lr), final_lr, self.min_lr)
|
511 |
-
|
512 |
-
return final_lr
|
513 |
-
|
514 |
-
def get_config(self):
|
515 |
-
return {
|
516 |
-
"total_steps": self.total_steps,
|
517 |
-
"peak_lr": self.peak_lr,
|
518 |
-
"warmup_steps": self.warmup_steps,
|
519 |
-
}
|
520 |
-
|
521 |
-
return CustomSchedule(total_steps, peak_lr, warmup_steps)
|
522 |
-
|
523 |
-
@tf.function
|
524 |
-
def _train_step(
|
525 |
-
self,
|
526 |
-
q_batch: tf.Tensor,
|
527 |
-
p_batch: tf.Tensor,
|
528 |
-
n_batch: tf.Tensor,
|
529 |
-
optimizer: tf.keras.optimizers.Optimizer,
|
530 |
-
training: bool = True
|
531 |
-
) -> tf.Tensor:
|
532 |
-
"""Single training step with triplet loss."""
|
533 |
-
with tf.GradientTape() as tape:
|
534 |
-
# Get embeddings
|
535 |
-
q_emb = self.query_encoder(q_batch, training=training)
|
536 |
-
p_emb = self.response_encoder(p_batch, training=training)
|
537 |
-
n_emb = self.response_encoder(n_batch, training=training)
|
538 |
-
|
539 |
-
# Calculate triplet loss
|
540 |
-
pos_dist = tf.reduce_sum(tf.square(q_emb - p_emb), axis=1)
|
541 |
-
neg_dist = tf.reduce_sum(tf.square(q_emb - n_emb), axis=1)
|
542 |
-
|
543 |
-
loss = tf.maximum(0.0, self.config.margin + pos_dist - neg_dist)
|
544 |
-
loss = tf.reduce_mean(loss)
|
545 |
-
|
546 |
-
if training:
|
547 |
-
# Apply gradients
|
548 |
-
gradients = tape.gradient(
|
549 |
-
loss,
|
550 |
-
self.query_encoder.trainable_variables +
|
551 |
-
self.response_encoder.trainable_variables
|
552 |
-
)
|
553 |
-
optimizer.apply_gradients(zip(
|
554 |
-
gradients,
|
555 |
-
self.query_encoder.trainable_variables +
|
556 |
-
self.response_encoder.trainable_variables
|
557 |
-
))
|
558 |
-
|
559 |
-
return loss
|
560 |
-
|
561 |
-
def _train_epoch(
|
562 |
-
self,
|
563 |
-
data: Tuple[tf.Tensor, tf.Tensor, tf.Tensor],
|
564 |
-
optimizer: tf.keras.optimizers.Optimizer,
|
565 |
-
batch_size: int,
|
566 |
-
training: bool = True
|
567 |
-
) -> float:
|
568 |
-
"""Train for one epoch with enhanced logging and progress tracking."""
|
569 |
-
q_data, p_data, n_data = data
|
570 |
-
total_loss = 0
|
571 |
-
num_batches = len(q_data) // batch_size
|
572 |
-
|
573 |
-
# Log current learning rate at start of epoch
|
574 |
-
if training:
|
575 |
-
if hasattr(optimizer.learning_rate, '__call__'):
|
576 |
-
current_lr = optimizer.learning_rate(optimizer.iterations)
|
577 |
-
else:
|
578 |
-
current_lr = optimizer.learning_rate
|
579 |
-
logger.info(f"Current learning rate: {float(current_lr):.6f}")
|
580 |
-
|
581 |
-
# Shuffle data
|
582 |
-
indices = np.random.permutation(len(q_data))
|
583 |
-
q_data = q_data[indices]
|
584 |
-
p_data = p_data[indices]
|
585 |
-
n_data = n_data[indices]
|
586 |
-
|
587 |
-
# Create progress bar
|
588 |
-
mode = "Training" if training else "Validation"
|
589 |
-
pbar = tqdm(
|
590 |
-
total=num_batches,
|
591 |
-
desc=f"{mode} batches",
|
592 |
-
unit="batch",
|
593 |
-
dynamic_ncols=True # Automatically adjust width
|
594 |
-
)
|
595 |
-
|
596 |
-
# Process batches
|
597 |
-
for i in range(num_batches):
|
598 |
-
start_idx = i * batch_size
|
599 |
-
end_idx = start_idx + batch_size
|
600 |
-
|
601 |
-
batch_loss = self._train_step(
|
602 |
-
q_data[start_idx:end_idx],
|
603 |
-
p_data[start_idx:end_idx],
|
604 |
-
n_data[start_idx:end_idx],
|
605 |
-
optimizer,
|
606 |
-
training
|
607 |
-
)
|
608 |
-
total_loss += batch_loss
|
609 |
-
|
610 |
-
# Update progress bar with current loss
|
611 |
-
avg_loss = total_loss / (i + 1)
|
612 |
-
pbar.set_postfix({
|
613 |
-
'loss': f'{avg_loss:.4f}',
|
614 |
-
'lr': f'{float(current_lr):.6f}' if training else 'N/A'
|
615 |
-
})
|
616 |
-
pbar.update(1)
|
617 |
-
|
618 |
-
pbar.close()
|
619 |
-
return total_loss / num_batches if num_batches > 0 else 0
|
620 |
-
|
621 |
-
def _prepare_sequences(
|
622 |
-
self,
|
623 |
-
queries: List[str],
|
624 |
-
positives: List[str],
|
625 |
-
negatives: List[str]
|
626 |
-
) -> Tuple[tf.Tensor, tf.Tensor, tf.Tensor]:
|
627 |
-
"""Enhanced sequence preparation with logging and text preprocessing."""
|
628 |
-
logger.info("Preparing sequences...")
|
629 |
-
|
630 |
-
# Text cleaning function from old version
|
631 |
-
def clean_text(text: str) -> str:
|
632 |
-
# Remove excessive whitespace
|
633 |
-
text = ' '.join(text.split())
|
634 |
-
# Remove very long repetitive sequences
|
635 |
-
if len(text) > 500: # Add length limit
|
636 |
-
text = ' '.join(dict.fromkeys(text.split()))
|
637 |
-
return text
|
638 |
-
|
639 |
-
# Process texts with special tokens and cleaning
|
640 |
-
queries = [f"{self.special_tokens['user']} {clean_text(q)}" for q in queries]
|
641 |
-
positives = [f"{self.special_tokens['assistant']} {clean_text(p)}" for p in positives]
|
642 |
-
negatives = [f"{self.special_tokens['assistant']} {clean_text(n)}" for n in negatives]
|
643 |
-
|
644 |
-
# Fit tokenizer and log vocabulary statistics
|
645 |
-
all_texts = queries + positives + negatives
|
646 |
-
self.tokenizer.fit_on_texts(all_texts)
|
647 |
-
|
648 |
-
# Log vocabulary statistics
|
649 |
-
vocab_size = len(self.tokenizer.word_index)
|
650 |
-
logger.info(f"Vocabulary statistics:")
|
651 |
-
logger.info(f" Total unique tokens: {vocab_size}")
|
652 |
-
logger.info(f" Vocab limit: {self.config.vocab_size}")
|
653 |
-
|
654 |
-
# Log most common tokens
|
655 |
-
word_freq = sorted(
|
656 |
-
self.tokenizer.word_counts.items(),
|
657 |
-
key=lambda x: x[1],
|
658 |
-
reverse=True
|
659 |
-
)[:10]
|
660 |
-
logger.info("Most common tokens:")
|
661 |
-
for word, freq in word_freq:
|
662 |
-
logger.info(f" {word}: {freq}")
|
663 |
-
|
664 |
-
# Padding function from old version
|
665 |
-
def pad_sequences(texts: List[str]) -> tf.Tensor:
|
666 |
-
sequences = self.tokenizer.texts_to_sequences(texts)
|
667 |
-
return tf.keras.preprocessing.sequence.pad_sequences(
|
668 |
-
sequences,
|
669 |
-
maxlen=self.config.max_sequence_length,
|
670 |
-
padding='post',
|
671 |
-
truncating='post'
|
672 |
-
)
|
673 |
-
|
674 |
-
# Return padded sequences
|
675 |
-
return (
|
676 |
-
pad_sequences(queries),
|
677 |
-
pad_sequences(positives),
|
678 |
-
pad_sequences(negatives)
|
679 |
-
)
|
680 |
-
|
681 |
-
def prepare_dataset(
|
682 |
-
self,
|
683 |
-
dialogues: List[dict],
|
684 |
-
neg_samples_per_pos: int = 3,
|
685 |
-
debug_samples: Optional[int] = None
|
686 |
-
) -> Tuple[tf.Tensor, tf.Tensor, tf.Tensor]:
|
687 |
-
"""Prepare dataset with enhanced logging and statistics."""
|
688 |
-
logger.info("Preparing dataset...")
|
689 |
-
|
690 |
-
# Log dataset statistics
|
691 |
-
total_dialogues = len(dialogues)
|
692 |
-
total_turns = sum(len(d['turns']) for d in dialogues)
|
693 |
-
avg_turns = total_turns / total_dialogues
|
694 |
-
|
695 |
-
logger.info(f"Dataset statistics:")
|
696 |
-
logger.info(f" Total dialogues: {total_dialogues}")
|
697 |
-
logger.info(f" Total turns: {total_turns}")
|
698 |
-
logger.info(f" Average turns per dialogue: {avg_turns:.2f}")
|
699 |
-
|
700 |
-
# Extract and filter responses with logging
|
701 |
-
response_pool = []
|
702 |
-
skipped_short = 0
|
703 |
-
skipped_long = 0
|
704 |
-
|
705 |
-
for d in dialogues:
|
706 |
-
for turn in d['turns']:
|
707 |
-
if turn['speaker'] == 'assistant':
|
708 |
-
text = turn['text'].strip()
|
709 |
-
length = len(text.split())
|
710 |
-
if length < self.config.min_text_length:
|
711 |
-
skipped_short += 1
|
712 |
-
continue
|
713 |
-
if length > self.config.max_sequence_length:
|
714 |
-
skipped_long += 1
|
715 |
-
continue
|
716 |
-
response_pool.append(text)
|
717 |
-
|
718 |
-
logger.info(f"Response pool statistics:")
|
719 |
-
logger.info(f" Total responses: {len(response_pool)}")
|
720 |
-
logger.info(f" Skipped (too short): {skipped_short}")
|
721 |
-
logger.info(f" Skipped (too long): {skipped_long}")
|
722 |
-
|
723 |
-
# Process dialogues and create training examples
|
724 |
-
queries, positives, negatives = [], [], []
|
725 |
-
|
726 |
-
for dialogue in tqdm(dialogues, desc="Processing dialogues"):
|
727 |
-
turns = dialogue['turns']
|
728 |
-
for i in range(len(turns) - 1):
|
729 |
-
if turns[i]['speaker'] == 'user' and turns[i+1]['speaker'] == 'assistant':
|
730 |
-
query = turns[i]['text'].strip()
|
731 |
-
positive = turns[i+1]['text'].strip()
|
732 |
-
|
733 |
-
# Skip short texts
|
734 |
-
if (len(query.split()) < self.config.min_text_length or
|
735 |
-
len(positive.split()) < self.config.min_text_length): # Fixed
|
736 |
-
continue
|
737 |
-
|
738 |
-
# Get negative samples
|
739 |
-
neg_samples = self._smart_negative_sampling(
|
740 |
-
positive,
|
741 |
-
response_pool,
|
742 |
-
neg_samples_per_pos
|
743 |
-
)
|
744 |
-
|
745 |
-
if len(neg_samples) == neg_samples_per_pos:
|
746 |
-
for neg in neg_samples:
|
747 |
-
queries.append(query)
|
748 |
-
positives.append(positive)
|
749 |
-
negatives.append(neg)
|
750 |
-
|
751 |
-
# Log final dataset statistics
|
752 |
-
logger.info(f"Final dataset statistics:")
|
753 |
-
logger.info(f" Training examples: {len(queries)}")
|
754 |
-
logger.info(f" Unique queries: {len(set(queries))}")
|
755 |
-
logger.info(f" Unique responses: {len(set(positives))}")
|
756 |
-
|
757 |
-
return self._prepare_sequences(queries, positives, negatives)
|
758 |
-
|
759 |
-
def encode_query(self, query: str, context: Optional[List[Tuple[str, str]]] = None) -> tf.Tensor:
|
760 |
-
"""Encode a query with optional conversation context."""
|
761 |
-
# Prepare query with context
|
762 |
-
if context:
|
763 |
-
context_str = ' '.join([
|
764 |
-
f"{self.special_tokens['user']} {q} "
|
765 |
-
f"{self.special_tokens['assistant']} {r}"
|
766 |
-
for q, r in context[-self.config.max_context_turns:]
|
767 |
-
])
|
768 |
-
query = f"{context_str} {self.special_tokens['user']} {query}"
|
769 |
-
else:
|
770 |
-
query = f"{self.special_tokens['user']} {query}"
|
771 |
-
|
772 |
-
# Tokenize and pad
|
773 |
-
seq = self.tokenizer.texts_to_sequences([query])
|
774 |
-
padded_seq = tf.keras.preprocessing.sequence.pad_sequences(
|
775 |
-
seq,
|
776 |
-
maxlen=self.config.max_sequence_length,
|
777 |
-
padding='post',
|
778 |
-
truncating='post'
|
779 |
-
)
|
780 |
-
|
781 |
-
return self.query_encoder(padded_seq, training=False)
|
782 |
-
|
783 |
-
def encode_responses(self, responses: List[str]) -> tf.Tensor:
|
784 |
-
"""Encode a batch of responses."""
|
785 |
-
# Prepare responses
|
786 |
-
responses = [
|
787 |
-
f"{self.special_tokens['assistant']} {r}"
|
788 |
-
for r in responses
|
789 |
-
]
|
790 |
-
|
791 |
-
# Tokenize and pad
|
792 |
-
sequences = self.tokenizer.texts_to_sequences(responses)
|
793 |
-
padded_sequences = tf.keras.preprocessing.sequence.pad_sequences(
|
794 |
-
sequences,
|
795 |
-
maxlen=self.config.max_sequence_length,
|
796 |
-
padding='post',
|
797 |
-
truncating='post'
|
798 |
-
)
|
799 |
-
|
800 |
-
return self.response_encoder(padded_sequences, training=False)
|
801 |
-
|
802 |
-
def retrieve_responses(
|
803 |
-
self,
|
804 |
-
query: str,
|
805 |
-
candidates: List[str],
|
806 |
-
context: Optional[List[Tuple[str, str]]] = None,
|
807 |
-
top_k: int = 5
|
808 |
-
) -> List[Tuple[str, float]]:
|
809 |
-
"""Retrieve top-k responses for a query."""
|
810 |
-
# Encode query and candidates
|
811 |
-
q_emb = self.encode_query(query, context)
|
812 |
-
c_emb = self.encode_responses(candidates)
|
813 |
-
|
814 |
-
# Calculate similarities
|
815 |
-
similarities = tf.matmul(q_emb, c_emb, transpose_b=True).numpy()[0]
|
816 |
-
|
817 |
-
# Get top-k responses
|
818 |
-
top_indices = np.argsort(similarities)[::-1][:top_k]
|
819 |
-
|
820 |
-
return [(candidates[i], similarities[i]) for i in top_indices]
|
821 |
-
|
822 |
-
def chat(
|
823 |
-
self,
|
824 |
-
query: str,
|
825 |
-
response_pool: List[str],
|
826 |
-
conversation_history: Optional[List[Tuple[str, str]]] = None,
|
827 |
-
top_k: int = 5
|
828 |
-
) -> Tuple[str, List[Tuple[str, float]]]:
|
829 |
-
"""Interactive chat with response selection."""
|
830 |
-
# Get responses with scores
|
831 |
-
responses = self.retrieve_responses(
|
832 |
-
query,
|
833 |
-
response_pool,
|
834 |
-
conversation_history,
|
835 |
-
top_k
|
836 |
-
)
|
837 |
-
|
838 |
-
# Return best response and all candidates with scores
|
839 |
-
return responses[0][0], responses
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
chatbot3.py
DELETED
@@ -1,824 +0,0 @@
|
|
1 |
-
from transformers import TFAutoModel, AutoTokenizer
|
2 |
-
import tensorflow as tf
|
3 |
-
import numpy as np
|
4 |
-
from typing import List, Tuple, Dict, Optional, Union
|
5 |
-
from dataclasses import dataclass
|
6 |
-
import logging
|
7 |
-
import spacy
|
8 |
-
import random
|
9 |
-
import json
|
10 |
-
from tqdm import tqdm
|
11 |
-
from pathlib import Path
|
12 |
-
|
13 |
-
# Configure logging
|
14 |
-
logging.basicConfig(
|
15 |
-
level=logging.INFO,
|
16 |
-
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
|
17 |
-
)
|
18 |
-
logger = logging.getLogger(__name__)
|
19 |
-
|
20 |
-
@dataclass
|
21 |
-
class ChatbotConfig:
|
22 |
-
"""Enhanced configuration with pretrained model settings."""
|
23 |
-
vocab_size: int = 10000
|
24 |
-
max_sequence_length: int = 512
|
25 |
-
embedding_dim: int = 768 # Match DistilBERT's dimension
|
26 |
-
encoder_units: int = 256
|
27 |
-
num_attention_heads: int = 8
|
28 |
-
dropout_rate: float = 0.2
|
29 |
-
l2_reg_weight: float = 0.001
|
30 |
-
margin: float = 0.3
|
31 |
-
learning_rate: float = 0.001
|
32 |
-
min_text_length: int = 3
|
33 |
-
max_context_turns: int = 5
|
34 |
-
warmup_steps: int = 200
|
35 |
-
pretrained_model: str = 'distilbert-base-uncased'
|
36 |
-
freeze_embeddings: bool = True
|
37 |
-
spacy_model: str = 'en_core_web_md'
|
38 |
-
|
39 |
-
def to_dict(self) -> dict:
|
40 |
-
"""Convert config to dictionary."""
|
41 |
-
return {k: str(v) if isinstance(v, Path) else v
|
42 |
-
for k, v in self.__dict__.items()}
|
43 |
-
|
44 |
-
@classmethod
|
45 |
-
def from_dict(cls, config_dict: dict) -> 'ChatbotConfig':
|
46 |
-
"""Create config from dictionary."""
|
47 |
-
return cls(**{k: v for k, v in config_dict.items()
|
48 |
-
if k in cls.__dataclass_fields__})
|
49 |
-
|
50 |
-
class TransformerBlock(tf.keras.layers.Layer):
|
51 |
-
"""Custom Transformer block with pre-layer normalization."""
|
52 |
-
def __init__(
|
53 |
-
self,
|
54 |
-
embed_dim: int,
|
55 |
-
num_heads: int,
|
56 |
-
ff_dim: int,
|
57 |
-
dropout: float = 0.1,
|
58 |
-
**kwargs
|
59 |
-
):
|
60 |
-
super().__init__(**kwargs)
|
61 |
-
self.embed_dim = embed_dim
|
62 |
-
self.num_heads = num_heads
|
63 |
-
self.ff_dim = ff_dim
|
64 |
-
self.dropout = dropout
|
65 |
-
|
66 |
-
self.attention = tf.keras.layers.MultiHeadAttention(
|
67 |
-
num_heads=num_heads,
|
68 |
-
key_dim=embed_dim // num_heads,
|
69 |
-
dropout=dropout
|
70 |
-
)
|
71 |
-
self.ffn = tf.keras.Sequential([
|
72 |
-
tf.keras.layers.Dense(ff_dim, activation="gelu"),
|
73 |
-
tf.keras.layers.Dense(embed_dim),
|
74 |
-
])
|
75 |
-
|
76 |
-
self.layernorm1 = tf.keras.layers.LayerNormalization(epsilon=1e-6)
|
77 |
-
self.layernorm2 = tf.keras.layers.LayerNormalization(epsilon=1e-6)
|
78 |
-
self.dropout1 = tf.keras.layers.Dropout(dropout)
|
79 |
-
self.dropout2 = tf.keras.layers.Dropout(dropout)
|
80 |
-
|
81 |
-
def call(self, inputs: tf.Tensor, training: bool, mask: Optional[tf.Tensor] = None) -> tf.Tensor:
|
82 |
-
# Pre-layer normalization
|
83 |
-
norm_inputs = self.layernorm1(inputs)
|
84 |
-
|
85 |
-
# Self-attention
|
86 |
-
attention_output = self.attention(
|
87 |
-
query=norm_inputs,
|
88 |
-
value=norm_inputs,
|
89 |
-
key=norm_inputs,
|
90 |
-
attention_mask=mask,
|
91 |
-
training=training
|
92 |
-
)
|
93 |
-
attention_output = self.dropout1(attention_output, training=training)
|
94 |
-
attention_output = inputs + attention_output
|
95 |
-
|
96 |
-
# Feed-forward network
|
97 |
-
norm_attention = self.layernorm2(attention_output)
|
98 |
-
ffn_output = self.ffn(norm_attention)
|
99 |
-
ffn_output = self.dropout2(ffn_output, training=training)
|
100 |
-
|
101 |
-
return attention_output + ffn_output
|
102 |
-
|
103 |
-
def get_config(self) -> dict:
|
104 |
-
config = super().get_config()
|
105 |
-
config.update({
|
106 |
-
"embed_dim": self.embed_dim,
|
107 |
-
"num_heads": self.num_heads,
|
108 |
-
"ff_dim": self.ff_dim,
|
109 |
-
"dropout": self.dropout,
|
110 |
-
})
|
111 |
-
return config
|
112 |
-
|
113 |
-
class EncoderModel(tf.keras.Model):
|
114 |
-
"""Dual encoder model with pretrained embeddings."""
|
115 |
-
def __init__(
|
116 |
-
self,
|
117 |
-
config: ChatbotConfig,
|
118 |
-
name: str = "encoder",
|
119 |
-
shared_weights: bool = False,
|
120 |
-
**kwargs
|
121 |
-
):
|
122 |
-
super().__init__(name=name, **kwargs)
|
123 |
-
self.config = config
|
124 |
-
self.shared_weights = shared_weights
|
125 |
-
|
126 |
-
# Load pretrained model and tokenizer
|
127 |
-
self.pretrained = TFAutoModel.from_pretrained(config.pretrained_model)
|
128 |
-
|
129 |
-
# Freeze pretrained weights if specified
|
130 |
-
if config.freeze_embeddings:
|
131 |
-
self.pretrained.trainable = False
|
132 |
-
|
133 |
-
# Transformer blocks for additional processing
|
134 |
-
self.transformer_blocks = [
|
135 |
-
TransformerBlock(
|
136 |
-
config.embedding_dim,
|
137 |
-
config.num_attention_heads,
|
138 |
-
config.encoder_units * 4,
|
139 |
-
config.dropout_rate,
|
140 |
-
name=f"{name}_transformer_{i}"
|
141 |
-
) for i in range(2) # Reduced number of blocks since we're using pretrained
|
142 |
-
]
|
143 |
-
|
144 |
-
# Final LSTM layer
|
145 |
-
self.final_lstm = tf.keras.layers.LSTM(
|
146 |
-
config.encoder_units,
|
147 |
-
kernel_regularizer=tf.keras.regularizers.l2(config.l2_reg_weight),
|
148 |
-
name=f"{name}_final_lstm"
|
149 |
-
)
|
150 |
-
|
151 |
-
self.dropout = tf.keras.layers.Dropout(config.dropout_rate)
|
152 |
-
self.normalize = tf.keras.layers.Lambda(
|
153 |
-
lambda x: tf.nn.l2_normalize(x, axis=1)
|
154 |
-
)
|
155 |
-
|
156 |
-
def call(self, inputs: tf.Tensor, training: bool = False) -> tf.Tensor:
|
157 |
-
# Get pretrained embeddings
|
158 |
-
pretrained_outputs = self.pretrained(inputs, training=training)
|
159 |
-
x = pretrained_outputs.last_hidden_state
|
160 |
-
|
161 |
-
# Get attention mask from input
|
162 |
-
attention_mask = tf.cast(tf.not_equal(inputs, 0), tf.float32)
|
163 |
-
attention_mask = attention_mask[:, tf.newaxis, tf.newaxis, :]
|
164 |
-
|
165 |
-
# Apply transformer blocks
|
166 |
-
for transformer_block in self.transformer_blocks:
|
167 |
-
x = transformer_block(x, training=training, mask=attention_mask)
|
168 |
-
|
169 |
-
# Final processing
|
170 |
-
x = self.final_lstm(x)
|
171 |
-
x = self.dropout(x, training=training)
|
172 |
-
return self.normalize(x)
|
173 |
-
|
174 |
-
class RetrievalChatbot:
|
175 |
-
"""Modified chatbot using pretrained embeddings with full functionality."""
|
176 |
-
def __init__(self, config: ChatbotConfig):
|
177 |
-
self.config = config
|
178 |
-
self.nlp = spacy.load(config.spacy_model)
|
179 |
-
|
180 |
-
# Use HuggingFace tokenizer instead of Keras
|
181 |
-
self.tokenizer = AutoTokenizer.from_pretrained(config.pretrained_model)
|
182 |
-
|
183 |
-
# Special tokens
|
184 |
-
self.special_tokens = {
|
185 |
-
"user": "<USER>",
|
186 |
-
"assistant": "<ASSISTANT>",
|
187 |
-
"context": "<CONTEXT>",
|
188 |
-
"sep": "<SEP>"
|
189 |
-
}
|
190 |
-
|
191 |
-
# Add special tokens to tokenizer
|
192 |
-
self.tokenizer.add_special_tokens(
|
193 |
-
{'additional_special_tokens': list(self.special_tokens.values())}
|
194 |
-
)
|
195 |
-
|
196 |
-
# Build models
|
197 |
-
self._build_models()
|
198 |
-
|
199 |
-
# Initialize training tracking
|
200 |
-
self.history = {
|
201 |
-
"train_loss": [],
|
202 |
-
"val_loss": [],
|
203 |
-
"train_metrics": {},
|
204 |
-
"val_metrics": {}
|
205 |
-
}
|
206 |
-
|
207 |
-
self.similarity_cache = {}
|
208 |
-
|
209 |
-
def _build_models(self):
|
210 |
-
"""Initialize the encoder models."""
|
211 |
-
# Query encoder
|
212 |
-
self.query_encoder = EncoderModel(
|
213 |
-
self.config,
|
214 |
-
name="query_encoder",
|
215 |
-
shared_weights=False
|
216 |
-
)
|
217 |
-
|
218 |
-
# Response encoder (can share weights with query encoder)
|
219 |
-
self.response_encoder = EncoderModel(
|
220 |
-
self.config,
|
221 |
-
name="response_encoder",
|
222 |
-
shared_weights=False
|
223 |
-
)
|
224 |
-
|
225 |
-
# Resize token embeddings to match the tokenizer's vocab size
|
226 |
-
new_vocab_size = len(self.tokenizer)
|
227 |
-
self.query_encoder.pretrained.resize_token_embeddings(new_vocab_size)
|
228 |
-
self.response_encoder.pretrained.resize_token_embeddings(new_vocab_size)
|
229 |
-
|
230 |
-
def save_models(self, save_dir: Union[str, Path]):
|
231 |
-
"""Save models and configuration."""
|
232 |
-
save_dir = Path(save_dir)
|
233 |
-
save_dir.mkdir(parents=True, exist_ok=True)
|
234 |
-
|
235 |
-
# Save config
|
236 |
-
with open(save_dir / "config.json", "w") as f:
|
237 |
-
json.dump(self.config.to_dict(), f, indent=2)
|
238 |
-
|
239 |
-
# Save models
|
240 |
-
self.query_encoder.pretrained.save_pretrained(save_dir / "query_encoder")
|
241 |
-
self.response_encoder.pretrained.save_pretrained(save_dir / "response_encoder")
|
242 |
-
|
243 |
-
# Save tokenizer
|
244 |
-
self.tokenizer.save_pretrained(save_dir / "tokenizer")
|
245 |
-
|
246 |
-
@classmethod
|
247 |
-
def load_models(cls, load_dir: Union[str, Path]) -> 'RetrievalChatbot':
|
248 |
-
"""Load saved models and configuration."""
|
249 |
-
load_dir = Path(load_dir)
|
250 |
-
|
251 |
-
# Load config
|
252 |
-
with open(load_dir / "config.json", "r") as f:
|
253 |
-
config = ChatbotConfig.from_dict(json.load(f))
|
254 |
-
|
255 |
-
# Initialize chatbot
|
256 |
-
chatbot = cls(config)
|
257 |
-
|
258 |
-
# Load models
|
259 |
-
chatbot.query_encoder.pretrained = TFAutoModel.from_pretrained(
|
260 |
-
load_dir / "query_encoder",
|
261 |
-
config=config
|
262 |
-
)
|
263 |
-
chatbot.response_encoder.pretrained = TFAutoModel.from_pretrained(
|
264 |
-
load_dir / "response_encoder",
|
265 |
-
config=config
|
266 |
-
)
|
267 |
-
|
268 |
-
# Load tokenizer
|
269 |
-
chatbot.tokenizer = AutoTokenizer.from_pretrained(load_dir / "tokenizer")
|
270 |
-
|
271 |
-
return chatbot
|
272 |
-
|
273 |
-
def _improved_spacy_similarity(self, text1: str, text2: str) -> float:
|
274 |
-
"""Calculate semantic similarity between texts with preprocessing."""
|
275 |
-
def preprocess(text: str) -> str:
|
276 |
-
# Basic cleaning
|
277 |
-
text = ' '.join(text.split())
|
278 |
-
return text if text.strip() else "empty_document"
|
279 |
-
|
280 |
-
# Get cache key
|
281 |
-
cache_key = f"{hash(text1)}_{hash(text2)}"
|
282 |
-
if cache_key in self.similarity_cache:
|
283 |
-
return self.similarity_cache[cache_key]
|
284 |
-
|
285 |
-
# Process texts
|
286 |
-
text1, text2 = preprocess(text1), preprocess(text2)
|
287 |
-
doc1, doc2 = self.nlp(text1), self.nlp(text2)
|
288 |
-
|
289 |
-
# Calculate similarity
|
290 |
-
if doc1.has_vector and doc2.has_vector:
|
291 |
-
sim = doc1.similarity(doc2)
|
292 |
-
else:
|
293 |
-
# Fallback to token overlap similarity
|
294 |
-
tokens1 = {t.lower_ for t in doc1 if not t.is_stop and not t.is_punct}
|
295 |
-
tokens2 = {t.lower_ for t in doc2 if not t.is_stop and not t.is_punct}
|
296 |
-
intersection = len(tokens1.intersection(tokens2))
|
297 |
-
union = len(tokens1.union(tokens2))
|
298 |
-
sim = intersection / union if union > 0 else 0.0
|
299 |
-
|
300 |
-
# Cache result
|
301 |
-
self.similarity_cache[cache_key] = sim
|
302 |
-
return sim
|
303 |
-
|
304 |
-
def prepare_dataset(
|
305 |
-
self,
|
306 |
-
dialogues: List[dict],
|
307 |
-
neg_samples_per_pos: int = 3,
|
308 |
-
debug_samples: Optional[int] = None
|
309 |
-
) -> Tuple[tf.Tensor, tf.Tensor, tf.Tensor]:
|
310 |
-
"""Prepare dataset with enhanced logging and statistics."""
|
311 |
-
logger.info("Preparing dataset...")
|
312 |
-
|
313 |
-
# Apply debug_samples limit if specified
|
314 |
-
if debug_samples is not None:
|
315 |
-
dialogues = dialogues[:debug_samples]
|
316 |
-
logger.info(f"Debug mode: Limited to {debug_samples} dialogues")
|
317 |
-
|
318 |
-
# Log dataset statistics
|
319 |
-
total_dialogues = len(dialogues)
|
320 |
-
total_turns = sum(len(d['turns']) for d in dialogues)
|
321 |
-
avg_turns = total_turns / total_dialogues if total_dialogues > 0 else 0
|
322 |
-
|
323 |
-
logger.info(f"Dataset statistics:")
|
324 |
-
logger.info(f" Total dialogues: {total_dialogues}")
|
325 |
-
logger.info(f" Total turns: {total_turns}")
|
326 |
-
logger.info(f" Average turns per dialogue: {avg_turns:.2f}")
|
327 |
-
|
328 |
-
# Extract and filter responses with logging
|
329 |
-
response_pool = []
|
330 |
-
skipped_short = 0
|
331 |
-
skipped_long = 0
|
332 |
-
|
333 |
-
for d in dialogues:
|
334 |
-
for turn in d['turns']:
|
335 |
-
if turn.get('speaker') == 'assistant' and 'text' in turn:
|
336 |
-
text = turn['text'].strip()
|
337 |
-
length = len(text.split())
|
338 |
-
if length < self.config.min_text_length:
|
339 |
-
skipped_short += 1
|
340 |
-
continue
|
341 |
-
if length > self.config.max_sequence_length:
|
342 |
-
skipped_long += 1
|
343 |
-
continue
|
344 |
-
response_pool.append(text)
|
345 |
-
|
346 |
-
logger.info(f"Response pool statistics:")
|
347 |
-
logger.info(f" Total responses: {len(response_pool)}")
|
348 |
-
logger.info(f" Skipped (too short): {skipped_short}")
|
349 |
-
logger.info(f" Skipped (too long): {skipped_long}")
|
350 |
-
|
351 |
-
# Process dialogues and create training examples
|
352 |
-
queries, positives, negatives = [], [], []
|
353 |
-
|
354 |
-
for dialogue in tqdm(dialogues, desc="Processing dialogues"):
|
355 |
-
turns = dialogue.get('turns', [])
|
356 |
-
for i in range(len(turns) - 1):
|
357 |
-
current_turn = turns[i]
|
358 |
-
next_turn = turns[i+1]
|
359 |
-
|
360 |
-
if (current_turn.get('speaker') == 'user' and
|
361 |
-
next_turn.get('speaker') == 'assistant' and
|
362 |
-
'text' in current_turn and
|
363 |
-
'text' in next_turn):
|
364 |
-
|
365 |
-
query = current_turn['text'].strip()
|
366 |
-
positive = next_turn['text'].strip()
|
367 |
-
|
368 |
-
# Skip short texts
|
369 |
-
if (len(query.split()) < self.config.min_text_length or
|
370 |
-
len(positive.split()) < self.config.min_text_length):
|
371 |
-
continue
|
372 |
-
|
373 |
-
# Get negative samples
|
374 |
-
neg_samples = self._smart_negative_sampling(
|
375 |
-
positive,
|
376 |
-
response_pool,
|
377 |
-
neg_samples_per_pos
|
378 |
-
)
|
379 |
-
|
380 |
-
if len(neg_samples) == neg_samples_per_pos:
|
381 |
-
for neg in neg_samples:
|
382 |
-
queries.append(query)
|
383 |
-
positives.append(positive)
|
384 |
-
negatives.append(neg)
|
385 |
-
else:
|
386 |
-
logger.warning(f"Insufficient negative samples for positive response: '{positive}'")
|
387 |
-
|
388 |
-
# Log final dataset statistics
|
389 |
-
logger.info(f"Final dataset statistics:")
|
390 |
-
logger.info(f" Training examples: {len(queries)}")
|
391 |
-
logger.info(f" Unique queries: {len(set(queries))}")
|
392 |
-
logger.info(f" Unique responses: {len(set(positives))}")
|
393 |
-
|
394 |
-
return self._prepare_sequences(queries, positives, negatives)
|
395 |
-
|
396 |
-
def _smart_negative_sampling(
|
397 |
-
self,
|
398 |
-
positive: str,
|
399 |
-
response_pool: List[str],
|
400 |
-
n_samples: int,
|
401 |
-
max_attempts: int = 200,
|
402 |
-
similarity_bounds: Tuple[float, float] = (0.2, 0.9),
|
403 |
-
batch_size: int = 10
|
404 |
-
) -> List[str]:
|
405 |
-
"""Smart negative sampling with similarity bounds and fallback strategies."""
|
406 |
-
candidates = []
|
407 |
-
seen = set()
|
408 |
-
attempts = 0
|
409 |
-
|
410 |
-
while len(candidates) < n_samples and attempts < max_attempts:
|
411 |
-
remaining = min(batch_size, len(response_pool) - len(seen), max_attempts - attempts)
|
412 |
-
if remaining <= 0:
|
413 |
-
break
|
414 |
-
batch = random.sample(
|
415 |
-
[r for r in response_pool if r not in seen and r != positive],
|
416 |
-
remaining
|
417 |
-
)
|
418 |
-
|
419 |
-
for candidate in batch:
|
420 |
-
seen.add(candidate)
|
421 |
-
sim = self._improved_spacy_similarity(candidate, positive)
|
422 |
-
|
423 |
-
if similarity_bounds[0] < sim < similarity_bounds[1]:
|
424 |
-
candidates.append(candidate)
|
425 |
-
if len(candidates) == n_samples:
|
426 |
-
break
|
427 |
-
|
428 |
-
attempts += len(batch)
|
429 |
-
|
430 |
-
if len(candidates) < n_samples:
|
431 |
-
logger.warning(f"Only found {len(candidates)} negative samples for positive response: '{positive}'")
|
432 |
-
# Fallback to random negatives without similarity constraints
|
433 |
-
fallback_needed = n_samples - len(candidates)
|
434 |
-
available_negatives = [r for r in response_pool if r != positive and r not in seen]
|
435 |
-
if available_negatives:
|
436 |
-
additional_negatives = random.sample(
|
437 |
-
available_negatives,
|
438 |
-
min(fallback_needed, len(available_negatives))
|
439 |
-
)
|
440 |
-
candidates.extend(additional_negatives)
|
441 |
-
|
442 |
-
return candidates
|
443 |
-
|
444 |
-
def _prepare_sequences(
|
445 |
-
self,
|
446 |
-
queries: List[str],
|
447 |
-
positives: List[str],
|
448 |
-
negatives: List[str]
|
449 |
-
) -> Tuple[tf.Tensor, tf.Tensor, tf.Tensor]:
|
450 |
-
"""Modified sequence preparation for pretrained tokenizer."""
|
451 |
-
logger.info("Preparing sequences...")
|
452 |
-
|
453 |
-
# Process texts with special tokens
|
454 |
-
queries = [f"{self.special_tokens['user']} {q}" for q in queries]
|
455 |
-
positives = [f"{self.special_tokens['assistant']} {p}" for p in positives]
|
456 |
-
negatives = [f"{self.special_tokens['assistant']} {n}" for n in negatives]
|
457 |
-
|
458 |
-
# Tokenize using HuggingFace tokenizer
|
459 |
-
def encode_batch(texts: List[str]) -> tf.Tensor:
|
460 |
-
# HuggingFace tokenizer returns TensorFlow tensors when return_tensors='tf'
|
461 |
-
encodings = self.tokenizer(
|
462 |
-
texts,
|
463 |
-
padding='max_length',
|
464 |
-
truncation=True,
|
465 |
-
max_length=self.config.max_sequence_length,
|
466 |
-
return_tensors='tf'
|
467 |
-
)
|
468 |
-
return encodings['input_ids']
|
469 |
-
|
470 |
-
# Encode all sequences
|
471 |
-
q_tensor = encode_batch(queries)
|
472 |
-
p_tensor = encode_batch(positives)
|
473 |
-
n_tensor = encode_batch(negatives)
|
474 |
-
|
475 |
-
# Log statistics about encoded sequences
|
476 |
-
logger.info("Sequence statistics:")
|
477 |
-
logger.info(f" Query sequence shape: {q_tensor.shape}")
|
478 |
-
logger.info(f" Positive response sequence shape: {p_tensor.shape}")
|
479 |
-
logger.info(f" Negative response sequence shape: {n_tensor.shape}")
|
480 |
-
|
481 |
-
return q_tensor, p_tensor, n_tensor
|
482 |
-
|
483 |
-
def train(
|
484 |
-
self,
|
485 |
-
q_pad: tf.Tensor,
|
486 |
-
p_pad: tf.Tensor,
|
487 |
-
n_pad: tf.Tensor,
|
488 |
-
epochs: int = 3,
|
489 |
-
batch_size: int = 32,
|
490 |
-
validation_split: float = 0.2,
|
491 |
-
checkpoint_dir: Optional[Union[str, Path]] = None
|
492 |
-
):
|
493 |
-
"""Train the model with improved training loop."""
|
494 |
-
# Setup training
|
495 |
-
total_samples = tf.shape(q_pad)[0]
|
496 |
-
train_size = int((1 - validation_split) * total_samples.numpy())
|
497 |
-
|
498 |
-
# Shuffle and split data
|
499 |
-
indices = tf.random.shuffle(tf.range(start=0, limit=total_samples, dtype=tf.int32))
|
500 |
-
train_idx = indices[:train_size]
|
501 |
-
val_idx = indices[train_size:]
|
502 |
-
|
503 |
-
# Split data using TF indexing
|
504 |
-
train_data = (
|
505 |
-
tf.gather(q_pad, train_idx),
|
506 |
-
tf.gather(p_pad, train_idx),
|
507 |
-
tf.gather(n_pad, train_idx)
|
508 |
-
)
|
509 |
-
val_data = (
|
510 |
-
tf.gather(q_pad, val_idx),
|
511 |
-
tf.gather(p_pad, val_idx),
|
512 |
-
tf.gather(n_pad, val_idx)
|
513 |
-
)
|
514 |
-
|
515 |
-
# Setup optimizer with learning rate schedule
|
516 |
-
steps_per_epoch = train_size // batch_size
|
517 |
-
total_steps = steps_per_epoch * epochs
|
518 |
-
|
519 |
-
lr_schedule = self._get_lr_schedule(
|
520 |
-
total_steps,
|
521 |
-
self.config.learning_rate,
|
522 |
-
self.config.warmup_steps
|
523 |
-
)
|
524 |
-
|
525 |
-
optimizer = tf.keras.optimizers.Adam(learning_rate=lr_schedule)
|
526 |
-
|
527 |
-
# Setup checkpointing
|
528 |
-
if checkpoint_dir:
|
529 |
-
checkpoint_dir = Path(checkpoint_dir)
|
530 |
-
checkpoint_dir.mkdir(parents=True, exist_ok=True)
|
531 |
-
|
532 |
-
# Setup checkpoint callback with correct file format
|
533 |
-
checkpoint_template = str(checkpoint_dir / "model_epoch_{epoch:04d}.weights.h5")
|
534 |
-
checkpoint_callback = tf.keras.callbacks.ModelCheckpoint(
|
535 |
-
checkpoint_template,
|
536 |
-
save_weights_only=True,
|
537 |
-
save_best_only=True,
|
538 |
-
monitor='val_loss',
|
539 |
-
mode='min',
|
540 |
-
verbose=1
|
541 |
-
)
|
542 |
-
|
543 |
-
# Training loop
|
544 |
-
best_val_loss = float('inf')
|
545 |
-
patience = 5
|
546 |
-
wait = 0
|
547 |
-
|
548 |
-
for epoch in range(epochs):
|
549 |
-
# Training
|
550 |
-
train_loss = self._train_epoch(
|
551 |
-
train_data,
|
552 |
-
optimizer,
|
553 |
-
batch_size,
|
554 |
-
training=True
|
555 |
-
)
|
556 |
-
|
557 |
-
# Validation
|
558 |
-
val_loss = self._train_epoch(
|
559 |
-
val_data,
|
560 |
-
optimizer,
|
561 |
-
batch_size,
|
562 |
-
training=False
|
563 |
-
)
|
564 |
-
|
565 |
-
# Update history
|
566 |
-
self.history['train_loss'].append(train_loss)
|
567 |
-
self.history['val_loss'].append(val_loss)
|
568 |
-
|
569 |
-
logger.info(
|
570 |
-
f"Epoch {epoch + 1}/{epochs} - "
|
571 |
-
f"train_loss: {train_loss:.4f} - "
|
572 |
-
f"val_loss: {val_loss:.4f}"
|
573 |
-
)
|
574 |
-
|
575 |
-
# Early stopping
|
576 |
-
if val_loss < best_val_loss:
|
577 |
-
best_val_loss = val_loss
|
578 |
-
wait = 0
|
579 |
-
if checkpoint_dir:
|
580 |
-
self.save_models(checkpoint_dir / f"best_model")
|
581 |
-
else:
|
582 |
-
wait += 1
|
583 |
-
if wait >= patience:
|
584 |
-
logger.info("Early stopping triggered")
|
585 |
-
break
|
586 |
-
|
587 |
-
def _train_epoch(
|
588 |
-
self,
|
589 |
-
data: Tuple[tf.Tensor, tf.Tensor, tf.Tensor],
|
590 |
-
optimizer: tf.keras.optimizers.Optimizer,
|
591 |
-
batch_size: int,
|
592 |
-
training: bool = True
|
593 |
-
) -> float:
|
594 |
-
"""Train for one epoch with enhanced logging and progress tracking."""
|
595 |
-
q_data, p_data, n_data = data
|
596 |
-
total_loss = 0.0
|
597 |
-
num_batches = tf.shape(q_data)[0] // batch_size
|
598 |
-
|
599 |
-
# Log current learning rate at start of epoch
|
600 |
-
if training:
|
601 |
-
if hasattr(optimizer.learning_rate, '__call__'):
|
602 |
-
current_lr = optimizer.learning_rate(optimizer.iterations)
|
603 |
-
else:
|
604 |
-
current_lr = optimizer.learning_rate
|
605 |
-
logger.info(f"Current learning rate: {float(current_lr):.6f}")
|
606 |
-
|
607 |
-
# Create progress bar
|
608 |
-
mode = "Training" if training else "Validation"
|
609 |
-
pbar = tqdm(
|
610 |
-
total=num_batches.numpy(),
|
611 |
-
desc=f"{mode} batches",
|
612 |
-
unit="batch",
|
613 |
-
dynamic_ncols=True
|
614 |
-
)
|
615 |
-
|
616 |
-
# Process batches
|
617 |
-
for i in range(num_batches):
|
618 |
-
start_idx = i * batch_size
|
619 |
-
end_idx = start_idx + batch_size
|
620 |
-
|
621 |
-
batch_loss = self._train_step(
|
622 |
-
q_data[start_idx:end_idx],
|
623 |
-
p_data[start_idx:end_idx],
|
624 |
-
n_data[start_idx:end_idx],
|
625 |
-
optimizer,
|
626 |
-
training
|
627 |
-
)
|
628 |
-
total_loss += batch_loss.numpy()
|
629 |
-
|
630 |
-
# Update progress bar with current loss
|
631 |
-
avg_loss = total_loss / (i + 1)
|
632 |
-
pbar.set_postfix({
|
633 |
-
'loss': f'{avg_loss:.4f}',
|
634 |
-
'lr': f'{float(current_lr):.6f}' if training else 'N/A'
|
635 |
-
})
|
636 |
-
pbar.update(1)
|
637 |
-
|
638 |
-
pbar.close()
|
639 |
-
return total_loss / num_batches.numpy() if num_batches > 0 else 0.0
|
640 |
-
|
641 |
-
@tf.function
|
642 |
-
def _train_step(
|
643 |
-
self,
|
644 |
-
q_batch: tf.Tensor,
|
645 |
-
p_batch: tf.Tensor,
|
646 |
-
n_batch: tf.Tensor,
|
647 |
-
optimizer: tf.keras.optimizers.Optimizer,
|
648 |
-
training: bool = True
|
649 |
-
) -> tf.Tensor:
|
650 |
-
"""Single training step with triplet loss."""
|
651 |
-
with tf.GradientTape() as tape:
|
652 |
-
# Get embeddings
|
653 |
-
q_emb = self.query_encoder(q_batch, training=training)
|
654 |
-
p_emb = self.response_encoder(p_batch, training=training)
|
655 |
-
n_emb = self.response_encoder(n_batch, training=training)
|
656 |
-
|
657 |
-
# Calculate triplet loss
|
658 |
-
pos_dist = tf.reduce_sum(tf.square(q_emb - p_emb), axis=1)
|
659 |
-
neg_dist = tf.reduce_sum(tf.square(q_emb - n_emb), axis=1)
|
660 |
-
|
661 |
-
loss = tf.maximum(0.0, self.config.margin + pos_dist - neg_dist)
|
662 |
-
loss = tf.reduce_mean(loss)
|
663 |
-
|
664 |
-
if training:
|
665 |
-
# Apply gradients
|
666 |
-
gradients = tape.gradient(
|
667 |
-
loss,
|
668 |
-
self.query_encoder.trainable_variables +
|
669 |
-
self.response_encoder.trainable_variables
|
670 |
-
)
|
671 |
-
optimizer.apply_gradients(zip(
|
672 |
-
gradients,
|
673 |
-
self.query_encoder.trainable_variables +
|
674 |
-
self.response_encoder.trainable_variables
|
675 |
-
))
|
676 |
-
|
677 |
-
return loss
|
678 |
-
|
679 |
-
def _get_lr_schedule(
|
680 |
-
self,
|
681 |
-
total_steps: int,
|
682 |
-
peak_lr: float,
|
683 |
-
warmup_steps: int
|
684 |
-
) -> tf.keras.optimizers.schedules.LearningRateSchedule:
|
685 |
-
"""Enhanced learning rate schedule with better error handling and logging."""
|
686 |
-
class CustomSchedule(tf.keras.optimizers.schedules.LearningRateSchedule):
|
687 |
-
def __init__(
|
688 |
-
self,
|
689 |
-
total_steps: int,
|
690 |
-
peak_lr: float,
|
691 |
-
warmup_steps: int
|
692 |
-
):
|
693 |
-
super().__init__()
|
694 |
-
self.total_steps = tf.cast(total_steps, tf.float32)
|
695 |
-
self.peak_lr = tf.cast(peak_lr, tf.float32)
|
696 |
-
self.warmup_steps = tf.cast(max(1, warmup_steps), tf.float32) # Prevent 0
|
697 |
-
|
698 |
-
# Calculate and store constants
|
699 |
-
self.initial_lr = self.peak_lr * 0.1 # Start at 10% of peak
|
700 |
-
self.min_lr = self.peak_lr * 0.01 # Minimum 1% of peak
|
701 |
-
|
702 |
-
logger.info(f"Learning rate schedule initialized:")
|
703 |
-
logger.info(f" Initial LR: {float(self.initial_lr):.6f}")
|
704 |
-
logger.info(f" Peak LR: {float(self.peak_lr):.6f}")
|
705 |
-
logger.info(f" Min LR: {float(self.min_lr):.6f}")
|
706 |
-
logger.info(f" Warmup steps: {int(self.warmup_steps)}")
|
707 |
-
logger.info(f" Total steps: {int(self.total_steps)}")
|
708 |
-
|
709 |
-
def __call__(self, step):
|
710 |
-
step = tf.cast(step, tf.float32)
|
711 |
-
|
712 |
-
# Warmup phase
|
713 |
-
warmup_factor = tf.minimum(1.0, step / self.warmup_steps)
|
714 |
-
warmup_lr = self.initial_lr + (self.peak_lr - self.initial_lr) * warmup_factor
|
715 |
-
|
716 |
-
# Decay phase
|
717 |
-
decay_steps = tf.maximum(1.0, self.total_steps - self.warmup_steps)
|
718 |
-
decay_factor = (step - self.warmup_steps) / decay_steps
|
719 |
-
decay_factor = tf.minimum(tf.maximum(0.0, decay_factor), 1.0) # Clip to [0,1]
|
720 |
-
|
721 |
-
cosine_decay = 0.5 * (1.0 + tf.cos(np.pi * decay_factor))
|
722 |
-
decay_lr = self.min_lr + (self.peak_lr - self.min_lr) * cosine_decay
|
723 |
-
|
724 |
-
# Choose between warmup and decay
|
725 |
-
final_lr = tf.where(step < self.warmup_steps, warmup_lr, decay_lr)
|
726 |
-
|
727 |
-
# Ensure learning rate is valid
|
728 |
-
final_lr = tf.maximum(self.min_lr, final_lr)
|
729 |
-
final_lr = tf.where(tf.math.is_finite(final_lr), final_lr, self.min_lr)
|
730 |
-
|
731 |
-
return final_lr
|
732 |
-
|
733 |
-
def get_config(self):
|
734 |
-
return {
|
735 |
-
"total_steps": self.total_steps,
|
736 |
-
"peak_lr": self.peak_lr,
|
737 |
-
"warmup_steps": self.warmup_steps,
|
738 |
-
}
|
739 |
-
|
740 |
-
return CustomSchedule(total_steps, peak_lr, warmup_steps)
|
741 |
-
|
742 |
-
def encode_query(self, query: str, context: Optional[List[Tuple[str, str]]] = None) -> tf.Tensor:
|
743 |
-
"""Encode a query with optional conversation context."""
|
744 |
-
# Prepare query with context
|
745 |
-
if context:
|
746 |
-
context_str = ' '.join([
|
747 |
-
f"{self.special_tokens['user']} {q} "
|
748 |
-
f"{self.special_tokens['assistant']} {r}"
|
749 |
-
for q, r in context[-self.config.max_context_turns:]
|
750 |
-
])
|
751 |
-
query = f"{context_str} {self.special_tokens['user']} {query}"
|
752 |
-
else:
|
753 |
-
query = f"{self.special_tokens['user']} {query}"
|
754 |
-
|
755 |
-
# Tokenize and pad using TensorFlow tensors
|
756 |
-
encodings = self.tokenizer(
|
757 |
-
[query],
|
758 |
-
padding='max_length',
|
759 |
-
truncation=True,
|
760 |
-
max_length=self.config.max_sequence_length,
|
761 |
-
return_tensors='tf'
|
762 |
-
)
|
763 |
-
input_ids = encodings['input_ids']
|
764 |
-
|
765 |
-
return self.query_encoder(input_ids, training=False)
|
766 |
-
|
767 |
-
def encode_responses(self, responses: List[str]) -> tf.Tensor:
|
768 |
-
"""Encode a batch of responses."""
|
769 |
-
# Prepare responses
|
770 |
-
responses = [
|
771 |
-
f"{self.special_tokens['assistant']} {r}"
|
772 |
-
for r in responses
|
773 |
-
]
|
774 |
-
|
775 |
-
# Tokenize and pad using TensorFlow tensors
|
776 |
-
encodings = self.tokenizer(
|
777 |
-
responses,
|
778 |
-
padding='max_length',
|
779 |
-
truncation=True,
|
780 |
-
max_length=self.config.max_sequence_length,
|
781 |
-
return_tensors='tf'
|
782 |
-
)
|
783 |
-
input_ids = encodings['input_ids']
|
784 |
-
|
785 |
-
return self.response_encoder(input_ids, training=False)
|
786 |
-
|
787 |
-
def retrieve_responses(
|
788 |
-
self,
|
789 |
-
query: str,
|
790 |
-
candidates: List[str],
|
791 |
-
context: Optional[List[Tuple[str, str]]] = None,
|
792 |
-
top_k: int = 5
|
793 |
-
) -> List[Tuple[str, float]]:
|
794 |
-
"""Retrieve top-k responses for a query."""
|
795 |
-
# Encode query and candidates
|
796 |
-
q_emb = self.encode_query(query, context)
|
797 |
-
c_emb = self.encode_responses(candidates)
|
798 |
-
|
799 |
-
# Calculate similarities
|
800 |
-
similarities = tf.matmul(q_emb, c_emb, transpose_b=True).numpy()[0]
|
801 |
-
|
802 |
-
# Get top-k responses
|
803 |
-
top_indices = np.argsort(similarities)[::-1][:top_k]
|
804 |
-
|
805 |
-
return [(candidates[i], similarities[i]) for i in top_indices]
|
806 |
-
|
807 |
-
def chat(
|
808 |
-
self,
|
809 |
-
query: str,
|
810 |
-
response_pool: List[str],
|
811 |
-
conversation_history: Optional[List[Tuple[str, str]]] = None,
|
812 |
-
top_k: int = 5
|
813 |
-
) -> Tuple[str, List[Tuple[str, float]]]:
|
814 |
-
"""Interactive chat with response selection."""
|
815 |
-
# Get responses with scores
|
816 |
-
responses = self.retrieve_responses(
|
817 |
-
query,
|
818 |
-
response_pool,
|
819 |
-
conversation_history,
|
820 |
-
top_k
|
821 |
-
)
|
822 |
-
|
823 |
-
# Return best response and all candidates with scores
|
824 |
-
return responses[0][0], responses
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
chatbot4.py → chatbot_model.py
RENAMED
@@ -2,30 +2,25 @@ from transformers import TFAutoModel, AutoTokenizer
|
|
2 |
import tensorflow as tf
|
3 |
import numpy as np
|
4 |
from typing import List, Tuple, Dict, Optional, Union, Any
|
|
|
5 |
from dataclasses import dataclass
|
6 |
-
import logging
|
7 |
import json
|
8 |
from tqdm import tqdm
|
9 |
from pathlib import Path
|
|
|
10 |
import faiss
|
11 |
from response_quality_checker import ResponseQualityChecker
|
12 |
-
|
13 |
-
|
14 |
-
|
15 |
-
|
16 |
-
# Configure logging
|
17 |
-
logging.basicConfig(
|
18 |
-
level=logging.DEBUG,
|
19 |
-
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
|
20 |
-
)
|
21 |
-
logger = logging.getLogger(__name__)
|
22 |
|
23 |
@dataclass
|
24 |
class ChatbotConfig:
|
25 |
"""Configuration for the RetrievalChatbot."""
|
26 |
vocab_size: int = 30526 # DistilBERT vocab size
|
27 |
-
|
28 |
-
embedding_dim: int =
|
29 |
encoder_units: int = 256
|
30 |
num_attention_heads: int = 8
|
31 |
dropout_rate: float = 0.2
|
@@ -70,13 +65,20 @@ class EncoderModel(tf.keras.Model):
|
|
70 |
# Freeze pretrained weights if specified
|
71 |
self.pretrained.distilbert.embeddings.trainable = False
|
72 |
for i, layer_module in enumerate(self.pretrained.distilbert.transformer.layer):
|
73 |
-
if i <
|
74 |
layer_module.trainable = False
|
75 |
else:
|
76 |
layer_module.trainable = True
|
77 |
|
78 |
# Pooling layer (Global Average Pooling)
|
79 |
self.pooler = tf.keras.layers.GlobalAveragePooling1D()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
80 |
|
81 |
# Dropout and normalization
|
82 |
self.dropout = tf.keras.layers.Dropout(config.dropout_rate)
|
@@ -90,14 +92,11 @@ class EncoderModel(tf.keras.Model):
|
|
90 |
pretrained_outputs = self.pretrained(inputs, training=training)
|
91 |
x = pretrained_outputs.last_hidden_state # Shape: [batch_size, seq_len, embedding_dim]
|
92 |
|
93 |
-
# Apply pooling
|
94 |
-
x = self.pooler(x) # Shape: [batch_size,
|
95 |
-
|
96 |
-
# Apply dropout
|
97 |
-
x = self.
|
98 |
-
|
99 |
-
# L2 normalization
|
100 |
-
x = self.normalize(x) # Shape: [batch_size, embedding_dim]
|
101 |
|
102 |
return x
|
103 |
|
@@ -110,42 +109,34 @@ class EncoderModel(tf.keras.Model):
|
|
110 |
"name": self.name
|
111 |
})
|
112 |
return config
|
113 |
-
|
114 |
-
# class CustomLRSchedule(tf.keras.optimizers.schedules.LearningRateSchedule):
|
115 |
-
# def __init__(self, initial_lr, peak_lr, min_lr, warmup_steps, total_steps):
|
116 |
-
# super().__init__()
|
117 |
-
# self.initial_lr = initial_lr
|
118 |
-
# self.peak_lr = peak_lr
|
119 |
-
# self.min_lr = min_lr
|
120 |
-
# self.warmup_steps = min(warmup_steps, total_steps // 2) # Ensure warmup_steps <= total_steps
|
121 |
-
# self.total_steps = total_steps
|
122 |
-
|
123 |
-
# def __call__(self, step):
|
124 |
-
# if step < self.warmup_steps:
|
125 |
-
# # Linear warmup
|
126 |
-
# lr = self.initial_lr + (self.peak_lr - self.initial_lr) * (step / self.warmup_steps)
|
127 |
-
# else:
|
128 |
-
# # Linear decay
|
129 |
-
# decay_steps = self.total_steps - self.warmup_steps
|
130 |
-
# if decay_steps > 0:
|
131 |
-
# lr = self.peak_lr - (self.peak_lr - self.min_lr) * ((step - self.warmup_steps) / decay_steps)
|
132 |
-
# else:
|
133 |
-
# lr = self.peak_lr
|
134 |
-
# return lr
|
135 |
-
|
136 |
-
# def get_config(self):
|
137 |
-
# return {
|
138 |
-
# "initial_lr": self.initial_lr,
|
139 |
-
# "peak_lr": self.peak_lr,
|
140 |
-
# "min_lr": self.min_lr,
|
141 |
-
# "warmup_steps": self.warmup_steps,
|
142 |
-
# "total_steps": self.total_steps,
|
143 |
-
# }
|
144 |
|
145 |
-
class RetrievalChatbot:
|
146 |
"""Retrieval-based chatbot using pretrained embeddings and FAISS for similarity search."""
|
147 |
-
def __init__(self, config: ChatbotConfig, dialogues: List[dict] = []):
|
148 |
self.config = config
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
149 |
|
150 |
# Special tokens
|
151 |
self.special_tokens = {
|
@@ -161,8 +152,12 @@ class RetrievalChatbot:
|
|
161 |
{'additional_special_tokens': list(self.special_tokens.values())}
|
162 |
)
|
163 |
|
164 |
-
# Build encoders
|
165 |
-
self.
|
|
|
|
|
|
|
|
|
166 |
|
167 |
# Initialize FAISS index
|
168 |
self._initialize_faiss()
|
@@ -193,31 +188,41 @@ class RetrievalChatbot:
|
|
193 |
self.encoder.pretrained.resize_token_embeddings(new_vocab_size)
|
194 |
logger.info(f"Token embeddings resized to: {new_vocab_size}")
|
195 |
|
196 |
-
#
|
197 |
logger.info("Inspecting embeddings attributes:")
|
198 |
for attr in dir(self.encoder.pretrained.distilbert.embeddings):
|
199 |
if not attr.startswith('_'):
|
200 |
logger.info(f" {attr}")
|
201 |
|
202 |
-
#
|
203 |
-
|
204 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
205 |
logger.info(f"Encoder Embedding Dimension: {embedding_dim}")
|
206 |
logger.info(f"Encoder Embedding Vocabulary Size: {vocab_size}")
|
207 |
-
|
208 |
-
|
209 |
-
|
210 |
-
logger.
|
211 |
-
|
212 |
-
def check_trainable_variables(self):
|
213 |
-
"""Logs the trainable variables in both encoders."""
|
214 |
-
logger.info("Checking trainable variables in shared_encoder:")
|
215 |
-
for var in self.encoder.pretrained.trainable_variables:
|
216 |
-
logger.info(f" {var.name}, shape: {var.shape}")
|
217 |
-
|
218 |
-
# logger.info("Checking trainable variables in response_encoder:")
|
219 |
-
# for var in self.response_encoder.pretrained.trainable_variables:
|
220 |
-
# logger.info(f" {var.name}, shape: {var.shape}")
|
221 |
|
222 |
def _initialize_faiss(self):
|
223 |
"""Initialize FAISS index based on available resources."""
|
@@ -239,10 +244,10 @@ class RetrievalChatbot:
|
|
239 |
self.index = faiss.IndexFlatIP(self.config.embedding_dim)
|
240 |
logger.info("FAISS index initialized.")
|
241 |
|
242 |
-
def verify_faiss_index(
|
243 |
"""Verify that FAISS index matches the response pool."""
|
244 |
-
indexed_size =
|
245 |
-
pool_size = len(
|
246 |
logger.info(f"FAISS index size: {indexed_size}")
|
247 |
logger.info(f"Response pool size: {pool_size}")
|
248 |
if indexed_size != pool_size:
|
@@ -268,12 +273,12 @@ class RetrievalChatbot:
|
|
268 |
logger.info(f"Found {len(unique_responses)} unique responses.")
|
269 |
|
270 |
# Encode responses
|
|
|
271 |
response_embeddings = self.encode_responses(unique_responses)
|
272 |
response_embeddings = response_embeddings.numpy()
|
273 |
|
274 |
# Ensure float32
|
275 |
if response_embeddings.dtype != np.float32:
|
276 |
-
logger.info(f"Converting embeddings from {response_embeddings.dtype} to float32.")
|
277 |
response_embeddings = response_embeddings.astype('float32')
|
278 |
|
279 |
# Ensure the array is contiguous in memory
|
@@ -312,14 +317,8 @@ class RetrievalChatbot:
|
|
312 |
Returns:
|
313 |
tf.Tensor: Tensor of shape (N, emb_dim) with all response embeddings.
|
314 |
"""
|
315 |
-
|
316 |
-
|
317 |
-
# We'll accumulate embeddings in a list and concatenate at the end
|
318 |
all_embeddings = []
|
319 |
-
|
320 |
-
# Set up a progress bar
|
321 |
-
from tqdm import tqdm
|
322 |
-
pbar = tqdm(total=len(responses), desc="Encoding responses")
|
323 |
|
324 |
# Process the responses in chunks of 'batch_size'
|
325 |
for start_idx in range(0, len(responses), batch_size):
|
@@ -331,7 +330,7 @@ class RetrievalChatbot:
|
|
331 |
batch_texts,
|
332 |
padding='max_length',
|
333 |
truncation=True,
|
334 |
-
max_length=self.config.
|
335 |
return_tensors='tf',
|
336 |
)
|
337 |
|
@@ -346,11 +345,6 @@ class RetrievalChatbot:
|
|
346 |
# Collect
|
347 |
all_embeddings.append(embeddings_batch)
|
348 |
|
349 |
-
# Update progress bar
|
350 |
-
pbar.update(len(batch_texts))
|
351 |
-
|
352 |
-
pbar.close()
|
353 |
-
|
354 |
# Concatenate all batch embeddings along axis=0
|
355 |
if len(all_embeddings) == 1:
|
356 |
# Only one batch
|
@@ -359,10 +353,6 @@ class RetrievalChatbot:
|
|
359 |
# Multiple batches, concatenate
|
360 |
final_embeddings = tf.concat(all_embeddings, axis=0)
|
361 |
|
362 |
-
logger.info(
|
363 |
-
f"Finished encoding {len(responses)} responses. "
|
364 |
-
f"Final shape: {final_embeddings.shape}"
|
365 |
-
)
|
366 |
return final_embeddings
|
367 |
|
368 |
def encode_query(self, query: str, context: Optional[List[Tuple[str, str]]] = None) -> tf.Tensor:
|
@@ -383,7 +373,7 @@ class RetrievalChatbot:
|
|
383 |
[query],
|
384 |
padding='max_length',
|
385 |
truncation=True,
|
386 |
-
max_length=self.config.
|
387 |
return_tensors='tf'
|
388 |
)
|
389 |
input_ids = encodings['input_ids']
|
@@ -391,7 +381,6 @@ class RetrievalChatbot:
|
|
391 |
# Verify token IDs
|
392 |
max_id = tf.reduce_max(input_ids).numpy()
|
393 |
new_vocab_size = len(self.tokenizer)
|
394 |
-
logger.info(f"Maximum input_id: {max_id}, Vocab Size: {new_vocab_size}")
|
395 |
|
396 |
if max_id >= new_vocab_size:
|
397 |
logger.error(f"Token ID {max_id} exceeds the vocabulary size {new_vocab_size}.")
|
@@ -399,6 +388,46 @@ class RetrievalChatbot:
|
|
399 |
|
400 |
# Get embeddings from the shared encoder
|
401 |
return self.encoder(input_ids, training=False)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
402 |
|
403 |
def retrieve_responses_faiss(self, query: str, top_k: int = 5) -> List[Tuple[str, float]]:
|
404 |
"""Retrieve top-k responses using FAISS."""
|
@@ -456,10 +485,6 @@ class RetrievalChatbot:
|
|
456 |
load_dir / "shared_encoder",
|
457 |
config=config
|
458 |
)
|
459 |
-
# chatbot.response_encoder.pretrained = TFAutoModel.from_pretrained(
|
460 |
-
# load_dir / "response_encoder",
|
461 |
-
# config=config
|
462 |
-
# )
|
463 |
|
464 |
# Load tokenizer
|
465 |
chatbot.tokenizer = AutoTokenizer.from_pretrained(load_dir / "tokenizer")
|
@@ -498,77 +523,137 @@ class RetrievalChatbot:
|
|
498 |
def prepare_dataset(
|
499 |
self,
|
500 |
dialogues: List[dict],
|
|
|
501 |
debug_samples: int = None
|
502 |
) -> Tuple[tf.Tensor, tf.Tensor]:
|
503 |
"""
|
504 |
-
Prepares dataset for
|
505 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
506 |
"""
|
507 |
-
|
|
|
508 |
|
509 |
queries, positives = [], []
|
510 |
|
|
|
511 |
for dialogue in dialogues:
|
512 |
turns = dialogue.get('turns', [])
|
513 |
for i in range(len(turns) - 1):
|
514 |
current_turn = turns[i]
|
515 |
next_turn = turns[i+1]
|
516 |
|
517 |
-
if (current_turn.get('speaker') == 'user'
|
518 |
-
next_turn.get('speaker') == 'assistant'
|
519 |
-
'text' in current_turn
|
520 |
-
'text' in next_turn):
|
521 |
|
522 |
-
|
523 |
-
|
524 |
|
525 |
-
queries.append(
|
526 |
-
positives.append(
|
527 |
|
528 |
-
#
|
529 |
if debug_samples is not None:
|
530 |
queries = queries[:debug_samples]
|
531 |
positives = positives[:debug_samples]
|
532 |
logger.info(f"Debug mode: limited to {debug_samples} pairs.")
|
533 |
|
534 |
-
logger.info(f"Prepared {len(queries)} (query, positive) pairs.")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
535 |
|
536 |
-
#
|
|
|
|
|
|
|
|
|
|
|
537 |
encoded_queries = self.tokenizer(
|
538 |
-
|
539 |
padding='max_length',
|
540 |
truncation=True,
|
541 |
-
max_length=self.config.
|
542 |
return_tensors='tf'
|
543 |
)
|
544 |
-
# Tokenize positives
|
545 |
encoded_positives = self.tokenizer(
|
546 |
-
|
547 |
padding='max_length',
|
548 |
truncation=True,
|
549 |
-
max_length=self.config.
|
550 |
return_tensors='tf'
|
551 |
)
|
552 |
|
553 |
q_tensor = encoded_queries['input_ids']
|
554 |
p_tensor = encoded_positives['input_ids']
|
555 |
|
556 |
-
logger.info("Tokenized and padded sequences for in-batch training.")
|
557 |
return q_tensor, p_tensor
|
558 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
559 |
def train(
|
560 |
self,
|
561 |
q_pad: tf.Tensor,
|
562 |
p_pad: tf.Tensor,
|
563 |
-
epochs: int,
|
564 |
-
batch_size: int,
|
565 |
-
validation_split: float,
|
566 |
-
checkpoint_dir: str,
|
567 |
use_lr_schedule: bool = True,
|
568 |
peak_lr: float = 2e-5,
|
569 |
warmup_steps_ratio: float = 0.1,
|
570 |
early_stopping_patience: int = 3,
|
571 |
-
min_delta: float = 1e-4
|
|
|
572 |
):
|
573 |
dataset_size = tf.shape(q_pad)[0].numpy()
|
574 |
val_size = int(dataset_size * validation_split)
|
@@ -604,21 +689,20 @@ class RetrievalChatbot:
|
|
604 |
val_q = q_pad[train_size:]
|
605 |
val_p = p_pad[train_size:]
|
606 |
|
607 |
-
train_dataset = tf.data.Dataset.from_tensor_slices((train_q, train_p))
|
608 |
-
|
|
|
|
|
609 |
|
610 |
-
val_dataset = tf.data.Dataset.from_tensor_slices((val_q, val_p))
|
611 |
-
|
|
|
612 |
|
613 |
# 3) Checkpoint + manager
|
614 |
checkpoint = tf.train.Checkpoint(optimizer=self.optimizer, model=self.encoder)
|
615 |
manager = tf.train.CheckpointManager(checkpoint, checkpoint_dir, max_to_keep=3)
|
616 |
|
617 |
# 4) TensorBoard setup
|
618 |
-
import datetime
|
619 |
-
import os
|
620 |
-
from pathlib import Path
|
621 |
-
|
622 |
log_dir = Path(checkpoint_dir) / "tensorboard_logs"
|
623 |
log_dir.mkdir(parents=True, exist_ok=True)
|
624 |
|
@@ -638,48 +722,91 @@ class RetrievalChatbot:
|
|
638 |
logger.info("Beginning training loop...")
|
639 |
global_step = 0
|
640 |
|
|
|
|
|
|
|
|
|
|
|
641 |
from tqdm import tqdm
|
642 |
for epoch in range(1, epochs + 1):
|
643 |
logger.info(f"\n=== Epoch {epoch}/{epochs} ===")
|
644 |
epoch_loss_avg = tf.keras.metrics.Mean()
|
645 |
|
646 |
-
|
647 |
with tqdm(total=steps_per_epoch, desc=f"Training Epoch {epoch}") as pbar:
|
648 |
for (q_batch, p_batch) in train_dataset:
|
|
|
649 |
global_step += 1
|
650 |
|
651 |
-
|
652 |
-
|
653 |
-
|
654 |
-
|
655 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
656 |
if use_lr_schedule:
|
|
|
657 |
lr = self.optimizer.learning_rate
|
658 |
if isinstance(lr, tf.keras.optimizers.schedules.LearningRateSchedule):
|
659 |
-
# Get the current step
|
660 |
current_step = tf.cast(self.optimizer.iterations, tf.float32)
|
661 |
-
# Compute the current learning rate
|
662 |
current_lr = lr(current_step)
|
663 |
else:
|
664 |
-
# If learning_rate is not a schedule, use it directly
|
665 |
current_lr = lr
|
666 |
-
# Convert to float for logging
|
667 |
current_lr_value = float(current_lr.numpy())
|
668 |
else:
|
669 |
-
# If using fixed learning rate
|
670 |
current_lr_value = float(self.optimizer.learning_rate.numpy())
|
671 |
|
672 |
-
# Update tqdm
|
673 |
pbar.update(1)
|
674 |
pbar.set_postfix({
|
675 |
-
"loss": f"{
|
676 |
"lr": f"{current_lr_value:.2e}"
|
677 |
})
|
678 |
|
679 |
-
# TensorBoard
|
680 |
-
|
681 |
-
|
682 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
683 |
|
684 |
# Validation
|
685 |
val_loss_avg = tf.keras.metrics.Mean()
|
@@ -726,90 +853,6 @@ class RetrievalChatbot:
|
|
726 |
|
727 |
logger.info("In-batch training completed!")
|
728 |
|
729 |
-
@tf.function
|
730 |
-
def _train_step(self, q_batch, p_batch):
|
731 |
-
"""
|
732 |
-
Single training step using in-batch negatives.
|
733 |
-
q_batch: (batch_size, seq_len) int32 input_ids for queries
|
734 |
-
p_batch: (batch_size, seq_len) int32 input_ids for positives
|
735 |
-
"""
|
736 |
-
with tf.GradientTape() as tape:
|
737 |
-
# Encode queries and positives
|
738 |
-
q_enc = self.encoder(q_batch, training=True) # [B, emb_dim]
|
739 |
-
p_enc = self.encoder(p_batch, training=True) # [B, emb_dim]
|
740 |
-
|
741 |
-
# Compute similarity matrix: (B, B) = q_enc * p_enc^T
|
742 |
-
# If embeddings are L2-normalized, this is cosine similarity
|
743 |
-
sim_matrix = tf.matmul(q_enc, p_enc, transpose_b=True) # [B, B]
|
744 |
-
|
745 |
-
# Labels are just the diagonal indices
|
746 |
-
batch_size = tf.shape(q_enc)[0]
|
747 |
-
labels = tf.range(batch_size, dtype=tf.int32) # [0..B-1]
|
748 |
-
|
749 |
-
# Softmax cross-entropy
|
750 |
-
loss = tf.nn.sparse_softmax_cross_entropy_with_logits(
|
751 |
-
labels=labels,
|
752 |
-
logits=sim_matrix
|
753 |
-
)
|
754 |
-
loss = tf.reduce_mean(loss)
|
755 |
-
|
756 |
-
# Compute gradients for the pretrained DistilBERT variables only
|
757 |
-
train_vars = self.encoder.pretrained.trainable_variables
|
758 |
-
gradients = tape.gradient(loss, train_vars)
|
759 |
-
|
760 |
-
# Remove any None grads (in case some layers are frozen)
|
761 |
-
grads_and_vars = [(g, v) for g, v in zip(gradients, train_vars) if g is not None]
|
762 |
-
if grads_and_vars:
|
763 |
-
self.optimizer.apply_gradients(grads_and_vars)
|
764 |
-
|
765 |
-
return loss
|
766 |
-
|
767 |
-
def _prepare_sequences(
|
768 |
-
self,
|
769 |
-
queries: List[str],
|
770 |
-
positives: List[str],
|
771 |
-
negatives: List[str]
|
772 |
-
) -> Tuple[tf.Tensor, tf.Tensor, tf.Tensor]:
|
773 |
-
"""Prepare and tokenize sequences for training."""
|
774 |
-
logger.info("Preparing sequences for training...")
|
775 |
-
|
776 |
-
# Handle empty lists
|
777 |
-
if not queries:
|
778 |
-
logger.error("No queries to encode. Skipping sequence preparation.")
|
779 |
-
return tf.constant([]), tf.constant([]), tf.constant([])
|
780 |
-
|
781 |
-
# Process texts with special tokens
|
782 |
-
queries = [f"{self.special_tokens['user']} {q}" for q in queries]
|
783 |
-
positives = [f"{self.special_tokens['assistant']} {p}" for p in positives]
|
784 |
-
negatives = [f"{self.special_tokens['assistant']} {n}" for n in negatives]
|
785 |
-
|
786 |
-
# Tokenize using HuggingFace tokenizer
|
787 |
-
def encode_batch(texts: List[str]) -> tf.Tensor:
|
788 |
-
if not texts:
|
789 |
-
logger.error("Empty text list provided to tokenizer.")
|
790 |
-
return tf.constant([])
|
791 |
-
encodings = self.tokenizer(
|
792 |
-
texts,
|
793 |
-
padding='max_length',
|
794 |
-
truncation=True,
|
795 |
-
max_length=self.config.max_sequence_length,
|
796 |
-
return_tensors='tf'
|
797 |
-
)
|
798 |
-
return encodings['input_ids']
|
799 |
-
|
800 |
-
# Encode all sequences
|
801 |
-
q_tensor = encode_batch(queries)
|
802 |
-
p_tensor = encode_batch(positives)
|
803 |
-
n_tensor = encode_batch(negatives)
|
804 |
-
|
805 |
-
# Log statistics about encoded sequences
|
806 |
-
logger.info("Sequence statistics:")
|
807 |
-
logger.info(f" Query sequence shape: {q_tensor.shape}")
|
808 |
-
logger.info(f" Positive response sequence shape: {p_tensor.shape}")
|
809 |
-
logger.info(f" Negative response sequence shape: {n_tensor.shape}")
|
810 |
-
|
811 |
-
return q_tensor, p_tensor, n_tensor
|
812 |
-
|
813 |
def _get_lr_schedule(
|
814 |
self,
|
815 |
total_steps: int,
|
@@ -855,7 +898,7 @@ class RetrievalChatbot:
|
|
855 |
decay_factor = (step - self.warmup_steps) / decay_steps
|
856 |
decay_factor = tf.minimum(tf.maximum(0.0, decay_factor), 1.0) # Clip to [0,1]
|
857 |
|
858 |
-
cosine_decay = 0.5 * (1.0 + tf.cos(
|
859 |
decay_lr = self.min_lr + (self.peak_lr - self.min_lr) * cosine_decay
|
860 |
|
861 |
# Choose between warmup and decay
|
@@ -881,411 +924,334 @@ class RetrievalChatbot:
|
|
881 |
normalized_emb1 = emb1 / np.linalg.norm(emb1, axis=1, keepdims=True)
|
882 |
normalized_emb2 = emb2 / np.linalg.norm(emb2, axis=1, keepdims=True)
|
883 |
return np.dot(normalized_emb1, normalized_emb2.T)
|
884 |
-
|
885 |
-
def run_automatic_validation(
|
886 |
-
self,
|
887 |
-
quality_checker: 'ResponseQualityChecker',
|
888 |
-
num_examples: int = 5
|
889 |
-
) -> Dict[str, Any]:
|
890 |
-
"""
|
891 |
-
Run automatic validation with quality metrics using FAISS-based retrieval.
|
892 |
-
"""
|
893 |
-
logger.info("\n=== Running Automatic Validation ===")
|
894 |
-
|
895 |
-
test_queries = [
|
896 |
-
"Hello, how are you today?",
|
897 |
-
"What's the weather like?",
|
898 |
-
"Can you help me with a problem?",
|
899 |
-
"Tell me a joke",
|
900 |
-
"What time is it?",
|
901 |
-
"I need help with my homework",
|
902 |
-
"Where's a good place to eat?",
|
903 |
-
"What movies are playing?",
|
904 |
-
"How do I reset my password?",
|
905 |
-
"Can you recommend a book?"
|
906 |
-
]
|
907 |
-
|
908 |
-
test_queries = test_queries[:num_examples]
|
909 |
-
metrics_history = []
|
910 |
-
|
911 |
-
for i, query in enumerate(test_queries, 1):
|
912 |
-
logger.info(f"\nTest Case {i}:")
|
913 |
-
logger.info(f"Query: {query}")
|
914 |
-
|
915 |
-
# Get responses and scores using FAISS
|
916 |
-
responses = self.retrieve_responses_faiss(query, top_k=5)
|
917 |
-
|
918 |
-
# Check quality
|
919 |
-
quality_metrics = quality_checker.check_response_quality(query, responses)
|
920 |
-
metrics_history.append(quality_metrics)
|
921 |
-
|
922 |
-
# Log results
|
923 |
-
logger.info(f"Quality Metrics: {quality_metrics}")
|
924 |
-
logger.info("Top responses:")
|
925 |
-
for j, (response, score) in enumerate(responses[:3], 1):
|
926 |
-
logger.info(f"{j}. Score: {score:.4f}")
|
927 |
-
logger.info(f" Response: {response}")
|
928 |
-
if j == 1 and not quality_metrics.get('is_confident', False):
|
929 |
-
logger.info(" [Low Confidence - Would abstain from answering]")
|
930 |
-
|
931 |
-
# Calculate aggregate metrics
|
932 |
-
aggregate_metrics = {
|
933 |
-
'num_queries_tested': len(test_queries),
|
934 |
-
'avg_top_response_score': np.mean([m.get('top_score', 0) for m in metrics_history]),
|
935 |
-
'avg_diversity': np.mean([m.get('response_diversity', 0) for m in metrics_history]),
|
936 |
-
'avg_relevance': np.mean([m.get('query_response_relevance', 0) for m in metrics_history]),
|
937 |
-
'avg_length_score': np.mean([m.get('response_length_score', 0) for m in metrics_history]),
|
938 |
-
'avg_score_gap': np.mean([m.get('top_3_score_gap', 0) for m in metrics_history]),
|
939 |
-
'confidence_rate': np.mean([m.get('is_confident', False) for m in metrics_history]),
|
940 |
-
}
|
941 |
-
|
942 |
-
logger.info("\n=== Validation Summary ===")
|
943 |
-
for metric, value in aggregate_metrics.items():
|
944 |
-
logger.info(f"{metric}: {value:.4f}")
|
945 |
-
|
946 |
-
return aggregate_metrics
|
947 |
|
948 |
def chat(
|
949 |
self,
|
950 |
query: str,
|
951 |
conversation_history: Optional[List[Tuple[str, str]]] = None,
|
952 |
quality_checker: Optional['ResponseQualityChecker'] = None,
|
953 |
-
top_k: int = 5
|
954 |
) -> Tuple[str, List[Tuple[str, float]], Dict[str, Any]]:
|
955 |
"""
|
956 |
-
|
957 |
-
|
958 |
-
Args:
|
959 |
-
query (str): The user's input query.
|
960 |
-
conversation_history (Optional[List[Tuple[str, str]]]): List of past (user, assistant) exchanges.
|
961 |
-
quality_checker (Optional['ResponseQualityChecker']): Quality checker instance.
|
962 |
-
top_k (int): Number of top responses to retrieve.
|
963 |
-
|
964 |
-
Returns:
|
965 |
-
Tuple[str, List[Tuple[str, float]], Dict[str, Any]]: (Response, Candidates, Quality Metrics)
|
966 |
"""
|
967 |
-
|
968 |
-
|
969 |
-
|
970 |
-
|
971 |
-
if quality_checker is None:
|
972 |
-
return responses[0][0] if responses else "I'm sorry, I don't have an answer for that.", responses, {}
|
973 |
-
|
974 |
-
# Check quality
|
975 |
-
quality_metrics = quality_checker.check_response_quality(query, responses)
|
976 |
-
|
977 |
-
if quality_metrics.get('is_confident', False):
|
978 |
-
return responses[0][0], responses, quality_metrics
|
979 |
-
else:
|
980 |
-
uncertainty_response = (
|
981 |
-
"I apologize, but I don't feel confident providing an answer to that "
|
982 |
-
"question at the moment. Could you please rephrase or ask something else?"
|
983 |
-
)
|
984 |
-
return uncertainty_response, responses, quality_metrics
|
985 |
-
|
986 |
-
# TODO: consider removal
|
987 |
-
# def prepare_dataset(self, dialogues: List[dict], neg_samples_per_pos: int = 1, debug_samples: int = None) -> Tuple[tf.Tensor, tf.Tensor, tf.Tensor]:
|
988 |
-
# """Prepares the dataset for training."""
|
989 |
-
# logger.info("Preparing dataset...")
|
990 |
-
|
991 |
-
# # Extract (query, positive, negative) triples
|
992 |
-
# queries, positives, negatives = [], [], []
|
993 |
-
|
994 |
-
# for dialogue in dialogues:
|
995 |
-
# turns = dialogue.get('turns', [])
|
996 |
-
# for i in range(len(turns) - 1):
|
997 |
-
# current_turn = turns[i]
|
998 |
-
# next_turn = turns[i+1]
|
999 |
-
|
1000 |
-
# if (current_turn.get('speaker') == 'user' and
|
1001 |
-
# next_turn.get('speaker') == 'assistant' and
|
1002 |
-
# 'text' in current_turn and
|
1003 |
-
# 'text' in next_turn):
|
1004 |
-
|
1005 |
-
# query = current_turn['text'].strip()
|
1006 |
-
# positive = next_turn['text'].strip()
|
1007 |
-
|
1008 |
-
# # Generate hard negative samples
|
1009 |
-
# hard_negatives = self.hard_negative_sampling(positive, n_samples=neg_samples_per_pos)
|
1010 |
-
# for negative in hard_negatives:
|
1011 |
-
# negatives.append(negative)
|
1012 |
-
# queries.append(query)
|
1013 |
-
# positives.append(positive)
|
1014 |
-
|
1015 |
-
# logger.info(f"Prepared {len(queries)} training examples.")
|
1016 |
-
|
1017 |
-
# # Tokenize and pad sequences
|
1018 |
-
# encoded_queries = self.tokenizer(
|
1019 |
-
# queries,
|
1020 |
-
# padding='max_length',
|
1021 |
-
# truncation=True,
|
1022 |
-
# max_length=self.config.max_sequence_length,
|
1023 |
-
# return_tensors='tf'
|
1024 |
-
# )
|
1025 |
-
# encoded_positives = self.tokenizer(
|
1026 |
-
# positives,
|
1027 |
-
# padding='max_length',
|
1028 |
-
# truncation=True,
|
1029 |
-
# max_length=self.config.max_sequence_length,
|
1030 |
-
# return_tensors='tf'
|
1031 |
-
# )
|
1032 |
-
# encoded_negatives = self.tokenizer(
|
1033 |
-
# negatives,
|
1034 |
-
# padding='max_length',
|
1035 |
-
# truncation=True,
|
1036 |
-
# max_length=self.config.max_sequence_length,
|
1037 |
-
# return_tensors='tf'
|
1038 |
-
# )
|
1039 |
-
|
1040 |
-
# q_tensor = encoded_queries['input_ids']
|
1041 |
-
# p_tensor = encoded_positives['input_ids']
|
1042 |
-
# n_tensor = encoded_negatives['input_ids']
|
1043 |
-
|
1044 |
-
# logger.info(f"Tokenized and padded sequences.")
|
1045 |
-
|
1046 |
-
# return q_tensor, p_tensor, n_tensor
|
1047 |
-
|
1048 |
-
|
1049 |
-
# # TODO: consider removal
|
1050 |
-
# def hard_negative_sampling(self, positive_response, n_samples=1):
|
1051 |
-
# """Select hard negatives based on cosine similarity."""
|
1052 |
-
# try:
|
1053 |
-
# # Ensure we don't request more negatives than available
|
1054 |
-
# max_neg_samples = len(self.response_pool) - 1 # Exclude the positive response
|
1055 |
-
# n_samples = min(n_samples, max_neg_samples)
|
1056 |
-
|
1057 |
-
# if n_samples <= 0:
|
1058 |
-
# logger.error("Not enough responses to sample negatives.")
|
1059 |
-
# return []
|
1060 |
-
|
1061 |
-
# # Encode the positive response using the chatbot's encode_responses method
|
1062 |
-
# pos_emb = self.encode_responses([positive_response]).numpy()
|
1063 |
-
# faiss.normalize_L2(pos_emb)
|
1064 |
-
# #logger.info(f"Normalized positive embedding for response: {positive_response}")
|
1065 |
-
|
1066 |
-
# # Search for the top n_samples + 1 most similar responses (including the positive itself)
|
1067 |
-
# D, I = self.index.search(pos_emb, n_samples + 1)
|
1068 |
-
# #logger.info(f"FAISS search results: {I}")
|
1069 |
-
|
1070 |
-
# # Exclude the positive response itself (assuming it's indexed)
|
1071 |
-
# negatives = []
|
1072 |
-
# for i in range(n_samples):
|
1073 |
-
# idx = I[0][i + 1] # Skip the first one as it's the positive
|
1074 |
-
# if idx < len(self.response_pool):
|
1075 |
-
# negative_response = self.response_pool[idx]
|
1076 |
-
# negatives.append(negative_response)
|
1077 |
-
# logger.info(f"Selected negative: {negative_response}")
|
1078 |
-
# else:
|
1079 |
-
# logger.warning(f"Index {idx} out of range for response_pool with size {len(self.response_pool)}.")
|
1080 |
|
1081 |
-
|
1082 |
-
|
1083 |
-
|
1084 |
-
|
1085 |
-
|
1086 |
-
|
1087 |
-
|
1088 |
-
|
1089 |
-
|
1090 |
-
|
1091 |
-
|
1092 |
-
|
1093 |
-
|
1094 |
-
|
1095 |
-
|
1096 |
-
|
1097 |
-
|
1098 |
-
|
1099 |
-
|
1100 |
-
|
1101 |
-
|
1102 |
-
|
1103 |
-
|
1104 |
-
|
1105 |
-
|
1106 |
-
|
1107 |
-
# checkpoint_dir (str): Directory to save model checkpoints.
|
1108 |
-
# callbacks (list, optional): List of Keras callbacks.
|
1109 |
-
# """
|
1110 |
-
# dataset_size = tf.shape(q_pad)[0].numpy()
|
1111 |
-
# val_size = int(dataset_size * validation_split)
|
1112 |
-
# train_size = dataset_size - val_size
|
1113 |
-
|
1114 |
-
# logger.info(f"Total samples: {dataset_size}")
|
1115 |
-
# logger.info(f"Training samples: {train_size}")
|
1116 |
-
# logger.info(f"Validation samples: {val_size}")
|
1117 |
-
|
1118 |
-
# # Calculate steps_per_epoch
|
1119 |
-
# steps_per_epoch = train_size // batch_size
|
1120 |
-
# if train_size % batch_size != 0:
|
1121 |
-
# steps_per_epoch += 1
|
1122 |
-
# total_steps = steps_per_epoch * epochs
|
1123 |
-
|
1124 |
-
# logger.info(f"Total training steps: {total_steps}")
|
1125 |
-
|
1126 |
-
# # Initialize learning rate schedule with adjusted warmup_steps
|
1127 |
-
# lr_schedule = self._get_lr_schedule(
|
1128 |
-
# total_steps=total_steps,
|
1129 |
-
# peak_lr=self.config.learning_rate,
|
1130 |
-
# warmup_steps=self.config.warmup_steps
|
1131 |
-
# )
|
1132 |
-
|
1133 |
-
# # callbacks = []
|
1134 |
-
# # if checkpoint_dir:
|
1135 |
-
# # checkpoint_dir = Path(checkpoint_dir)
|
1136 |
-
# # checkpoint_dir.mkdir(parents=True, exist_ok=True)
|
1137 |
-
|
1138 |
-
# # # Setup checkpoint callback with correct file format
|
1139 |
-
# # checkpoint_template = str(checkpoint_dir / "model_epoch_{epoch:04d}.weights.h5")
|
1140 |
-
# # checkpoint_callback = tf.keras.callbacks.ModelCheckpoint(
|
1141 |
-
# # checkpoint_template,
|
1142 |
-
# # save_weights_only=True,
|
1143 |
-
# # save_best_only=True,
|
1144 |
-
# # monitor='val_loss',
|
1145 |
-
# # mode='min',
|
1146 |
-
# # verbose=1
|
1147 |
-
# # )
|
1148 |
-
# # callbacks.append(checkpoint_callback)
|
1149 |
-
|
1150 |
-
# # # Early stopping callback
|
1151 |
-
# # early_stopping = tf.keras.callbacks.EarlyStopping(
|
1152 |
-
# # monitor='val_loss',
|
1153 |
-
# # patience=5,
|
1154 |
-
# # restore_best_weights=True,
|
1155 |
-
# # verbose=1
|
1156 |
-
# # )
|
1157 |
-
# # callbacks.append(early_stopping)
|
1158 |
|
1159 |
-
|
1160 |
-
# # tensorboard_callback = tf.keras.callbacks.TensorBoard(log_dir='./logs')
|
1161 |
-
# # callbacks.append(tensorboard_callback)
|
1162 |
-
|
1163 |
-
# # Update optimizer with the new learning rate schedule
|
1164 |
-
# self.optimizer = tf.keras.optimizers.Adam(learning_rate=lr_schedule)
|
1165 |
-
|
1166 |
-
# # Split the data
|
1167 |
-
# train_q = q_pad[:train_size]
|
1168 |
-
# train_p = p_pad[:train_size]
|
1169 |
-
# train_n = n_pad[:train_size]
|
1170 |
-
|
1171 |
-
# val_q = q_pad[train_size:]
|
1172 |
-
# val_p = p_pad[train_size:]
|
1173 |
-
# val_n = n_pad[train_size:]
|
1174 |
|
1175 |
-
|
1176 |
-
|
1177 |
-
|
1178 |
-
|
1179 |
-
|
1180 |
-
|
1181 |
-
|
1182 |
-
|
1183 |
-
|
1184 |
-
|
1185 |
-
|
1186 |
-
# # Create checkpoint manager
|
1187 |
-
# checkpoint = tf.train.Checkpoint(optimizer=self.optimizer, model=self.encoder)
|
1188 |
-
# manager = tf.train.CheckpointManager(checkpoint, checkpoint_dir, max_to_keep=3)
|
1189 |
-
|
1190 |
-
# for epoch in range(1, epochs + 1):
|
1191 |
-
# logger.info(f"Epoch {epoch}/{epochs}")
|
1192 |
-
# epoch_loss_avg = tf.keras.metrics.Mean()
|
1193 |
|
1194 |
-
|
1195 |
-
|
1196 |
-
|
1197 |
-
|
|
|
|
|
1198 |
|
1199 |
-
|
1200 |
-
|
1201 |
-
|
1202 |
-
|
1203 |
-
|
1204 |
-
|
1205 |
-
|
1206 |
-
|
1207 |
-
|
1208 |
-
|
1209 |
-
|
1210 |
-
|
1211 |
-
|
1212 |
-
|
1213 |
-
|
1214 |
-
|
1215 |
-
|
1216 |
-
|
1217 |
-
|
1218 |
-
|
1219 |
-
|
1220 |
-
|
1221 |
-
|
1222 |
-
|
1223 |
-
|
1224 |
-
|
1225 |
-
# train_loss = epoch_loss_avg.result().numpy()
|
1226 |
-
# val_loss = val_loss_avg.result().numpy()
|
1227 |
-
# logger.info(f"Epoch {epoch}: Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}")
|
1228 |
|
1229 |
-
|
1230 |
-
|
1231 |
|
1232 |
-
|
1233 |
-
|
1234 |
-
|
1235 |
-
|
1236 |
-
|
1237 |
-
|
1238 |
-
|
1239 |
-
|
1240 |
-
|
1241 |
-
|
1242 |
-
|
1243 |
-
|
1244 |
-
|
1245 |
-
|
1246 |
-
|
1247 |
-
|
1248 |
-
|
1249 |
-
|
1250 |
-
|
1251 |
-
|
1252 |
-
|
1253 |
-
|
1254 |
-
|
1255 |
-
|
1256 |
-
|
1257 |
-
|
1258 |
-
|
1259 |
-
|
1260 |
-
|
1261 |
-
|
1262 |
-
|
1263 |
-
|
1264 |
-
|
1265 |
-
|
1266 |
-
|
1267 |
-
|
1268 |
-
|
1269 |
-
|
1270 |
-
|
1271 |
-
|
1272 |
-
|
1273 |
-
|
1274 |
-
|
1275 |
-
|
1276 |
-
|
1277 |
-
|
1278 |
-
|
1279 |
-
|
1280 |
-
|
1281 |
-
|
1282 |
-
|
1283 |
-
|
1284 |
-
|
1285 |
-
|
1286 |
-
|
1287 |
-
|
1288 |
-
|
1289 |
-
|
1290 |
-
|
1291 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
2 |
import tensorflow as tf
|
3 |
import numpy as np
|
4 |
from typing import List, Tuple, Dict, Optional, Union, Any
|
5 |
+
import math
|
6 |
from dataclasses import dataclass
|
|
|
7 |
import json
|
8 |
from tqdm import tqdm
|
9 |
from pathlib import Path
|
10 |
+
import datetime
|
11 |
import faiss
|
12 |
from response_quality_checker import ResponseQualityChecker
|
13 |
+
from cross_encoder_reranker import CrossEncoderReranker
|
14 |
+
from conversation_summarizer import DeviceAwareModel, Summarizer
|
15 |
+
from logger_config import config_logger
|
16 |
+
logger = config_logger(__name__)
|
|
|
|
|
|
|
|
|
|
|
|
|
17 |
|
18 |
@dataclass
|
19 |
class ChatbotConfig:
|
20 |
"""Configuration for the RetrievalChatbot."""
|
21 |
vocab_size: int = 30526 # DistilBERT vocab size
|
22 |
+
max_context_token_limit: int = 512
|
23 |
+
embedding_dim: int = 512 # Match DistilBERT's dimension
|
24 |
encoder_units: int = 256
|
25 |
num_attention_heads: int = 8
|
26 |
dropout_rate: float = 0.2
|
|
|
65 |
# Freeze pretrained weights if specified
|
66 |
self.pretrained.distilbert.embeddings.trainable = False
|
67 |
for i, layer_module in enumerate(self.pretrained.distilbert.transformer.layer):
|
68 |
+
if i < 1: # freeze first layer
|
69 |
layer_module.trainable = False
|
70 |
else:
|
71 |
layer_module.trainable = True
|
72 |
|
73 |
# Pooling layer (Global Average Pooling)
|
74 |
self.pooler = tf.keras.layers.GlobalAveragePooling1D()
|
75 |
+
|
76 |
+
# Projection layer
|
77 |
+
self.projection = tf.keras.layers.Dense(
|
78 |
+
config.embedding_dim,
|
79 |
+
activation='tanh',
|
80 |
+
name="projection"
|
81 |
+
)
|
82 |
|
83 |
# Dropout and normalization
|
84 |
self.dropout = tf.keras.layers.Dropout(config.dropout_rate)
|
|
|
92 |
pretrained_outputs = self.pretrained(inputs, training=training)
|
93 |
x = pretrained_outputs.last_hidden_state # Shape: [batch_size, seq_len, embedding_dim]
|
94 |
|
95 |
+
# Apply pooling, projection, dropout, and normalization
|
96 |
+
x = self.pooler(x) # Shape: [batch_size, 768]
|
97 |
+
x = self.projection(x) # Shape: [batch_size, 512]
|
98 |
+
x = self.dropout(x, training=training) # Apply dropout
|
99 |
+
x = self.normalize(x) # Shape: [batch_size, 512]
|
|
|
|
|
|
|
100 |
|
101 |
return x
|
102 |
|
|
|
109 |
"name": self.name
|
110 |
})
|
111 |
return config
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
112 |
|
113 |
+
class RetrievalChatbot(DeviceAwareModel):
|
114 |
"""Retrieval-based chatbot using pretrained embeddings and FAISS for similarity search."""
|
115 |
+
def __init__(self, config: ChatbotConfig, dialogues: List[dict] = [], device: str = None, strategy=None, reranker: Optional[CrossEncoderReranker] = None, summarizer: Optional[Summarizer] = None):
|
116 |
self.config = config
|
117 |
+
self.strategy = strategy
|
118 |
+
self.setup_device(device)
|
119 |
+
|
120 |
+
if reranker is None:
|
121 |
+
logger.info("Creating default CrossEncoderReranker...")
|
122 |
+
reranker = CrossEncoderReranker(model_name="cross-encoder/ms-marco-MiniLM-L-12-v2")
|
123 |
+
self.reranker = reranker
|
124 |
+
|
125 |
+
if summarizer is None:
|
126 |
+
logger.info("Creating default Summarizer...")
|
127 |
+
summarizer = Summarizer(device=self.device)
|
128 |
+
self.summarizer = summarizer
|
129 |
+
|
130 |
+
# Configure XLA optimization if on GPU/TPU
|
131 |
+
if self.device in ["GPU", "TPU"]:
|
132 |
+
tf.config.optimizer.set_jit(True)
|
133 |
+
logger.info(f"XLA compilation enabled for {self.device}")
|
134 |
+
|
135 |
+
# Configure mixed precision for GPU/TPU
|
136 |
+
if self.device != "CPU":
|
137 |
+
policy = tf.keras.mixed_precision.Policy('mixed_float16')
|
138 |
+
tf.keras.mixed_precision.set_global_policy(policy)
|
139 |
+
logger.info("Mixed precision training enabled (float16)")
|
140 |
|
141 |
# Special tokens
|
142 |
self.special_tokens = {
|
|
|
152 |
{'additional_special_tokens': list(self.special_tokens.values())}
|
153 |
)
|
154 |
|
155 |
+
# Build encoders within device strategy scope
|
156 |
+
if self.strategy:
|
157 |
+
with self.strategy.scope():
|
158 |
+
self._build_models()
|
159 |
+
else:
|
160 |
+
self._build_models()
|
161 |
|
162 |
# Initialize FAISS index
|
163 |
self._initialize_faiss()
|
|
|
188 |
self.encoder.pretrained.resize_token_embeddings(new_vocab_size)
|
189 |
logger.info(f"Token embeddings resized to: {new_vocab_size}")
|
190 |
|
191 |
+
# Debug embeddings attributes
|
192 |
logger.info("Inspecting embeddings attributes:")
|
193 |
for attr in dir(self.encoder.pretrained.distilbert.embeddings):
|
194 |
if not attr.startswith('_'):
|
195 |
logger.info(f" {attr}")
|
196 |
|
197 |
+
# Try different ways to get embedding dimension
|
198 |
+
try:
|
199 |
+
# First try: from config
|
200 |
+
embedding_dim = self.encoder.pretrained.config.dim
|
201 |
+
logger.info("Got embedding dim from config")
|
202 |
+
except AttributeError:
|
203 |
+
try:
|
204 |
+
# Second try: from word embeddings
|
205 |
+
embedding_dim = self.encoder.pretrained.distilbert.embeddings.word_embeddings.embedding_dim
|
206 |
+
logger.info("Got embedding dim from word embeddings")
|
207 |
+
except AttributeError:
|
208 |
+
try:
|
209 |
+
# Third try: from embeddings module
|
210 |
+
embedding_dim = self.encoder.pretrained.distilbert.embeddings.embedding_dim
|
211 |
+
logger.info("Got embedding dim from embeddings module")
|
212 |
+
except AttributeError:
|
213 |
+
# Fallback to config value
|
214 |
+
embedding_dim = self.config.embedding_dim
|
215 |
+
logger.info("Using config embedding dim")
|
216 |
+
|
217 |
+
vocab_size = len(self.tokenizer)
|
218 |
+
|
219 |
logger.info(f"Encoder Embedding Dimension: {embedding_dim}")
|
220 |
logger.info(f"Encoder Embedding Vocabulary Size: {vocab_size}")
|
221 |
+
if vocab_size >= embedding_dim:
|
222 |
+
logger.info("Encoder model built and embeddings resized successfully.")
|
223 |
+
else:
|
224 |
+
logger.error("Vocabulary size is less than embedding dimension.")
|
225 |
+
raise ValueError("Vocabulary size is less than embedding dimension.")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
226 |
|
227 |
def _initialize_faiss(self):
|
228 |
"""Initialize FAISS index based on available resources."""
|
|
|
244 |
self.index = faiss.IndexFlatIP(self.config.embedding_dim)
|
245 |
logger.info("FAISS index initialized.")
|
246 |
|
247 |
+
def verify_faiss_index(self):
|
248 |
"""Verify that FAISS index matches the response pool."""
|
249 |
+
indexed_size = self.index.ntotal
|
250 |
+
pool_size = len(self.response_pool)
|
251 |
logger.info(f"FAISS index size: {indexed_size}")
|
252 |
logger.info(f"Response pool size: {pool_size}")
|
253 |
if indexed_size != pool_size:
|
|
|
273 |
logger.info(f"Found {len(unique_responses)} unique responses.")
|
274 |
|
275 |
# Encode responses
|
276 |
+
logger.info("Encoding unique responses")
|
277 |
response_embeddings = self.encode_responses(unique_responses)
|
278 |
response_embeddings = response_embeddings.numpy()
|
279 |
|
280 |
# Ensure float32
|
281 |
if response_embeddings.dtype != np.float32:
|
|
|
282 |
response_embeddings = response_embeddings.astype('float32')
|
283 |
|
284 |
# Ensure the array is contiguous in memory
|
|
|
317 |
Returns:
|
318 |
tf.Tensor: Tensor of shape (N, emb_dim) with all response embeddings.
|
319 |
"""
|
320 |
+
# Accumulate embeddings in a list and concatenate at the end
|
|
|
|
|
321 |
all_embeddings = []
|
|
|
|
|
|
|
|
|
322 |
|
323 |
# Process the responses in chunks of 'batch_size'
|
324 |
for start_idx in range(0, len(responses), batch_size):
|
|
|
330 |
batch_texts,
|
331 |
padding='max_length',
|
332 |
truncation=True,
|
333 |
+
max_length=self.config.max_context_token_limit,
|
334 |
return_tensors='tf',
|
335 |
)
|
336 |
|
|
|
345 |
# Collect
|
346 |
all_embeddings.append(embeddings_batch)
|
347 |
|
|
|
|
|
|
|
|
|
|
|
348 |
# Concatenate all batch embeddings along axis=0
|
349 |
if len(all_embeddings) == 1:
|
350 |
# Only one batch
|
|
|
353 |
# Multiple batches, concatenate
|
354 |
final_embeddings = tf.concat(all_embeddings, axis=0)
|
355 |
|
|
|
|
|
|
|
|
|
356 |
return final_embeddings
|
357 |
|
358 |
def encode_query(self, query: str, context: Optional[List[Tuple[str, str]]] = None) -> tf.Tensor:
|
|
|
373 |
[query],
|
374 |
padding='max_length',
|
375 |
truncation=True,
|
376 |
+
max_length=self.config.max_context_token_limit,
|
377 |
return_tensors='tf'
|
378 |
)
|
379 |
input_ids = encodings['input_ids']
|
|
|
381 |
# Verify token IDs
|
382 |
max_id = tf.reduce_max(input_ids).numpy()
|
383 |
new_vocab_size = len(self.tokenizer)
|
|
|
384 |
|
385 |
if max_id >= new_vocab_size:
|
386 |
logger.error(f"Token ID {max_id} exceeds the vocabulary size {new_vocab_size}.")
|
|
|
388 |
|
389 |
# Get embeddings from the shared encoder
|
390 |
return self.encoder(input_ids, training=False)
|
391 |
+
|
392 |
+
def retrieve_responses_cross_encoder(
|
393 |
+
self,
|
394 |
+
query: str,
|
395 |
+
top_k: int,
|
396 |
+
reranker: Optional[CrossEncoderReranker] = None,
|
397 |
+
summarizer: Optional[Summarizer] = None,
|
398 |
+
summarize_threshold: int = 512 # Summarize over 512 tokens
|
399 |
+
) -> List[Tuple[str, float]]:
|
400 |
+
"""
|
401 |
+
Retrieve top-k from FAISS, then re-rank them with a cross-encoder.
|
402 |
+
Optionally summarize the user query if it's too long.
|
403 |
+
"""
|
404 |
+
if reranker is None:
|
405 |
+
reranker = self.reranker
|
406 |
+
if summarizer is None:
|
407 |
+
summarizer = self.summarizer
|
408 |
+
|
409 |
+
# Optional summarization
|
410 |
+
if summarizer and len(query.split()) > summarize_threshold:
|
411 |
+
logger.info(f"Query is long. Summarizing before cross-encoder. Original length: {len(query.split())}")
|
412 |
+
query = summarizer.summarize_text(query)
|
413 |
+
logger.info(f"Summarized query: {query}")
|
414 |
+
|
415 |
+
# 2) Dense retrieval
|
416 |
+
dense_topk = self.retrieve_responses_faiss(query, top_k=top_k) # [(resp, dense_score), ...]
|
417 |
+
|
418 |
+
if not dense_topk:
|
419 |
+
return []
|
420 |
+
|
421 |
+
# 3) Cross-encoder rerank
|
422 |
+
candidate_texts = [pair[0] for pair in dense_topk]
|
423 |
+
cross_scores = reranker.rerank(query, candidate_texts, max_length=256)
|
424 |
+
|
425 |
+
# Combine
|
426 |
+
combined = [(text, score) for (text, _), score in zip(dense_topk, cross_scores)]
|
427 |
+
# Sort descending by cross-encoder score
|
428 |
+
combined.sort(key=lambda x: x[1], reverse=True)
|
429 |
+
|
430 |
+
return combined
|
431 |
|
432 |
def retrieve_responses_faiss(self, query: str, top_k: int = 5) -> List[Tuple[str, float]]:
|
433 |
"""Retrieve top-k responses using FAISS."""
|
|
|
485 |
load_dir / "shared_encoder",
|
486 |
config=config
|
487 |
)
|
|
|
|
|
|
|
|
|
488 |
|
489 |
# Load tokenizer
|
490 |
chatbot.tokenizer = AutoTokenizer.from_pretrained(load_dir / "tokenizer")
|
|
|
523 |
def prepare_dataset(
|
524 |
self,
|
525 |
dialogues: List[dict],
|
526 |
+
neg_samples: int = 1,
|
527 |
debug_samples: int = None
|
528 |
) -> Tuple[tf.Tensor, tf.Tensor]:
|
529 |
"""
|
530 |
+
Prepares dataset for multiple-negatives ranking,
|
531 |
+
but also appends 'hard negative' pairs for each query.
|
532 |
+
|
533 |
+
We'll generate:
|
534 |
+
- (query, positive) as usual
|
535 |
+
- (query, negative) for each query, using FAISS top-1 approx. negative.
|
536 |
+
Then, in-batch training sees them as 'two different positives'
|
537 |
+
for the same query, forcing the model to discriminate them.
|
538 |
"""
|
539 |
+
|
540 |
+
logger.info("Preparing in-batch dataset with hard negatives...")
|
541 |
|
542 |
queries, positives = [], []
|
543 |
|
544 |
+
# Assemble (q, p)
|
545 |
for dialogue in dialogues:
|
546 |
turns = dialogue.get('turns', [])
|
547 |
for i in range(len(turns) - 1):
|
548 |
current_turn = turns[i]
|
549 |
next_turn = turns[i+1]
|
550 |
|
551 |
+
if (current_turn.get('speaker') == 'user'
|
552 |
+
and next_turn.get('speaker') == 'assistant'
|
553 |
+
and 'text' in current_turn
|
554 |
+
and 'text' in next_turn):
|
555 |
|
556 |
+
query_text = current_turn['text'].strip()
|
557 |
+
pos_text = next_turn['text'].strip()
|
558 |
|
559 |
+
queries.append(query_text)
|
560 |
+
positives.append(pos_text)
|
561 |
|
562 |
+
# Debug slicing
|
563 |
if debug_samples is not None:
|
564 |
queries = queries[:debug_samples]
|
565 |
positives = positives[:debug_samples]
|
566 |
logger.info(f"Debug mode: limited to {debug_samples} pairs.")
|
567 |
|
568 |
+
logger.info(f"Prepared {len(queries)} (query, positive) pairs initially.")
|
569 |
+
|
570 |
+
# Find a hard negative from FAISS for each (q, p)
|
571 |
+
# Create a second 'positive' row => (q, negative). In-batch, it's seen as a different 'positive' row, but is a hard negative.
|
572 |
+
augmented_queries = []
|
573 |
+
augmented_positives = []
|
574 |
+
|
575 |
+
for q_text, p_text in zip(queries, positives):
|
576 |
+
neg_texts = self._find_hard_negative(q_text, p_text, top_k=5, neg_samples=neg_samples)
|
577 |
+
for neg_text in neg_texts:
|
578 |
+
augmented_queries.append(q_text)
|
579 |
+
augmented_positives.append(neg_text)
|
580 |
+
|
581 |
+
logger.info(f"Found hard negatives for {len(augmented_queries)} queries.")
|
582 |
|
583 |
+
# Combine them into a single big list -> Original pairs: (q, p) & Hard neg pairs: (q, n)
|
584 |
+
final_queries = queries + augmented_queries
|
585 |
+
final_positives = positives + augmented_positives
|
586 |
+
logger.info(f"Total dataset size after adding hard neg: {len(final_queries)}")
|
587 |
+
|
588 |
+
# Tokenize
|
589 |
encoded_queries = self.tokenizer(
|
590 |
+
final_queries,
|
591 |
padding='max_length',
|
592 |
truncation=True,
|
593 |
+
max_length=self.config.max_context_token_limit,
|
594 |
return_tensors='tf'
|
595 |
)
|
|
|
596 |
encoded_positives = self.tokenizer(
|
597 |
+
final_positives,
|
598 |
padding='max_length',
|
599 |
truncation=True,
|
600 |
+
max_length=self.config.max_context_token_limit,
|
601 |
return_tensors='tf'
|
602 |
)
|
603 |
|
604 |
q_tensor = encoded_queries['input_ids']
|
605 |
p_tensor = encoded_positives['input_ids']
|
606 |
|
607 |
+
logger.info("Tokenized and padded sequences for in-batch training + hard negatives.")
|
608 |
return q_tensor, p_tensor
|
609 |
|
610 |
+
def _find_hard_negative(
|
611 |
+
self,
|
612 |
+
query_text: str,
|
613 |
+
positive_text: str,
|
614 |
+
top_k: int = 5,
|
615 |
+
neg_samples: int = 1
|
616 |
+
) -> List[str]:
|
617 |
+
"""
|
618 |
+
Return up to `neg_samples` unique negatives from top_k FAISS results,
|
619 |
+
excluding the known positive_text.
|
620 |
+
"""
|
621 |
+
# Encode the query to get the embedding
|
622 |
+
query_emb = self.encode_query(query_text)
|
623 |
+
q_emb_np = query_emb.numpy().astype('float32')
|
624 |
+
|
625 |
+
# Normalize for cosine similarity
|
626 |
+
faiss.normalize_L2(q_emb_np)
|
627 |
+
|
628 |
+
# Search in FAISS
|
629 |
+
distances, indices = self.index.search(q_emb_np, top_k)
|
630 |
+
|
631 |
+
# Exclude the actual positive from these results
|
632 |
+
hard_negatives = []
|
633 |
+
for idx in indices[0]:
|
634 |
+
if idx < len(self.response_pool):
|
635 |
+
candidate = self.response_pool[idx].strip()
|
636 |
+
if candidate != positive_text.strip():
|
637 |
+
hard_negatives.append(candidate)
|
638 |
+
if len(hard_negatives) == neg_samples:
|
639 |
+
break
|
640 |
+
|
641 |
+
return hard_negatives
|
642 |
+
|
643 |
def train(
|
644 |
self,
|
645 |
q_pad: tf.Tensor,
|
646 |
p_pad: tf.Tensor,
|
647 |
+
epochs: int = 20,
|
648 |
+
batch_size: int = 16,
|
649 |
+
validation_split: float = 0.2,
|
650 |
+
checkpoint_dir: str = "checkpoints/",
|
651 |
use_lr_schedule: bool = True,
|
652 |
peak_lr: float = 2e-5,
|
653 |
warmup_steps_ratio: float = 0.1,
|
654 |
early_stopping_patience: int = 3,
|
655 |
+
min_delta: float = 1e-4,
|
656 |
+
accum_steps: int = 2 # Gradient accumulation steps
|
657 |
):
|
658 |
dataset_size = tf.shape(q_pad)[0].numpy()
|
659 |
val_size = int(dataset_size * validation_split)
|
|
|
689 |
val_q = q_pad[train_size:]
|
690 |
val_p = p_pad[train_size:]
|
691 |
|
692 |
+
train_dataset = (tf.data.Dataset.from_tensor_slices((train_q, train_p))
|
693 |
+
.shuffle(4096)
|
694 |
+
.batch(batch_size)
|
695 |
+
.prefetch(tf.data.AUTOTUNE))
|
696 |
|
697 |
+
val_dataset = (tf.data.Dataset.from_tensor_slices((val_q, val_p))
|
698 |
+
.batch(batch_size)
|
699 |
+
.prefetch(tf.data.AUTOTUNE))
|
700 |
|
701 |
# 3) Checkpoint + manager
|
702 |
checkpoint = tf.train.Checkpoint(optimizer=self.optimizer, model=self.encoder)
|
703 |
manager = tf.train.CheckpointManager(checkpoint, checkpoint_dir, max_to_keep=3)
|
704 |
|
705 |
# 4) TensorBoard setup
|
|
|
|
|
|
|
|
|
706 |
log_dir = Path(checkpoint_dir) / "tensorboard_logs"
|
707 |
log_dir.mkdir(parents=True, exist_ok=True)
|
708 |
|
|
|
722 |
logger.info("Beginning training loop...")
|
723 |
global_step = 0
|
724 |
|
725 |
+
# Prepare zero-initialized accumulators for your trainable variables
|
726 |
+
# We'll accumulate gradients across mini-batches, then apply them every accum_steps.
|
727 |
+
train_vars = self.encoder.pretrained.trainable_variables
|
728 |
+
accum_grads = [tf.zeros_like(var, dtype=tf.float32) for var in train_vars]
|
729 |
+
|
730 |
from tqdm import tqdm
|
731 |
for epoch in range(1, epochs + 1):
|
732 |
logger.info(f"\n=== Epoch {epoch}/{epochs} ===")
|
733 |
epoch_loss_avg = tf.keras.metrics.Mean()
|
734 |
|
735 |
+
step_in_epoch = 0
|
736 |
with tqdm(total=steps_per_epoch, desc=f"Training Epoch {epoch}") as pbar:
|
737 |
for (q_batch, p_batch) in train_dataset:
|
738 |
+
step_in_epoch += 1
|
739 |
global_step += 1
|
740 |
|
741 |
+
with tf.GradientTape() as tape:
|
742 |
+
q_enc = self.encoder(q_batch, training=True)
|
743 |
+
p_enc = self.encoder(p_batch, training=True)
|
744 |
+
|
745 |
+
sim_matrix = tf.matmul(q_enc, p_enc, transpose_b=True)
|
746 |
+
bsz = tf.shape(q_enc)[0]
|
747 |
+
labels = tf.range(bsz, dtype=tf.int32)
|
748 |
+
loss_value = tf.nn.sparse_softmax_cross_entropy_with_logits(
|
749 |
+
labels=labels, logits=sim_matrix
|
750 |
+
)
|
751 |
+
loss_value = tf.reduce_mean(loss_value)
|
752 |
+
|
753 |
+
gradients = tape.gradient(loss_value, train_vars)
|
754 |
+
|
755 |
+
# -- Accumulate gradients --
|
756 |
+
for i, grad in enumerate(gradients):
|
757 |
+
if grad is not None:
|
758 |
+
accum_grads[i] += tf.cast(grad, tf.float32)
|
759 |
+
|
760 |
+
epoch_loss_avg(loss_value)
|
761 |
+
|
762 |
+
# -- Apply gradients every 'accum_steps' mini-batches --
|
763 |
+
if (step_in_epoch % accum_steps) == 0:
|
764 |
+
# Scale by 1/accum_steps so that each accumulation cycle
|
765 |
+
# is effectively the same as one “normal” update
|
766 |
+
for i in range(len(accum_grads)):
|
767 |
+
accum_grads[i] /= accum_steps
|
768 |
+
|
769 |
+
self.optimizer.apply_gradients(
|
770 |
+
[(accum_grads[i], train_vars[i]) for i in range(len(accum_grads))]
|
771 |
+
)
|
772 |
+
# Reset the accumulator
|
773 |
+
accum_grads = [tf.zeros_like(var, dtype=tf.float32) for var in train_vars]
|
774 |
+
|
775 |
+
# Logging / tqdm updates
|
776 |
if use_lr_schedule:
|
777 |
+
# measure current LR
|
778 |
lr = self.optimizer.learning_rate
|
779 |
if isinstance(lr, tf.keras.optimizers.schedules.LearningRateSchedule):
|
|
|
780 |
current_step = tf.cast(self.optimizer.iterations, tf.float32)
|
|
|
781 |
current_lr = lr(current_step)
|
782 |
else:
|
|
|
783 |
current_lr = lr
|
|
|
784 |
current_lr_value = float(current_lr.numpy())
|
785 |
else:
|
|
|
786 |
current_lr_value = float(self.optimizer.learning_rate.numpy())
|
787 |
|
|
|
788 |
pbar.update(1)
|
789 |
pbar.set_postfix({
|
790 |
+
"loss": f"{loss_value.numpy():.4f}",
|
791 |
"lr": f"{current_lr_value:.2e}"
|
792 |
})
|
793 |
|
794 |
+
# TensorBoard logging omitted for brevity...
|
795 |
+
|
796 |
+
# -- Handle leftover partial accumulation at epoch end --
|
797 |
+
leftover = (step_in_epoch % accum_steps)
|
798 |
+
if leftover != 0:
|
799 |
+
logger.info(f"Applying leftover accum_grads for partial batch group (size={leftover}).")
|
800 |
+
# If you want each leftover batch to contribute proportionally:
|
801 |
+
# multiply by leftover/accum_steps (this ensures leftover
|
802 |
+
# steps have the same "average" effect as a full accumulation cycle)
|
803 |
+
for i in range(len(accum_grads)):
|
804 |
+
accum_grads[i] *= float(leftover) / float(accum_steps)
|
805 |
+
|
806 |
+
self.optimizer.apply_gradients(
|
807 |
+
[(accum_grads[i], train_vars[i]) for i in range(len(accum_grads))]
|
808 |
+
)
|
809 |
+
accum_grads = [tf.zeros_like(var, dtype=tf.float32) for var in train_vars]
|
810 |
|
811 |
# Validation
|
812 |
val_loss_avg = tf.keras.metrics.Mean()
|
|
|
853 |
|
854 |
logger.info("In-batch training completed!")
|
855 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
856 |
def _get_lr_schedule(
|
857 |
self,
|
858 |
total_steps: int,
|
|
|
898 |
decay_factor = (step - self.warmup_steps) / decay_steps
|
899 |
decay_factor = tf.minimum(tf.maximum(0.0, decay_factor), 1.0) # Clip to [0,1]
|
900 |
|
901 |
+
cosine_decay = 0.5 * (1.0 + tf.cos(tf.constant(math.pi) * decay_factor))
|
902 |
decay_lr = self.min_lr + (self.peak_lr - self.min_lr) * cosine_decay
|
903 |
|
904 |
# Choose between warmup and decay
|
|
|
924 |
normalized_emb1 = emb1 / np.linalg.norm(emb1, axis=1, keepdims=True)
|
925 |
normalized_emb2 = emb2 / np.linalg.norm(emb2, axis=1, keepdims=True)
|
926 |
return np.dot(normalized_emb1, normalized_emb2.T)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
927 |
|
928 |
def chat(
|
929 |
self,
|
930 |
query: str,
|
931 |
conversation_history: Optional[List[Tuple[str, str]]] = None,
|
932 |
quality_checker: Optional['ResponseQualityChecker'] = None,
|
933 |
+
top_k: int = 5,
|
934 |
) -> Tuple[str, List[Tuple[str, float]], Dict[str, Any]]:
|
935 |
"""
|
936 |
+
Example chat method that always uses cross-encoder re-ranking
|
937 |
+
if self.reranker is available.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
938 |
"""
|
939 |
+
@self.run_on_device
|
940 |
+
def get_response(self_arg, query_arg): # Add parameters that match decorator's expectations
|
941 |
+
# 1) Build conversation context string
|
942 |
+
conversation_str = self_arg._build_conversation_context(query_arg, conversation_history)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
943 |
|
944 |
+
# 2) Retrieve + cross-encoder re-rank
|
945 |
+
results = self_arg.retrieve_responses_cross_encoder(
|
946 |
+
query=conversation_str,
|
947 |
+
top_k=top_k,
|
948 |
+
reranker=self_arg.reranker,
|
949 |
+
summarizer=self_arg.summarizer,
|
950 |
+
summarize_threshold=512
|
951 |
+
)
|
952 |
+
|
953 |
+
# 3) Handle empty or confidence
|
954 |
+
if not results:
|
955 |
+
return (
|
956 |
+
"I'm sorry, but I couldn't find a relevant response.",
|
957 |
+
[],
|
958 |
+
{}
|
959 |
+
)
|
960 |
+
|
961 |
+
if quality_checker:
|
962 |
+
metrics = quality_checker.check_response_quality(query_arg, results)
|
963 |
+
if not metrics.get('is_confident', False):
|
964 |
+
return (
|
965 |
+
"I need more information to provide a good answer. Could you please clarify?",
|
966 |
+
results,
|
967 |
+
metrics
|
968 |
+
)
|
969 |
+
return results[0][0], results, metrics
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
970 |
|
971 |
+
return results[0][0], results, {}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
972 |
|
973 |
+
return get_response(self, query)
|
974 |
+
|
975 |
+
def _build_conversation_context(
|
976 |
+
self,
|
977 |
+
query: str,
|
978 |
+
conversation_history: Optional[List[Tuple[str, str]]]
|
979 |
+
) -> str:
|
980 |
+
"""Build conversation context with better memory management."""
|
981 |
+
if not conversation_history:
|
982 |
+
return f"{self.special_tokens['user']} {query}"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
983 |
|
984 |
+
conversation_parts = []
|
985 |
+
for user_txt, assistant_txt in conversation_history:
|
986 |
+
conversation_parts.extend([
|
987 |
+
f"{self.special_tokens['user']} {user_txt}",
|
988 |
+
f"{self.special_tokens['assistant']} {assistant_txt}"
|
989 |
+
])
|
990 |
|
991 |
+
conversation_parts.append(f"{self.special_tokens['user']} {query}")
|
992 |
+
return "\n".join(conversation_parts)
|
993 |
+
|
994 |
+
# def prepare_dataset(
|
995 |
+
# self,
|
996 |
+
# dialogues: List[dict],
|
997 |
+
# debug_samples: int = None
|
998 |
+
# ) -> Tuple[tf.Tensor, tf.Tensor]:
|
999 |
+
# """
|
1000 |
+
# Prepares dataset for in-batch negatives:
|
1001 |
+
# Only returns (query, positive) pairs.
|
1002 |
+
# """
|
1003 |
+
# logger.info("Preparing in-batch dataset...")
|
1004 |
+
|
1005 |
+
# queries, positives = [], []
|
1006 |
+
|
1007 |
+
# for dialogue in dialogues:
|
1008 |
+
# turns = dialogue.get('turns', [])
|
1009 |
+
# for i in range(len(turns) - 1):
|
1010 |
+
# current_turn = turns[i]
|
1011 |
+
# next_turn = turns[i+1]
|
1012 |
+
|
1013 |
+
# if (current_turn.get('speaker') == 'user' and
|
1014 |
+
# next_turn.get('speaker') == 'assistant' and
|
1015 |
+
# 'text' in current_turn and
|
1016 |
+
# 'text' in next_turn):
|
|
|
|
|
|
|
1017 |
|
1018 |
+
# query = current_turn['text'].strip()
|
1019 |
+
# positive = next_turn['text'].strip()
|
1020 |
|
1021 |
+
# queries.append(query)
|
1022 |
+
# positives.append(positive)
|
1023 |
+
|
1024 |
+
# # Optional debug slicing
|
1025 |
+
# if debug_samples is not None:
|
1026 |
+
# queries = queries[:debug_samples]
|
1027 |
+
# positives = positives[:debug_samples]
|
1028 |
+
# logger.info(f"Debug mode: limited to {debug_samples} pairs.")
|
1029 |
+
|
1030 |
+
# logger.info(f"Prepared {len(queries)} (query, positive) pairs.")
|
1031 |
+
|
1032 |
+
# # Tokenize queries
|
1033 |
+
# encoded_queries = self.tokenizer(
|
1034 |
+
# queries,
|
1035 |
+
# padding='max_length',
|
1036 |
+
# truncation=True,
|
1037 |
+
# max_length=self.config.max_sequence_length,
|
1038 |
+
# return_tensors='tf'
|
1039 |
+
# )
|
1040 |
+
# # Tokenize positives
|
1041 |
+
# encoded_positives = self.tokenizer(
|
1042 |
+
# positives,
|
1043 |
+
# padding='max_length',
|
1044 |
+
# truncation=True,
|
1045 |
+
# max_length=self.config.max_sequence_length,
|
1046 |
+
# return_tensors='tf'
|
1047 |
+
# )
|
1048 |
+
|
1049 |
+
# q_tensor = encoded_queries['input_ids']
|
1050 |
+
# p_tensor = encoded_positives['input_ids']
|
1051 |
+
|
1052 |
+
# logger.info("Tokenized and padded sequences for in-batch training.")
|
1053 |
+
# return q_tensor, p_tensor
|
1054 |
+
|
1055 |
+
# def train(
|
1056 |
+
# self,
|
1057 |
+
# q_pad: tf.Tensor,
|
1058 |
+
# p_pad: tf.Tensor,
|
1059 |
+
# epochs: int = 20,
|
1060 |
+
# batch_size: int = 16,
|
1061 |
+
# validation_split: float = 0.2,
|
1062 |
+
# checkpoint_dir: str = "checkpoints/",
|
1063 |
+
# use_lr_schedule: bool = True,
|
1064 |
+
# peak_lr: float = 2e-5,
|
1065 |
+
# warmup_steps_ratio: float = 0.1,
|
1066 |
+
# early_stopping_patience: int = 3,
|
1067 |
+
# min_delta: float = 1e-4
|
1068 |
+
# ):
|
1069 |
+
# dataset_size = tf.shape(q_pad)[0].numpy()
|
1070 |
+
# val_size = int(dataset_size * validation_split)
|
1071 |
+
# train_size = dataset_size - val_size
|
1072 |
+
|
1073 |
+
# logger.info(f"Total samples: {dataset_size}")
|
1074 |
+
# logger.info(f"Training samples: {train_size}")
|
1075 |
+
# logger.info(f"Validation samples: {val_size}")
|
1076 |
+
|
1077 |
+
# steps_per_epoch = train_size // batch_size
|
1078 |
+
# if train_size % batch_size != 0:
|
1079 |
+
# steps_per_epoch += 1
|
1080 |
+
# total_steps = steps_per_epoch * epochs
|
1081 |
+
# logger.info(f"Total training steps (approx): {total_steps}")
|
1082 |
+
|
1083 |
+
# # 1) Set up LR schedule or fixed LR
|
1084 |
+
# if use_lr_schedule:
|
1085 |
+
# warmup_steps = int(total_steps * warmup_steps_ratio)
|
1086 |
+
# lr_schedule = self._get_lr_schedule(
|
1087 |
+
# total_steps=total_steps,
|
1088 |
+
# peak_lr=peak_lr,
|
1089 |
+
# warmup_steps=warmup_steps
|
1090 |
+
# )
|
1091 |
+
# self.optimizer = tf.keras.optimizers.Adam(learning_rate=lr_schedule)
|
1092 |
+
# logger.info("Using custom learning rate schedule.")
|
1093 |
+
# else:
|
1094 |
+
# self.optimizer = tf.keras.optimizers.Adam(learning_rate=peak_lr)
|
1095 |
+
# logger.info("Using fixed learning rate.")
|
1096 |
+
|
1097 |
+
# # 2) Prepare data splits
|
1098 |
+
# train_q = q_pad[:train_size]
|
1099 |
+
# train_p = p_pad[:train_size]
|
1100 |
+
# val_q = q_pad[train_size:]
|
1101 |
+
# val_p = p_pad[train_size:]
|
1102 |
+
|
1103 |
+
# train_dataset = tf.data.Dataset.from_tensor_slices((train_q, train_p))
|
1104 |
+
# train_dataset = train_dataset.shuffle(buffer_size=4096).batch(batch_size)
|
1105 |
+
|
1106 |
+
# val_dataset = tf.data.Dataset.from_tensor_slices((val_q, val_p))
|
1107 |
+
# val_dataset = val_dataset.batch(batch_size)
|
1108 |
+
|
1109 |
+
# # 3) Checkpoint + manager
|
1110 |
+
# checkpoint = tf.train.Checkpoint(optimizer=self.optimizer, model=self.encoder)
|
1111 |
+
# manager = tf.train.CheckpointManager(checkpoint, checkpoint_dir, max_to_keep=3)
|
1112 |
+
|
1113 |
+
# # 4) TensorBoard setup
|
1114 |
+
# log_dir = Path(checkpoint_dir) / "tensorboard_logs"
|
1115 |
+
# log_dir.mkdir(parents=True, exist_ok=True)
|
1116 |
+
|
1117 |
+
# current_time = datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
|
1118 |
+
# train_log_dir = str(log_dir / f"train_{current_time}")
|
1119 |
+
# val_log_dir = str(log_dir / f"val_{current_time}")
|
1120 |
+
|
1121 |
+
# train_summary_writer = tf.summary.create_file_writer(train_log_dir)
|
1122 |
+
# val_summary_writer = tf.summary.create_file_writer(val_log_dir)
|
1123 |
+
|
1124 |
+
# logger.info(f"TensorBoard logs will be saved in {log_dir}")
|
1125 |
+
|
1126 |
+
# # 5) Early stopping
|
1127 |
+
# best_val_loss = float("inf")
|
1128 |
+
# epochs_no_improve = 0
|
1129 |
+
|
1130 |
+
# logger.info("Beginning training loop...")
|
1131 |
+
# global_step = 0
|
1132 |
+
|
1133 |
+
# from tqdm import tqdm
|
1134 |
+
# for epoch in range(1, epochs + 1):
|
1135 |
+
# logger.info(f"\n=== Epoch {epoch}/{epochs} ===")
|
1136 |
+
# epoch_loss_avg = tf.keras.metrics.Mean()
|
1137 |
+
|
1138 |
+
# # Training loop
|
1139 |
+
# with tqdm(total=steps_per_epoch, desc=f"Training Epoch {epoch}") as pbar:
|
1140 |
+
# for (q_batch, p_batch) in train_dataset:
|
1141 |
+
# global_step += 1
|
1142 |
+
|
1143 |
+
# # Train step
|
1144 |
+
# batch_loss = self._train_step(q_batch, p_batch)
|
1145 |
+
# epoch_loss_avg(batch_loss)
|
1146 |
+
|
1147 |
+
# # Get current LR
|
1148 |
+
# if use_lr_schedule:
|
1149 |
+
# lr = self.optimizer.learning_rate
|
1150 |
+
# if isinstance(lr, tf.keras.optimizers.schedules.LearningRateSchedule):
|
1151 |
+
# # Get the current step
|
1152 |
+
# current_step = tf.cast(self.optimizer.iterations, tf.float32)
|
1153 |
+
# # Compute the current learning rate
|
1154 |
+
# current_lr = lr(current_step)
|
1155 |
+
# else:
|
1156 |
+
# # If learning_rate is not a schedule, use it directly
|
1157 |
+
# current_lr = lr
|
1158 |
+
# # Convert to float for logging
|
1159 |
+
# current_lr_value = float(current_lr.numpy())
|
1160 |
+
# else:
|
1161 |
+
# # If using fixed learning rate
|
1162 |
+
# current_lr_value = float(self.optimizer.learning_rate.numpy())
|
1163 |
+
|
1164 |
+
# # Update tqdm
|
1165 |
+
# pbar.update(1)
|
1166 |
+
# pbar.set_postfix({
|
1167 |
+
# "loss": f"{batch_loss.numpy():.4f}",
|
1168 |
+
# "lr": f"{current_lr_value:.2e}"
|
1169 |
+
# })
|
1170 |
+
|
1171 |
+
# # TensorBoard: log train metrics per step
|
1172 |
+
# with train_summary_writer.as_default():
|
1173 |
+
# tf.summary.scalar("loss", batch_loss, step=global_step)
|
1174 |
+
# tf.summary.scalar("learning_rate", current_lr_value, step=global_step)
|
1175 |
+
|
1176 |
+
# # Validation
|
1177 |
+
# val_loss_avg = tf.keras.metrics.Mean()
|
1178 |
+
# for q_val, p_val in val_dataset:
|
1179 |
+
# q_enc = self.encoder(q_val, training=False)
|
1180 |
+
# p_enc = self.encoder(p_val, training=False)
|
1181 |
+
# sim_matrix = tf.matmul(q_enc, p_enc, transpose_b=True)
|
1182 |
+
# bs_val = tf.shape(q_enc)[0]
|
1183 |
+
# labels_val = tf.range(bs_val, dtype=tf.int32)
|
1184 |
+
# loss_val = tf.nn.sparse_softmax_cross_entropy_with_logits(
|
1185 |
+
# labels=labels_val,
|
1186 |
+
# logits=sim_matrix
|
1187 |
+
# )
|
1188 |
+
# val_loss_avg(tf.reduce_mean(loss_val))
|
1189 |
+
|
1190 |
+
# train_loss = epoch_loss_avg.result().numpy()
|
1191 |
+
# val_loss = val_loss_avg.result().numpy()
|
1192 |
+
|
1193 |
+
# logger.info(f"Epoch {epoch} Complete: Train Loss={train_loss:.4f}, Val Loss={val_loss:.4f}")
|
1194 |
+
|
1195 |
+
# # TensorBoard: validation loss
|
1196 |
+
# with val_summary_writer.as_default():
|
1197 |
+
# tf.summary.scalar("val_loss", val_loss, step=epoch)
|
1198 |
+
|
1199 |
+
# # Save checkpoint
|
1200 |
+
# manager.save()
|
1201 |
+
|
1202 |
+
# # Update history
|
1203 |
+
# self.history['train_loss'].append(train_loss)
|
1204 |
+
# self.history['val_loss'].append(val_loss)
|
1205 |
+
# self.history.setdefault('learning_rate', []).append(float(current_lr_value))
|
1206 |
+
|
1207 |
+
# # Early stopping
|
1208 |
+
# if val_loss < best_val_loss - min_delta:
|
1209 |
+
# best_val_loss = val_loss
|
1210 |
+
# epochs_no_improve = 0
|
1211 |
+
# logger.info(f"Validation loss improved to {val_loss:.4f}. Reset patience.")
|
1212 |
+
# else:
|
1213 |
+
# epochs_no_improve += 1
|
1214 |
+
# logger.info(f"No improvement this epoch. Patience: {epochs_no_improve}/{early_stopping_patience}")
|
1215 |
+
# if epochs_no_improve >= early_stopping_patience:
|
1216 |
+
# logger.info("Early stopping triggered.")
|
1217 |
+
# break
|
1218 |
+
|
1219 |
+
# logger.info("In-batch training completed!")
|
1220 |
+
|
1221 |
+
# @tf.function
|
1222 |
+
# def _train_step(self, q_batch, p_batch):
|
1223 |
+
# """
|
1224 |
+
# Single training step using in-batch negatives.
|
1225 |
+
# q_batch: (batch_size, seq_len) int32 input_ids for queries
|
1226 |
+
# p_batch: (batch_size, seq_len) int32 input_ids for positives
|
1227 |
+
# """
|
1228 |
+
# with tf.GradientTape() as tape:
|
1229 |
+
# # Encode queries and positives
|
1230 |
+
# q_enc = self.encoder(q_batch, training=True) # [B, emb_dim]
|
1231 |
+
# p_enc = self.encoder(p_batch, training=True) # [B, emb_dim]
|
1232 |
+
|
1233 |
+
# # Compute similarity matrix: (B, B) = q_enc * p_enc^T
|
1234 |
+
# # If embeddings are L2-normalized, this is cosine similarity
|
1235 |
+
# sim_matrix = tf.matmul(q_enc, p_enc, transpose_b=True) # [B, B]
|
1236 |
+
|
1237 |
+
# # Labels are just the diagonal indices
|
1238 |
+
# batch_size = tf.shape(q_enc)[0]
|
1239 |
+
# labels = tf.range(batch_size, dtype=tf.int32) # [0..B-1]
|
1240 |
+
|
1241 |
+
# # Softmax cross-entropy
|
1242 |
+
# loss = tf.nn.sparse_softmax_cross_entropy_with_logits(
|
1243 |
+
# labels=labels,
|
1244 |
+
# logits=sim_matrix
|
1245 |
+
# )
|
1246 |
+
# loss = tf.reduce_mean(loss)
|
1247 |
+
|
1248 |
+
# # Compute gradients for the pretrained DistilBERT variables only
|
1249 |
+
# train_vars = self.encoder.pretrained.trainable_variables
|
1250 |
+
# gradients = tape.gradient(loss, train_vars)
|
1251 |
+
|
1252 |
+
# # Remove any None grads (in case some layers are frozen)
|
1253 |
+
# grads_and_vars = [(g, v) for g, v in zip(gradients, train_vars) if g is not None]
|
1254 |
+
# if grads_and_vars:
|
1255 |
+
# self.optimizer.apply_gradients(grads_and_vars)
|
1256 |
+
|
1257 |
+
# return loss
|
chatbot_validator.py
ADDED
@@ -0,0 +1,207 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Dict, List, Tuple, Any, Optional
|
2 |
+
import numpy as np
|
3 |
+
from logger_config import config_logger
|
4 |
+
|
5 |
+
logger = config_logger(__name__)
|
6 |
+
|
7 |
+
class ChatbotValidator:
|
8 |
+
"""Handles automated validation and performance analysis for the chatbot."""
|
9 |
+
|
10 |
+
def __init__(self, chatbot, quality_checker):
|
11 |
+
"""
|
12 |
+
Initialize the validator.
|
13 |
+
|
14 |
+
Args:
|
15 |
+
chatbot: RetrievalChatbot instance
|
16 |
+
quality_checker: ResponseQualityChecker instance
|
17 |
+
"""
|
18 |
+
self.chatbot = chatbot
|
19 |
+
self.quality_checker = quality_checker
|
20 |
+
|
21 |
+
# Domain-specific test queries aligned with Taskmaster-1 and Schema-Guided
|
22 |
+
self.domain_queries = {
|
23 |
+
'restaurant': [
|
24 |
+
"I'd like to make a reservation for dinner tonight.",
|
25 |
+
"Can you book a table for 4 people at an Italian place?",
|
26 |
+
"Do you have any availability for tomorrow at 7pm?",
|
27 |
+
"I need to change my dinner reservation time.",
|
28 |
+
"What's the wait time for a table right now?"
|
29 |
+
],
|
30 |
+
'movie_tickets': [
|
31 |
+
"I want to buy tickets for the new Marvel movie.",
|
32 |
+
"Are there any showings of Avatar after 6pm?",
|
33 |
+
"Can I get 3 tickets for the 8pm show?",
|
34 |
+
"What movies are playing this weekend?",
|
35 |
+
"Do you have any matinee showtimes available?"
|
36 |
+
],
|
37 |
+
'rideshare': [
|
38 |
+
"I need a ride from the airport to downtown.",
|
39 |
+
"How much would it cost to get to the mall?",
|
40 |
+
"Can you book a car for tomorrow morning?",
|
41 |
+
"Is there a driver available now?",
|
42 |
+
"What's the estimated arrival time?"
|
43 |
+
],
|
44 |
+
'services': [
|
45 |
+
"I need to schedule an oil change for my car.",
|
46 |
+
"When can I bring my car in for maintenance?",
|
47 |
+
"Do you have any openings for auto repair today?",
|
48 |
+
"How long will the service take?",
|
49 |
+
"Can I get an estimate for brake repair?"
|
50 |
+
],
|
51 |
+
'events': [
|
52 |
+
"I need tickets to the concert this weekend.",
|
53 |
+
"What events are happening near me?",
|
54 |
+
"Can I book seats for the basketball game?",
|
55 |
+
"Are there any comedy shows tonight?",
|
56 |
+
"How much are tickets to the theater?"
|
57 |
+
]
|
58 |
+
}
|
59 |
+
|
60 |
+
def run_validation(
|
61 |
+
self,
|
62 |
+
num_examples: int = 10,
|
63 |
+
top_k: int = 10,
|
64 |
+
domains: Optional[List[str]] = None
|
65 |
+
) -> Dict[str, Any]:
|
66 |
+
"""
|
67 |
+
Run comprehensive validation across specified domains.
|
68 |
+
|
69 |
+
Args:
|
70 |
+
num_examples: Number of test queries per domain
|
71 |
+
top_k: Number of responses to retrieve for each query
|
72 |
+
domains: Optional list of specific domains to test
|
73 |
+
|
74 |
+
Returns:
|
75 |
+
Dict containing detailed validation metrics and domain-specific performance
|
76 |
+
"""
|
77 |
+
logger.info("\n=== Running Enhanced Automatic Validation ===")
|
78 |
+
|
79 |
+
# Select domains to test
|
80 |
+
test_domains = domains if domains else list(self.domain_queries.keys())
|
81 |
+
metrics_history = []
|
82 |
+
domain_metrics = {}
|
83 |
+
|
84 |
+
# Run validation for each domain
|
85 |
+
for domain in test_domains:
|
86 |
+
domain_metrics[domain] = []
|
87 |
+
queries = self.domain_queries[domain][:num_examples]
|
88 |
+
|
89 |
+
logger.info(f"\n=== Testing {domain.title()} Domain ===")
|
90 |
+
|
91 |
+
for i, query in enumerate(queries, 1):
|
92 |
+
logger.info(f"\nTest Case {i}:")
|
93 |
+
logger.info(f"Query: {query}")
|
94 |
+
|
95 |
+
# Get responses with increased top_k
|
96 |
+
responses = self.chatbot.retrieve_responses_cross_encoder(query, top_k=top_k)
|
97 |
+
|
98 |
+
# Enhanced quality checking with context
|
99 |
+
quality_metrics = self.quality_checker.check_response_quality(query, responses)
|
100 |
+
|
101 |
+
# Add domain info
|
102 |
+
quality_metrics['domain'] = domain
|
103 |
+
metrics_history.append(quality_metrics)
|
104 |
+
domain_metrics[domain].append(quality_metrics)
|
105 |
+
|
106 |
+
# Detailed logging
|
107 |
+
self._log_validation_results(query, responses, quality_metrics, i)
|
108 |
+
|
109 |
+
# Calculate and log overall metrics
|
110 |
+
aggregate_metrics = self._calculate_aggregate_metrics(metrics_history)
|
111 |
+
domain_analysis = self._analyze_domain_performance(domain_metrics)
|
112 |
+
confidence_analysis = self._analyze_confidence_distribution(metrics_history)
|
113 |
+
|
114 |
+
aggregate_metrics.update({
|
115 |
+
'domain_performance': domain_analysis,
|
116 |
+
'confidence_analysis': confidence_analysis
|
117 |
+
})
|
118 |
+
|
119 |
+
self._log_validation_summary(aggregate_metrics)
|
120 |
+
return aggregate_metrics
|
121 |
+
|
122 |
+
def _calculate_aggregate_metrics(self, metrics_history: List[Dict]) -> Dict[str, float]:
|
123 |
+
"""Calculate comprehensive aggregate metrics."""
|
124 |
+
metrics = {
|
125 |
+
'num_queries_tested': len(metrics_history),
|
126 |
+
'avg_top_response_score': np.mean([m.get('top_score', 0) for m in metrics_history]),
|
127 |
+
'avg_diversity': np.mean([m.get('response_diversity', 0) for m in metrics_history]),
|
128 |
+
'avg_relevance': np.mean([m.get('query_response_relevance', 0) for m in metrics_history]),
|
129 |
+
'avg_length_score': np.mean([m.get('response_length_score', 0) for m in metrics_history]),
|
130 |
+
'avg_score_gap': np.mean([m.get('top_3_score_gap', 0) for m in metrics_history]),
|
131 |
+
'confidence_rate': np.mean([m.get('is_confident', False) for m in metrics_history]),
|
132 |
+
|
133 |
+
# Additional statistical metrics
|
134 |
+
'median_top_score': np.median([m.get('top_score', 0) for m in metrics_history]),
|
135 |
+
'score_std': np.std([m.get('top_score', 0) for m in metrics_history]),
|
136 |
+
'min_score': np.min([m.get('top_score', 0) for m in metrics_history]),
|
137 |
+
'max_score': np.max([m.get('top_score', 0) for m in metrics_history])
|
138 |
+
}
|
139 |
+
return metrics
|
140 |
+
|
141 |
+
def _analyze_domain_performance(self, domain_metrics: Dict[str, List[Dict]]) -> Dict[str, Dict]:
|
142 |
+
"""Analyze performance by domain."""
|
143 |
+
domain_analysis = {}
|
144 |
+
|
145 |
+
for domain, metrics in domain_metrics.items():
|
146 |
+
domain_analysis[domain] = {
|
147 |
+
'confidence_rate': np.mean([m.get('is_confident', False) for m in metrics]),
|
148 |
+
'avg_relevance': np.mean([m.get('query_response_relevance', 0) for m in metrics]),
|
149 |
+
'avg_diversity': np.mean([m.get('response_diversity', 0) for m in metrics]),
|
150 |
+
'avg_top_score': np.mean([m.get('top_score', 0) for m in metrics]),
|
151 |
+
'num_samples': len(metrics)
|
152 |
+
}
|
153 |
+
|
154 |
+
return domain_analysis
|
155 |
+
|
156 |
+
def _analyze_confidence_distribution(self, metrics_history: List[Dict]) -> Dict[str, float]:
|
157 |
+
"""Analyze the distribution of confidence scores."""
|
158 |
+
scores = [m.get('top_score', 0) for m in metrics_history]
|
159 |
+
|
160 |
+
return {
|
161 |
+
'percentile_25': np.percentile(scores, 25),
|
162 |
+
'percentile_50': np.percentile(scores, 50),
|
163 |
+
'percentile_75': np.percentile(scores, 75),
|
164 |
+
'percentile_90': np.percentile(scores, 90)
|
165 |
+
}
|
166 |
+
|
167 |
+
def _log_validation_results(
|
168 |
+
self,
|
169 |
+
query: str,
|
170 |
+
responses: List[Tuple[str, float]],
|
171 |
+
metrics: Dict[str, Any],
|
172 |
+
case_num: int
|
173 |
+
):
|
174 |
+
"""Log detailed validation results."""
|
175 |
+
logger.info(f"\nTest Case {case_num}:")
|
176 |
+
logger.info(f"Query: {query}")
|
177 |
+
logger.info(f"Domain: {metrics.get('domain', 'Unknown')}")
|
178 |
+
logger.info(f"Confidence: {'Yes' if metrics.get('is_confident', False) else 'No'}")
|
179 |
+
logger.info("\nQuality Metrics:")
|
180 |
+
for metric, value in metrics.items():
|
181 |
+
if isinstance(value, (int, float)):
|
182 |
+
logger.info(f" {metric}: {value:.4f}")
|
183 |
+
|
184 |
+
logger.info("\nTop Responses:")
|
185 |
+
for i, (response, score) in enumerate(responses[:3], 1):
|
186 |
+
logger.info(f"{i}. Score: {score:.4f}. Response: {response}")
|
187 |
+
if i == 1 and not metrics.get('is_confident', False):
|
188 |
+
logger.info(" [Low Confidence]")
|
189 |
+
|
190 |
+
def _log_validation_summary(self, metrics: Dict[str, Any]):
|
191 |
+
"""Log comprehensive validation summary."""
|
192 |
+
logger.info("\n=== Validation Summary ===")
|
193 |
+
|
194 |
+
logger.info("\nOverall Metrics:")
|
195 |
+
for metric, value in metrics.items():
|
196 |
+
if isinstance(value, (int, float)):
|
197 |
+
logger.info(f"{metric}: {value:.4f}")
|
198 |
+
|
199 |
+
logger.info("\nDomain Performance:")
|
200 |
+
for domain, domain_metrics in metrics['domain_performance'].items():
|
201 |
+
logger.info(f"\n{domain.title()}:")
|
202 |
+
for metric, value in domain_metrics.items():
|
203 |
+
logger.info(f" {metric}: {value:.4f}")
|
204 |
+
|
205 |
+
logger.info("\nConfidence Distribution:")
|
206 |
+
for percentile, value in metrics['confidence_analysis'].items():
|
207 |
+
logger.info(f"{percentile}: {value:.4f}")
|
conversation_summarizer.py
ADDED
@@ -0,0 +1,147 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import tensorflow as tf
|
2 |
+
from typing import List, Dict
|
3 |
+
from transformers import TFAutoModelForSeq2SeqLM, AutoTokenizer
|
4 |
+
import logging
|
5 |
+
from dataclasses import dataclass
|
6 |
+
|
7 |
+
logger = logging.getLogger(__name__)
|
8 |
+
|
9 |
+
@dataclass
|
10 |
+
class ChatConfig:
|
11 |
+
max_sequence_length: int = 512
|
12 |
+
default_top_k: int = 5
|
13 |
+
chunk_size: int = 512
|
14 |
+
chunk_overlap: int = 256
|
15 |
+
min_confidence_score: float = 0.7
|
16 |
+
|
17 |
+
class DeviceAwareModel:
|
18 |
+
"""Mixin to handle device placement and mixed precision training."""
|
19 |
+
|
20 |
+
def setup_device(self, device: str = None):
|
21 |
+
if device is None:
|
22 |
+
device = 'GPU' if tf.config.list_physical_devices('GPU') else 'CPU'
|
23 |
+
|
24 |
+
self.device = device.upper()
|
25 |
+
self.strategy = None
|
26 |
+
|
27 |
+
if self.device == 'GPU':
|
28 |
+
# Enable mixed precision for better performance
|
29 |
+
policy = tf.keras.mixed_precision.Policy('mixed_float16')
|
30 |
+
tf.keras.mixed_precision.set_global_policy(policy)
|
31 |
+
|
32 |
+
# Setup distribution strategy for multi-GPU if available
|
33 |
+
gpus = tf.config.list_physical_devices('GPU')
|
34 |
+
if len(gpus) > 1:
|
35 |
+
self.strategy = tf.distribute.MirroredStrategy()
|
36 |
+
|
37 |
+
return self.device
|
38 |
+
|
39 |
+
def run_on_device(self, func):
|
40 |
+
"""Decorator to ensure ops run on the correct device."""
|
41 |
+
def wrapper(*args, **kwargs):
|
42 |
+
with tf.device(f'/{self.device}:0'):
|
43 |
+
return func(*args, **kwargs)
|
44 |
+
return wrapper
|
45 |
+
|
46 |
+
class Summarizer(DeviceAwareModel):
|
47 |
+
"""
|
48 |
+
Enhanced T5-based summarizer with better chunking and device management.
|
49 |
+
Handles long conversations by intelligent chunking and progressive summarization.
|
50 |
+
"""
|
51 |
+
|
52 |
+
def __init__(self, model_name="t5-small", max_summary_length=128, device=None, max_summary_rounds=2):
|
53 |
+
self.setup_device(device)
|
54 |
+
|
55 |
+
# Initialize model within strategy scope if using distribution
|
56 |
+
if self.strategy:
|
57 |
+
with self.strategy.scope():
|
58 |
+
self._setup_model(model_name)
|
59 |
+
else:
|
60 |
+
self._setup_model(model_name)
|
61 |
+
|
62 |
+
self.max_summary_length = max_summary_length
|
63 |
+
self.max_summary_rounds = max_summary_rounds
|
64 |
+
|
65 |
+
def _setup_model(self, model_name):
|
66 |
+
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
|
67 |
+
self.model = TFAutoModelForSeq2SeqLM.from_pretrained(model_name)
|
68 |
+
|
69 |
+
# Optimize model for inference
|
70 |
+
self.model.predict = tf.function(
|
71 |
+
self.model.predict,
|
72 |
+
input_signature=[
|
73 |
+
{
|
74 |
+
'input_ids': tf.TensorSpec(shape=[None, None], dtype=tf.int32),
|
75 |
+
'attention_mask': tf.TensorSpec(shape=[None, None], dtype=tf.int32)
|
76 |
+
}
|
77 |
+
]
|
78 |
+
)
|
79 |
+
|
80 |
+
@tf.function
|
81 |
+
def _generate_summary(self, inputs):
|
82 |
+
return self.model.generate(
|
83 |
+
inputs,
|
84 |
+
max_length=self.max_summary_length,
|
85 |
+
num_beams=4,
|
86 |
+
length_penalty=2.0,
|
87 |
+
early_stopping=True,
|
88 |
+
no_repeat_ngram_size=3
|
89 |
+
)
|
90 |
+
|
91 |
+
def chunk_text(self, text: str, chunk_size: int = 512, overlap: int = 256) -> List[str]:
|
92 |
+
"""Split text into overlapping chunks for better context preservation."""
|
93 |
+
tokens = self.tokenizer.encode(text)
|
94 |
+
chunks = []
|
95 |
+
|
96 |
+
for i in range(0, len(tokens), chunk_size - overlap):
|
97 |
+
chunk = tokens[i:i + chunk_size]
|
98 |
+
chunks.append(self.tokenizer.decode(chunk, skip_special_tokens=True))
|
99 |
+
|
100 |
+
return chunks
|
101 |
+
|
102 |
+
def summarize_text(
|
103 |
+
self,
|
104 |
+
text: str,
|
105 |
+
progressive: bool = True,
|
106 |
+
round_idx: int = 0
|
107 |
+
) -> str:
|
108 |
+
"""
|
109 |
+
Summarize text with optional progressive summarization
|
110 |
+
and limit the maximum number of re-summarization rounds.
|
111 |
+
"""
|
112 |
+
@self.run_on_device
|
113 |
+
def _summarize_chunk(chunk: str) -> str:
|
114 |
+
input_text = "summarize: " + chunk
|
115 |
+
inputs = self.tokenizer(
|
116 |
+
input_text,
|
117 |
+
return_tensors="tf",
|
118 |
+
padding=True,
|
119 |
+
truncation=True
|
120 |
+
)
|
121 |
+
summary_ids = self._generate_summary(inputs)
|
122 |
+
return self.tokenizer.decode(summary_ids[0], skip_special_tokens=True)
|
123 |
+
|
124 |
+
# If we've hit our max allowed summarization rounds, just do a single pass
|
125 |
+
if round_idx >= self.max_summary_rounds:
|
126 |
+
return _summarize_chunk(text)
|
127 |
+
|
128 |
+
# If text is longer than threshold and progressive summarization is on
|
129 |
+
if len(text.split()) > 512 and progressive:
|
130 |
+
chunks = self.chunk_text(text)
|
131 |
+
chunk_summaries = [_summarize_chunk(chunk) for chunk in chunks]
|
132 |
+
|
133 |
+
# Combine chunk-level summaries
|
134 |
+
combined_summary = " ".join(chunk_summaries)
|
135 |
+
|
136 |
+
# If still too long, do another summarization pass but increment round_idx
|
137 |
+
if len(combined_summary.split()) > 512:
|
138 |
+
return self.summarize_text(
|
139 |
+
combined_summary,
|
140 |
+
progressive=True,
|
141 |
+
round_idx=round_idx + 1
|
142 |
+
)
|
143 |
+
|
144 |
+
return combined_summary
|
145 |
+
else:
|
146 |
+
# If text is not too long, just summarize once and return
|
147 |
+
return _summarize_chunk(text)
|
cross_encoder_reranker.py
ADDED
@@ -0,0 +1,51 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from transformers import AutoTokenizer, TFAutoModelForSequenceClassification
|
2 |
+
import tensorflow as tf
|
3 |
+
from typing import List, Tuple
|
4 |
+
|
5 |
+
from logger_config import config_logger
|
6 |
+
logger = config_logger(__name__)
|
7 |
+
|
8 |
+
class CrossEncoderReranker:
|
9 |
+
"""
|
10 |
+
Cross-Encoder Re-Ranker: Takes (query, candidate) pairs,
|
11 |
+
outputs a single relevance score (one logit).
|
12 |
+
"""
|
13 |
+
def __init__(self, model_name: str = "cross-encoder/ms-marco-MiniLM-L-12-v2"):
|
14 |
+
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
|
15 |
+
self.model = TFAutoModelForSequenceClassification.from_pretrained(model_name)
|
16 |
+
# Model outputs shape [batch_size, 1] -> Interpret the logit as relevance score.
|
17 |
+
|
18 |
+
def rerank(
|
19 |
+
self,
|
20 |
+
query: str,
|
21 |
+
candidates: List[str],
|
22 |
+
max_length: int = 256
|
23 |
+
) -> List[float]:
|
24 |
+
"""
|
25 |
+
Returns a list of re_scores, one for each candidate, indicating
|
26 |
+
how relevant the candidate is to the query.
|
27 |
+
"""
|
28 |
+
# Build (query, candidate) pairs
|
29 |
+
pair_texts = [(query, candidate) for candidate in candidates]
|
30 |
+
|
31 |
+
# Tokenize the entire batch
|
32 |
+
encodings = self.tokenizer(
|
33 |
+
pair_texts,
|
34 |
+
padding=True,
|
35 |
+
truncation=True,
|
36 |
+
max_length=max_length,
|
37 |
+
return_tensors="tf"
|
38 |
+
)
|
39 |
+
|
40 |
+
# Forward pass -> logits shape [batch_size, 1]
|
41 |
+
outputs = self.model(
|
42 |
+
input_ids=encodings["input_ids"],
|
43 |
+
attention_mask=encodings["attention_mask"],
|
44 |
+
token_type_ids=encodings.get("token_type_ids")
|
45 |
+
)
|
46 |
+
|
47 |
+
logits = outputs.logits
|
48 |
+
# Flatten to shape [batch_size]
|
49 |
+
scores = tf.reshape(logits, [-1]).numpy()
|
50 |
+
|
51 |
+
return scores.tolist()
|
dialogue_augmenter.py
CHANGED
@@ -7,7 +7,6 @@ from pipeline_config import PipelineConfig
|
|
7 |
from quality_metrics import QualityMetrics
|
8 |
from paraphraser import Paraphraser
|
9 |
import nlpaug.augmenter.word as naw
|
10 |
-
from concurrent.futures import ThreadPoolExecutor
|
11 |
from functools import lru_cache
|
12 |
from sklearn.metrics.pairwise import cosine_similarity
|
13 |
|
|
|
7 |
from quality_metrics import QualityMetrics
|
8 |
from paraphraser import Paraphraser
|
9 |
import nlpaug.augmenter.word as naw
|
|
|
10 |
from functools import lru_cache
|
11 |
from sklearn.metrics.pairwise import cosine_similarity
|
12 |
|
environment_setup.py
ADDED
@@ -0,0 +1,207 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Dict, Optional, Tuple
|
2 |
+
from pathlib import Path
|
3 |
+
import tensorflow as tf
|
4 |
+
import os
|
5 |
+
import subprocess
|
6 |
+
from datetime import datetime
|
7 |
+
from logger_config import config_logger
|
8 |
+
|
9 |
+
logger = config_logger(__name__)
|
10 |
+
|
11 |
+
class EnvironmentSetup:
|
12 |
+
def __init__(self):
|
13 |
+
self.device_type, self.strategy = self.setup_devices()
|
14 |
+
self.cache_dir = None
|
15 |
+
|
16 |
+
def initialize(self, cache_dir: Optional[Path] = None):
|
17 |
+
self.cache_dir = self.setup_model_cache(cache_dir)
|
18 |
+
self.training_dirs = self.setup_training_directories()
|
19 |
+
|
20 |
+
@staticmethod
|
21 |
+
def setup_model_cache(cache_dir: Optional[Path] = None) -> Path:
|
22 |
+
"""Setup and manage model cache directory."""
|
23 |
+
if cache_dir is None:
|
24 |
+
cache_dir = Path.home() / '.chatbot_cache'
|
25 |
+
|
26 |
+
cache_dir.mkdir(parents=True, exist_ok=True)
|
27 |
+
|
28 |
+
# Set environment variables for various libraries
|
29 |
+
os.environ['TRANSFORMERS_CACHE'] = str(cache_dir / 'transformers')
|
30 |
+
os.environ['TORCH_HOME'] = str(cache_dir / 'torch')
|
31 |
+
os.environ['HF_HOME'] = str(cache_dir / 'huggingface')
|
32 |
+
|
33 |
+
logger.info(f"Using cache directory: {cache_dir}")
|
34 |
+
return cache_dir
|
35 |
+
|
36 |
+
@staticmethod
|
37 |
+
def setup_training_directories(base_dir: str = "chatbot_training") -> Dict[str, Path]:
|
38 |
+
"""Setup directory structure for training artifacts."""
|
39 |
+
base_dir = Path(base_dir)
|
40 |
+
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
41 |
+
train_dir = base_dir / f"training_run_{timestamp}"
|
42 |
+
|
43 |
+
directories = {
|
44 |
+
'base': train_dir,
|
45 |
+
'checkpoints': train_dir / 'checkpoints',
|
46 |
+
'plots': train_dir / 'plots',
|
47 |
+
'logs': train_dir / 'logs'
|
48 |
+
}
|
49 |
+
|
50 |
+
# Create directories
|
51 |
+
for dir_path in directories.values():
|
52 |
+
dir_path.mkdir(parents=True, exist_ok=True)
|
53 |
+
|
54 |
+
return directories
|
55 |
+
|
56 |
+
@staticmethod
|
57 |
+
def is_colab() -> bool:
|
58 |
+
"""Check if code is running in Google Colab."""
|
59 |
+
try:
|
60 |
+
# Handle both import and attribute checks
|
61 |
+
import google.colab # type: ignore
|
62 |
+
import IPython # type: ignore
|
63 |
+
return True
|
64 |
+
except (ImportError, AttributeError):
|
65 |
+
return False
|
66 |
+
|
67 |
+
def setup_colab_tpu(self) -> Optional[tf.distribute.Strategy]:
|
68 |
+
"""Setup TPU in Colab environment if available."""
|
69 |
+
if not self.is_colab():
|
70 |
+
return None
|
71 |
+
|
72 |
+
try:
|
73 |
+
import requests
|
74 |
+
import os
|
75 |
+
|
76 |
+
# Check TPU availability
|
77 |
+
if 'COLAB_TPU_ADDR' not in os.environ:
|
78 |
+
return None
|
79 |
+
|
80 |
+
# TPU address should be set
|
81 |
+
tpu_address = 'grpc://' + os.environ['COLAB_TPU_ADDR']
|
82 |
+
resolver = tf.distribute.cluster_resolver.TPUClusterResolver(tpu=tpu_address)
|
83 |
+
tf.config.experimental_connect_to_cluster(resolver)
|
84 |
+
tf.tpu.experimental.initialize_tpu_system(resolver)
|
85 |
+
strategy = tf.distribute.TPUStrategy(resolver)
|
86 |
+
|
87 |
+
return strategy
|
88 |
+
except Exception as e:
|
89 |
+
logger.warning(f"Failed to initialize Colab TPU: {e}")
|
90 |
+
return None
|
91 |
+
|
92 |
+
def setup_devices(self) -> Tuple[str, tf.distribute.Strategy]:
|
93 |
+
"""Configure available compute devices with Colab-specific optimizations."""
|
94 |
+
logger.info("Checking available compute devices...")
|
95 |
+
|
96 |
+
# Colab-specific setup
|
97 |
+
if self.is_colab():
|
98 |
+
logger.info("Running in Google Colab environment")
|
99 |
+
|
100 |
+
# Try TPU first in Colab
|
101 |
+
tpu_strategy = self.setup_colab_tpu()
|
102 |
+
if tpu_strategy is not None:
|
103 |
+
logger.info("Colab TPU detected and initialized")
|
104 |
+
return "TPU", tpu_strategy
|
105 |
+
|
106 |
+
# Colab GPU setup
|
107 |
+
gpus = tf.config.list_physical_devices('GPU')
|
108 |
+
if gpus:
|
109 |
+
try:
|
110 |
+
# Colab-specific GPU memory management
|
111 |
+
for gpu in gpus:
|
112 |
+
tf.config.experimental.set_memory_growth(gpu, True)
|
113 |
+
|
114 |
+
# Get GPU info using subprocess
|
115 |
+
try:
|
116 |
+
gpu_name = subprocess.check_output(
|
117 |
+
['nvidia-smi', '--query-gpu=gpu_name', '--format=csv,noheader'],
|
118 |
+
stderr=subprocess.DEVNULL
|
119 |
+
).decode('utf-8').strip()
|
120 |
+
logger.info(f"Colab GPU detected: {gpu_name}")
|
121 |
+
|
122 |
+
except (subprocess.SubprocessError, FileNotFoundError):
|
123 |
+
logger.warning("Could not detect specific GPU model")
|
124 |
+
|
125 |
+
# Enable XLA
|
126 |
+
tf.config.optimizer.set_jit(True)
|
127 |
+
logger.info("XLA compilation enabled for Colab GPU")
|
128 |
+
|
129 |
+
# Set mixed precision policy
|
130 |
+
policy = tf.keras.mixed_precision.Policy('mixed_float16')
|
131 |
+
tf.keras.mixed_precision.set_global_policy(policy)
|
132 |
+
logger.info("Mixed precision training enabled (float16)")
|
133 |
+
|
134 |
+
strategy = tf.distribute.OneDeviceStrategy("/GPU:0")
|
135 |
+
return "GPU", strategy
|
136 |
+
|
137 |
+
except Exception as e:
|
138 |
+
logger.error(f"Error configuring Colab GPU: {str(e)}")
|
139 |
+
|
140 |
+
# Non-Colab setup (same as before)
|
141 |
+
else:
|
142 |
+
# Check for TPU
|
143 |
+
try:
|
144 |
+
resolver = tf.distribute.cluster_resolver.TPUClusterResolver()
|
145 |
+
tf.config.experimental_connect_to_cluster(resolver)
|
146 |
+
tf.tpu.experimental.initialize_tpu_system(resolver)
|
147 |
+
strategy = tf.distribute.TPUStrategy(resolver)
|
148 |
+
logger.info("TPU detected and initialized")
|
149 |
+
return "TPU", strategy
|
150 |
+
except ValueError:
|
151 |
+
logger.info("No TPU detected. Checking for GPUs...")
|
152 |
+
|
153 |
+
# Check for GPUs
|
154 |
+
gpus = tf.config.list_physical_devices('GPU')
|
155 |
+
if gpus:
|
156 |
+
try:
|
157 |
+
for gpu in gpus:
|
158 |
+
tf.config.experimental.set_memory_growth(gpu, True)
|
159 |
+
|
160 |
+
if len(gpus) > 1:
|
161 |
+
strategy = tf.distribute.MirroredStrategy()
|
162 |
+
logger.info(f"Multi-GPU strategy set up with {len(gpus)} GPUs")
|
163 |
+
else:
|
164 |
+
strategy = tf.distribute.OneDeviceStrategy("/GPU:0")
|
165 |
+
logger.info("Single GPU strategy set up")
|
166 |
+
|
167 |
+
return "GPU", strategy
|
168 |
+
|
169 |
+
except Exception as e:
|
170 |
+
logger.error(f"Error configuring GPU: {str(e)}")
|
171 |
+
|
172 |
+
# CPU fallback
|
173 |
+
strategy = tf.distribute.OneDeviceStrategy("/CPU:0")
|
174 |
+
logger.info("Using CPU strategy")
|
175 |
+
return "CPU", strategy
|
176 |
+
|
177 |
+
def optimize_batch_size(self, base_batch_size: int = 16) -> int:
|
178 |
+
"""Apply Colab-specific optimizations for training."""
|
179 |
+
if not self.is_colab():
|
180 |
+
return base_batch_size
|
181 |
+
|
182 |
+
# Colab-specific batch size optimization
|
183 |
+
if self.device_type == "GPU":
|
184 |
+
try:
|
185 |
+
gpu_name = subprocess.check_output(
|
186 |
+
['nvidia-smi', '--query-gpu=gpu_name', '--format=csv,noheader'],
|
187 |
+
stderr=subprocess.DEVNULL
|
188 |
+
).decode('utf-8').strip()
|
189 |
+
|
190 |
+
if "T4" in gpu_name:
|
191 |
+
# T4 optimizations
|
192 |
+
logger.info("Optimizing for Colab T4 GPU")
|
193 |
+
base_batch_size = min(base_batch_size * 2, 32) # T4 can handle larger batches
|
194 |
+
elif "V100" in gpu_name:
|
195 |
+
# V100 optimizations
|
196 |
+
logger.info("Optimizing for Colab V100 GPU")
|
197 |
+
base_batch_size = min(base_batch_size * 3, 48) # V100 can handle even larger batches
|
198 |
+
except (subprocess.SubprocessError, FileNotFoundError):
|
199 |
+
logger.warning("Could not detect specific GPU model, using default settings")
|
200 |
+
|
201 |
+
elif self.device_type == "TPU":
|
202 |
+
# TPU optimizations
|
203 |
+
base_batch_size = min(base_batch_size * 4, 64) # TPUs can handle very large batches
|
204 |
+
logger.info("Optimizing for Colab TPU")
|
205 |
+
|
206 |
+
logger.info(f"Optimized batch size for Colab: {base_batch_size}")
|
207 |
+
return base_batch_size
|
logger_config.py
ADDED
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
|
3 |
+
def config_logger(name):
|
4 |
+
logging.basicConfig(
|
5 |
+
level=logging.DEBUG,
|
6 |
+
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
|
7 |
+
)
|
8 |
+
logger = logging.getLogger(name)
|
9 |
+
logger.setLevel(logging.DEBUG)
|
10 |
+
return logger
|
requirements.txt
CHANGED
@@ -1,6 +1,12 @@
|
|
|
|
|
|
|
|
|
|
1 |
nlpaug>=1.1.0 # Data augmentation for NLP
|
2 |
nltk>=3.6.0 # Natural language toolkit
|
3 |
numpy>=1.19.0 # General numerical computation
|
|
|
|
|
4 |
scikit-learn>=1.0.0 # Machine learning tools
|
5 |
sacremoses>=0.0.53 # Required for some HuggingFace models
|
6 |
sentencepiece>=0.1.99 # Required for HuggingFace transformers
|
@@ -11,4 +17,10 @@ tokenizers>=0.13.0 # Required for HuggingFace transformers
|
|
11 |
torch>=2.0.0 # PyTorch, for deep learning
|
12 |
tqdm>=4.64.0 # Progress bar
|
13 |
transformers>=4.30.0 # Hugging Face Transformers library
|
14 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
faiss-cpu>=1.7.0 # Required for Facebook AI Similarity Search
|
2 |
+
ipython>=8.0.0 # For interactive Python
|
3 |
+
loguru>=0.7.0 # Enhanced logging (optional but recommended)
|
4 |
+
matplotlib>=3.5.0 # For validation plotting
|
5 |
nlpaug>=1.1.0 # Data augmentation for NLP
|
6 |
nltk>=3.6.0 # Natural language toolkit
|
7 |
numpy>=1.19.0 # General numerical computation
|
8 |
+
pandas>=1.5.0 # For data handling
|
9 |
+
pyyaml>=6.0.0 # For config management
|
10 |
scikit-learn>=1.0.0 # Machine learning tools
|
11 |
sacremoses>=0.0.53 # Required for some HuggingFace models
|
12 |
sentencepiece>=0.1.99 # Required for HuggingFace transformers
|
|
|
17 |
torch>=2.0.0 # PyTorch, for deep learning
|
18 |
tqdm>=4.64.0 # Progress bar
|
19 |
transformers>=4.30.0 # Hugging Face Transformers library
|
20 |
+
typing-extensions>=4.0.0 # For better type hints
|
21 |
+
|
22 |
+
# Dev dependencies
|
23 |
+
black>=22.0.0 # For code formatting
|
24 |
+
isort>=5.10.0 # For import sorting
|
25 |
+
mypy>=1.0.0 # For type checking
|
26 |
+
pytest>=7.0.0 # For testing
|
response_quality_checker.py
CHANGED
@@ -1,164 +1,170 @@
|
|
1 |
import numpy as np
|
2 |
-
from typing import List, Tuple, Dict, Any
|
3 |
from sklearn.metrics.pairwise import cosine_similarity
|
4 |
-
|
|
|
|
|
|
|
|
|
|
|
5 |
|
6 |
class ResponseQualityChecker:
|
7 |
-
"""
|
8 |
|
9 |
def __init__(
|
10 |
self,
|
11 |
-
chatbot: RetrievalChatbot,
|
12 |
-
confidence_threshold: float = 0.
|
13 |
-
diversity_threshold: float = 0.
|
14 |
-
min_response_length: int =
|
15 |
-
|
16 |
):
|
17 |
self.confidence_threshold = confidence_threshold
|
18 |
self.diversity_threshold = diversity_threshold
|
19 |
self.min_response_length = min_response_length
|
20 |
-
self.
|
21 |
self.chatbot = chatbot
|
22 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
23 |
def check_response_quality(
|
24 |
self,
|
25 |
query: str,
|
26 |
responses: List[Tuple[str, float]]
|
27 |
) -> Dict[str, Any]:
|
28 |
"""
|
29 |
-
Evaluate the quality of
|
30 |
-
"""
|
31 |
-
# Calculate diversity based on the responses themselves
|
32 |
-
diversity = self.calculate_diversity(responses)
|
33 |
|
34 |
-
|
35 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
36 |
|
37 |
-
#
|
38 |
-
|
39 |
-
avg_length_score = np.mean(length_scores) if length_scores else 0.0
|
40 |
|
41 |
-
|
42 |
-
|
|
|
|
|
|
|
|
|
|
|
43 |
|
44 |
-
#
|
45 |
-
|
|
|
46 |
|
47 |
-
#
|
48 |
-
|
49 |
-
|
50 |
-
'response_diversity': diversity,
|
51 |
-
'query_response_relevance': relevance,
|
52 |
-
'response_length_score': avg_length_score,
|
53 |
-
'top_3_score_gap': score_gap
|
54 |
-
}
|
55 |
|
56 |
-
|
57 |
-
|
58 |
-
|
59 |
-
return {
|
60 |
-
'diversity': diversity,
|
61 |
-
'relevance': relevance,
|
62 |
-
'is_confident': is_confident,
|
63 |
-
'top_score': metrics['top_score'],
|
64 |
-
'response_diversity': metrics['response_diversity'],
|
65 |
-
'query_response_relevance': metrics['query_response_relevance'],
|
66 |
-
'response_length_score': metrics['response_length_score'],
|
67 |
-
'top_3_score_gap': metrics['top_3_score_gap']
|
68 |
-
}
|
69 |
-
|
70 |
def calculate_diversity(self, responses: List[Tuple[str, float]]) -> float:
|
71 |
-
"""
|
72 |
-
Calculate diversity as the average pairwise similarity between responses.
|
73 |
-
Lower similarity indicates higher diversity.
|
74 |
-
"""
|
75 |
if not responses:
|
76 |
return 0.0
|
77 |
|
78 |
-
# Encode responses
|
79 |
embeddings = [self.encode_text(response) for response, _ in responses]
|
80 |
if len(embeddings) < 2:
|
81 |
-
return 1.0
|
82 |
|
83 |
-
#
|
84 |
similarity_matrix = cosine_similarity(embeddings)
|
|
|
85 |
|
86 |
-
#
|
87 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
88 |
num_pairs = len(responses) * (len(responses) - 1)
|
89 |
avg_similarity = sum_similarities / num_pairs if num_pairs > 0 else 0.0
|
90 |
-
diversity_score = 1 - avg_similarity # Higher value indicates more diversity
|
91 |
-
return diversity_score
|
92 |
-
|
93 |
-
def calculate_relevance(self, query: str, responses: List[Tuple[str, float]]) -> float:
|
94 |
-
"""
|
95 |
-
Calculate relevance as the average similarity between the query and each response.
|
96 |
-
"""
|
97 |
-
if not responses:
|
98 |
-
return 0.0
|
99 |
|
100 |
-
|
101 |
-
|
102 |
-
|
103 |
-
|
104 |
-
|
|
|
|
|
|
|
|
|
|
|
105 |
|
106 |
-
#
|
107 |
-
|
|
|
|
|
|
|
|
|
108 |
|
109 |
-
|
110 |
-
|
111 |
-
|
112 |
def _calculate_length_score(self, response: str) -> float:
|
113 |
-
"""
|
114 |
-
|
115 |
-
|
116 |
-
|
|
|
|
|
|
|
117 |
return 1.0
|
118 |
-
|
119 |
def _calculate_score_gap(self, scores: List[float], top_n: int = 3) -> float:
|
120 |
-
"""
|
121 |
-
Calculate the average gap between the top N scores.
|
122 |
-
|
123 |
-
Args:
|
124 |
-
scores (List[float]): List of similarity scores.
|
125 |
-
top_n (int): Number of top scores to consider.
|
126 |
-
|
127 |
-
Returns:
|
128 |
-
float: Average score gap.
|
129 |
-
"""
|
130 |
if len(scores) < top_n + 1:
|
131 |
return 0.0
|
132 |
-
gaps = [scores[i] - scores[i + 1] for i in range(top_n)]
|
133 |
-
|
134 |
-
|
135 |
-
|
136 |
-
def _determine_confidence(self, metrics: Dict[str, float]) -> bool:
|
137 |
-
"""
|
138 |
-
Determine if we're confident enough in the response.
|
139 |
-
|
140 |
-
Returns:
|
141 |
-
bool: True if we should use this response, False if we should abstain
|
142 |
-
"""
|
143 |
-
conditions = [
|
144 |
-
metrics['top_score'] >= self.confidence_threshold,
|
145 |
-
metrics['response_diversity'] >= self.diversity_threshold,
|
146 |
-
metrics['response_length_score'] >= 0.8,
|
147 |
-
metrics['query_response_relevance'] >= 0.3, # was 0.5
|
148 |
-
metrics['top_3_score_gap'] >= 0.05 # was 0.1
|
149 |
-
]
|
150 |
-
return all(conditions)
|
151 |
-
|
152 |
def encode_text(self, text: str) -> np.ndarray:
|
153 |
-
|
154 |
-
|
155 |
-
|
156 |
-
|
157 |
-
|
158 |
-
return embedding
|
159 |
-
|
160 |
def encode_query(self, query: str) -> np.ndarray:
|
161 |
-
|
162 |
-
|
163 |
-
embedding =
|
164 |
-
return embedding
|
|
|
|
|
|
|
|
|
|
|
|
1 |
import numpy as np
|
2 |
+
from typing import List, Tuple, Dict, Any, TYPE_CHECKING
|
3 |
from sklearn.metrics.pairwise import cosine_similarity
|
4 |
+
|
5 |
+
from logger_config import config_logger
|
6 |
+
logger = config_logger(__name__)
|
7 |
+
|
8 |
+
if TYPE_CHECKING:
|
9 |
+
from chatbot_model import RetrievalChatbot
|
10 |
|
11 |
class ResponseQualityChecker:
|
12 |
+
"""Enhanced quality checking with dynamic thresholds."""
|
13 |
|
14 |
def __init__(
|
15 |
self,
|
16 |
+
chatbot: 'RetrievalChatbot',
|
17 |
+
confidence_threshold: float = 0.6,
|
18 |
+
diversity_threshold: float = 0.15,
|
19 |
+
min_response_length: int = 5,
|
20 |
+
similarity_cap: float = 0.85 # Renamed from max_similarity_ratio and used in diversity calc
|
21 |
):
|
22 |
self.confidence_threshold = confidence_threshold
|
23 |
self.diversity_threshold = diversity_threshold
|
24 |
self.min_response_length = min_response_length
|
25 |
+
self.similarity_cap = similarity_cap
|
26 |
self.chatbot = chatbot
|
27 |
+
|
28 |
+
# Dynamic thresholds based on response patterns
|
29 |
+
self.thresholds = {
|
30 |
+
'relevance': 0.35,
|
31 |
+
'length_score': 0.85,
|
32 |
+
'score_gap': 0.07
|
33 |
+
}
|
34 |
+
|
35 |
def check_response_quality(
|
36 |
self,
|
37 |
query: str,
|
38 |
responses: List[Tuple[str, float]]
|
39 |
) -> Dict[str, Any]:
|
40 |
"""
|
41 |
+
Evaluate the quality of responses based on various metrics.
|
|
|
|
|
|
|
42 |
|
43 |
+
Args:
|
44 |
+
query: The user's query
|
45 |
+
responses: List of (response_text, score) tuples
|
46 |
+
|
47 |
+
Returns:
|
48 |
+
Dict containing quality metrics and confidence assessment
|
49 |
+
"""
|
50 |
+
if not responses:
|
51 |
+
return {
|
52 |
+
'response_diversity': 0.0,
|
53 |
+
'query_response_relevance': 0.0,
|
54 |
+
'is_confident': False,
|
55 |
+
'top_score': 0.0,
|
56 |
+
'response_length_score': 0.0,
|
57 |
+
'top_3_score_gap': 0.0
|
58 |
+
}
|
59 |
+
|
60 |
+
# Calculate core metrics
|
61 |
+
metrics = {
|
62 |
+
'response_diversity': self.calculate_diversity(responses),
|
63 |
+
'query_response_relevance': self.calculate_relevance(query, responses),
|
64 |
+
'response_length_score': np.mean([
|
65 |
+
self._calculate_length_score(response) for response, _ in responses
|
66 |
+
]),
|
67 |
+
'top_score': responses[0][1],
|
68 |
+
'top_3_score_gap': self._calculate_score_gap([score for _, score in responses], top_n=3)
|
69 |
+
}
|
70 |
|
71 |
+
# Determine confidence using thresholds
|
72 |
+
metrics['is_confident'] = self._determine_confidence(metrics)
|
|
|
73 |
|
74 |
+
logger.info(f"Quality metrics: {metrics}")
|
75 |
+
return metrics
|
76 |
+
|
77 |
+
def calculate_relevance(self, query: str, responses: List[Tuple[str, float]]) -> float:
|
78 |
+
"""Calculate relevance as weighted similarity between query and responses."""
|
79 |
+
if not responses:
|
80 |
+
return 0.0
|
81 |
|
82 |
+
# Get embeddings
|
83 |
+
query_embedding = self.encode_query(query)
|
84 |
+
response_embeddings = [self.encode_text(response) for response, _ in responses]
|
85 |
|
86 |
+
# Compute similarities with decreasing weights for later responses
|
87 |
+
similarities = cosine_similarity([query_embedding], response_embeddings)[0]
|
88 |
+
weights = np.array([1.0 / (i + 1) for i in range(len(similarities))])
|
|
|
|
|
|
|
|
|
|
|
89 |
|
90 |
+
return np.average(similarities, weights=weights)
|
91 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
92 |
def calculate_diversity(self, responses: List[Tuple[str, float]]) -> float:
|
93 |
+
"""Calculate diversity with length normalization and similarity capping."""
|
|
|
|
|
|
|
94 |
if not responses:
|
95 |
return 0.0
|
96 |
|
|
|
97 |
embeddings = [self.encode_text(response) for response, _ in responses]
|
98 |
if len(embeddings) < 2:
|
99 |
+
return 1.0
|
100 |
|
101 |
+
# Calculate similarities and apply cap
|
102 |
similarity_matrix = cosine_similarity(embeddings)
|
103 |
+
similarity_matrix = np.minimum(similarity_matrix, self.similarity_cap)
|
104 |
|
105 |
+
# Apply length normalization
|
106 |
+
lengths = [len(resp[0].split()) for resp in responses]
|
107 |
+
length_ratios = np.array([min(a, b) / max(a, b) for a in lengths for b in lengths])
|
108 |
+
length_ratios = length_ratios.reshape(len(responses), len(responses))
|
109 |
+
|
110 |
+
# Combine factors with weights
|
111 |
+
adjusted_similarity = (similarity_matrix * 0.7 + length_ratios * 0.3)
|
112 |
+
|
113 |
+
# Calculate final score
|
114 |
+
sum_similarities = np.sum(adjusted_similarity) - len(responses)
|
115 |
num_pairs = len(responses) * (len(responses) - 1)
|
116 |
avg_similarity = sum_similarities / num_pairs if num_pairs > 0 else 0.0
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
117 |
|
118 |
+
return 1 - avg_similarity
|
119 |
+
|
120 |
+
def _determine_confidence(self, metrics: Dict[str, float]) -> bool:
|
121 |
+
"""Determine confidence using primary and secondary conditions."""
|
122 |
+
# Primary conditions (must all be met)
|
123 |
+
primary_conditions = [
|
124 |
+
metrics['top_score'] >= self.confidence_threshold,
|
125 |
+
metrics['response_diversity'] >= self.diversity_threshold,
|
126 |
+
metrics['response_length_score'] >= self.thresholds['length_score']
|
127 |
+
]
|
128 |
|
129 |
+
# Secondary conditions (majority must be met)
|
130 |
+
secondary_conditions = [
|
131 |
+
metrics['query_response_relevance'] >= self.thresholds['relevance'],
|
132 |
+
metrics['top_3_score_gap'] >= self.thresholds['score_gap'],
|
133 |
+
metrics['top_score'] >= (self.confidence_threshold * 1.1) # Extra confidence boost
|
134 |
+
]
|
135 |
|
136 |
+
return all(primary_conditions) and sum(secondary_conditions) >= 2
|
137 |
+
|
|
|
138 |
def _calculate_length_score(self, response: str) -> float:
|
139 |
+
"""Calculate length score with penalty for very short or long responses."""
|
140 |
+
words = len(response.split())
|
141 |
+
|
142 |
+
if words < self.min_response_length:
|
143 |
+
return words / self.min_response_length
|
144 |
+
elif words > 50: # Penalty for very long responses
|
145 |
+
return min(1.0, 50 / words)
|
146 |
return 1.0
|
147 |
+
|
148 |
def _calculate_score_gap(self, scores: List[float], top_n: int = 3) -> float:
|
149 |
+
"""Calculate average gap between top N scores."""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
150 |
if len(scores) < top_n + 1:
|
151 |
return 0.0
|
152 |
+
gaps = [scores[i] - scores[i + 1] for i in range(min(len(scores) - 1, top_n))]
|
153 |
+
return np.mean(gaps)
|
154 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
155 |
def encode_text(self, text: str) -> np.ndarray:
|
156 |
+
"""Encode response text to embedding."""
|
157 |
+
embedding_tensor = self.chatbot.encode_responses([text])
|
158 |
+
embedding = embedding_tensor.numpy()[0].astype('float32')
|
159 |
+
return self._normalize_embedding(embedding)
|
160 |
+
|
|
|
|
|
161 |
def encode_query(self, query: str) -> np.ndarray:
|
162 |
+
"""Encode query text to embedding."""
|
163 |
+
embedding_tensor = self.chatbot.encode_query(query)
|
164 |
+
embedding = embedding_tensor.numpy()[0].astype('float32')
|
165 |
+
return self._normalize_embedding(embedding)
|
166 |
+
|
167 |
+
def _normalize_embedding(self, embedding: np.ndarray) -> np.ndarray:
|
168 |
+
"""Normalize embedding vector."""
|
169 |
+
norm = np.linalg.norm(embedding)
|
170 |
+
return embedding / norm if norm > 0 else embedding
|
run_model.py
DELETED
@@ -1,162 +0,0 @@
|
|
1 |
-
import json
|
2 |
-
import glob
|
3 |
-
import os
|
4 |
-
from chatbot import RetrievalChatbot
|
5 |
-
import tensorflow as tf
|
6 |
-
from sklearn.model_selection import train_test_split
|
7 |
-
import matplotlib.pyplot as plt
|
8 |
-
|
9 |
-
def load_training_data(data_directory: str) -> list:
|
10 |
-
"""Load and combine dialogue data from multiple JSON files."""
|
11 |
-
all_dialogues = []
|
12 |
-
|
13 |
-
# Get all json files matching the pattern
|
14 |
-
pattern = os.path.join(data_directory, "batch_*.json")
|
15 |
-
json_files = sorted(glob.glob(pattern))
|
16 |
-
|
17 |
-
print(f"Found {len(json_files)} batch files")
|
18 |
-
|
19 |
-
for file_path in json_files:
|
20 |
-
try:
|
21 |
-
with open(file_path, 'r', encoding='utf-8') as f:
|
22 |
-
batch_dialogues = json.load(f)
|
23 |
-
all_dialogues.extend(batch_dialogues)
|
24 |
-
print(f"Loaded {len(batch_dialogues)} dialogues from {os.path.basename(file_path)}")
|
25 |
-
except Exception as e:
|
26 |
-
print(f"Error loading {file_path}: {str(e)}")
|
27 |
-
|
28 |
-
print(f"Total dialogues loaded: {len(all_dialogues)}")
|
29 |
-
return all_dialogues
|
30 |
-
|
31 |
-
def plot_training_history(train_losses, val_losses):
|
32 |
-
# Plot training and validation loss
|
33 |
-
plt.figure()
|
34 |
-
plt.plot(train_losses, label='Train Loss')
|
35 |
-
plt.plot(val_losses, label='Val Loss')
|
36 |
-
plt.xlabel('Epoch')
|
37 |
-
plt.ylabel('Triplet Loss')
|
38 |
-
plt.legend()
|
39 |
-
plt.show()
|
40 |
-
|
41 |
-
dialogues = load_training_data('processed_outputs')
|
42 |
-
|
43 |
-
# Initialize the chatbot
|
44 |
-
chatbot = RetrievalChatbot(
|
45 |
-
vocab_size=10000,
|
46 |
-
max_sequence_length=80,
|
47 |
-
embedding_dim=256,
|
48 |
-
lstm_units=256,
|
49 |
-
num_attention_heads=8,
|
50 |
-
margin=0.3
|
51 |
-
)
|
52 |
-
|
53 |
-
# Prepare the dataset for triplet training
|
54 |
-
q_pad, p_pad, n_pad = chatbot.prepare_dataset(dialogues, neg_samples_per_pos=3)
|
55 |
-
|
56 |
-
# Train with triplet loss
|
57 |
-
train_losses, val_losses = chatbot.train_with_triplet_loss(
|
58 |
-
q_pad, p_pad, n_pad,
|
59 |
-
epochs=1,
|
60 |
-
batch_size=32,
|
61 |
-
validation_split=0.2
|
62 |
-
)
|
63 |
-
|
64 |
-
plot_training_history(train_losses, val_losses)
|
65 |
-
|
66 |
-
# After training, test prediction
|
67 |
-
response_candidates = [turn['text'] for d in dialogues for turn in d['turns'] if turn['speaker'] == 'assistant']
|
68 |
-
|
69 |
-
# Test retrieval
|
70 |
-
test_query = "I'd like a recommendation for a Korean restaurant in NYC."
|
71 |
-
top_responses = chatbot.retrieve_top_n(test_query, response_candidates, top_n=5)
|
72 |
-
print("Top responses:")
|
73 |
-
for resp, score in top_responses:
|
74 |
-
print(f"Score: {score:.4f} - {resp}")
|
75 |
-
|
76 |
-
# Single-turn validation:
|
77 |
-
test_queries = [
|
78 |
-
"I want to book a Korean restaurant in NYC.",
|
79 |
-
"Can I get two tickets for 'What Men Want'?",
|
80 |
-
"What's the best time to watch the movie today?"
|
81 |
-
]
|
82 |
-
for query in test_queries:
|
83 |
-
top_responses = chatbot.retrieve_top_n(query, response_candidates, top_n=3)
|
84 |
-
print(f"\nQuery: {query}")
|
85 |
-
for resp, score in top_responses:
|
86 |
-
print(f"Score: {score:.4f} - {resp}")
|
87 |
-
|
88 |
-
# Multi-turn conversation:
|
89 |
-
multi_turn_history = []
|
90 |
-
|
91 |
-
def update_context(multi_turn_history, query, response, max_context_turns=3):
|
92 |
-
multi_turn_history.append((query, response))
|
93 |
-
if len(multi_turn_history) > max_context_turns:
|
94 |
-
multi_turn_history.pop(0)
|
95 |
-
|
96 |
-
def get_context_enhanced_query(multi_turn_history, query):
|
97 |
-
if not multi_turn_history:
|
98 |
-
return query
|
99 |
-
context = " ".join([f"User: {q} Assistant: {r}" for q, r in multi_turn_history])
|
100 |
-
return f"{context} User: {query}"
|
101 |
-
|
102 |
-
conversation_queries = [
|
103 |
-
"I'd like to watch a movie tonight.",
|
104 |
-
"Is there a showing of 'What Men Want'?",
|
105 |
-
"What time is the last show?",
|
106 |
-
"Can I get two tickets?"
|
107 |
-
]
|
108 |
-
|
109 |
-
for query in conversation_queries:
|
110 |
-
context_query = get_context_enhanced_query(multi_turn_history, query)
|
111 |
-
top_responses = chatbot.retrieve_top_n(context_query, response_candidates, top_n=3)
|
112 |
-
best_response = top_responses[0][0]
|
113 |
-
print(f"\nUser: {query}\nAssistant: {best_response}")
|
114 |
-
update_context(multi_turn_history, query, best_response)
|
115 |
-
|
116 |
-
|
117 |
-
|
118 |
-
|
119 |
-
|
120 |
-
|
121 |
-
|
122 |
-
|
123 |
-
#queries, responses, labels = chatbot.prepare_dataset(dialogues, neg_samples_per_pos=3)
|
124 |
-
|
125 |
-
#train_dialogues, val_dialogues = train_test_split(dialogues, test_size=0.2, random_state=20)
|
126 |
-
#query_train, query_val, response_train, response_val, labels_train, labels_val = train_test_split(queries, responses, labels, test_size=0.2, random_state=20)
|
127 |
-
|
128 |
-
# chatbot.model.compile(
|
129 |
-
# optimizer=tf.keras.optimizers.Adam(learning_rate=0.001, clipnorm=1.0),
|
130 |
-
# loss='binary_crossentropy',
|
131 |
-
# metrics=['accuracy']
|
132 |
-
# )
|
133 |
-
|
134 |
-
# # Train the model with early stopping to prevent overfitting
|
135 |
-
# callbacks = [
|
136 |
-
# tf.keras.callbacks.EarlyStopping(
|
137 |
-
# monitor='val_loss',
|
138 |
-
# patience=3,
|
139 |
-
# restore_best_weights=True
|
140 |
-
# ),
|
141 |
-
# tf.keras.callbacks.ReduceLROnPlateau(
|
142 |
-
# monitor='val_loss',
|
143 |
-
# factor=0.5,
|
144 |
-
# patience=2,
|
145 |
-
# min_lr=1e-6,
|
146 |
-
# verbose=1
|
147 |
-
# ),
|
148 |
-
# tf.keras.callbacks.ModelCheckpoint(
|
149 |
-
# 'chatbot_model.keras',
|
150 |
-
# monitor='val_loss',
|
151 |
-
# save_best_only=True
|
152 |
-
# )
|
153 |
-
# ]
|
154 |
-
|
155 |
-
# history = chatbot.model.fit(
|
156 |
-
# {'query_input': query_train, 'response_input': response_train},
|
157 |
-
# labels_train,
|
158 |
-
# validation_data=({'query_input': query_val, 'response_input': response_val}, labels_val),
|
159 |
-
# epochs=5,
|
160 |
-
# batch_size=32,
|
161 |
-
# callbacks=callbacks
|
162 |
-
# )
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
run_model2.py
DELETED
@@ -1,340 +0,0 @@
|
|
1 |
-
from chatbot2 import RetrievalChatbot, ChatbotConfig
|
2 |
-
import os
|
3 |
-
import json
|
4 |
-
import glob
|
5 |
-
import matplotlib.pyplot as plt
|
6 |
-
import logging
|
7 |
-
from pathlib import Path
|
8 |
-
from typing import List, Dict, Optional, Any, Tuple
|
9 |
-
import numpy as np
|
10 |
-
from datetime import datetime
|
11 |
-
from response_quality_checker import ResponseQualityChecker
|
12 |
-
|
13 |
-
# Configure logging
|
14 |
-
logging.basicConfig(
|
15 |
-
level=logging.INFO,
|
16 |
-
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
|
17 |
-
)
|
18 |
-
logger = logging.getLogger(__name__)
|
19 |
-
|
20 |
-
def load_training_data(data_directory: str, debug_samples: Optional[int] = None) -> list:
|
21 |
-
"""
|
22 |
-
Load and combine dialogue data from multiple JSON files.
|
23 |
-
|
24 |
-
Args:
|
25 |
-
data_directory: Directory containing the dialogue files
|
26 |
-
debug_samples: If set, only load this many dialogues for debugging
|
27 |
-
"""
|
28 |
-
all_dialogues = []
|
29 |
-
data_directory = Path(data_directory)
|
30 |
-
|
31 |
-
# Get all json files matching the pattern
|
32 |
-
pattern = "batch_*.json"
|
33 |
-
json_files = sorted(data_directory.glob(pattern))
|
34 |
-
|
35 |
-
logger.info(f"Found {len(json_files)} batch files")
|
36 |
-
|
37 |
-
if debug_samples:
|
38 |
-
logger.info(f"Debug mode: Will load up to {debug_samples} dialogues")
|
39 |
-
|
40 |
-
for file_path in json_files:
|
41 |
-
try:
|
42 |
-
with open(file_path, 'r', encoding='utf-8') as f:
|
43 |
-
batch_dialogues = json.load(f)
|
44 |
-
|
45 |
-
# If in debug mode, only take what we need from this batch
|
46 |
-
if debug_samples is not None:
|
47 |
-
remaining_samples = debug_samples - len(all_dialogues)
|
48 |
-
if remaining_samples <= 0:
|
49 |
-
break
|
50 |
-
batch_dialogues = batch_dialogues[:remaining_samples]
|
51 |
-
|
52 |
-
all_dialogues.extend(batch_dialogues)
|
53 |
-
logger.info(f"Loaded {len(batch_dialogues)} dialogues from {file_path.name}")
|
54 |
-
|
55 |
-
# If we've reached our debug sample limit, stop loading
|
56 |
-
if debug_samples is not None and len(all_dialogues) >= debug_samples:
|
57 |
-
logger.info(f"Debug mode: Reached {debug_samples} samples, stopping load")
|
58 |
-
break
|
59 |
-
|
60 |
-
except Exception as e:
|
61 |
-
logger.error(f"Error loading {file_path}: {str(e)}")
|
62 |
-
|
63 |
-
total_loaded = len(all_dialogues)
|
64 |
-
if debug_samples:
|
65 |
-
logger.info(f"Debug mode: Loaded {total_loaded}/{debug_samples} requested dialogues")
|
66 |
-
else:
|
67 |
-
logger.info(f"Total dialogues loaded: {total_loaded}")
|
68 |
-
|
69 |
-
return all_dialogues
|
70 |
-
|
71 |
-
def plot_training_history(history: Dict[str, List[float]], save_dir: Path = None):
|
72 |
-
"""Plot and optionally save training history."""
|
73 |
-
# Create figure with two subplots
|
74 |
-
fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(10, 12))
|
75 |
-
|
76 |
-
# Plot losses
|
77 |
-
ax1.plot(history['train_loss'], label='Train Loss')
|
78 |
-
ax1.plot(history['val_loss'], label='Validation Loss')
|
79 |
-
ax1.set_xlabel('Epoch')
|
80 |
-
ax1.set_ylabel('Triplet Loss')
|
81 |
-
ax1.set_title('Training and Validation Loss')
|
82 |
-
ax1.legend()
|
83 |
-
ax1.grid(True)
|
84 |
-
|
85 |
-
# Plot learning rate if available
|
86 |
-
if 'learning_rate' in history:
|
87 |
-
ax2.plot(history['learning_rate'], label='Learning Rate')
|
88 |
-
ax2.set_xlabel('Step')
|
89 |
-
ax2.set_ylabel('Learning Rate')
|
90 |
-
ax2.set_title('Learning Rate Schedule')
|
91 |
-
ax2.legend()
|
92 |
-
ax2.grid(True)
|
93 |
-
|
94 |
-
plt.tight_layout()
|
95 |
-
|
96 |
-
# Save if directory provided
|
97 |
-
if save_dir:
|
98 |
-
save_dir = Path(save_dir)
|
99 |
-
save_dir.mkdir(parents=True, exist_ok=True)
|
100 |
-
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
101 |
-
plt.savefig(save_dir / f'training_history_{timestamp}.png')
|
102 |
-
|
103 |
-
plt.show()
|
104 |
-
|
105 |
-
def setup_training_directories(base_dir: str = "chatbot_training") -> Dict[str, Path]:
|
106 |
-
"""Setup directory structure for training artifacts."""
|
107 |
-
base_dir = Path(base_dir)
|
108 |
-
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
109 |
-
train_dir = base_dir / f"training_run_{timestamp}"
|
110 |
-
|
111 |
-
directories = {
|
112 |
-
'base': train_dir,
|
113 |
-
'checkpoints': train_dir / 'checkpoints',
|
114 |
-
'plots': train_dir / 'plots',
|
115 |
-
'logs': train_dir / 'logs'
|
116 |
-
}
|
117 |
-
|
118 |
-
# Create directories
|
119 |
-
for dir_path in directories.values():
|
120 |
-
dir_path.mkdir(parents=True, exist_ok=True)
|
121 |
-
|
122 |
-
return directories
|
123 |
-
|
124 |
-
def run_automatic_validation(
|
125 |
-
chatbot,
|
126 |
-
response_pool: List[str],
|
127 |
-
quality_checker: ResponseQualityChecker,
|
128 |
-
num_examples: int = 5
|
129 |
-
) -> Dict[str, Any]:
|
130 |
-
"""
|
131 |
-
Run automatic validation with quality metrics.
|
132 |
-
"""
|
133 |
-
logger.info("\n=== Running Automatic Validation ===")
|
134 |
-
|
135 |
-
test_queries = [
|
136 |
-
"Hello, how are you today?",
|
137 |
-
"What's the weather like?",
|
138 |
-
"Can you help me with a problem?",
|
139 |
-
"Tell me a joke",
|
140 |
-
"What time is it?",
|
141 |
-
"I need help with my homework",
|
142 |
-
"Where's a good place to eat?",
|
143 |
-
"What movies are playing?",
|
144 |
-
"How do I reset my password?",
|
145 |
-
"Can you recommend a book?"
|
146 |
-
]
|
147 |
-
|
148 |
-
test_queries = test_queries[:num_examples]
|
149 |
-
metrics_history = []
|
150 |
-
|
151 |
-
for i, query in enumerate(test_queries, 1):
|
152 |
-
logger.info(f"\nTest Case {i}:")
|
153 |
-
logger.info(f"Query: {query}")
|
154 |
-
|
155 |
-
# Get responses and scores
|
156 |
-
responses = chatbot.retrieve_responses(
|
157 |
-
query,
|
158 |
-
response_pool,
|
159 |
-
context=None,
|
160 |
-
top_k=5
|
161 |
-
)
|
162 |
-
|
163 |
-
# Check quality
|
164 |
-
quality_metrics = quality_checker.check_response_quality(
|
165 |
-
query, responses, response_pool
|
166 |
-
)
|
167 |
-
metrics_history.append(quality_metrics)
|
168 |
-
|
169 |
-
# Log results
|
170 |
-
logger.info(f"Quality Metrics: {quality_metrics}")
|
171 |
-
logger.info("Top responses:")
|
172 |
-
for j, (response, score) in enumerate(responses[:3], 1):
|
173 |
-
logger.info(f"{j}. Score: {score:.4f}")
|
174 |
-
logger.info(f" Response: {response}")
|
175 |
-
if j == 1 and not quality_metrics['is_confident']:
|
176 |
-
logger.info(" [Low Confidence - Would abstain from answering]")
|
177 |
-
|
178 |
-
# Calculate aggregate metrics
|
179 |
-
aggregate_metrics = {
|
180 |
-
'num_queries_tested': len(test_queries),
|
181 |
-
'avg_top_response_score': np.mean([m['top_score'] for m in metrics_history]),
|
182 |
-
'avg_diversity': np.mean([m['response_diversity'] for m in metrics_history]),
|
183 |
-
'avg_relevance': np.mean([m['query_response_relevance'] for m in metrics_history]),
|
184 |
-
'confidence_rate': np.mean([m['is_confident'] for m in metrics_history]),
|
185 |
-
}
|
186 |
-
|
187 |
-
logger.info("\n=== Validation Summary ===")
|
188 |
-
for metric, value in aggregate_metrics.items():
|
189 |
-
logger.info(f"{metric}: {value:.4f}")
|
190 |
-
|
191 |
-
return aggregate_metrics
|
192 |
-
|
193 |
-
def chat_with_quality_check(
|
194 |
-
chatbot,
|
195 |
-
query: str,
|
196 |
-
response_pool: List[str],
|
197 |
-
conversation_history: List[Tuple[str, str]],
|
198 |
-
quality_checker: ResponseQualityChecker
|
199 |
-
) -> Tuple[Optional[str], List[Tuple[str, float]], Dict[str, Any]]:
|
200 |
-
"""
|
201 |
-
Enhanced chat function with quality checking.
|
202 |
-
"""
|
203 |
-
# Get responses and scores
|
204 |
-
responses = chatbot.retrieve_responses(
|
205 |
-
query,
|
206 |
-
response_pool,
|
207 |
-
conversation_history
|
208 |
-
)
|
209 |
-
|
210 |
-
# Check quality
|
211 |
-
quality_metrics = quality_checker.check_response_quality(
|
212 |
-
query, responses, response_pool
|
213 |
-
)
|
214 |
-
|
215 |
-
if quality_metrics['is_confident']:
|
216 |
-
return responses[0][0], responses, quality_metrics
|
217 |
-
else:
|
218 |
-
uncertainty_response = (
|
219 |
-
"I apologize, but I don't feel confident providing an answer to that "
|
220 |
-
"question at the moment. Could you please rephrase or ask something else?"
|
221 |
-
)
|
222 |
-
return uncertainty_response, responses, quality_metrics
|
223 |
-
|
224 |
-
def get_total_steps(dialogues: List[Dict[str, Any]], batch_size: int, epochs: int) -> int:
|
225 |
-
"""
|
226 |
-
Calculate total training steps based on dialogues and batch size.
|
227 |
-
Assume 80% of data for training due to validation split
|
228 |
-
"""
|
229 |
-
estimated_train_samples = int(len(dialogues) * 0.8)
|
230 |
-
steps_per_epoch = estimated_train_samples // batch_size
|
231 |
-
return steps_per_epoch * epochs
|
232 |
-
|
233 |
-
def main():
|
234 |
-
DEBUG_SAMPLES = 350
|
235 |
-
BATCH_SIZE = 32
|
236 |
-
EPOCHS = 5 if DEBUG_SAMPLES else 10
|
237 |
-
|
238 |
-
# Setup training directories
|
239 |
-
dirs = setup_training_directories()
|
240 |
-
|
241 |
-
# Load training data
|
242 |
-
dialogues = load_training_data('processed_outputs', debug_samples=DEBUG_SAMPLES)
|
243 |
-
total_steps = get_total_steps(dialogues, BATCH_SIZE, EPOCHS)
|
244 |
-
|
245 |
-
# Initialize configuration
|
246 |
-
config = ChatbotConfig(
|
247 |
-
embedding_dim=32, # TODO: 256
|
248 |
-
encoder_units=32, # TODO: 256
|
249 |
-
num_attention_heads=2, # TODO: 8
|
250 |
-
warmup_steps=int(total_steps * 0.1), # 10% of total steps for warmup
|
251 |
-
)
|
252 |
-
|
253 |
-
# Save configuration
|
254 |
-
with open(dirs['base'] / 'config.json', 'w') as f:
|
255 |
-
json.dump(config.to_dict(), f, indent=2)
|
256 |
-
|
257 |
-
# Initialize chatbot
|
258 |
-
chatbot = RetrievalChatbot(config)
|
259 |
-
|
260 |
-
# Prepare dataset
|
261 |
-
logger.info("Preparing dataset...")
|
262 |
-
|
263 |
-
# Prepare and train with debug samples
|
264 |
-
q_pad, p_pad, n_pad = chatbot.prepare_dataset(
|
265 |
-
dialogues,
|
266 |
-
neg_samples_per_pos=3,
|
267 |
-
debug_samples=DEBUG_SAMPLES
|
268 |
-
)
|
269 |
-
|
270 |
-
# Train model
|
271 |
-
logger.info("Starting training...")
|
272 |
-
chatbot.train(
|
273 |
-
q_pad, p_pad, n_pad,
|
274 |
-
epochs=EPOCHS,
|
275 |
-
batch_size=BATCH_SIZE,
|
276 |
-
checkpoint_dir=dirs['checkpoints']
|
277 |
-
)
|
278 |
-
|
279 |
-
# Plot and save training history
|
280 |
-
plot_training_history(chatbot.history, save_dir=dirs['plots'])
|
281 |
-
|
282 |
-
# Save final model
|
283 |
-
chatbot.save_models(dirs['base'] / 'final_model')
|
284 |
-
|
285 |
-
# Prepare response pool for chat
|
286 |
-
response_pool = [
|
287 |
-
turn['text'] for d in dialogues
|
288 |
-
for turn in d['turns'] if turn['speaker'] == 'assistant'
|
289 |
-
]
|
290 |
-
|
291 |
-
# Initialize quality checker with appropriate thresholds
|
292 |
-
quality_checker = ResponseQualityChecker(
|
293 |
-
confidence_threshold=0.6 if not DEBUG_SAMPLES else 0.4, # Lower threshold for debug
|
294 |
-
diversity_threshold=0.2,
|
295 |
-
min_response_length=10,
|
296 |
-
max_similarity_ratio=0.9
|
297 |
-
)
|
298 |
-
|
299 |
-
# Run automatic validation
|
300 |
-
validation_metrics = run_automatic_validation(
|
301 |
-
chatbot,
|
302 |
-
response_pool,
|
303 |
-
quality_checker,
|
304 |
-
num_examples=5 if DEBUG_SAMPLES else 10
|
305 |
-
)
|
306 |
-
|
307 |
-
# Log validation metrics
|
308 |
-
logger.info(f"Validation Metrics: {validation_metrics}")
|
309 |
-
|
310 |
-
# Now continue with interactive chat
|
311 |
-
logger.info("\nStarting interactive chat session...")
|
312 |
-
conversation_history = []
|
313 |
-
|
314 |
-
while True:
|
315 |
-
query = input("\nYou: ")
|
316 |
-
if query.lower() in ['quit', 'exit', 'bye']:
|
317 |
-
break
|
318 |
-
|
319 |
-
try:
|
320 |
-
response, candidates = chatbot.chat(
|
321 |
-
query,
|
322 |
-
response_pool,
|
323 |
-
conversation_history
|
324 |
-
)
|
325 |
-
print(f"\nAssistant: {response}")
|
326 |
-
|
327 |
-
# Print top alternative responses
|
328 |
-
print("\nAlternative responses:")
|
329 |
-
for resp, score in candidates[1:4]:
|
330 |
-
print(f"Score: {score:.4f} - {resp}")
|
331 |
-
|
332 |
-
# Update history
|
333 |
-
conversation_history.append((query, response))
|
334 |
-
|
335 |
-
except Exception as e:
|
336 |
-
logger.error(f"Error during chat: {str(e)}")
|
337 |
-
print("Sorry, I encountered an error. Please try again.")
|
338 |
-
|
339 |
-
if __name__ == "__main__":
|
340 |
-
main()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
run_model3.py
DELETED
@@ -1,434 +0,0 @@
|
|
1 |
-
from chatbot3 import RetrievalChatbot, ChatbotConfig
|
2 |
-
import os
|
3 |
-
import json
|
4 |
-
import glob
|
5 |
-
import tensorflow as tf
|
6 |
-
import matplotlib.pyplot as plt
|
7 |
-
import logging
|
8 |
-
from pathlib import Path
|
9 |
-
from typing import List, Dict, Optional, Any, Tuple
|
10 |
-
import numpy as np
|
11 |
-
from datetime import datetime
|
12 |
-
from response_quality_checker import ResponseQualityChecker
|
13 |
-
import torch
|
14 |
-
from transformers import TFAutoModel, AutoTokenizer
|
15 |
-
|
16 |
-
|
17 |
-
policy = tf.keras.mixed_precision.Policy('mixed_float16')
|
18 |
-
tf.keras.mixed_precision.set_global_policy(policy)
|
19 |
-
|
20 |
-
# Configure logging
|
21 |
-
logging.basicConfig(
|
22 |
-
level=logging.INFO,
|
23 |
-
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
|
24 |
-
)
|
25 |
-
logger = logging.getLogger(__name__)
|
26 |
-
|
27 |
-
def setup_model_cache(cache_dir: Optional[Path] = None) -> Path:
|
28 |
-
"""Setup and manage model cache directory."""
|
29 |
-
if cache_dir is None:
|
30 |
-
cache_dir = Path.home() / '.chatbot_cache'
|
31 |
-
|
32 |
-
cache_dir.mkdir(parents=True, exist_ok=True)
|
33 |
-
|
34 |
-
# Set environment variables for various libraries
|
35 |
-
os.environ['TRANSFORMERS_CACHE'] = str(cache_dir / 'transformers')
|
36 |
-
os.environ['TORCH_HOME'] = str(cache_dir / 'torch')
|
37 |
-
os.environ['HF_HOME'] = str(cache_dir / 'huggingface')
|
38 |
-
|
39 |
-
logger.info(f"Using cache directory: {cache_dir}")
|
40 |
-
return cache_dir
|
41 |
-
|
42 |
-
def setup_gpu():
|
43 |
-
"""Configure GPU settings for optimal performance."""
|
44 |
-
logger.info("Checking GPU availability...")
|
45 |
-
|
46 |
-
gpus = tf.config.list_physical_devices('GPU')
|
47 |
-
if gpus:
|
48 |
-
try:
|
49 |
-
# Allow memory growth to prevent taking all GPU memory at once
|
50 |
-
for gpu in gpus:
|
51 |
-
tf.config.experimental.set_memory_growth(gpu, True)
|
52 |
-
logger.info(f"Found {len(gpus)} GPU(s). Memory growth enabled.")
|
53 |
-
|
54 |
-
# Log GPU info
|
55 |
-
for gpu in gpus:
|
56 |
-
logger.info(f"GPU Device: {gpu}")
|
57 |
-
return True
|
58 |
-
except Exception as e:
|
59 |
-
logger.error(f"Error configuring GPU: {str(e)}")
|
60 |
-
return False
|
61 |
-
|
62 |
-
else:
|
63 |
-
logger.info("No GPU found. Using CPU.")
|
64 |
-
return False
|
65 |
-
|
66 |
-
def preload_models(config: ChatbotConfig, cache_dir: Path):
|
67 |
-
"""Preload and cache models."""
|
68 |
-
logger.info("Preloading models...")
|
69 |
-
|
70 |
-
# Cache DistilBERT
|
71 |
-
model_name = config.pretrained_model
|
72 |
-
cache_path = cache_dir / 'transformers' / model_name
|
73 |
-
|
74 |
-
if not cache_path.exists():
|
75 |
-
logger.info(f"Downloading and caching {model_name}...")
|
76 |
-
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
77 |
-
model = TFAutoModel.from_pretrained(model_name)
|
78 |
-
|
79 |
-
# Save to cache
|
80 |
-
tokenizer.save_pretrained(cache_path)
|
81 |
-
model.save_pretrained(cache_path)
|
82 |
-
else:
|
83 |
-
logger.info(f"Using cached model from {cache_path}")
|
84 |
-
|
85 |
-
return cache_path
|
86 |
-
|
87 |
-
def load_training_data(data_directory: str, debug_samples: Optional[int] = None) -> list:
|
88 |
-
"""
|
89 |
-
Load and combine dialogue data from multiple JSON files.
|
90 |
-
|
91 |
-
Args:
|
92 |
-
data_directory: Directory containing the dialogue files
|
93 |
-
debug_samples: If set, only load this many dialogues for debugging
|
94 |
-
"""
|
95 |
-
all_dialogues = []
|
96 |
-
data_directory = Path(data_directory)
|
97 |
-
|
98 |
-
# Get all json files matching the pattern
|
99 |
-
pattern = "batch_*.json"
|
100 |
-
json_files = sorted(data_directory.glob(pattern))
|
101 |
-
|
102 |
-
logger.info(f"Found {len(json_files)} batch files")
|
103 |
-
|
104 |
-
if debug_samples:
|
105 |
-
logger.info(f"Debug mode: Will load up to {debug_samples} dialogues")
|
106 |
-
|
107 |
-
for file_path in json_files:
|
108 |
-
try:
|
109 |
-
with open(file_path, 'r', encoding='utf-8') as f:
|
110 |
-
batch_dialogues = json.load(f)
|
111 |
-
|
112 |
-
# If in debug mode, only take what we need from this batch
|
113 |
-
if debug_samples is not None:
|
114 |
-
remaining_samples = debug_samples - len(all_dialogues)
|
115 |
-
if remaining_samples <= 0:
|
116 |
-
break
|
117 |
-
batch_dialogues = batch_dialogues[:remaining_samples]
|
118 |
-
|
119 |
-
all_dialogues.extend(batch_dialogues)
|
120 |
-
logger.info(f"Loaded {len(batch_dialogues)} dialogues from {file_path.name}")
|
121 |
-
|
122 |
-
# If we've reached our debug sample limit, stop loading
|
123 |
-
if debug_samples is not None and len(all_dialogues) >= debug_samples:
|
124 |
-
logger.info(f"Debug mode: Reached {debug_samples} samples, stopping load")
|
125 |
-
break
|
126 |
-
|
127 |
-
except Exception as e:
|
128 |
-
logger.error(f"Error loading {file_path}: {str(e)}")
|
129 |
-
|
130 |
-
total_loaded = len(all_dialogues)
|
131 |
-
if debug_samples:
|
132 |
-
logger.info(f"Debug mode: Loaded {total_loaded}/{debug_samples} requested dialogues")
|
133 |
-
else:
|
134 |
-
logger.info(f"Total dialogues loaded: {total_loaded}")
|
135 |
-
|
136 |
-
return all_dialogues
|
137 |
-
|
138 |
-
def plot_training_history(history: Dict[str, List[float]], save_dir: Path = None):
|
139 |
-
"""Plot and optionally save training history."""
|
140 |
-
# Create figure with two subplots
|
141 |
-
fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(10, 12))
|
142 |
-
|
143 |
-
# Plot losses
|
144 |
-
ax1.plot(history['train_loss'], label='Train Loss')
|
145 |
-
ax1.plot(history['val_loss'], label='Validation Loss')
|
146 |
-
ax1.set_xlabel('Epoch')
|
147 |
-
ax1.set_ylabel('Triplet Loss')
|
148 |
-
ax1.set_title('Training and Validation Loss')
|
149 |
-
ax1.legend()
|
150 |
-
ax1.grid(True)
|
151 |
-
|
152 |
-
# Plot learning rate if available
|
153 |
-
if 'learning_rate' in history:
|
154 |
-
ax2.plot(history['learning_rate'], label='Learning Rate')
|
155 |
-
ax2.set_xlabel('Step')
|
156 |
-
ax2.set_ylabel('Learning Rate')
|
157 |
-
ax2.set_title('Learning Rate Schedule')
|
158 |
-
ax2.legend()
|
159 |
-
ax2.grid(True)
|
160 |
-
|
161 |
-
plt.tight_layout()
|
162 |
-
|
163 |
-
# Save if directory provided
|
164 |
-
if save_dir:
|
165 |
-
save_dir = Path(save_dir)
|
166 |
-
save_dir.mkdir(parents=True, exist_ok=True)
|
167 |
-
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
168 |
-
plt.savefig(save_dir / f'training_history_{timestamp}.png')
|
169 |
-
|
170 |
-
plt.show()
|
171 |
-
|
172 |
-
def setup_training_directories(base_dir: str = "chatbot_training") -> Dict[str, Path]:
|
173 |
-
"""Setup directory structure for training artifacts."""
|
174 |
-
base_dir = Path(base_dir)
|
175 |
-
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
176 |
-
train_dir = base_dir / f"training_run_{timestamp}"
|
177 |
-
|
178 |
-
directories = {
|
179 |
-
'base': train_dir,
|
180 |
-
'checkpoints': train_dir / 'checkpoints',
|
181 |
-
'plots': train_dir / 'plots',
|
182 |
-
'logs': train_dir / 'logs'
|
183 |
-
}
|
184 |
-
|
185 |
-
# Create directories
|
186 |
-
for dir_path in directories.values():
|
187 |
-
dir_path.mkdir(parents=True, exist_ok=True)
|
188 |
-
|
189 |
-
return directories
|
190 |
-
|
191 |
-
def run_automatic_validation(
|
192 |
-
chatbot,
|
193 |
-
response_pool: List[str],
|
194 |
-
quality_checker: ResponseQualityChecker,
|
195 |
-
num_examples: int = 5
|
196 |
-
) -> Dict[str, Any]:
|
197 |
-
"""
|
198 |
-
Run automatic validation with quality metrics.
|
199 |
-
"""
|
200 |
-
logger.info("\n=== Running Automatic Validation ===")
|
201 |
-
|
202 |
-
test_queries = [
|
203 |
-
"Hello, how are you today?",
|
204 |
-
"What's the weather like?",
|
205 |
-
"Can you help me with a problem?",
|
206 |
-
"Tell me a joke",
|
207 |
-
"What time is it?",
|
208 |
-
"I need help with my homework",
|
209 |
-
"Where's a good place to eat?",
|
210 |
-
"What movies are playing?",
|
211 |
-
"How do I reset my password?",
|
212 |
-
"Can you recommend a book?"
|
213 |
-
]
|
214 |
-
|
215 |
-
test_queries = test_queries[:num_examples]
|
216 |
-
metrics_history = []
|
217 |
-
|
218 |
-
for i, query in enumerate(test_queries, 1):
|
219 |
-
logger.info(f"\nTest Case {i}:")
|
220 |
-
logger.info(f"Query: {query}")
|
221 |
-
|
222 |
-
# Get responses and scores
|
223 |
-
responses = chatbot.retrieve_responses(
|
224 |
-
query,
|
225 |
-
response_pool,
|
226 |
-
context=None,
|
227 |
-
top_k=5
|
228 |
-
)
|
229 |
-
|
230 |
-
# Check quality
|
231 |
-
quality_metrics = quality_checker.check_response_quality(
|
232 |
-
query, responses, response_pool
|
233 |
-
)
|
234 |
-
metrics_history.append(quality_metrics)
|
235 |
-
|
236 |
-
# Log results
|
237 |
-
logger.info(f"Quality Metrics: {quality_metrics}")
|
238 |
-
logger.info("Top responses:")
|
239 |
-
for j, (response, score) in enumerate(responses[:3], 1):
|
240 |
-
logger.info(f"{j}. Score: {score:.4f}")
|
241 |
-
logger.info(f" Response: {response}")
|
242 |
-
if j == 1 and not quality_metrics['is_confident']:
|
243 |
-
logger.info(" [Low Confidence - Would abstain from answering]")
|
244 |
-
|
245 |
-
# Calculate aggregate metrics
|
246 |
-
aggregate_metrics = {
|
247 |
-
'num_queries_tested': len(test_queries),
|
248 |
-
'avg_top_response_score': np.mean([m['top_score'] for m in metrics_history]),
|
249 |
-
'avg_diversity': np.mean([m['response_diversity'] for m in metrics_history]),
|
250 |
-
'avg_relevance': np.mean([m['query_response_relevance'] for m in metrics_history]),
|
251 |
-
'confidence_rate': np.mean([m['is_confident'] for m in metrics_history]),
|
252 |
-
}
|
253 |
-
|
254 |
-
logger.info("\n=== Validation Summary ===")
|
255 |
-
for metric, value in aggregate_metrics.items():
|
256 |
-
logger.info(f"{metric}: {value:.4f}")
|
257 |
-
|
258 |
-
return aggregate_metrics
|
259 |
-
|
260 |
-
def chat_with_quality_check(
|
261 |
-
chatbot,
|
262 |
-
query: str,
|
263 |
-
response_pool: List[str],
|
264 |
-
conversation_history: List[Tuple[str, str]],
|
265 |
-
quality_checker: ResponseQualityChecker
|
266 |
-
) -> Tuple[Optional[str], List[Tuple[str, float]], Dict[str, Any]]:
|
267 |
-
"""
|
268 |
-
Enhanced chat function with quality checking.
|
269 |
-
"""
|
270 |
-
# Get responses and scores
|
271 |
-
responses = chatbot.retrieve_responses(
|
272 |
-
query,
|
273 |
-
response_pool,
|
274 |
-
conversation_history
|
275 |
-
)
|
276 |
-
|
277 |
-
# Check quality
|
278 |
-
quality_metrics = quality_checker.check_response_quality(
|
279 |
-
query, responses, response_pool
|
280 |
-
)
|
281 |
-
|
282 |
-
if quality_metrics['is_confident']:
|
283 |
-
return responses[0][0], responses, quality_metrics
|
284 |
-
else:
|
285 |
-
uncertainty_response = (
|
286 |
-
"I apologize, but I don't feel confident providing an answer to that "
|
287 |
-
"question at the moment. Could you please rephrase or ask something else?"
|
288 |
-
)
|
289 |
-
return uncertainty_response, responses, quality_metrics
|
290 |
-
|
291 |
-
def get_total_steps(dialogues: List[Dict[str, Any]], batch_size: int, epochs: int) -> int:
|
292 |
-
"""
|
293 |
-
Calculate total training steps based on dialogues and batch size.
|
294 |
-
Assume 80% of data for training due to validation split
|
295 |
-
"""
|
296 |
-
estimated_train_samples = int(len(dialogues) * 0.8)
|
297 |
-
steps_per_epoch = estimated_train_samples // batch_size
|
298 |
-
return steps_per_epoch * epochs
|
299 |
-
|
300 |
-
def main():
|
301 |
-
# Set up GPU
|
302 |
-
is_gpu = setup_gpu()
|
303 |
-
|
304 |
-
DEBUG_SAMPLES = 350
|
305 |
-
BATCH_SIZE = 64 if is_gpu else 32
|
306 |
-
EPOCHS = 5 if DEBUG_SAMPLES else 10
|
307 |
-
|
308 |
-
# Set device
|
309 |
-
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
310 |
-
logger.info(f"Using device: {device}")
|
311 |
-
|
312 |
-
# Set up caching
|
313 |
-
cache_dir = setup_model_cache()
|
314 |
-
|
315 |
-
# Set up training directories
|
316 |
-
dirs = setup_training_directories()
|
317 |
-
|
318 |
-
# Load training data
|
319 |
-
dialogues = load_training_data('processed_outputs', debug_samples=DEBUG_SAMPLES)
|
320 |
-
total_steps = get_total_steps(dialogues, BATCH_SIZE, EPOCHS)
|
321 |
-
|
322 |
-
# Initialize configuration
|
323 |
-
config = ChatbotConfig(
|
324 |
-
embedding_dim=768, # Match DistilBERT's dimension
|
325 |
-
encoder_units=256,
|
326 |
-
num_attention_heads=8,
|
327 |
-
warmup_steps=int(total_steps * 0.1),
|
328 |
-
learning_rate=0.0003,
|
329 |
-
margin=0.5,
|
330 |
-
pretrained_model='distilbert-base-uncased'
|
331 |
-
)
|
332 |
-
|
333 |
-
# Preload models
|
334 |
-
preload_models(config, cache_dir)
|
335 |
-
|
336 |
-
# Save configuration
|
337 |
-
with open(dirs['base'] / 'config.json', 'w') as f:
|
338 |
-
json.dump(config.to_dict(), f, indent=2)
|
339 |
-
|
340 |
-
# Initialize chatbot
|
341 |
-
chatbot = RetrievalChatbot(config)
|
342 |
-
|
343 |
-
# Prepare dataset
|
344 |
-
logger.info("Preparing dataset...")
|
345 |
-
|
346 |
-
# Prepare and train with debug samples
|
347 |
-
q_pad, p_pad, n_pad = chatbot.prepare_dataset(
|
348 |
-
dialogues,
|
349 |
-
neg_samples_per_pos=3,
|
350 |
-
debug_samples=DEBUG_SAMPLES
|
351 |
-
)
|
352 |
-
|
353 |
-
tensorboard_callback = tf.keras.callbacks.TensorBoard(log_dir='./logs')
|
354 |
-
|
355 |
-
# Train model
|
356 |
-
logger.info("Starting training...")
|
357 |
-
chatbot.train(
|
358 |
-
q_pad, p_pad, n_pad,
|
359 |
-
epochs=EPOCHS,
|
360 |
-
batch_size=BATCH_SIZE,
|
361 |
-
validation_split=0.2,
|
362 |
-
checkpoint_dir=dirs['checkpoints'],
|
363 |
-
callbacks=[tensorboard_callback]
|
364 |
-
)
|
365 |
-
|
366 |
-
# Plot and save training history
|
367 |
-
plot_training_history(chatbot.history, save_dir=dirs['plots'])
|
368 |
-
|
369 |
-
# Save final model
|
370 |
-
chatbot.save_models(dirs['base'] / 'final_model')
|
371 |
-
|
372 |
-
# Prepare response pool for chat
|
373 |
-
response_pool = [
|
374 |
-
turn['text'] for d in dialogues
|
375 |
-
for turn in d['turns'] if turn['speaker'] == 'assistant'
|
376 |
-
]
|
377 |
-
|
378 |
-
# Initialize quality checker with appropriate thresholds
|
379 |
-
quality_checker = ResponseQualityChecker(
|
380 |
-
confidence_threshold=0.6 if not DEBUG_SAMPLES else 0.4, # Lower threshold for debug
|
381 |
-
diversity_threshold=0.2,
|
382 |
-
min_response_length=10,
|
383 |
-
max_similarity_ratio=0.9
|
384 |
-
)
|
385 |
-
|
386 |
-
# Run automatic validation
|
387 |
-
validation_metrics = run_automatic_validation(
|
388 |
-
chatbot,
|
389 |
-
response_pool,
|
390 |
-
quality_checker,
|
391 |
-
num_examples=5 if DEBUG_SAMPLES else 10
|
392 |
-
)
|
393 |
-
|
394 |
-
# Log validation metrics
|
395 |
-
logger.info(f"Validation Metrics: {validation_metrics}")
|
396 |
-
|
397 |
-
# Now continue with interactive chat
|
398 |
-
logger.info("\nStarting interactive chat session...")
|
399 |
-
conversation_history = []
|
400 |
-
|
401 |
-
while True:
|
402 |
-
query = input("\nYou: ")
|
403 |
-
if query.lower() in ['quit', 'exit', 'bye']:
|
404 |
-
break
|
405 |
-
|
406 |
-
try:
|
407 |
-
response, candidates, quality_metrics = chat_with_quality_check(
|
408 |
-
chatbot,
|
409 |
-
query,
|
410 |
-
response_pool,
|
411 |
-
conversation_history,
|
412 |
-
quality_checker
|
413 |
-
)
|
414 |
-
print(f"\nAssistant: {response}")
|
415 |
-
|
416 |
-
# Print top alternative responses if confident
|
417 |
-
if quality_metrics['is_confident']:
|
418 |
-
print("\nAlternative responses:")
|
419 |
-
for resp, score in candidates[1:4]:
|
420 |
-
print(f"Score: {score:.4f} - {resp}")
|
421 |
-
|
422 |
-
# Update history only for confident responses
|
423 |
-
conversation_history.append((query, response))
|
424 |
-
else:
|
425 |
-
print("\nQuality metrics indicated low confidence:")
|
426 |
-
print(f"Confidence score: {quality_metrics['top_score']:.4f}")
|
427 |
-
print(f"Response relevance: {quality_metrics['query_response_relevance']:.4f}")
|
428 |
-
|
429 |
-
except Exception as e:
|
430 |
-
logger.error(f"Error during chat: {str(e)}")
|
431 |
-
print("Sorry, I encountered an error. Please try again.")
|
432 |
-
|
433 |
-
if __name__ == "__main__":
|
434 |
-
main()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
run_model4.py
DELETED
@@ -1,237 +0,0 @@
|
|
1 |
-
from chatbot4 import RetrievalChatbot, ChatbotConfig
|
2 |
-
import os
|
3 |
-
import tensorflow as tf
|
4 |
-
import matplotlib.pyplot as plt
|
5 |
-
import logging
|
6 |
-
from pathlib import Path
|
7 |
-
from typing import List, Dict, Optional
|
8 |
-
from datetime import datetime
|
9 |
-
from response_quality_checker import ResponseQualityChecker
|
10 |
-
|
11 |
-
# Configure logging
|
12 |
-
logging.basicConfig(
|
13 |
-
level=logging.INFO,
|
14 |
-
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
|
15 |
-
)
|
16 |
-
logger = logging.getLogger(__name__)
|
17 |
-
|
18 |
-
def setup_model_cache(cache_dir: Optional[Path] = None) -> Path:
|
19 |
-
"""Setup and manage model cache directory."""
|
20 |
-
if cache_dir is None:
|
21 |
-
cache_dir = Path.home() / '.chatbot_cache'
|
22 |
-
|
23 |
-
cache_dir.mkdir(parents=True, exist_ok=True)
|
24 |
-
|
25 |
-
# Set environment variables for various libraries
|
26 |
-
os.environ['TRANSFORMERS_CACHE'] = str(cache_dir / 'transformers')
|
27 |
-
os.environ['TORCH_HOME'] = str(cache_dir / 'torch')
|
28 |
-
os.environ['HF_HOME'] = str(cache_dir / 'huggingface')
|
29 |
-
|
30 |
-
logger.info(f"Using cache directory: {cache_dir}")
|
31 |
-
return cache_dir
|
32 |
-
|
33 |
-
def setup_gpu():
|
34 |
-
"""Configure GPU settings for optimal performance."""
|
35 |
-
logger.info("Checking GPU availability...")
|
36 |
-
|
37 |
-
gpus = tf.config.list_physical_devices('GPU')
|
38 |
-
if gpus:
|
39 |
-
try:
|
40 |
-
# Allow memory growth to prevent taking all GPU memory at once
|
41 |
-
for gpu in gpus:
|
42 |
-
tf.config.experimental.set_memory_growth(gpu, True)
|
43 |
-
logger.info(f"Found {len(gpus)} GPU(s). Memory growth enabled.")
|
44 |
-
|
45 |
-
# Log GPU info
|
46 |
-
for gpu in gpus:
|
47 |
-
logger.info(f"GPU Device: {gpu}")
|
48 |
-
return True
|
49 |
-
except Exception as e:
|
50 |
-
logger.error(f"Error configuring GPU: {str(e)}")
|
51 |
-
return False
|
52 |
-
|
53 |
-
else:
|
54 |
-
logger.info("No GPU found. Using CPU.")
|
55 |
-
return False
|
56 |
-
|
57 |
-
# def preload_models(config: ChatbotConfig, cache_dir: Path):
|
58 |
-
# """Preload and cache models."""
|
59 |
-
# logger.info("Preloading models...")
|
60 |
-
|
61 |
-
# # Cache DistilBERT
|
62 |
-
# model_name = config.pretrained_model
|
63 |
-
# cache_path = cache_dir / 'transformers' / model_name
|
64 |
-
|
65 |
-
# if not cache_path.exists():
|
66 |
-
# logger.info(f"Downloading and caching {model_name}...")
|
67 |
-
# tokenizer = AutoTokenizer.from_pretrained(model_name)
|
68 |
-
# model = TFAutoModel.from_pretrained(model_name)
|
69 |
-
|
70 |
-
# # Save to cache
|
71 |
-
# tokenizer.save_pretrained(cache_path)
|
72 |
-
# model.save_pretrained(cache_path)
|
73 |
-
# else:
|
74 |
-
# logger.info(f"Using cached model from {cache_path}")
|
75 |
-
|
76 |
-
# return cache_path
|
77 |
-
|
78 |
-
def plot_training_history(history: Dict[str, List[float]], save_dir: Path = None):
|
79 |
-
"""Plot and optionally save training history."""
|
80 |
-
# Create figure with two subplots
|
81 |
-
fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(10, 12))
|
82 |
-
|
83 |
-
# Plot losses
|
84 |
-
ax1.plot(history['train_loss'], label='Train Loss')
|
85 |
-
ax1.plot(history['val_loss'], label='Validation Loss')
|
86 |
-
ax1.set_xlabel('Epoch')
|
87 |
-
ax1.set_ylabel('Triplet Loss')
|
88 |
-
ax1.set_title('Training and Validation Loss')
|
89 |
-
ax1.legend()
|
90 |
-
ax1.grid(True)
|
91 |
-
|
92 |
-
# Plot learning rate if available
|
93 |
-
if 'learning_rate' in history:
|
94 |
-
ax2.plot(history['learning_rate'], label='Learning Rate')
|
95 |
-
ax2.set_xlabel('Step')
|
96 |
-
ax2.set_ylabel('Learning Rate')
|
97 |
-
ax2.set_title('Learning Rate Schedule')
|
98 |
-
ax2.legend()
|
99 |
-
ax2.grid(True)
|
100 |
-
|
101 |
-
plt.tight_layout()
|
102 |
-
|
103 |
-
# Save if directory provided
|
104 |
-
if save_dir:
|
105 |
-
save_dir = Path(save_dir)
|
106 |
-
save_dir.mkdir(parents=True, exist_ok=True)
|
107 |
-
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
108 |
-
plt.savefig(save_dir / f'training_history_{timestamp}.png')
|
109 |
-
|
110 |
-
plt.show()
|
111 |
-
|
112 |
-
def setup_training_directories(base_dir: str = "chatbot_training") -> Dict[str, Path]:
|
113 |
-
"""Setup directory structure for training artifacts."""
|
114 |
-
base_dir = Path(base_dir)
|
115 |
-
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
116 |
-
train_dir = base_dir / f"training_run_{timestamp}"
|
117 |
-
|
118 |
-
directories = {
|
119 |
-
'base': train_dir,
|
120 |
-
'checkpoints': train_dir / 'checkpoints',
|
121 |
-
'plots': train_dir / 'plots',
|
122 |
-
'logs': train_dir / 'logs'
|
123 |
-
}
|
124 |
-
|
125 |
-
# Create directories
|
126 |
-
for dir_path in directories.values():
|
127 |
-
dir_path.mkdir(parents=True, exist_ok=True)
|
128 |
-
|
129 |
-
return directories
|
130 |
-
|
131 |
-
def main():
|
132 |
-
# Set up GPU
|
133 |
-
is_gpu = setup_gpu()
|
134 |
-
|
135 |
-
DEBUG_SAMPLES = 2000
|
136 |
-
BATCH_SIZE = 128 if is_gpu else 64
|
137 |
-
EPOCHS = 5 if DEBUG_SAMPLES else 10
|
138 |
-
|
139 |
-
# Set up caching
|
140 |
-
cache_dir = setup_model_cache()
|
141 |
-
|
142 |
-
# Set up training directories
|
143 |
-
dirs = setup_training_directories()
|
144 |
-
|
145 |
-
# Initialize configuration
|
146 |
-
config = ChatbotConfig(
|
147 |
-
embedding_dim=768, # Match DistilBERT's dimension
|
148 |
-
max_sequence_length=512,
|
149 |
-
freeze_embeddings=False
|
150 |
-
)
|
151 |
-
|
152 |
-
# Preload models
|
153 |
-
#preload_models(config, cache_dir)
|
154 |
-
|
155 |
-
# Save configuration
|
156 |
-
# with open(dirs['base'] / 'config.json', 'w') as f:
|
157 |
-
# json.dump(config.to_dict(), f, indent=4)
|
158 |
-
|
159 |
-
# Load training data
|
160 |
-
dialogues = RetrievalChatbot.load_training_data(data_path='processed_outputs/batch_group_0010.json', debug_samples=DEBUG_SAMPLES)
|
161 |
-
|
162 |
-
# Initialize chatbot
|
163 |
-
chatbot = RetrievalChatbot(config, dialogues)
|
164 |
-
|
165 |
-
# Check trainable variables
|
166 |
-
chatbot.check_trainable_variables()
|
167 |
-
|
168 |
-
# Verify FAISS
|
169 |
-
chatbot.verify_faiss_index()
|
170 |
-
|
171 |
-
# Prepare dataset
|
172 |
-
logger.info("Preparing dataset...")
|
173 |
-
|
174 |
-
# Prepare and train with debug samples
|
175 |
-
q_tensor, p_tensor = chatbot.prepare_dataset(dialogues)
|
176 |
-
|
177 |
-
quality_checker = ResponseQualityChecker(chatbot=chatbot)
|
178 |
-
|
179 |
-
# Train model
|
180 |
-
logger.info("Starting training...")
|
181 |
-
|
182 |
-
tf.config.optimizer.set_jit(True) # XLA
|
183 |
-
policy = tf.keras.mixed_precision.Policy('mixed_float16')
|
184 |
-
tf.keras.mixed_precision.set_global_policy(policy)
|
185 |
-
|
186 |
-
chatbot.train(
|
187 |
-
q_pad=q_tensor,
|
188 |
-
p_pad=p_tensor,
|
189 |
-
epochs=EPOCHS,
|
190 |
-
batch_size=BATCH_SIZE,
|
191 |
-
validation_split=0.2,
|
192 |
-
checkpoint_dir="checkpoints/",
|
193 |
-
use_lr_schedule=True, # Enable custom schedule
|
194 |
-
peak_lr=2e-5, # Peak learning rate
|
195 |
-
warmup_steps_ratio=0.1, # 10% warmup
|
196 |
-
early_stopping_patience=3 # Stop if no improvement for 3 epochs
|
197 |
-
)
|
198 |
-
|
199 |
-
# Plot and save training history
|
200 |
-
#plot_training_history(chatbot.history, save_dir=dirs['plots'])
|
201 |
-
|
202 |
-
# Save final model
|
203 |
-
chatbot.save_models(dirs['base'] / 'final_model')
|
204 |
-
|
205 |
-
# Run automatic validation
|
206 |
-
validation_metrics = chatbot.run_automatic_validation(quality_checker, num_examples=5)
|
207 |
-
|
208 |
-
# Log validation metrics
|
209 |
-
logger.info(f"Validation Metrics: {validation_metrics}")
|
210 |
-
|
211 |
-
# Now continue with interactive chat
|
212 |
-
logger.info("\nStarting interactive chat session...")
|
213 |
-
conversation_history = []
|
214 |
-
|
215 |
-
while True:
|
216 |
-
user_input = input("You: ")
|
217 |
-
if user_input.lower() in ['quit', 'exit', 'bye']:
|
218 |
-
print("Assistant: Goodbye!")
|
219 |
-
break
|
220 |
-
|
221 |
-
response, candidates, metrics = chatbot.chat(
|
222 |
-
query=user_input,
|
223 |
-
conversation_history=None, # Pass conversation history if available
|
224 |
-
quality_checker=quality_checker,
|
225 |
-
top_k=5
|
226 |
-
)
|
227 |
-
|
228 |
-
print(f"Assistant: {response}")
|
229 |
-
|
230 |
-
# Optionally, display alternative responses
|
231 |
-
if metrics.get('is_confident', False):
|
232 |
-
print("\nAlternative responses:")
|
233 |
-
for resp, score in candidates[1:4]:
|
234 |
-
print(f"Score: {score:.4f} - {resp}")
|
235 |
-
|
236 |
-
if __name__ == "__main__":
|
237 |
-
main()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
run_model_train.py
ADDED
@@ -0,0 +1,96 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from chatbot_model import RetrievalChatbot, ChatbotConfig
|
2 |
+
from environment_setup import EnvironmentSetup
|
3 |
+
from response_quality_checker import ResponseQualityChecker
|
4 |
+
from chatbot_validator import ChatbotValidator
|
5 |
+
from training_plotter import TrainingPlotter
|
6 |
+
|
7 |
+
|
8 |
+
# Configure logging
|
9 |
+
from logger_config import config_logger
|
10 |
+
logger = config_logger(__name__)
|
11 |
+
|
12 |
+
def run_interactive_chat(chatbot, quality_checker):
|
13 |
+
"""Separate function for interactive chat loop"""
|
14 |
+
while True:
|
15 |
+
user_input = input("You: ")
|
16 |
+
if user_input.lower() in ['quit', 'exit', 'bye']:
|
17 |
+
print("Assistant: Goodbye!")
|
18 |
+
break
|
19 |
+
|
20 |
+
response, candidates, metrics = chatbot.chat(
|
21 |
+
query=user_input,
|
22 |
+
conversation_history=None,
|
23 |
+
quality_checker=quality_checker,
|
24 |
+
top_k=5
|
25 |
+
)
|
26 |
+
|
27 |
+
print(f"Assistant: {response}")
|
28 |
+
|
29 |
+
if metrics.get('is_confident', False):
|
30 |
+
print("\nAlternative responses:")
|
31 |
+
for resp, score in candidates[1:4]:
|
32 |
+
print(f"Score: {score:.4f} - {resp}")
|
33 |
+
|
34 |
+
def main():
|
35 |
+
# Initialize environment
|
36 |
+
env = EnvironmentSetup()
|
37 |
+
env.initialize()
|
38 |
+
|
39 |
+
DEBUG_SAMPLES = 5
|
40 |
+
EPOCHS = 1 if DEBUG_SAMPLES else 20
|
41 |
+
TRAINING_DATA_PATH = 'processed_outputs/batch_group_0010.json'
|
42 |
+
|
43 |
+
# Optimize batch size for Colab
|
44 |
+
batch_size = env.optimize_batch_size(base_batch_size=16)
|
45 |
+
|
46 |
+
# Initialize configuration
|
47 |
+
config = ChatbotConfig(
|
48 |
+
embedding_dim=512, # 768, # Match DistilBERT's dimension
|
49 |
+
max_context_token_limit=512,
|
50 |
+
freeze_embeddings=False,
|
51 |
+
)
|
52 |
+
|
53 |
+
# Load training data
|
54 |
+
dialogues = RetrievalChatbot.load_training_data(data_path=TRAINING_DATA_PATH, debug_samples=DEBUG_SAMPLES)
|
55 |
+
|
56 |
+
# Initialize chatbot and verify FAISS index
|
57 |
+
with env.strategy.scope():
|
58 |
+
chatbot = RetrievalChatbot(config, dialogues)
|
59 |
+
chatbot.verify_faiss_index()
|
60 |
+
|
61 |
+
# Prepare dataset
|
62 |
+
logger.info("Preparing dataset...")
|
63 |
+
q_tensor, p_tensor = chatbot.prepare_dataset(dialogues)
|
64 |
+
quality_checker = ResponseQualityChecker(chatbot=chatbot)
|
65 |
+
|
66 |
+
# Train model
|
67 |
+
logger.info("Starting training...")
|
68 |
+
chatbot.train(
|
69 |
+
q_pad=q_tensor,
|
70 |
+
p_pad=p_tensor,
|
71 |
+
epochs=EPOCHS,
|
72 |
+
batch_size=batch_size,
|
73 |
+
validation_split=0.2,
|
74 |
+
)
|
75 |
+
|
76 |
+
# Save final model
|
77 |
+
model_save_path = env.training_dirs['base'] / 'final_model'
|
78 |
+
chatbot.save_models(model_save_path)
|
79 |
+
|
80 |
+
# Run automatic validation
|
81 |
+
quality_checker = ResponseQualityChecker(chatbot=chatbot)
|
82 |
+
validator = ChatbotValidator(chatbot, quality_checker)
|
83 |
+
validation_metrics = validator.run_validation(num_examples=5)
|
84 |
+
logger.info(f"Validation Metrics: {validation_metrics}")
|
85 |
+
|
86 |
+
# Plot and save training history
|
87 |
+
plotter = TrainingPlotter(save_dir=env.training_dirs['plots'])
|
88 |
+
plotter.plot_training_history(chatbot.history)
|
89 |
+
plotter.plot_validation_metrics(validation_metrics)
|
90 |
+
|
91 |
+
# Run interactive chat
|
92 |
+
logger.info("\nStarting interactive chat session...")
|
93 |
+
run_interactive_chat(chatbot, quality_checker)
|
94 |
+
|
95 |
+
if __name__ == "__main__":
|
96 |
+
main()
|
setup.py
CHANGED
@@ -25,6 +25,27 @@ def setup_spacy_models(models=['en_core_web_sm', 'en_core_web_md']):
|
|
25 |
print(f"Error downloading spaCy model: {model}")
|
26 |
print(e)
|
27 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
28 |
def setup_models():
|
29 |
"""
|
30 |
Download other required models.
|
@@ -40,7 +61,7 @@ def setup_models():
|
|
40 |
DistilBertModel
|
41 |
)
|
42 |
|
43 |
-
#
|
44 |
tokenizer = DistilBertTokenizer.from_pretrained('distilbert-base-uncased')
|
45 |
model = DistilBertModel.from_pretrained('distilbert-base-uncased')
|
46 |
|
@@ -116,11 +137,11 @@ def setup_faiss():
|
|
116 |
print(e)
|
117 |
|
118 |
setup(
|
119 |
-
name="
|
120 |
-
version="0.
|
121 |
author="Joe Armani",
|
122 |
author_email="[email protected]",
|
123 |
-
description="A
|
124 |
long_description=long_description,
|
125 |
long_description_content_type="text/markdown",
|
126 |
packages=find_packages(),
|
@@ -132,24 +153,39 @@ setup(
|
|
132 |
"Programming Language :: Python :: 3",
|
133 |
"Programming Language :: Python :: 3.8",
|
134 |
"Programming Language :: Python :: 3.9",
|
|
|
135 |
"Topic :: Scientific/Engineering :: Artificial Intelligence",
|
136 |
"Topic :: Text Processing :: Linguistic",
|
137 |
],
|
138 |
python_requires=">=3.8",
|
139 |
install_requires=requirements,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
140 |
entry_points={
|
141 |
"console_scripts": [
|
142 |
"dialogue-augment=dialogue_augmenter.main:main",
|
|
|
143 |
],
|
144 |
},
|
145 |
include_package_data=True,
|
146 |
package_data={
|
|
|
147 |
"dialogue_augmenter": ["data/*.json", "config/*.yaml"],
|
148 |
},
|
149 |
)
|
150 |
|
151 |
if __name__ == '__main__':
|
152 |
setup_spacy_models()
|
|
|
153 |
setup_models()
|
154 |
setup_nltk()
|
155 |
setup_faiss()
|
|
|
25 |
print(f"Error downloading spaCy model: {model}")
|
26 |
print(e)
|
27 |
|
28 |
+
def setup_gpu_dependencies():
|
29 |
+
"""Setup GPU-specific dependencies."""
|
30 |
+
cuda_available = False
|
31 |
+
|
32 |
+
# Check CUDA availability
|
33 |
+
try:
|
34 |
+
import torch
|
35 |
+
cuda_available = torch.cuda.is_available()
|
36 |
+
except ImportError:
|
37 |
+
pass
|
38 |
+
|
39 |
+
if cuda_available:
|
40 |
+
try:
|
41 |
+
subprocess.check_call([sys.executable, "-m", "pip", "install", "faiss-gpu>=1.7.0"])
|
42 |
+
print("Successfully installed faiss-gpu")
|
43 |
+
except subprocess.CalledProcessError:
|
44 |
+
print("Failed to install faiss-gpu. Falling back to faiss-cpu")
|
45 |
+
subprocess.check_call([sys.executable, "-m", "pip", "install", "faiss-cpu>=1.7.0"])
|
46 |
+
else:
|
47 |
+
subprocess.check_call([sys.executable, "-m", "pip", "install", "faiss-cpu>=1.7.0"])
|
48 |
+
|
49 |
def setup_models():
|
50 |
"""
|
51 |
Download other required models.
|
|
|
61 |
DistilBertModel
|
62 |
)
|
63 |
|
64 |
+
# Cache the models
|
65 |
tokenizer = DistilBertTokenizer.from_pretrained('distilbert-base-uncased')
|
66 |
model = DistilBertModel.from_pretrained('distilbert-base-uncased')
|
67 |
|
|
|
137 |
print(e)
|
138 |
|
139 |
setup(
|
140 |
+
name="retrieval-chatbot",
|
141 |
+
version="0.2.0",
|
142 |
author="Joe Armani",
|
143 |
author_email="[email protected]",
|
144 |
+
description="A retrieval-based chatbot with enhanced validation",
|
145 |
long_description=long_description,
|
146 |
long_description_content_type="text/markdown",
|
147 |
packages=find_packages(),
|
|
|
153 |
"Programming Language :: Python :: 3",
|
154 |
"Programming Language :: Python :: 3.8",
|
155 |
"Programming Language :: Python :: 3.9",
|
156 |
+
"Programming Language :: Python :: 3.10",
|
157 |
"Topic :: Scientific/Engineering :: Artificial Intelligence",
|
158 |
"Topic :: Text Processing :: Linguistic",
|
159 |
],
|
160 |
python_requires=">=3.8",
|
161 |
install_requires=requirements,
|
162 |
+
extras_require={
|
163 |
+
'dev': [
|
164 |
+
'pytest>=7.0.0',
|
165 |
+
'black>=22.0.0',
|
166 |
+
'isort>=5.10.0',
|
167 |
+
'mypy>=1.0.0',
|
168 |
+
],
|
169 |
+
'gpu': [
|
170 |
+
'faiss-gpu>=1.7.0',
|
171 |
+
],
|
172 |
+
},
|
173 |
entry_points={
|
174 |
"console_scripts": [
|
175 |
"dialogue-augment=dialogue_augmenter.main:main",
|
176 |
+
"run-chatbot=chatbot.main:main",
|
177 |
],
|
178 |
},
|
179 |
include_package_data=True,
|
180 |
package_data={
|
181 |
+
"chatbot": ["config/*.yaml"],
|
182 |
"dialogue_augmenter": ["data/*.json", "config/*.yaml"],
|
183 |
},
|
184 |
)
|
185 |
|
186 |
if __name__ == '__main__':
|
187 |
setup_spacy_models()
|
188 |
+
setup_gpu_dependencies()
|
189 |
setup_models()
|
190 |
setup_nltk()
|
191 |
setup_faiss()
|
training_plotter.py
ADDED
@@ -0,0 +1,127 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from pathlib import Path
|
2 |
+
from typing import Dict, List, Optional
|
3 |
+
import matplotlib.pyplot as plt
|
4 |
+
from datetime import datetime
|
5 |
+
import logging
|
6 |
+
|
7 |
+
logger = logging.getLogger(__name__)
|
8 |
+
|
9 |
+
class TrainingPlotter:
|
10 |
+
def __init__(self, save_dir: Optional[Path] = None):
|
11 |
+
self.save_dir = save_dir
|
12 |
+
if save_dir:
|
13 |
+
self.save_dir.mkdir(parents=True, exist_ok=True)
|
14 |
+
|
15 |
+
def plot_training_history(self, history: Dict[str, List[float]], title: str = "Training History"):
|
16 |
+
"""Plot and optionally save training metrics history.
|
17 |
+
|
18 |
+
Args:
|
19 |
+
history: Dictionary containing training metrics
|
20 |
+
title: Title for the plot
|
21 |
+
"""
|
22 |
+
# Silence matplotlib debug messages
|
23 |
+
logger.setLevel(logging.WARNING)
|
24 |
+
|
25 |
+
# Create figure with subplots
|
26 |
+
fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(10, 12))
|
27 |
+
|
28 |
+
# Plot losses
|
29 |
+
ax1.plot(history['train_loss'], label='Train Loss')
|
30 |
+
ax1.plot(history['val_loss'], label='Validation Loss')
|
31 |
+
ax1.set_xlabel('Epoch')
|
32 |
+
ax1.set_ylabel('Loss')
|
33 |
+
ax1.set_title('Training and Validation Loss')
|
34 |
+
ax1.legend()
|
35 |
+
ax1.grid(True)
|
36 |
+
|
37 |
+
# Plot learning rate if available
|
38 |
+
if 'learning_rate' in history:
|
39 |
+
ax2.plot(history['learning_rate'], label='Learning Rate')
|
40 |
+
ax2.set_xlabel('Step')
|
41 |
+
ax2.set_ylabel('Learning Rate')
|
42 |
+
ax2.set_title('Learning Rate Schedule')
|
43 |
+
ax2.legend()
|
44 |
+
ax2.grid(True)
|
45 |
+
|
46 |
+
plt.suptitle(title)
|
47 |
+
plt.tight_layout()
|
48 |
+
|
49 |
+
# Reset the logger level
|
50 |
+
logger.setLevel(logging.INFO)
|
51 |
+
|
52 |
+
# Save if directory provided
|
53 |
+
if self.save_dir:
|
54 |
+
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
55 |
+
save_path = self.save_dir / f'training_history_{timestamp}.png'
|
56 |
+
plt.savefig(save_path)
|
57 |
+
logger.info(f"Saved training history plot to {save_path}")
|
58 |
+
|
59 |
+
plt.show()
|
60 |
+
|
61 |
+
def plot_validation_metrics(self, metrics: Dict[str, float]):
|
62 |
+
"""Plot validation metrics as a bar chart.
|
63 |
+
|
64 |
+
Args:
|
65 |
+
metrics: Dictionary of validation metrics. Can handle nested dictionaries.
|
66 |
+
"""
|
67 |
+
# Silence matplotlib debug messages
|
68 |
+
logger.setLevel(logging.WARNING)
|
69 |
+
|
70 |
+
# Flatten nested metrics dictionary
|
71 |
+
flat_metrics = {}
|
72 |
+
for key, value in metrics.items():
|
73 |
+
# Skip num_queries_tested
|
74 |
+
if key == 'num_queries_tested':
|
75 |
+
continue
|
76 |
+
|
77 |
+
if isinstance(value, dict):
|
78 |
+
# If value is a dictionary, flatten it with key prefix
|
79 |
+
for subkey, subvalue in value.items():
|
80 |
+
if isinstance(subvalue, (int, float)): # Only include numeric values
|
81 |
+
flat_metrics[f"{key}_{subkey}"] = subvalue
|
82 |
+
elif isinstance(value, (int, float)): # Only include numeric values
|
83 |
+
flat_metrics[key] = value
|
84 |
+
|
85 |
+
if not flat_metrics:
|
86 |
+
logger.warning("No numeric metrics to plot")
|
87 |
+
return
|
88 |
+
|
89 |
+
plt.figure(figsize=(12, 6))
|
90 |
+
|
91 |
+
# Extract metrics and values
|
92 |
+
metric_names = list(flat_metrics.keys())
|
93 |
+
values = list(flat_metrics.values())
|
94 |
+
|
95 |
+
# Create bar chart
|
96 |
+
bars = plt.bar(range(len(metric_names)), values)
|
97 |
+
|
98 |
+
# Customize the plot
|
99 |
+
plt.title('Validation Metrics')
|
100 |
+
plt.xticks(range(len(metric_names)), metric_names, rotation=45, ha='right')
|
101 |
+
plt.ylabel('Value')
|
102 |
+
|
103 |
+
# Add value labels on top of bars
|
104 |
+
for bar in bars:
|
105 |
+
height = bar.get_height()
|
106 |
+
plt.text(bar.get_x() + bar.get_width()/2., height,
|
107 |
+
f'{height:.3f}',
|
108 |
+
ha='center', va='bottom')
|
109 |
+
|
110 |
+
# Set y-axis limits to focus on metrics between 0 and 1
|
111 |
+
plt.ylim(0, 1.1) # Slight padding above 1 for label visibility
|
112 |
+
|
113 |
+
# Adjust layout to prevent label cutoff
|
114 |
+
plt.tight_layout()
|
115 |
+
|
116 |
+
# Reset the logger level
|
117 |
+
logger.setLevel(logging.INFO)
|
118 |
+
|
119 |
+
# Save if directory provided
|
120 |
+
if self.save_dir:
|
121 |
+
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
122 |
+
save_path = self.save_dir / f'validation_metrics_{timestamp}.png'
|
123 |
+
plt.savefig(save_path)
|
124 |
+
logger.info(f"Saved validation metrics plot to {save_path}")
|
125 |
+
|
126 |
+
plt.show()
|
127 |
+
|