JoeArmani commited on
Commit
f7b283c
·
1 Parent(s): 300fe5d

summarization, reranker, environment setup, and response quality checker

Browse files
chatbot.py DELETED
@@ -1,261 +0,0 @@
1
- import numpy as np
2
- import tensorflow as tf
3
- import keras
4
- print(tf.__version__)
5
- print(keras.__version__)
6
- import spacy
7
- import random
8
- from tqdm import trange
9
-
10
- class RetrievalChatbot:
11
- def __init__(
12
- self,
13
- vocab_size: int = 10000,
14
- max_sequence_length: int = 80,
15
- embedding_dim: int = 256,
16
- lstm_units: int = 256,
17
- num_attention_heads: int = 8,
18
- margin: float = 0.3
19
- ):
20
- self.vocab_size = vocab_size
21
- self.max_sequence_length = max_sequence_length
22
- self.embedding_dim = embedding_dim
23
- self.lstm_units = lstm_units
24
- self.num_attention_heads = num_attention_heads
25
- self.margin = margin
26
-
27
- self.nlp = spacy.load('en_core_web_md')
28
- self.tokenizer = tf.keras.preprocessing.text.Tokenizer(
29
- num_words=vocab_size,
30
- oov_token="<OOV>"
31
- )
32
-
33
- self.query_encoder_model, self.response_encoder_model = self._build_encoders()
34
-
35
- def _positional_encoding(self, position: int, d_model: int) -> tf.Tensor:
36
- angles = np.arange(position)[:, np.newaxis] / np.power(
37
- 10000,
38
- (2 * (np.arange(d_model)[np.newaxis, :] // 2)) / d_model
39
- )
40
- sines = np.sin(angles[:, 0::2])
41
- cosines = np.cos(angles[:, 1::2])
42
- pos_encoding = np.concatenate([sines, cosines], axis=-1)
43
- pos_encoding = pos_encoding[np.newaxis, ...]
44
- return tf.cast(pos_encoding, dtype=tf.float32)
45
-
46
- def _build_single_encoder(self, name_prefix: str):
47
- input_layer = tf.keras.Input(shape=(self.max_sequence_length,), name=f"{name_prefix}_input")
48
- embedding = tf.keras.layers.Embedding(
49
- self.vocab_size,
50
- self.embedding_dim,
51
- mask_zero=True,
52
- name=f"{name_prefix}_embedding"
53
- )(input_layer)
54
-
55
- pos_encoding = self._positional_encoding(self.max_sequence_length, self.embedding_dim)
56
- x = embedding + pos_encoding
57
-
58
- # # Multi-head attention
59
- # attention_output = tf.keras.layers.MultiHeadAttention(
60
- # num_heads=self.num_attention_heads,
61
- # key_dim=self.embedding_dim // self.num_attention_heads
62
- # )(x, x)
63
- # x = tf.keras.layers.LayerNormalization()(x + attention_output)
64
-
65
- for i in range(2):
66
- lstm_out = tf.keras.layers.LSTM(
67
- self.lstm_units,
68
- return_sequences=True,
69
- kernel_regularizer=tf.keras.regularizers.l2(0.01),
70
- name=f"{name_prefix}_lstm_{i}"
71
- )(x)
72
- x = tf.keras.layers.LayerNormalization()(x + lstm_out)
73
-
74
- encoder_output = tf.keras.layers.LSTM(
75
- self.lstm_units,
76
- name=f"{name_prefix}_final_lstm"
77
- )(x)
78
- encoder_output = tf.keras.layers.Dropout(0.2)(encoder_output)
79
- encoder_output = tf.keras.layers.Lambda(lambda t: tf.nn.l2_normalize(t, axis=1))(encoder_output)
80
-
81
- return tf.keras.Model(input_layer, encoder_output, name=f"{name_prefix}_encoder")
82
-
83
- def _build_encoders(self):
84
- query_encoder = self._build_single_encoder("query")
85
- response_encoder = self._build_single_encoder("response")
86
- return query_encoder, response_encoder
87
-
88
- def _spacy_similarity(self, text1: str, text2: str) -> float:
89
- doc1 = self.nlp(text1)
90
- doc2 = self.nlp(text2)
91
- print('doc1:', doc1)
92
- print('doc2:', doc2)
93
- print('doc1.similarity(doc2):', doc1.similarity(doc2))
94
- return doc1.similarity(doc2)
95
-
96
- def prepare_dataset(self, dialogues: list, neg_samples_per_pos=3):
97
- # Create triplets: (query, positive, negative)
98
- response_pool = [
99
- turn['text'] for d in dialogues for turn in d['turns'] if turn['speaker'] == 'assistant'
100
- ]
101
- queries, positives, negatives = [], [], []
102
-
103
- for dialogue in dialogues:
104
- turns = dialogue['turns']
105
- for i in range(0, len(turns)-1):
106
- if turns[i]['speaker'] == 'user' and turns[i+1]['speaker'] == 'assistant':
107
- q = turns[i]['text']
108
- p = turns[i+1]['text']
109
-
110
- # Find negatives using spaCy similarity
111
- neg_candidates = []
112
- attempts = 0
113
- while len(neg_candidates) < neg_samples_per_pos and attempts < 200:
114
- cand = random.choice(response_pool)
115
- if cand != p:
116
- sim = self._spacy_similarity(cand, p)
117
- # Choose thresholds that produce hard negatives
118
- if 0.4 < sim < 0.9:
119
- neg_candidates.append(cand)
120
- attempts += 1
121
-
122
- if len(neg_candidates) == neg_samples_per_pos:
123
- for neg in neg_candidates:
124
- queries.append(q)
125
- positives.append(p)
126
- negatives.append(neg)
127
-
128
- # Fit tokenizer
129
- all_text = queries + positives + negatives
130
- self.tokenizer.fit_on_texts(all_text)
131
-
132
- def seq_pad(txts):
133
- seq = self.tokenizer.texts_to_sequences(txts)
134
- return tf.keras.preprocessing.sequence.pad_sequences(seq, maxlen=self.max_sequence_length, padding='post')
135
-
136
- q_pad = seq_pad(queries)
137
- p_pad = seq_pad(positives)
138
- n_pad = seq_pad(negatives)
139
-
140
- return q_pad, p_pad, n_pad
141
-
142
- def triplet_loss(self, q_emb, p_emb, n_emb):
143
- pos_dist = tf.reduce_sum(tf.square(q_emb - p_emb), axis=1)
144
- neg_dist = tf.reduce_sum(tf.square(q_emb - n_emb), axis=1)
145
- loss = tf.maximum(0.0, self.margin + pos_dist - neg_dist)
146
- return tf.reduce_mean(loss)
147
-
148
- def train_with_triplet_loss(
149
- self, q_pad, p_pad, n_pad,
150
- epochs=3,
151
- batch_size=16,
152
- validation_split=0.2,
153
- early_stopping_patience=3,
154
- use_tqdm=True
155
- ):
156
- train_losses = []
157
- val_losses = []
158
-
159
- total_samples = len(q_pad)
160
- idxs = np.arange(total_samples)
161
- np.random.shuffle(idxs)
162
- train_size = int((1 - validation_split) * total_samples)
163
-
164
- train_idxs = idxs[:train_size]
165
- val_idxs = idxs[train_size:]
166
-
167
- q_train, p_train, n_train = q_pad[train_idxs], p_pad[train_idxs], n_pad[train_idxs]
168
- q_val, p_val, n_val = q_pad[val_idxs], p_pad[val_idxs], n_pad[val_idxs]
169
-
170
- optimizer = tf.keras.optimizers.Adam(learning_rate=0.001)
171
- best_val_loss = float('inf')
172
- wait = 0
173
-
174
- for epoch in range(epochs):
175
- # Shuffle training data each epoch
176
- perm = np.random.permutation(len(q_train))
177
- q_train, p_train, n_train = q_train[perm], p_train[perm], n_train[perm]
178
-
179
- num_batches = len(q_train) // batch_size
180
- epoch_train_loss = 0.0
181
-
182
- batch_iter = range(num_batches)
183
- if use_tqdm:
184
- batch_iter = trange(num_batches, desc=f"Epoch {epoch+1}/{epochs}")
185
-
186
- for i in batch_iter:
187
- q_batch = q_train[i*batch_size:(i+1)*batch_size]
188
- p_batch = p_train[i*batch_size:(i+1)*batch_size]
189
- n_batch = n_train[i*batch_size:(i+1)*batch_size]
190
-
191
- with tf.GradientTape() as tape:
192
- q_emb = self.query_encoder_model(q_batch, training=True)
193
- p_emb = self.response_encoder_model(p_batch, training=True)
194
- n_emb = self.response_encoder_model(n_batch, training=True)
195
- loss = self.triplet_loss(q_emb, p_emb, n_emb)
196
-
197
- grads = tape.gradient(
198
- loss,
199
- self.query_encoder_model.trainable_variables +
200
- self.response_encoder_model.trainable_variables
201
- )
202
- optimizer.apply_gradients(zip(
203
- grads,
204
- self.query_encoder_model.trainable_variables +
205
- self.response_encoder_model.trainable_variables
206
- ))
207
- epoch_train_loss += loss.numpy()
208
-
209
- epoch_train_loss /= num_batches
210
-
211
- # Validation loss
212
- val_batches = len(q_val) // batch_size
213
- epoch_val_loss = 0.0
214
- for i in range(val_batches):
215
- q_batch = q_val[i*batch_size:(i+1)*batch_size]
216
- p_batch = p_val[i*batch_size:(i+1)*batch_size]
217
- n_batch = n_val[i*batch_size:(i+1)*batch_size]
218
-
219
- q_emb = self.query_encoder_model(q_batch, training=False)
220
- p_emb = self.response_encoder_model(p_batch, training=False)
221
- n_emb = self.response_encoder_model(n_batch, training=False)
222
- v_loss = self.triplet_loss(q_emb, p_emb, n_emb)
223
- epoch_val_loss += v_loss.numpy()
224
-
225
- if val_batches > 0:
226
- epoch_val_loss /= val_batches
227
-
228
- train_losses.append(epoch_train_loss)
229
- val_losses.append(epoch_val_loss)
230
-
231
- print(f"Epoch {epoch+1}/{epochs}, Train Loss: {epoch_train_loss:.4f}, Val Loss: {epoch_val_loss:.4f}")
232
-
233
- # Early Stopping logic
234
- if epoch_val_loss < best_val_loss:
235
- best_val_loss = epoch_val_loss
236
- wait = 0
237
- # (Optional) Save best weights
238
- else:
239
- wait += 1
240
- if wait >= early_stopping_patience:
241
- print("Early stopping triggered.")
242
- break
243
-
244
- return train_losses, val_losses
245
-
246
- def encode_texts(self, texts, is_query=True):
247
- seq = self.tokenizer.texts_to_sequences(texts)
248
- pad_seq = tf.keras.preprocessing.sequence.pad_sequences(seq, maxlen=self.max_sequence_length, padding='post')
249
- if is_query:
250
- return self.query_encoder_model(pad_seq, training=False)
251
- else:
252
- return self.response_encoder_model(pad_seq, training=False)
253
-
254
- def retrieve_top_n(self, query: str, candidates: list, top_n=5):
255
- q_emb = self.encode_texts([query], is_query=True) # shape (1, d)
256
- c_emb = self.encode_texts(candidates, is_query=False) # shape (num_cand, d)
257
- sim = tf.matmul(q_emb, c_emb, transpose_b=True).numpy()[0] # dot product similarity
258
- top_indices = np.argsort(sim)[::-1][:top_n]
259
- return [(candidates[i], sim[i]) for i in top_indices]
260
-
261
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
chatbot2.py DELETED
@@ -1,839 +0,0 @@
1
- import numpy as np
2
- import tensorflow as tf
3
- import spacy
4
- import random
5
- from typing import List, Tuple, Dict, Optional, Union
6
- from dataclasses import dataclass
7
- from tqdm import tqdm
8
- import logging
9
- from pathlib import Path
10
- import json
11
-
12
- # Configure logging
13
- logging.basicConfig(
14
- level=logging.INFO,
15
- format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
16
- )
17
- logger = logging.getLogger(__name__)
18
-
19
- @dataclass
20
- class ChatbotConfig:
21
- """Configuration for the retrieval chatbot."""
22
- vocab_size: int = 10000
23
- max_sequence_length: int = 512
24
- embedding_dim: int = 256
25
- encoder_units: int = 256
26
- num_attention_heads: int = 8
27
- dropout_rate: float = 0.2
28
- l2_reg_weight: float = 0.001
29
- margin: float = 0.3
30
- learning_rate: float = 0.001
31
- min_text_length: int = 3 # Reduced from 10 to allow shorter responses
32
- max_context_turns: int = 5
33
- warmup_steps: int = 200
34
- spacy_model: str = 'en_core_web_md'
35
-
36
- def to_dict(self) -> dict:
37
- """Convert config to dictionary."""
38
- return {k: str(v) if isinstance(v, Path) else v
39
- for k, v in self.__dict__.items()}
40
-
41
- @classmethod
42
- def from_dict(cls, config_dict: dict) -> 'ChatbotConfig':
43
- """Create config from dictionary."""
44
- return cls(**{k: v for k, v in config_dict.items()
45
- if k in cls.__dataclass_fields__})
46
-
47
- class TransformerBlock(tf.keras.layers.Layer):
48
- """Custom Transformer block with pre-layer normalization."""
49
- def __init__(
50
- self,
51
- embed_dim: int,
52
- num_heads: int,
53
- ff_dim: int,
54
- dropout: float = 0.1,
55
- **kwargs
56
- ):
57
- super().__init__(**kwargs)
58
- self.embed_dim = embed_dim
59
- self.num_heads = num_heads
60
- self.ff_dim = ff_dim
61
- self.dropout = dropout
62
-
63
- self.attention = tf.keras.layers.MultiHeadAttention(
64
- num_heads=num_heads,
65
- key_dim=embed_dim // num_heads,
66
- dropout=dropout
67
- )
68
- self.ffn = tf.keras.Sequential([
69
- tf.keras.layers.Dense(ff_dim, activation="gelu"),
70
- tf.keras.layers.Dense(embed_dim),
71
- ])
72
-
73
- self.layernorm1 = tf.keras.layers.LayerNormalization(epsilon=1e-6)
74
- self.layernorm2 = tf.keras.layers.LayerNormalization(epsilon=1e-6)
75
- self.dropout1 = tf.keras.layers.Dropout(dropout)
76
- self.dropout2 = tf.keras.layers.Dropout(dropout)
77
-
78
- def call(self, inputs: tf.Tensor, training: bool, mask: Optional[tf.Tensor] = None) -> tf.Tensor:
79
- # Pre-layer normalization
80
- norm_inputs = self.layernorm1(inputs)
81
-
82
- # Self-attention
83
- attention_output = self.attention(
84
- query=norm_inputs,
85
- value=norm_inputs,
86
- key=norm_inputs,
87
- attention_mask=mask,
88
- training=training
89
- )
90
- attention_output = self.dropout1(attention_output, training=training)
91
- attention_output = inputs + attention_output
92
-
93
- # Feed-forward network
94
- norm_attention = self.layernorm2(attention_output)
95
- ffn_output = self.ffn(norm_attention)
96
- ffn_output = self.dropout2(ffn_output, training=training)
97
-
98
- return attention_output + ffn_output
99
-
100
- def get_config(self) -> dict:
101
- config = super().get_config()
102
- config.update({
103
- "embed_dim": self.embed_dim,
104
- "num_heads": self.num_heads,
105
- "ff_dim": self.ff_dim,
106
- "dropout": self.dropout,
107
- })
108
- return config
109
-
110
- class EncoderModel(tf.keras.Model):
111
- """Dual encoder model with shared weights option."""
112
- def __init__(
113
- self,
114
- config: ChatbotConfig,
115
- name: str = "encoder",
116
- shared_weights: bool = False,
117
- **kwargs
118
- ):
119
- super().__init__(name=name, **kwargs)
120
- self.config = config
121
- self.shared_weights = shared_weights
122
-
123
- # Input embedding layer
124
- self.embedding = tf.keras.layers.Embedding(
125
- config.vocab_size,
126
- config.embedding_dim,
127
- mask_zero=True,
128
- name=f"{name}_embedding"
129
- )
130
-
131
- # Positional encoding
132
- self.pos_encoding = self._get_positional_encoding()
133
-
134
- # Transformer blocks
135
- self.transformer_blocks = [
136
- TransformerBlock(
137
- config.embedding_dim,
138
- config.num_attention_heads,
139
- config.encoder_units * 4,
140
- config.dropout_rate,
141
- name=f"{name}_transformer_{i}"
142
- ) for i in range(3)
143
- ]
144
-
145
- # Final LSTM layer
146
- self.final_lstm = tf.keras.layers.LSTM(
147
- config.encoder_units,
148
- kernel_regularizer=tf.keras.regularizers.l2(config.l2_reg_weight),
149
- name=f"{name}_final_lstm"
150
- )
151
-
152
- self.dropout = tf.keras.layers.Dropout(config.dropout_rate)
153
- self.normalize = tf.keras.layers.Lambda(
154
- lambda x: tf.nn.l2_normalize(x, axis=1)
155
- )
156
-
157
- def _get_positional_encoding(self) -> tf.Tensor:
158
- """Generate positional encoding matrix."""
159
- pos = np.arange(self.config.max_sequence_length)[:, np.newaxis]
160
- i = np.arange(self.config.embedding_dim)[np.newaxis, :]
161
- angle = pos / np.power(10000, (2 * (i // 2)) / self.config.embedding_dim)
162
-
163
- pos_encoding = np.zeros_like(angle)
164
- pos_encoding[:, 0::2] = np.sin(angle[:, 0::2])
165
- pos_encoding[:, 1::2] = np.cos(angle[:, 1::2])
166
-
167
- return tf.cast(pos_encoding[np.newaxis, ...], dtype=tf.float32)
168
-
169
- def call(self, inputs: tf.Tensor, training: bool = False) -> tf.Tensor:
170
- # Get input mask
171
- mask = self.embedding.compute_mask(inputs)
172
- mask = mask[:, tf.newaxis, tf.newaxis, :] # Add attention dims
173
-
174
- # Embedding + positional encoding
175
- x = self.embedding(inputs)
176
- x = x + self.pos_encoding
177
-
178
- # Apply transformer blocks
179
- for transformer_block in self.transformer_blocks:
180
- x = transformer_block(x, training=training, mask=mask)
181
-
182
- # Final processing
183
- x = self.final_lstm(x)
184
- x = self.dropout(x, training=training)
185
- return self.normalize(x)
186
-
187
- class RetrievalChatbot:
188
- """Professional implementation of a retrieval-based chatbot."""
189
- def __init__(self, config: ChatbotConfig):
190
- self.config = config
191
- self.nlp = spacy.load(config.spacy_model)
192
-
193
- # Initialize tokenizer
194
- self.tokenizer = tf.keras.preprocessing.text.Tokenizer(
195
- num_words=config.vocab_size,
196
- oov_token="<OOV>"
197
- )
198
-
199
- # Special tokens
200
- self.special_tokens = {
201
- "user": "<USER>",
202
- "assistant": "<ASSISTANT>",
203
- "context": "<CONTEXT>",
204
- "sep": "<SEP>"
205
- }
206
-
207
- # Build models
208
- self._build_models()
209
-
210
- # Training history
211
- self.history = {
212
- "train_loss": [],
213
- "val_loss": [],
214
- "train_metrics": {},
215
- "val_metrics": {}
216
- }
217
-
218
- # Initialize similarity cache
219
- self.similarity_cache = {}
220
-
221
- def _build_models(self):
222
- """Initialize the encoder models."""
223
- # Query encoder
224
- self.query_encoder = EncoderModel(
225
- self.config,
226
- name="query_encoder",
227
- shared_weights=False
228
- )
229
-
230
- # Response encoder (can share weights with query encoder)
231
- self.response_encoder = EncoderModel(
232
- self.config,
233
- name="response_encoder",
234
- shared_weights=False
235
- )
236
-
237
- def save_models(self, save_dir: Union[str, Path]):
238
- """Save models and configuration."""
239
- save_dir = Path(save_dir)
240
- save_dir.mkdir(parents=True, exist_ok=True)
241
-
242
- # Save config
243
- with open(save_dir / "config.json", "w") as f:
244
- json.dump(self.config.to_dict(), f, indent=2)
245
-
246
- # Save models with proper extension
247
- self.query_encoder.save(save_dir / "query_encoder.keras")
248
- self.response_encoder.save(save_dir / "response_encoder.keras")
249
-
250
- # Save tokenizer config
251
- tokenizer_config = {
252
- "word_index": self.tokenizer.word_index,
253
- "word_counts": self.tokenizer.word_counts,
254
- "document_count": self.tokenizer.document_count,
255
- "index_docs": self.tokenizer.index_docs,
256
- "index_word": self.tokenizer.index_word
257
- }
258
- with open(save_dir / "tokenizer_config.json", "w") as f:
259
- json.dump(tokenizer_config, f)
260
-
261
- @classmethod
262
- def load_models(cls, load_dir: Union[str, Path]) -> 'RetrievalChatbot':
263
- """Load saved models and configuration."""
264
- load_dir = Path(load_dir)
265
-
266
- # Load config
267
- with open(load_dir / "config.json", "r") as f:
268
- config = ChatbotConfig.from_dict(json.load(f))
269
-
270
- # Initialize chatbot
271
- chatbot = cls(config)
272
-
273
- # Load models with proper extension
274
- chatbot.query_encoder = tf.keras.models.load_model(
275
- load_dir / "query_encoder.keras",
276
- custom_objects={"TransformerBlock": TransformerBlock}
277
- )
278
- chatbot.response_encoder = tf.keras.models.load_model(
279
- load_dir / "response_encoder.keras",
280
- custom_objects={"TransformerBlock": TransformerBlock}
281
- )
282
-
283
- # Load tokenizer config
284
- with open(load_dir / "tokenizer_config.json", "r") as f:
285
- tokenizer_config = json.load(f)
286
-
287
- chatbot.tokenizer = tf.keras.preprocessing.text.Tokenizer(
288
- num_words=config.vocab_size,
289
- oov_token="<OOV>"
290
- )
291
- chatbot.tokenizer.word_index = tokenizer_config["word_index"]
292
- chatbot.tokenizer.word_counts = tokenizer_config["word_counts"]
293
- chatbot.tokenizer.document_count = tokenizer_config["document_count"]
294
- chatbot.tokenizer.index_docs = tokenizer_config["index_docs"]
295
- chatbot.tokenizer.index_word = tokenizer_config["index_word"]
296
-
297
- return chatbot
298
-
299
- def _improved_spacy_similarity(self, text1: str, text2: str) -> float:
300
- """Calculate semantic similarity between texts with preprocessing."""
301
- def preprocess(text: str) -> str:
302
- # Basic cleaning
303
- text = ' '.join(text.split())
304
- return text if text.strip() else "empty_document"
305
-
306
- # Get cache key
307
- cache_key = f"{hash(text1)}_{hash(text2)}"
308
- if cache_key in self.similarity_cache:
309
- return self.similarity_cache[cache_key]
310
-
311
- # Process texts
312
- text1, text2 = preprocess(text1), preprocess(text2)
313
- doc1, doc2 = self.nlp(text1), self.nlp(text2)
314
-
315
- # Calculate similarity
316
- if doc1.has_vector and doc2.has_vector:
317
- sim = doc1.similarity(doc2)
318
- else:
319
- # Fallback to token overlap similarity
320
- tokens1 = {t.lower_ for t in doc1 if not t.is_stop and not t.is_punct}
321
- tokens2 = {t.lower_ for t in doc2 if not t.is_stop and not t.is_punct}
322
- intersection = len(tokens1.intersection(tokens2))
323
- union = len(tokens1.union(tokens2))
324
- sim = intersection / union if union > 0 else 0.0
325
-
326
- # Cache result
327
- self.similarity_cache[cache_key] = sim
328
- return sim
329
-
330
- def _smart_negative_sampling(
331
- self,
332
- positive: str,
333
- response_pool: List[str],
334
- n_samples: int,
335
- max_attempts: int = 200,
336
- similarity_bounds: Tuple[float, float] = (0.3, 0.8),
337
- batch_size: int = 10
338
- ) -> List[str]:
339
- """Smart negative sampling with similarity bounds and batching."""
340
- candidates = []
341
- seen = set()
342
- attempts = 0
343
-
344
- while len(candidates) < n_samples and attempts < max_attempts:
345
- # Batch process candidates
346
- batch = random.sample(
347
- response_pool,
348
- min(batch_size, max_attempts - attempts)
349
- )
350
-
351
- for candidate in batch:
352
- if candidate != positive and candidate not in seen:
353
- seen.add(candidate)
354
- sim = self._improved_spacy_similarity(candidate, positive)
355
-
356
- # Check similarity bounds
357
- if similarity_bounds[0] < sim < similarity_bounds[1]:
358
- candidates.append(candidate)
359
- if len(candidates) == n_samples:
360
- break
361
-
362
- attempts += len(batch)
363
-
364
- return candidates
365
-
366
- def train(
367
- self,
368
- q_pad: tf.Tensor,
369
- p_pad: tf.Tensor,
370
- n_pad: tf.Tensor,
371
- epochs: int = 3,
372
- batch_size: int = 32,
373
- validation_split: float = 0.2,
374
- checkpoint_dir: Optional[Union[str, Path]] = None
375
- ):
376
- """Train the model with improved training loop."""
377
- # Setup training
378
- total_samples = len(q_pad)
379
- train_size = int((1 - validation_split) * total_samples)
380
-
381
- # Split data
382
- indices = np.random.permutation(total_samples)
383
- train_idx, val_idx = indices[:train_size], indices[train_size:]
384
-
385
- train_data = (q_pad[train_idx], p_pad[train_idx], n_pad[train_idx])
386
- val_data = (q_pad[val_idx], p_pad[val_idx], n_pad[val_idx])
387
-
388
- # Setup optimizer with learning rate schedule
389
- steps_per_epoch = train_size // batch_size
390
- total_steps = steps_per_epoch * epochs
391
-
392
- lr_schedule = self._get_lr_schedule(
393
- total_steps,
394
- self.config.learning_rate,
395
- self.config.warmup_steps
396
- )
397
-
398
- optimizer = tf.keras.optimizers.Adam(lr_schedule)
399
-
400
- # Setup checkpointing
401
- if checkpoint_dir:
402
- checkpoint_dir = Path(checkpoint_dir)
403
- checkpoint_dir.mkdir(parents=True, exist_ok=True)
404
-
405
- # Setup checkpoint callback with correct file format
406
- checkpoint_template = str(checkpoint_dir / "model_epoch_{epoch:04d}.weights.h5")
407
- checkpoint_callback = tf.keras.callbacks.ModelCheckpoint(
408
- checkpoint_template,
409
- save_weights_only=True,
410
- save_best_only=True,
411
- monitor='val_loss',
412
- mode='min',
413
- verbose=1
414
- )
415
-
416
- # Training loop
417
- best_val_loss = float('inf')
418
- patience = 5
419
- wait = 0
420
-
421
- for epoch in range(epochs):
422
- # Training
423
- train_loss = self._train_epoch(
424
- train_data,
425
- optimizer,
426
- batch_size,
427
- training=True
428
- )
429
-
430
- # Validation
431
- val_loss = self._train_epoch(
432
- val_data,
433
- optimizer,
434
- batch_size,
435
- training=False
436
- )
437
-
438
- # Update history
439
- self.history['train_loss'].append(train_loss)
440
- self.history['val_loss'].append(val_loss)
441
-
442
- logger.info(
443
- f"Epoch {epoch + 1}/{epochs} - "
444
- f"train_loss: {train_loss:.4f} - "
445
- f"val_loss: {val_loss:.4f}"
446
- )
447
-
448
- # Early stopping
449
- if val_loss < best_val_loss:
450
- best_val_loss = val_loss
451
- wait = 0
452
- if checkpoint_dir:
453
- self.save_models(checkpoint_dir / f"best_model")
454
- else:
455
- wait += 1
456
- if wait >= patience:
457
- logger.info("Early stopping triggered")
458
- break
459
-
460
- def _get_lr_schedule(
461
- self,
462
- total_steps: int,
463
- peak_lr: float,
464
- warmup_steps: int
465
- ) -> tf.keras.optimizers.schedules.LearningRateSchedule:
466
- """Enhanced learning rate schedule with better error handling and logging."""
467
- class CustomSchedule(tf.keras.optimizers.schedules.LearningRateSchedule):
468
- def __init__(
469
- self,
470
- total_steps: int,
471
- peak_lr: float,
472
- warmup_steps: int
473
- ):
474
- super().__init__()
475
- self.total_steps = tf.cast(total_steps, tf.float32)
476
- self.peak_lr = tf.cast(peak_lr, tf.float32)
477
- self.warmup_steps = tf.cast(max(1, warmup_steps), tf.float32) # Prevent 0
478
-
479
- # Calculate and store constants
480
- self.initial_lr = self.peak_lr * 0.1 # Start at 10% of peak
481
- self.min_lr = self.peak_lr * 0.01 # Minimum 1% of peak
482
-
483
- logger.info(f"Learning rate schedule initialized:")
484
- logger.info(f" Initial LR: {float(self.initial_lr):.6f}")
485
- logger.info(f" Peak LR: {float(self.peak_lr):.6f}")
486
- logger.info(f" Min LR: {float(self.min_lr):.6f}")
487
- logger.info(f" Warmup steps: {int(self.warmup_steps)}")
488
- logger.info(f" Total steps: {int(self.total_steps)}")
489
-
490
- def __call__(self, step):
491
- step = tf.cast(step, tf.float32)
492
-
493
- # Warmup phase
494
- warmup_factor = tf.minimum(1.0, step / self.warmup_steps)
495
- warmup_lr = self.initial_lr + (self.peak_lr - self.initial_lr) * warmup_factor
496
-
497
- # Decay phase
498
- decay_steps = tf.maximum(1.0, self.total_steps - self.warmup_steps)
499
- decay_factor = (step - self.warmup_steps) / decay_steps
500
- decay_factor = tf.minimum(tf.maximum(0.0, decay_factor), 1.0) # Clip to [0,1]
501
-
502
- cosine_decay = 0.5 * (1.0 + tf.cos(tf.constant(np.pi) * decay_factor))
503
- decay_lr = self.min_lr + (self.peak_lr - self.min_lr) * cosine_decay
504
-
505
- # Choose between warmup and decay
506
- final_lr = tf.where(step < self.warmup_steps, warmup_lr, decay_lr)
507
-
508
- # Ensure learning rate is valid
509
- final_lr = tf.maximum(self.min_lr, final_lr)
510
- final_lr = tf.where(tf.math.is_finite(final_lr), final_lr, self.min_lr)
511
-
512
- return final_lr
513
-
514
- def get_config(self):
515
- return {
516
- "total_steps": self.total_steps,
517
- "peak_lr": self.peak_lr,
518
- "warmup_steps": self.warmup_steps,
519
- }
520
-
521
- return CustomSchedule(total_steps, peak_lr, warmup_steps)
522
-
523
- @tf.function
524
- def _train_step(
525
- self,
526
- q_batch: tf.Tensor,
527
- p_batch: tf.Tensor,
528
- n_batch: tf.Tensor,
529
- optimizer: tf.keras.optimizers.Optimizer,
530
- training: bool = True
531
- ) -> tf.Tensor:
532
- """Single training step with triplet loss."""
533
- with tf.GradientTape() as tape:
534
- # Get embeddings
535
- q_emb = self.query_encoder(q_batch, training=training)
536
- p_emb = self.response_encoder(p_batch, training=training)
537
- n_emb = self.response_encoder(n_batch, training=training)
538
-
539
- # Calculate triplet loss
540
- pos_dist = tf.reduce_sum(tf.square(q_emb - p_emb), axis=1)
541
- neg_dist = tf.reduce_sum(tf.square(q_emb - n_emb), axis=1)
542
-
543
- loss = tf.maximum(0.0, self.config.margin + pos_dist - neg_dist)
544
- loss = tf.reduce_mean(loss)
545
-
546
- if training:
547
- # Apply gradients
548
- gradients = tape.gradient(
549
- loss,
550
- self.query_encoder.trainable_variables +
551
- self.response_encoder.trainable_variables
552
- )
553
- optimizer.apply_gradients(zip(
554
- gradients,
555
- self.query_encoder.trainable_variables +
556
- self.response_encoder.trainable_variables
557
- ))
558
-
559
- return loss
560
-
561
- def _train_epoch(
562
- self,
563
- data: Tuple[tf.Tensor, tf.Tensor, tf.Tensor],
564
- optimizer: tf.keras.optimizers.Optimizer,
565
- batch_size: int,
566
- training: bool = True
567
- ) -> float:
568
- """Train for one epoch with enhanced logging and progress tracking."""
569
- q_data, p_data, n_data = data
570
- total_loss = 0
571
- num_batches = len(q_data) // batch_size
572
-
573
- # Log current learning rate at start of epoch
574
- if training:
575
- if hasattr(optimizer.learning_rate, '__call__'):
576
- current_lr = optimizer.learning_rate(optimizer.iterations)
577
- else:
578
- current_lr = optimizer.learning_rate
579
- logger.info(f"Current learning rate: {float(current_lr):.6f}")
580
-
581
- # Shuffle data
582
- indices = np.random.permutation(len(q_data))
583
- q_data = q_data[indices]
584
- p_data = p_data[indices]
585
- n_data = n_data[indices]
586
-
587
- # Create progress bar
588
- mode = "Training" if training else "Validation"
589
- pbar = tqdm(
590
- total=num_batches,
591
- desc=f"{mode} batches",
592
- unit="batch",
593
- dynamic_ncols=True # Automatically adjust width
594
- )
595
-
596
- # Process batches
597
- for i in range(num_batches):
598
- start_idx = i * batch_size
599
- end_idx = start_idx + batch_size
600
-
601
- batch_loss = self._train_step(
602
- q_data[start_idx:end_idx],
603
- p_data[start_idx:end_idx],
604
- n_data[start_idx:end_idx],
605
- optimizer,
606
- training
607
- )
608
- total_loss += batch_loss
609
-
610
- # Update progress bar with current loss
611
- avg_loss = total_loss / (i + 1)
612
- pbar.set_postfix({
613
- 'loss': f'{avg_loss:.4f}',
614
- 'lr': f'{float(current_lr):.6f}' if training else 'N/A'
615
- })
616
- pbar.update(1)
617
-
618
- pbar.close()
619
- return total_loss / num_batches if num_batches > 0 else 0
620
-
621
- def _prepare_sequences(
622
- self,
623
- queries: List[str],
624
- positives: List[str],
625
- negatives: List[str]
626
- ) -> Tuple[tf.Tensor, tf.Tensor, tf.Tensor]:
627
- """Enhanced sequence preparation with logging and text preprocessing."""
628
- logger.info("Preparing sequences...")
629
-
630
- # Text cleaning function from old version
631
- def clean_text(text: str) -> str:
632
- # Remove excessive whitespace
633
- text = ' '.join(text.split())
634
- # Remove very long repetitive sequences
635
- if len(text) > 500: # Add length limit
636
- text = ' '.join(dict.fromkeys(text.split()))
637
- return text
638
-
639
- # Process texts with special tokens and cleaning
640
- queries = [f"{self.special_tokens['user']} {clean_text(q)}" for q in queries]
641
- positives = [f"{self.special_tokens['assistant']} {clean_text(p)}" for p in positives]
642
- negatives = [f"{self.special_tokens['assistant']} {clean_text(n)}" for n in negatives]
643
-
644
- # Fit tokenizer and log vocabulary statistics
645
- all_texts = queries + positives + negatives
646
- self.tokenizer.fit_on_texts(all_texts)
647
-
648
- # Log vocabulary statistics
649
- vocab_size = len(self.tokenizer.word_index)
650
- logger.info(f"Vocabulary statistics:")
651
- logger.info(f" Total unique tokens: {vocab_size}")
652
- logger.info(f" Vocab limit: {self.config.vocab_size}")
653
-
654
- # Log most common tokens
655
- word_freq = sorted(
656
- self.tokenizer.word_counts.items(),
657
- key=lambda x: x[1],
658
- reverse=True
659
- )[:10]
660
- logger.info("Most common tokens:")
661
- for word, freq in word_freq:
662
- logger.info(f" {word}: {freq}")
663
-
664
- # Padding function from old version
665
- def pad_sequences(texts: List[str]) -> tf.Tensor:
666
- sequences = self.tokenizer.texts_to_sequences(texts)
667
- return tf.keras.preprocessing.sequence.pad_sequences(
668
- sequences,
669
- maxlen=self.config.max_sequence_length,
670
- padding='post',
671
- truncating='post'
672
- )
673
-
674
- # Return padded sequences
675
- return (
676
- pad_sequences(queries),
677
- pad_sequences(positives),
678
- pad_sequences(negatives)
679
- )
680
-
681
- def prepare_dataset(
682
- self,
683
- dialogues: List[dict],
684
- neg_samples_per_pos: int = 3,
685
- debug_samples: Optional[int] = None
686
- ) -> Tuple[tf.Tensor, tf.Tensor, tf.Tensor]:
687
- """Prepare dataset with enhanced logging and statistics."""
688
- logger.info("Preparing dataset...")
689
-
690
- # Log dataset statistics
691
- total_dialogues = len(dialogues)
692
- total_turns = sum(len(d['turns']) for d in dialogues)
693
- avg_turns = total_turns / total_dialogues
694
-
695
- logger.info(f"Dataset statistics:")
696
- logger.info(f" Total dialogues: {total_dialogues}")
697
- logger.info(f" Total turns: {total_turns}")
698
- logger.info(f" Average turns per dialogue: {avg_turns:.2f}")
699
-
700
- # Extract and filter responses with logging
701
- response_pool = []
702
- skipped_short = 0
703
- skipped_long = 0
704
-
705
- for d in dialogues:
706
- for turn in d['turns']:
707
- if turn['speaker'] == 'assistant':
708
- text = turn['text'].strip()
709
- length = len(text.split())
710
- if length < self.config.min_text_length:
711
- skipped_short += 1
712
- continue
713
- if length > self.config.max_sequence_length:
714
- skipped_long += 1
715
- continue
716
- response_pool.append(text)
717
-
718
- logger.info(f"Response pool statistics:")
719
- logger.info(f" Total responses: {len(response_pool)}")
720
- logger.info(f" Skipped (too short): {skipped_short}")
721
- logger.info(f" Skipped (too long): {skipped_long}")
722
-
723
- # Process dialogues and create training examples
724
- queries, positives, negatives = [], [], []
725
-
726
- for dialogue in tqdm(dialogues, desc="Processing dialogues"):
727
- turns = dialogue['turns']
728
- for i in range(len(turns) - 1):
729
- if turns[i]['speaker'] == 'user' and turns[i+1]['speaker'] == 'assistant':
730
- query = turns[i]['text'].strip()
731
- positive = turns[i+1]['text'].strip()
732
-
733
- # Skip short texts
734
- if (len(query.split()) < self.config.min_text_length or
735
- len(positive.split()) < self.config.min_text_length): # Fixed
736
- continue
737
-
738
- # Get negative samples
739
- neg_samples = self._smart_negative_sampling(
740
- positive,
741
- response_pool,
742
- neg_samples_per_pos
743
- )
744
-
745
- if len(neg_samples) == neg_samples_per_pos:
746
- for neg in neg_samples:
747
- queries.append(query)
748
- positives.append(positive)
749
- negatives.append(neg)
750
-
751
- # Log final dataset statistics
752
- logger.info(f"Final dataset statistics:")
753
- logger.info(f" Training examples: {len(queries)}")
754
- logger.info(f" Unique queries: {len(set(queries))}")
755
- logger.info(f" Unique responses: {len(set(positives))}")
756
-
757
- return self._prepare_sequences(queries, positives, negatives)
758
-
759
- def encode_query(self, query: str, context: Optional[List[Tuple[str, str]]] = None) -> tf.Tensor:
760
- """Encode a query with optional conversation context."""
761
- # Prepare query with context
762
- if context:
763
- context_str = ' '.join([
764
- f"{self.special_tokens['user']} {q} "
765
- f"{self.special_tokens['assistant']} {r}"
766
- for q, r in context[-self.config.max_context_turns:]
767
- ])
768
- query = f"{context_str} {self.special_tokens['user']} {query}"
769
- else:
770
- query = f"{self.special_tokens['user']} {query}"
771
-
772
- # Tokenize and pad
773
- seq = self.tokenizer.texts_to_sequences([query])
774
- padded_seq = tf.keras.preprocessing.sequence.pad_sequences(
775
- seq,
776
- maxlen=self.config.max_sequence_length,
777
- padding='post',
778
- truncating='post'
779
- )
780
-
781
- return self.query_encoder(padded_seq, training=False)
782
-
783
- def encode_responses(self, responses: List[str]) -> tf.Tensor:
784
- """Encode a batch of responses."""
785
- # Prepare responses
786
- responses = [
787
- f"{self.special_tokens['assistant']} {r}"
788
- for r in responses
789
- ]
790
-
791
- # Tokenize and pad
792
- sequences = self.tokenizer.texts_to_sequences(responses)
793
- padded_sequences = tf.keras.preprocessing.sequence.pad_sequences(
794
- sequences,
795
- maxlen=self.config.max_sequence_length,
796
- padding='post',
797
- truncating='post'
798
- )
799
-
800
- return self.response_encoder(padded_sequences, training=False)
801
-
802
- def retrieve_responses(
803
- self,
804
- query: str,
805
- candidates: List[str],
806
- context: Optional[List[Tuple[str, str]]] = None,
807
- top_k: int = 5
808
- ) -> List[Tuple[str, float]]:
809
- """Retrieve top-k responses for a query."""
810
- # Encode query and candidates
811
- q_emb = self.encode_query(query, context)
812
- c_emb = self.encode_responses(candidates)
813
-
814
- # Calculate similarities
815
- similarities = tf.matmul(q_emb, c_emb, transpose_b=True).numpy()[0]
816
-
817
- # Get top-k responses
818
- top_indices = np.argsort(similarities)[::-1][:top_k]
819
-
820
- return [(candidates[i], similarities[i]) for i in top_indices]
821
-
822
- def chat(
823
- self,
824
- query: str,
825
- response_pool: List[str],
826
- conversation_history: Optional[List[Tuple[str, str]]] = None,
827
- top_k: int = 5
828
- ) -> Tuple[str, List[Tuple[str, float]]]:
829
- """Interactive chat with response selection."""
830
- # Get responses with scores
831
- responses = self.retrieve_responses(
832
- query,
833
- response_pool,
834
- conversation_history,
835
- top_k
836
- )
837
-
838
- # Return best response and all candidates with scores
839
- return responses[0][0], responses
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
chatbot3.py DELETED
@@ -1,824 +0,0 @@
1
- from transformers import TFAutoModel, AutoTokenizer
2
- import tensorflow as tf
3
- import numpy as np
4
- from typing import List, Tuple, Dict, Optional, Union
5
- from dataclasses import dataclass
6
- import logging
7
- import spacy
8
- import random
9
- import json
10
- from tqdm import tqdm
11
- from pathlib import Path
12
-
13
- # Configure logging
14
- logging.basicConfig(
15
- level=logging.INFO,
16
- format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
17
- )
18
- logger = logging.getLogger(__name__)
19
-
20
- @dataclass
21
- class ChatbotConfig:
22
- """Enhanced configuration with pretrained model settings."""
23
- vocab_size: int = 10000
24
- max_sequence_length: int = 512
25
- embedding_dim: int = 768 # Match DistilBERT's dimension
26
- encoder_units: int = 256
27
- num_attention_heads: int = 8
28
- dropout_rate: float = 0.2
29
- l2_reg_weight: float = 0.001
30
- margin: float = 0.3
31
- learning_rate: float = 0.001
32
- min_text_length: int = 3
33
- max_context_turns: int = 5
34
- warmup_steps: int = 200
35
- pretrained_model: str = 'distilbert-base-uncased'
36
- freeze_embeddings: bool = True
37
- spacy_model: str = 'en_core_web_md'
38
-
39
- def to_dict(self) -> dict:
40
- """Convert config to dictionary."""
41
- return {k: str(v) if isinstance(v, Path) else v
42
- for k, v in self.__dict__.items()}
43
-
44
- @classmethod
45
- def from_dict(cls, config_dict: dict) -> 'ChatbotConfig':
46
- """Create config from dictionary."""
47
- return cls(**{k: v for k, v in config_dict.items()
48
- if k in cls.__dataclass_fields__})
49
-
50
- class TransformerBlock(tf.keras.layers.Layer):
51
- """Custom Transformer block with pre-layer normalization."""
52
- def __init__(
53
- self,
54
- embed_dim: int,
55
- num_heads: int,
56
- ff_dim: int,
57
- dropout: float = 0.1,
58
- **kwargs
59
- ):
60
- super().__init__(**kwargs)
61
- self.embed_dim = embed_dim
62
- self.num_heads = num_heads
63
- self.ff_dim = ff_dim
64
- self.dropout = dropout
65
-
66
- self.attention = tf.keras.layers.MultiHeadAttention(
67
- num_heads=num_heads,
68
- key_dim=embed_dim // num_heads,
69
- dropout=dropout
70
- )
71
- self.ffn = tf.keras.Sequential([
72
- tf.keras.layers.Dense(ff_dim, activation="gelu"),
73
- tf.keras.layers.Dense(embed_dim),
74
- ])
75
-
76
- self.layernorm1 = tf.keras.layers.LayerNormalization(epsilon=1e-6)
77
- self.layernorm2 = tf.keras.layers.LayerNormalization(epsilon=1e-6)
78
- self.dropout1 = tf.keras.layers.Dropout(dropout)
79
- self.dropout2 = tf.keras.layers.Dropout(dropout)
80
-
81
- def call(self, inputs: tf.Tensor, training: bool, mask: Optional[tf.Tensor] = None) -> tf.Tensor:
82
- # Pre-layer normalization
83
- norm_inputs = self.layernorm1(inputs)
84
-
85
- # Self-attention
86
- attention_output = self.attention(
87
- query=norm_inputs,
88
- value=norm_inputs,
89
- key=norm_inputs,
90
- attention_mask=mask,
91
- training=training
92
- )
93
- attention_output = self.dropout1(attention_output, training=training)
94
- attention_output = inputs + attention_output
95
-
96
- # Feed-forward network
97
- norm_attention = self.layernorm2(attention_output)
98
- ffn_output = self.ffn(norm_attention)
99
- ffn_output = self.dropout2(ffn_output, training=training)
100
-
101
- return attention_output + ffn_output
102
-
103
- def get_config(self) -> dict:
104
- config = super().get_config()
105
- config.update({
106
- "embed_dim": self.embed_dim,
107
- "num_heads": self.num_heads,
108
- "ff_dim": self.ff_dim,
109
- "dropout": self.dropout,
110
- })
111
- return config
112
-
113
- class EncoderModel(tf.keras.Model):
114
- """Dual encoder model with pretrained embeddings."""
115
- def __init__(
116
- self,
117
- config: ChatbotConfig,
118
- name: str = "encoder",
119
- shared_weights: bool = False,
120
- **kwargs
121
- ):
122
- super().__init__(name=name, **kwargs)
123
- self.config = config
124
- self.shared_weights = shared_weights
125
-
126
- # Load pretrained model and tokenizer
127
- self.pretrained = TFAutoModel.from_pretrained(config.pretrained_model)
128
-
129
- # Freeze pretrained weights if specified
130
- if config.freeze_embeddings:
131
- self.pretrained.trainable = False
132
-
133
- # Transformer blocks for additional processing
134
- self.transformer_blocks = [
135
- TransformerBlock(
136
- config.embedding_dim,
137
- config.num_attention_heads,
138
- config.encoder_units * 4,
139
- config.dropout_rate,
140
- name=f"{name}_transformer_{i}"
141
- ) for i in range(2) # Reduced number of blocks since we're using pretrained
142
- ]
143
-
144
- # Final LSTM layer
145
- self.final_lstm = tf.keras.layers.LSTM(
146
- config.encoder_units,
147
- kernel_regularizer=tf.keras.regularizers.l2(config.l2_reg_weight),
148
- name=f"{name}_final_lstm"
149
- )
150
-
151
- self.dropout = tf.keras.layers.Dropout(config.dropout_rate)
152
- self.normalize = tf.keras.layers.Lambda(
153
- lambda x: tf.nn.l2_normalize(x, axis=1)
154
- )
155
-
156
- def call(self, inputs: tf.Tensor, training: bool = False) -> tf.Tensor:
157
- # Get pretrained embeddings
158
- pretrained_outputs = self.pretrained(inputs, training=training)
159
- x = pretrained_outputs.last_hidden_state
160
-
161
- # Get attention mask from input
162
- attention_mask = tf.cast(tf.not_equal(inputs, 0), tf.float32)
163
- attention_mask = attention_mask[:, tf.newaxis, tf.newaxis, :]
164
-
165
- # Apply transformer blocks
166
- for transformer_block in self.transformer_blocks:
167
- x = transformer_block(x, training=training, mask=attention_mask)
168
-
169
- # Final processing
170
- x = self.final_lstm(x)
171
- x = self.dropout(x, training=training)
172
- return self.normalize(x)
173
-
174
- class RetrievalChatbot:
175
- """Modified chatbot using pretrained embeddings with full functionality."""
176
- def __init__(self, config: ChatbotConfig):
177
- self.config = config
178
- self.nlp = spacy.load(config.spacy_model)
179
-
180
- # Use HuggingFace tokenizer instead of Keras
181
- self.tokenizer = AutoTokenizer.from_pretrained(config.pretrained_model)
182
-
183
- # Special tokens
184
- self.special_tokens = {
185
- "user": "<USER>",
186
- "assistant": "<ASSISTANT>",
187
- "context": "<CONTEXT>",
188
- "sep": "<SEP>"
189
- }
190
-
191
- # Add special tokens to tokenizer
192
- self.tokenizer.add_special_tokens(
193
- {'additional_special_tokens': list(self.special_tokens.values())}
194
- )
195
-
196
- # Build models
197
- self._build_models()
198
-
199
- # Initialize training tracking
200
- self.history = {
201
- "train_loss": [],
202
- "val_loss": [],
203
- "train_metrics": {},
204
- "val_metrics": {}
205
- }
206
-
207
- self.similarity_cache = {}
208
-
209
- def _build_models(self):
210
- """Initialize the encoder models."""
211
- # Query encoder
212
- self.query_encoder = EncoderModel(
213
- self.config,
214
- name="query_encoder",
215
- shared_weights=False
216
- )
217
-
218
- # Response encoder (can share weights with query encoder)
219
- self.response_encoder = EncoderModel(
220
- self.config,
221
- name="response_encoder",
222
- shared_weights=False
223
- )
224
-
225
- # Resize token embeddings to match the tokenizer's vocab size
226
- new_vocab_size = len(self.tokenizer)
227
- self.query_encoder.pretrained.resize_token_embeddings(new_vocab_size)
228
- self.response_encoder.pretrained.resize_token_embeddings(new_vocab_size)
229
-
230
- def save_models(self, save_dir: Union[str, Path]):
231
- """Save models and configuration."""
232
- save_dir = Path(save_dir)
233
- save_dir.mkdir(parents=True, exist_ok=True)
234
-
235
- # Save config
236
- with open(save_dir / "config.json", "w") as f:
237
- json.dump(self.config.to_dict(), f, indent=2)
238
-
239
- # Save models
240
- self.query_encoder.pretrained.save_pretrained(save_dir / "query_encoder")
241
- self.response_encoder.pretrained.save_pretrained(save_dir / "response_encoder")
242
-
243
- # Save tokenizer
244
- self.tokenizer.save_pretrained(save_dir / "tokenizer")
245
-
246
- @classmethod
247
- def load_models(cls, load_dir: Union[str, Path]) -> 'RetrievalChatbot':
248
- """Load saved models and configuration."""
249
- load_dir = Path(load_dir)
250
-
251
- # Load config
252
- with open(load_dir / "config.json", "r") as f:
253
- config = ChatbotConfig.from_dict(json.load(f))
254
-
255
- # Initialize chatbot
256
- chatbot = cls(config)
257
-
258
- # Load models
259
- chatbot.query_encoder.pretrained = TFAutoModel.from_pretrained(
260
- load_dir / "query_encoder",
261
- config=config
262
- )
263
- chatbot.response_encoder.pretrained = TFAutoModel.from_pretrained(
264
- load_dir / "response_encoder",
265
- config=config
266
- )
267
-
268
- # Load tokenizer
269
- chatbot.tokenizer = AutoTokenizer.from_pretrained(load_dir / "tokenizer")
270
-
271
- return chatbot
272
-
273
- def _improved_spacy_similarity(self, text1: str, text2: str) -> float:
274
- """Calculate semantic similarity between texts with preprocessing."""
275
- def preprocess(text: str) -> str:
276
- # Basic cleaning
277
- text = ' '.join(text.split())
278
- return text if text.strip() else "empty_document"
279
-
280
- # Get cache key
281
- cache_key = f"{hash(text1)}_{hash(text2)}"
282
- if cache_key in self.similarity_cache:
283
- return self.similarity_cache[cache_key]
284
-
285
- # Process texts
286
- text1, text2 = preprocess(text1), preprocess(text2)
287
- doc1, doc2 = self.nlp(text1), self.nlp(text2)
288
-
289
- # Calculate similarity
290
- if doc1.has_vector and doc2.has_vector:
291
- sim = doc1.similarity(doc2)
292
- else:
293
- # Fallback to token overlap similarity
294
- tokens1 = {t.lower_ for t in doc1 if not t.is_stop and not t.is_punct}
295
- tokens2 = {t.lower_ for t in doc2 if not t.is_stop and not t.is_punct}
296
- intersection = len(tokens1.intersection(tokens2))
297
- union = len(tokens1.union(tokens2))
298
- sim = intersection / union if union > 0 else 0.0
299
-
300
- # Cache result
301
- self.similarity_cache[cache_key] = sim
302
- return sim
303
-
304
- def prepare_dataset(
305
- self,
306
- dialogues: List[dict],
307
- neg_samples_per_pos: int = 3,
308
- debug_samples: Optional[int] = None
309
- ) -> Tuple[tf.Tensor, tf.Tensor, tf.Tensor]:
310
- """Prepare dataset with enhanced logging and statistics."""
311
- logger.info("Preparing dataset...")
312
-
313
- # Apply debug_samples limit if specified
314
- if debug_samples is not None:
315
- dialogues = dialogues[:debug_samples]
316
- logger.info(f"Debug mode: Limited to {debug_samples} dialogues")
317
-
318
- # Log dataset statistics
319
- total_dialogues = len(dialogues)
320
- total_turns = sum(len(d['turns']) for d in dialogues)
321
- avg_turns = total_turns / total_dialogues if total_dialogues > 0 else 0
322
-
323
- logger.info(f"Dataset statistics:")
324
- logger.info(f" Total dialogues: {total_dialogues}")
325
- logger.info(f" Total turns: {total_turns}")
326
- logger.info(f" Average turns per dialogue: {avg_turns:.2f}")
327
-
328
- # Extract and filter responses with logging
329
- response_pool = []
330
- skipped_short = 0
331
- skipped_long = 0
332
-
333
- for d in dialogues:
334
- for turn in d['turns']:
335
- if turn.get('speaker') == 'assistant' and 'text' in turn:
336
- text = turn['text'].strip()
337
- length = len(text.split())
338
- if length < self.config.min_text_length:
339
- skipped_short += 1
340
- continue
341
- if length > self.config.max_sequence_length:
342
- skipped_long += 1
343
- continue
344
- response_pool.append(text)
345
-
346
- logger.info(f"Response pool statistics:")
347
- logger.info(f" Total responses: {len(response_pool)}")
348
- logger.info(f" Skipped (too short): {skipped_short}")
349
- logger.info(f" Skipped (too long): {skipped_long}")
350
-
351
- # Process dialogues and create training examples
352
- queries, positives, negatives = [], [], []
353
-
354
- for dialogue in tqdm(dialogues, desc="Processing dialogues"):
355
- turns = dialogue.get('turns', [])
356
- for i in range(len(turns) - 1):
357
- current_turn = turns[i]
358
- next_turn = turns[i+1]
359
-
360
- if (current_turn.get('speaker') == 'user' and
361
- next_turn.get('speaker') == 'assistant' and
362
- 'text' in current_turn and
363
- 'text' in next_turn):
364
-
365
- query = current_turn['text'].strip()
366
- positive = next_turn['text'].strip()
367
-
368
- # Skip short texts
369
- if (len(query.split()) < self.config.min_text_length or
370
- len(positive.split()) < self.config.min_text_length):
371
- continue
372
-
373
- # Get negative samples
374
- neg_samples = self._smart_negative_sampling(
375
- positive,
376
- response_pool,
377
- neg_samples_per_pos
378
- )
379
-
380
- if len(neg_samples) == neg_samples_per_pos:
381
- for neg in neg_samples:
382
- queries.append(query)
383
- positives.append(positive)
384
- negatives.append(neg)
385
- else:
386
- logger.warning(f"Insufficient negative samples for positive response: '{positive}'")
387
-
388
- # Log final dataset statistics
389
- logger.info(f"Final dataset statistics:")
390
- logger.info(f" Training examples: {len(queries)}")
391
- logger.info(f" Unique queries: {len(set(queries))}")
392
- logger.info(f" Unique responses: {len(set(positives))}")
393
-
394
- return self._prepare_sequences(queries, positives, negatives)
395
-
396
- def _smart_negative_sampling(
397
- self,
398
- positive: str,
399
- response_pool: List[str],
400
- n_samples: int,
401
- max_attempts: int = 200,
402
- similarity_bounds: Tuple[float, float] = (0.2, 0.9),
403
- batch_size: int = 10
404
- ) -> List[str]:
405
- """Smart negative sampling with similarity bounds and fallback strategies."""
406
- candidates = []
407
- seen = set()
408
- attempts = 0
409
-
410
- while len(candidates) < n_samples and attempts < max_attempts:
411
- remaining = min(batch_size, len(response_pool) - len(seen), max_attempts - attempts)
412
- if remaining <= 0:
413
- break
414
- batch = random.sample(
415
- [r for r in response_pool if r not in seen and r != positive],
416
- remaining
417
- )
418
-
419
- for candidate in batch:
420
- seen.add(candidate)
421
- sim = self._improved_spacy_similarity(candidate, positive)
422
-
423
- if similarity_bounds[0] < sim < similarity_bounds[1]:
424
- candidates.append(candidate)
425
- if len(candidates) == n_samples:
426
- break
427
-
428
- attempts += len(batch)
429
-
430
- if len(candidates) < n_samples:
431
- logger.warning(f"Only found {len(candidates)} negative samples for positive response: '{positive}'")
432
- # Fallback to random negatives without similarity constraints
433
- fallback_needed = n_samples - len(candidates)
434
- available_negatives = [r for r in response_pool if r != positive and r not in seen]
435
- if available_negatives:
436
- additional_negatives = random.sample(
437
- available_negatives,
438
- min(fallback_needed, len(available_negatives))
439
- )
440
- candidates.extend(additional_negatives)
441
-
442
- return candidates
443
-
444
- def _prepare_sequences(
445
- self,
446
- queries: List[str],
447
- positives: List[str],
448
- negatives: List[str]
449
- ) -> Tuple[tf.Tensor, tf.Tensor, tf.Tensor]:
450
- """Modified sequence preparation for pretrained tokenizer."""
451
- logger.info("Preparing sequences...")
452
-
453
- # Process texts with special tokens
454
- queries = [f"{self.special_tokens['user']} {q}" for q in queries]
455
- positives = [f"{self.special_tokens['assistant']} {p}" for p in positives]
456
- negatives = [f"{self.special_tokens['assistant']} {n}" for n in negatives]
457
-
458
- # Tokenize using HuggingFace tokenizer
459
- def encode_batch(texts: List[str]) -> tf.Tensor:
460
- # HuggingFace tokenizer returns TensorFlow tensors when return_tensors='tf'
461
- encodings = self.tokenizer(
462
- texts,
463
- padding='max_length',
464
- truncation=True,
465
- max_length=self.config.max_sequence_length,
466
- return_tensors='tf'
467
- )
468
- return encodings['input_ids']
469
-
470
- # Encode all sequences
471
- q_tensor = encode_batch(queries)
472
- p_tensor = encode_batch(positives)
473
- n_tensor = encode_batch(negatives)
474
-
475
- # Log statistics about encoded sequences
476
- logger.info("Sequence statistics:")
477
- logger.info(f" Query sequence shape: {q_tensor.shape}")
478
- logger.info(f" Positive response sequence shape: {p_tensor.shape}")
479
- logger.info(f" Negative response sequence shape: {n_tensor.shape}")
480
-
481
- return q_tensor, p_tensor, n_tensor
482
-
483
- def train(
484
- self,
485
- q_pad: tf.Tensor,
486
- p_pad: tf.Tensor,
487
- n_pad: tf.Tensor,
488
- epochs: int = 3,
489
- batch_size: int = 32,
490
- validation_split: float = 0.2,
491
- checkpoint_dir: Optional[Union[str, Path]] = None
492
- ):
493
- """Train the model with improved training loop."""
494
- # Setup training
495
- total_samples = tf.shape(q_pad)[0]
496
- train_size = int((1 - validation_split) * total_samples.numpy())
497
-
498
- # Shuffle and split data
499
- indices = tf.random.shuffle(tf.range(start=0, limit=total_samples, dtype=tf.int32))
500
- train_idx = indices[:train_size]
501
- val_idx = indices[train_size:]
502
-
503
- # Split data using TF indexing
504
- train_data = (
505
- tf.gather(q_pad, train_idx),
506
- tf.gather(p_pad, train_idx),
507
- tf.gather(n_pad, train_idx)
508
- )
509
- val_data = (
510
- tf.gather(q_pad, val_idx),
511
- tf.gather(p_pad, val_idx),
512
- tf.gather(n_pad, val_idx)
513
- )
514
-
515
- # Setup optimizer with learning rate schedule
516
- steps_per_epoch = train_size // batch_size
517
- total_steps = steps_per_epoch * epochs
518
-
519
- lr_schedule = self._get_lr_schedule(
520
- total_steps,
521
- self.config.learning_rate,
522
- self.config.warmup_steps
523
- )
524
-
525
- optimizer = tf.keras.optimizers.Adam(learning_rate=lr_schedule)
526
-
527
- # Setup checkpointing
528
- if checkpoint_dir:
529
- checkpoint_dir = Path(checkpoint_dir)
530
- checkpoint_dir.mkdir(parents=True, exist_ok=True)
531
-
532
- # Setup checkpoint callback with correct file format
533
- checkpoint_template = str(checkpoint_dir / "model_epoch_{epoch:04d}.weights.h5")
534
- checkpoint_callback = tf.keras.callbacks.ModelCheckpoint(
535
- checkpoint_template,
536
- save_weights_only=True,
537
- save_best_only=True,
538
- monitor='val_loss',
539
- mode='min',
540
- verbose=1
541
- )
542
-
543
- # Training loop
544
- best_val_loss = float('inf')
545
- patience = 5
546
- wait = 0
547
-
548
- for epoch in range(epochs):
549
- # Training
550
- train_loss = self._train_epoch(
551
- train_data,
552
- optimizer,
553
- batch_size,
554
- training=True
555
- )
556
-
557
- # Validation
558
- val_loss = self._train_epoch(
559
- val_data,
560
- optimizer,
561
- batch_size,
562
- training=False
563
- )
564
-
565
- # Update history
566
- self.history['train_loss'].append(train_loss)
567
- self.history['val_loss'].append(val_loss)
568
-
569
- logger.info(
570
- f"Epoch {epoch + 1}/{epochs} - "
571
- f"train_loss: {train_loss:.4f} - "
572
- f"val_loss: {val_loss:.4f}"
573
- )
574
-
575
- # Early stopping
576
- if val_loss < best_val_loss:
577
- best_val_loss = val_loss
578
- wait = 0
579
- if checkpoint_dir:
580
- self.save_models(checkpoint_dir / f"best_model")
581
- else:
582
- wait += 1
583
- if wait >= patience:
584
- logger.info("Early stopping triggered")
585
- break
586
-
587
- def _train_epoch(
588
- self,
589
- data: Tuple[tf.Tensor, tf.Tensor, tf.Tensor],
590
- optimizer: tf.keras.optimizers.Optimizer,
591
- batch_size: int,
592
- training: bool = True
593
- ) -> float:
594
- """Train for one epoch with enhanced logging and progress tracking."""
595
- q_data, p_data, n_data = data
596
- total_loss = 0.0
597
- num_batches = tf.shape(q_data)[0] // batch_size
598
-
599
- # Log current learning rate at start of epoch
600
- if training:
601
- if hasattr(optimizer.learning_rate, '__call__'):
602
- current_lr = optimizer.learning_rate(optimizer.iterations)
603
- else:
604
- current_lr = optimizer.learning_rate
605
- logger.info(f"Current learning rate: {float(current_lr):.6f}")
606
-
607
- # Create progress bar
608
- mode = "Training" if training else "Validation"
609
- pbar = tqdm(
610
- total=num_batches.numpy(),
611
- desc=f"{mode} batches",
612
- unit="batch",
613
- dynamic_ncols=True
614
- )
615
-
616
- # Process batches
617
- for i in range(num_batches):
618
- start_idx = i * batch_size
619
- end_idx = start_idx + batch_size
620
-
621
- batch_loss = self._train_step(
622
- q_data[start_idx:end_idx],
623
- p_data[start_idx:end_idx],
624
- n_data[start_idx:end_idx],
625
- optimizer,
626
- training
627
- )
628
- total_loss += batch_loss.numpy()
629
-
630
- # Update progress bar with current loss
631
- avg_loss = total_loss / (i + 1)
632
- pbar.set_postfix({
633
- 'loss': f'{avg_loss:.4f}',
634
- 'lr': f'{float(current_lr):.6f}' if training else 'N/A'
635
- })
636
- pbar.update(1)
637
-
638
- pbar.close()
639
- return total_loss / num_batches.numpy() if num_batches > 0 else 0.0
640
-
641
- @tf.function
642
- def _train_step(
643
- self,
644
- q_batch: tf.Tensor,
645
- p_batch: tf.Tensor,
646
- n_batch: tf.Tensor,
647
- optimizer: tf.keras.optimizers.Optimizer,
648
- training: bool = True
649
- ) -> tf.Tensor:
650
- """Single training step with triplet loss."""
651
- with tf.GradientTape() as tape:
652
- # Get embeddings
653
- q_emb = self.query_encoder(q_batch, training=training)
654
- p_emb = self.response_encoder(p_batch, training=training)
655
- n_emb = self.response_encoder(n_batch, training=training)
656
-
657
- # Calculate triplet loss
658
- pos_dist = tf.reduce_sum(tf.square(q_emb - p_emb), axis=1)
659
- neg_dist = tf.reduce_sum(tf.square(q_emb - n_emb), axis=1)
660
-
661
- loss = tf.maximum(0.0, self.config.margin + pos_dist - neg_dist)
662
- loss = tf.reduce_mean(loss)
663
-
664
- if training:
665
- # Apply gradients
666
- gradients = tape.gradient(
667
- loss,
668
- self.query_encoder.trainable_variables +
669
- self.response_encoder.trainable_variables
670
- )
671
- optimizer.apply_gradients(zip(
672
- gradients,
673
- self.query_encoder.trainable_variables +
674
- self.response_encoder.trainable_variables
675
- ))
676
-
677
- return loss
678
-
679
- def _get_lr_schedule(
680
- self,
681
- total_steps: int,
682
- peak_lr: float,
683
- warmup_steps: int
684
- ) -> tf.keras.optimizers.schedules.LearningRateSchedule:
685
- """Enhanced learning rate schedule with better error handling and logging."""
686
- class CustomSchedule(tf.keras.optimizers.schedules.LearningRateSchedule):
687
- def __init__(
688
- self,
689
- total_steps: int,
690
- peak_lr: float,
691
- warmup_steps: int
692
- ):
693
- super().__init__()
694
- self.total_steps = tf.cast(total_steps, tf.float32)
695
- self.peak_lr = tf.cast(peak_lr, tf.float32)
696
- self.warmup_steps = tf.cast(max(1, warmup_steps), tf.float32) # Prevent 0
697
-
698
- # Calculate and store constants
699
- self.initial_lr = self.peak_lr * 0.1 # Start at 10% of peak
700
- self.min_lr = self.peak_lr * 0.01 # Minimum 1% of peak
701
-
702
- logger.info(f"Learning rate schedule initialized:")
703
- logger.info(f" Initial LR: {float(self.initial_lr):.6f}")
704
- logger.info(f" Peak LR: {float(self.peak_lr):.6f}")
705
- logger.info(f" Min LR: {float(self.min_lr):.6f}")
706
- logger.info(f" Warmup steps: {int(self.warmup_steps)}")
707
- logger.info(f" Total steps: {int(self.total_steps)}")
708
-
709
- def __call__(self, step):
710
- step = tf.cast(step, tf.float32)
711
-
712
- # Warmup phase
713
- warmup_factor = tf.minimum(1.0, step / self.warmup_steps)
714
- warmup_lr = self.initial_lr + (self.peak_lr - self.initial_lr) * warmup_factor
715
-
716
- # Decay phase
717
- decay_steps = tf.maximum(1.0, self.total_steps - self.warmup_steps)
718
- decay_factor = (step - self.warmup_steps) / decay_steps
719
- decay_factor = tf.minimum(tf.maximum(0.0, decay_factor), 1.0) # Clip to [0,1]
720
-
721
- cosine_decay = 0.5 * (1.0 + tf.cos(np.pi * decay_factor))
722
- decay_lr = self.min_lr + (self.peak_lr - self.min_lr) * cosine_decay
723
-
724
- # Choose between warmup and decay
725
- final_lr = tf.where(step < self.warmup_steps, warmup_lr, decay_lr)
726
-
727
- # Ensure learning rate is valid
728
- final_lr = tf.maximum(self.min_lr, final_lr)
729
- final_lr = tf.where(tf.math.is_finite(final_lr), final_lr, self.min_lr)
730
-
731
- return final_lr
732
-
733
- def get_config(self):
734
- return {
735
- "total_steps": self.total_steps,
736
- "peak_lr": self.peak_lr,
737
- "warmup_steps": self.warmup_steps,
738
- }
739
-
740
- return CustomSchedule(total_steps, peak_lr, warmup_steps)
741
-
742
- def encode_query(self, query: str, context: Optional[List[Tuple[str, str]]] = None) -> tf.Tensor:
743
- """Encode a query with optional conversation context."""
744
- # Prepare query with context
745
- if context:
746
- context_str = ' '.join([
747
- f"{self.special_tokens['user']} {q} "
748
- f"{self.special_tokens['assistant']} {r}"
749
- for q, r in context[-self.config.max_context_turns:]
750
- ])
751
- query = f"{context_str} {self.special_tokens['user']} {query}"
752
- else:
753
- query = f"{self.special_tokens['user']} {query}"
754
-
755
- # Tokenize and pad using TensorFlow tensors
756
- encodings = self.tokenizer(
757
- [query],
758
- padding='max_length',
759
- truncation=True,
760
- max_length=self.config.max_sequence_length,
761
- return_tensors='tf'
762
- )
763
- input_ids = encodings['input_ids']
764
-
765
- return self.query_encoder(input_ids, training=False)
766
-
767
- def encode_responses(self, responses: List[str]) -> tf.Tensor:
768
- """Encode a batch of responses."""
769
- # Prepare responses
770
- responses = [
771
- f"{self.special_tokens['assistant']} {r}"
772
- for r in responses
773
- ]
774
-
775
- # Tokenize and pad using TensorFlow tensors
776
- encodings = self.tokenizer(
777
- responses,
778
- padding='max_length',
779
- truncation=True,
780
- max_length=self.config.max_sequence_length,
781
- return_tensors='tf'
782
- )
783
- input_ids = encodings['input_ids']
784
-
785
- return self.response_encoder(input_ids, training=False)
786
-
787
- def retrieve_responses(
788
- self,
789
- query: str,
790
- candidates: List[str],
791
- context: Optional[List[Tuple[str, str]]] = None,
792
- top_k: int = 5
793
- ) -> List[Tuple[str, float]]:
794
- """Retrieve top-k responses for a query."""
795
- # Encode query and candidates
796
- q_emb = self.encode_query(query, context)
797
- c_emb = self.encode_responses(candidates)
798
-
799
- # Calculate similarities
800
- similarities = tf.matmul(q_emb, c_emb, transpose_b=True).numpy()[0]
801
-
802
- # Get top-k responses
803
- top_indices = np.argsort(similarities)[::-1][:top_k]
804
-
805
- return [(candidates[i], similarities[i]) for i in top_indices]
806
-
807
- def chat(
808
- self,
809
- query: str,
810
- response_pool: List[str],
811
- conversation_history: Optional[List[Tuple[str, str]]] = None,
812
- top_k: int = 5
813
- ) -> Tuple[str, List[Tuple[str, float]]]:
814
- """Interactive chat with response selection."""
815
- # Get responses with scores
816
- responses = self.retrieve_responses(
817
- query,
818
- response_pool,
819
- conversation_history,
820
- top_k
821
- )
822
-
823
- # Return best response and all candidates with scores
824
- return responses[0][0], responses
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
chatbot4.py → chatbot_model.py RENAMED
@@ -2,30 +2,25 @@ from transformers import TFAutoModel, AutoTokenizer
2
  import tensorflow as tf
3
  import numpy as np
4
  from typing import List, Tuple, Dict, Optional, Union, Any
 
5
  from dataclasses import dataclass
6
- import logging
7
  import json
8
  from tqdm import tqdm
9
  from pathlib import Path
 
10
  import faiss
11
  from response_quality_checker import ResponseQualityChecker
12
-
13
- policy = tf.keras.mixed_precision.Policy('mixed_float16')
14
- tf.keras.mixed_precision.set_global_policy(policy)
15
-
16
- # Configure logging
17
- logging.basicConfig(
18
- level=logging.DEBUG,
19
- format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
20
- )
21
- logger = logging.getLogger(__name__)
22
 
23
  @dataclass
24
  class ChatbotConfig:
25
  """Configuration for the RetrievalChatbot."""
26
  vocab_size: int = 30526 # DistilBERT vocab size
27
- max_sequence_length: int = 512
28
- embedding_dim: int = 768 # Match DistilBERT's dimension
29
  encoder_units: int = 256
30
  num_attention_heads: int = 8
31
  dropout_rate: float = 0.2
@@ -70,13 +65,20 @@ class EncoderModel(tf.keras.Model):
70
  # Freeze pretrained weights if specified
71
  self.pretrained.distilbert.embeddings.trainable = False
72
  for i, layer_module in enumerate(self.pretrained.distilbert.transformer.layer):
73
- if i < 3: # freeze first 2 layers
74
  layer_module.trainable = False
75
  else:
76
  layer_module.trainable = True
77
 
78
  # Pooling layer (Global Average Pooling)
79
  self.pooler = tf.keras.layers.GlobalAveragePooling1D()
 
 
 
 
 
 
 
80
 
81
  # Dropout and normalization
82
  self.dropout = tf.keras.layers.Dropout(config.dropout_rate)
@@ -90,14 +92,11 @@ class EncoderModel(tf.keras.Model):
90
  pretrained_outputs = self.pretrained(inputs, training=training)
91
  x = pretrained_outputs.last_hidden_state # Shape: [batch_size, seq_len, embedding_dim]
92
 
93
- # Apply pooling
94
- x = self.pooler(x) # Shape: [batch_size, embedding_dim]
95
-
96
- # Apply dropout
97
- x = self.dropout(x, training=training)
98
-
99
- # L2 normalization
100
- x = self.normalize(x) # Shape: [batch_size, embedding_dim]
101
 
102
  return x
103
 
@@ -110,42 +109,34 @@ class EncoderModel(tf.keras.Model):
110
  "name": self.name
111
  })
112
  return config
113
-
114
- # class CustomLRSchedule(tf.keras.optimizers.schedules.LearningRateSchedule):
115
- # def __init__(self, initial_lr, peak_lr, min_lr, warmup_steps, total_steps):
116
- # super().__init__()
117
- # self.initial_lr = initial_lr
118
- # self.peak_lr = peak_lr
119
- # self.min_lr = min_lr
120
- # self.warmup_steps = min(warmup_steps, total_steps // 2) # Ensure warmup_steps <= total_steps
121
- # self.total_steps = total_steps
122
-
123
- # def __call__(self, step):
124
- # if step < self.warmup_steps:
125
- # # Linear warmup
126
- # lr = self.initial_lr + (self.peak_lr - self.initial_lr) * (step / self.warmup_steps)
127
- # else:
128
- # # Linear decay
129
- # decay_steps = self.total_steps - self.warmup_steps
130
- # if decay_steps > 0:
131
- # lr = self.peak_lr - (self.peak_lr - self.min_lr) * ((step - self.warmup_steps) / decay_steps)
132
- # else:
133
- # lr = self.peak_lr
134
- # return lr
135
-
136
- # def get_config(self):
137
- # return {
138
- # "initial_lr": self.initial_lr,
139
- # "peak_lr": self.peak_lr,
140
- # "min_lr": self.min_lr,
141
- # "warmup_steps": self.warmup_steps,
142
- # "total_steps": self.total_steps,
143
- # }
144
 
145
- class RetrievalChatbot:
146
  """Retrieval-based chatbot using pretrained embeddings and FAISS for similarity search."""
147
- def __init__(self, config: ChatbotConfig, dialogues: List[dict] = []):
148
  self.config = config
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
149
 
150
  # Special tokens
151
  self.special_tokens = {
@@ -161,8 +152,12 @@ class RetrievalChatbot:
161
  {'additional_special_tokens': list(self.special_tokens.values())}
162
  )
163
 
164
- # Build encoders
165
- self._build_models()
 
 
 
 
166
 
167
  # Initialize FAISS index
168
  self._initialize_faiss()
@@ -193,31 +188,41 @@ class RetrievalChatbot:
193
  self.encoder.pretrained.resize_token_embeddings(new_vocab_size)
194
  logger.info(f"Token embeddings resized to: {new_vocab_size}")
195
 
196
- # Inspect embeddings attributes for debugging
197
  logger.info("Inspecting embeddings attributes:")
198
  for attr in dir(self.encoder.pretrained.distilbert.embeddings):
199
  if not attr.startswith('_'):
200
  logger.info(f" {attr}")
201
 
202
- # Verify embedding layers without accessing word_embeddings directly
203
- embedding_dim = getattr(self.encoder.pretrained.distilbert.embeddings, 'embedding_dim', 'Unknown')
204
- vocab_size = getattr(self.encoder.pretrained.distilbert.embeddings, 'input_dim', len(self.tokenizer))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
205
  logger.info(f"Encoder Embedding Dimension: {embedding_dim}")
206
  logger.info(f"Encoder Embedding Vocabulary Size: {vocab_size}")
207
-
208
- logger.info("Encoder model built and embeddings resized successfully.")
209
- for var in self.encoder.pretrained.trainable_variables:
210
- logger.info(f"{var.name}, {var.shape}")
211
-
212
- def check_trainable_variables(self):
213
- """Logs the trainable variables in both encoders."""
214
- logger.info("Checking trainable variables in shared_encoder:")
215
- for var in self.encoder.pretrained.trainable_variables:
216
- logger.info(f" {var.name}, shape: {var.shape}")
217
-
218
- # logger.info("Checking trainable variables in response_encoder:")
219
- # for var in self.response_encoder.pretrained.trainable_variables:
220
- # logger.info(f" {var.name}, shape: {var.shape}")
221
 
222
  def _initialize_faiss(self):
223
  """Initialize FAISS index based on available resources."""
@@ -239,10 +244,10 @@ class RetrievalChatbot:
239
  self.index = faiss.IndexFlatIP(self.config.embedding_dim)
240
  logger.info("FAISS index initialized.")
241
 
242
- def verify_faiss_index(chatbot):
243
  """Verify that FAISS index matches the response pool."""
244
- indexed_size = chatbot.index.ntotal
245
- pool_size = len(chatbot.response_pool)
246
  logger.info(f"FAISS index size: {indexed_size}")
247
  logger.info(f"Response pool size: {pool_size}")
248
  if indexed_size != pool_size:
@@ -268,12 +273,12 @@ class RetrievalChatbot:
268
  logger.info(f"Found {len(unique_responses)} unique responses.")
269
 
270
  # Encode responses
 
271
  response_embeddings = self.encode_responses(unique_responses)
272
  response_embeddings = response_embeddings.numpy()
273
 
274
  # Ensure float32
275
  if response_embeddings.dtype != np.float32:
276
- logger.info(f"Converting embeddings from {response_embeddings.dtype} to float32.")
277
  response_embeddings = response_embeddings.astype('float32')
278
 
279
  # Ensure the array is contiguous in memory
@@ -312,14 +317,8 @@ class RetrievalChatbot:
312
  Returns:
313
  tf.Tensor: Tensor of shape (N, emb_dim) with all response embeddings.
314
  """
315
- logger.info(f"Encoding {len(responses)} responses in batches of size {batch_size}...")
316
-
317
- # We'll accumulate embeddings in a list and concatenate at the end
318
  all_embeddings = []
319
-
320
- # Set up a progress bar
321
- from tqdm import tqdm
322
- pbar = tqdm(total=len(responses), desc="Encoding responses")
323
 
324
  # Process the responses in chunks of 'batch_size'
325
  for start_idx in range(0, len(responses), batch_size):
@@ -331,7 +330,7 @@ class RetrievalChatbot:
331
  batch_texts,
332
  padding='max_length',
333
  truncation=True,
334
- max_length=self.config.max_sequence_length,
335
  return_tensors='tf',
336
  )
337
 
@@ -346,11 +345,6 @@ class RetrievalChatbot:
346
  # Collect
347
  all_embeddings.append(embeddings_batch)
348
 
349
- # Update progress bar
350
- pbar.update(len(batch_texts))
351
-
352
- pbar.close()
353
-
354
  # Concatenate all batch embeddings along axis=0
355
  if len(all_embeddings) == 1:
356
  # Only one batch
@@ -359,10 +353,6 @@ class RetrievalChatbot:
359
  # Multiple batches, concatenate
360
  final_embeddings = tf.concat(all_embeddings, axis=0)
361
 
362
- logger.info(
363
- f"Finished encoding {len(responses)} responses. "
364
- f"Final shape: {final_embeddings.shape}"
365
- )
366
  return final_embeddings
367
 
368
  def encode_query(self, query: str, context: Optional[List[Tuple[str, str]]] = None) -> tf.Tensor:
@@ -383,7 +373,7 @@ class RetrievalChatbot:
383
  [query],
384
  padding='max_length',
385
  truncation=True,
386
- max_length=self.config.max_sequence_length,
387
  return_tensors='tf'
388
  )
389
  input_ids = encodings['input_ids']
@@ -391,7 +381,6 @@ class RetrievalChatbot:
391
  # Verify token IDs
392
  max_id = tf.reduce_max(input_ids).numpy()
393
  new_vocab_size = len(self.tokenizer)
394
- logger.info(f"Maximum input_id: {max_id}, Vocab Size: {new_vocab_size}")
395
 
396
  if max_id >= new_vocab_size:
397
  logger.error(f"Token ID {max_id} exceeds the vocabulary size {new_vocab_size}.")
@@ -399,6 +388,46 @@ class RetrievalChatbot:
399
 
400
  # Get embeddings from the shared encoder
401
  return self.encoder(input_ids, training=False)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
402
 
403
  def retrieve_responses_faiss(self, query: str, top_k: int = 5) -> List[Tuple[str, float]]:
404
  """Retrieve top-k responses using FAISS."""
@@ -456,10 +485,6 @@ class RetrievalChatbot:
456
  load_dir / "shared_encoder",
457
  config=config
458
  )
459
- # chatbot.response_encoder.pretrained = TFAutoModel.from_pretrained(
460
- # load_dir / "response_encoder",
461
- # config=config
462
- # )
463
 
464
  # Load tokenizer
465
  chatbot.tokenizer = AutoTokenizer.from_pretrained(load_dir / "tokenizer")
@@ -498,77 +523,137 @@ class RetrievalChatbot:
498
  def prepare_dataset(
499
  self,
500
  dialogues: List[dict],
 
501
  debug_samples: int = None
502
  ) -> Tuple[tf.Tensor, tf.Tensor]:
503
  """
504
- Prepares dataset for in-batch negatives:
505
- Only returns (query, positive) pairs.
 
 
 
 
 
 
506
  """
507
- logger.info("Preparing in-batch dataset...")
 
508
 
509
  queries, positives = [], []
510
 
 
511
  for dialogue in dialogues:
512
  turns = dialogue.get('turns', [])
513
  for i in range(len(turns) - 1):
514
  current_turn = turns[i]
515
  next_turn = turns[i+1]
516
 
517
- if (current_turn.get('speaker') == 'user' and
518
- next_turn.get('speaker') == 'assistant' and
519
- 'text' in current_turn and
520
- 'text' in next_turn):
521
 
522
- query = current_turn['text'].strip()
523
- positive = next_turn['text'].strip()
524
 
525
- queries.append(query)
526
- positives.append(positive)
527
 
528
- # Optional debug slicing
529
  if debug_samples is not None:
530
  queries = queries[:debug_samples]
531
  positives = positives[:debug_samples]
532
  logger.info(f"Debug mode: limited to {debug_samples} pairs.")
533
 
534
- logger.info(f"Prepared {len(queries)} (query, positive) pairs.")
 
 
 
 
 
 
 
 
 
 
 
 
 
535
 
536
- # Tokenize queries
 
 
 
 
 
537
  encoded_queries = self.tokenizer(
538
- queries,
539
  padding='max_length',
540
  truncation=True,
541
- max_length=self.config.max_sequence_length,
542
  return_tensors='tf'
543
  )
544
- # Tokenize positives
545
  encoded_positives = self.tokenizer(
546
- positives,
547
  padding='max_length',
548
  truncation=True,
549
- max_length=self.config.max_sequence_length,
550
  return_tensors='tf'
551
  )
552
 
553
  q_tensor = encoded_queries['input_ids']
554
  p_tensor = encoded_positives['input_ids']
555
 
556
- logger.info("Tokenized and padded sequences for in-batch training.")
557
  return q_tensor, p_tensor
558
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
559
  def train(
560
  self,
561
  q_pad: tf.Tensor,
562
  p_pad: tf.Tensor,
563
- epochs: int,
564
- batch_size: int,
565
- validation_split: float,
566
- checkpoint_dir: str,
567
  use_lr_schedule: bool = True,
568
  peak_lr: float = 2e-5,
569
  warmup_steps_ratio: float = 0.1,
570
  early_stopping_patience: int = 3,
571
- min_delta: float = 1e-4
 
572
  ):
573
  dataset_size = tf.shape(q_pad)[0].numpy()
574
  val_size = int(dataset_size * validation_split)
@@ -604,21 +689,20 @@ class RetrievalChatbot:
604
  val_q = q_pad[train_size:]
605
  val_p = p_pad[train_size:]
606
 
607
- train_dataset = tf.data.Dataset.from_tensor_slices((train_q, train_p))
608
- train_dataset = train_dataset.shuffle(buffer_size=4096).batch(batch_size)
 
 
609
 
610
- val_dataset = tf.data.Dataset.from_tensor_slices((val_q, val_p))
611
- val_dataset = val_dataset.batch(batch_size)
 
612
 
613
  # 3) Checkpoint + manager
614
  checkpoint = tf.train.Checkpoint(optimizer=self.optimizer, model=self.encoder)
615
  manager = tf.train.CheckpointManager(checkpoint, checkpoint_dir, max_to_keep=3)
616
 
617
  # 4) TensorBoard setup
618
- import datetime
619
- import os
620
- from pathlib import Path
621
-
622
  log_dir = Path(checkpoint_dir) / "tensorboard_logs"
623
  log_dir.mkdir(parents=True, exist_ok=True)
624
 
@@ -638,48 +722,91 @@ class RetrievalChatbot:
638
  logger.info("Beginning training loop...")
639
  global_step = 0
640
 
 
 
 
 
 
641
  from tqdm import tqdm
642
  for epoch in range(1, epochs + 1):
643
  logger.info(f"\n=== Epoch {epoch}/{epochs} ===")
644
  epoch_loss_avg = tf.keras.metrics.Mean()
645
 
646
- # Training loop
647
  with tqdm(total=steps_per_epoch, desc=f"Training Epoch {epoch}") as pbar:
648
  for (q_batch, p_batch) in train_dataset:
 
649
  global_step += 1
650
 
651
- # Train step
652
- batch_loss = self._train_step(q_batch, p_batch)
653
- epoch_loss_avg(batch_loss)
654
-
655
- # Get current LR
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
656
  if use_lr_schedule:
 
657
  lr = self.optimizer.learning_rate
658
  if isinstance(lr, tf.keras.optimizers.schedules.LearningRateSchedule):
659
- # Get the current step
660
  current_step = tf.cast(self.optimizer.iterations, tf.float32)
661
- # Compute the current learning rate
662
  current_lr = lr(current_step)
663
  else:
664
- # If learning_rate is not a schedule, use it directly
665
  current_lr = lr
666
- # Convert to float for logging
667
  current_lr_value = float(current_lr.numpy())
668
  else:
669
- # If using fixed learning rate
670
  current_lr_value = float(self.optimizer.learning_rate.numpy())
671
 
672
- # Update tqdm
673
  pbar.update(1)
674
  pbar.set_postfix({
675
- "loss": f"{batch_loss.numpy():.4f}",
676
  "lr": f"{current_lr_value:.2e}"
677
  })
678
 
679
- # TensorBoard: log train metrics per step
680
- with train_summary_writer.as_default():
681
- tf.summary.scalar("loss", batch_loss, step=global_step)
682
- tf.summary.scalar("learning_rate", current_lr_value, step=global_step)
 
 
 
 
 
 
 
 
 
 
 
 
683
 
684
  # Validation
685
  val_loss_avg = tf.keras.metrics.Mean()
@@ -726,90 +853,6 @@ class RetrievalChatbot:
726
 
727
  logger.info("In-batch training completed!")
728
 
729
- @tf.function
730
- def _train_step(self, q_batch, p_batch):
731
- """
732
- Single training step using in-batch negatives.
733
- q_batch: (batch_size, seq_len) int32 input_ids for queries
734
- p_batch: (batch_size, seq_len) int32 input_ids for positives
735
- """
736
- with tf.GradientTape() as tape:
737
- # Encode queries and positives
738
- q_enc = self.encoder(q_batch, training=True) # [B, emb_dim]
739
- p_enc = self.encoder(p_batch, training=True) # [B, emb_dim]
740
-
741
- # Compute similarity matrix: (B, B) = q_enc * p_enc^T
742
- # If embeddings are L2-normalized, this is cosine similarity
743
- sim_matrix = tf.matmul(q_enc, p_enc, transpose_b=True) # [B, B]
744
-
745
- # Labels are just the diagonal indices
746
- batch_size = tf.shape(q_enc)[0]
747
- labels = tf.range(batch_size, dtype=tf.int32) # [0..B-1]
748
-
749
- # Softmax cross-entropy
750
- loss = tf.nn.sparse_softmax_cross_entropy_with_logits(
751
- labels=labels,
752
- logits=sim_matrix
753
- )
754
- loss = tf.reduce_mean(loss)
755
-
756
- # Compute gradients for the pretrained DistilBERT variables only
757
- train_vars = self.encoder.pretrained.trainable_variables
758
- gradients = tape.gradient(loss, train_vars)
759
-
760
- # Remove any None grads (in case some layers are frozen)
761
- grads_and_vars = [(g, v) for g, v in zip(gradients, train_vars) if g is not None]
762
- if grads_and_vars:
763
- self.optimizer.apply_gradients(grads_and_vars)
764
-
765
- return loss
766
-
767
- def _prepare_sequences(
768
- self,
769
- queries: List[str],
770
- positives: List[str],
771
- negatives: List[str]
772
- ) -> Tuple[tf.Tensor, tf.Tensor, tf.Tensor]:
773
- """Prepare and tokenize sequences for training."""
774
- logger.info("Preparing sequences for training...")
775
-
776
- # Handle empty lists
777
- if not queries:
778
- logger.error("No queries to encode. Skipping sequence preparation.")
779
- return tf.constant([]), tf.constant([]), tf.constant([])
780
-
781
- # Process texts with special tokens
782
- queries = [f"{self.special_tokens['user']} {q}" for q in queries]
783
- positives = [f"{self.special_tokens['assistant']} {p}" for p in positives]
784
- negatives = [f"{self.special_tokens['assistant']} {n}" for n in negatives]
785
-
786
- # Tokenize using HuggingFace tokenizer
787
- def encode_batch(texts: List[str]) -> tf.Tensor:
788
- if not texts:
789
- logger.error("Empty text list provided to tokenizer.")
790
- return tf.constant([])
791
- encodings = self.tokenizer(
792
- texts,
793
- padding='max_length',
794
- truncation=True,
795
- max_length=self.config.max_sequence_length,
796
- return_tensors='tf'
797
- )
798
- return encodings['input_ids']
799
-
800
- # Encode all sequences
801
- q_tensor = encode_batch(queries)
802
- p_tensor = encode_batch(positives)
803
- n_tensor = encode_batch(negatives)
804
-
805
- # Log statistics about encoded sequences
806
- logger.info("Sequence statistics:")
807
- logger.info(f" Query sequence shape: {q_tensor.shape}")
808
- logger.info(f" Positive response sequence shape: {p_tensor.shape}")
809
- logger.info(f" Negative response sequence shape: {n_tensor.shape}")
810
-
811
- return q_tensor, p_tensor, n_tensor
812
-
813
  def _get_lr_schedule(
814
  self,
815
  total_steps: int,
@@ -855,7 +898,7 @@ class RetrievalChatbot:
855
  decay_factor = (step - self.warmup_steps) / decay_steps
856
  decay_factor = tf.minimum(tf.maximum(0.0, decay_factor), 1.0) # Clip to [0,1]
857
 
858
- cosine_decay = 0.5 * (1.0 + tf.cos(np.pi * decay_factor))
859
  decay_lr = self.min_lr + (self.peak_lr - self.min_lr) * cosine_decay
860
 
861
  # Choose between warmup and decay
@@ -881,411 +924,334 @@ class RetrievalChatbot:
881
  normalized_emb1 = emb1 / np.linalg.norm(emb1, axis=1, keepdims=True)
882
  normalized_emb2 = emb2 / np.linalg.norm(emb2, axis=1, keepdims=True)
883
  return np.dot(normalized_emb1, normalized_emb2.T)
884
-
885
- def run_automatic_validation(
886
- self,
887
- quality_checker: 'ResponseQualityChecker',
888
- num_examples: int = 5
889
- ) -> Dict[str, Any]:
890
- """
891
- Run automatic validation with quality metrics using FAISS-based retrieval.
892
- """
893
- logger.info("\n=== Running Automatic Validation ===")
894
-
895
- test_queries = [
896
- "Hello, how are you today?",
897
- "What's the weather like?",
898
- "Can you help me with a problem?",
899
- "Tell me a joke",
900
- "What time is it?",
901
- "I need help with my homework",
902
- "Where's a good place to eat?",
903
- "What movies are playing?",
904
- "How do I reset my password?",
905
- "Can you recommend a book?"
906
- ]
907
-
908
- test_queries = test_queries[:num_examples]
909
- metrics_history = []
910
-
911
- for i, query in enumerate(test_queries, 1):
912
- logger.info(f"\nTest Case {i}:")
913
- logger.info(f"Query: {query}")
914
-
915
- # Get responses and scores using FAISS
916
- responses = self.retrieve_responses_faiss(query, top_k=5)
917
-
918
- # Check quality
919
- quality_metrics = quality_checker.check_response_quality(query, responses)
920
- metrics_history.append(quality_metrics)
921
-
922
- # Log results
923
- logger.info(f"Quality Metrics: {quality_metrics}")
924
- logger.info("Top responses:")
925
- for j, (response, score) in enumerate(responses[:3], 1):
926
- logger.info(f"{j}. Score: {score:.4f}")
927
- logger.info(f" Response: {response}")
928
- if j == 1 and not quality_metrics.get('is_confident', False):
929
- logger.info(" [Low Confidence - Would abstain from answering]")
930
-
931
- # Calculate aggregate metrics
932
- aggregate_metrics = {
933
- 'num_queries_tested': len(test_queries),
934
- 'avg_top_response_score': np.mean([m.get('top_score', 0) for m in metrics_history]),
935
- 'avg_diversity': np.mean([m.get('response_diversity', 0) for m in metrics_history]),
936
- 'avg_relevance': np.mean([m.get('query_response_relevance', 0) for m in metrics_history]),
937
- 'avg_length_score': np.mean([m.get('response_length_score', 0) for m in metrics_history]),
938
- 'avg_score_gap': np.mean([m.get('top_3_score_gap', 0) for m in metrics_history]),
939
- 'confidence_rate': np.mean([m.get('is_confident', False) for m in metrics_history]),
940
- }
941
-
942
- logger.info("\n=== Validation Summary ===")
943
- for metric, value in aggregate_metrics.items():
944
- logger.info(f"{metric}: {value:.4f}")
945
-
946
- return aggregate_metrics
947
 
948
  def chat(
949
  self,
950
  query: str,
951
  conversation_history: Optional[List[Tuple[str, str]]] = None,
952
  quality_checker: Optional['ResponseQualityChecker'] = None,
953
- top_k: int = 5
954
  ) -> Tuple[str, List[Tuple[str, float]], Dict[str, Any]]:
955
  """
956
- Interactive chat function with quality checking using FAISS-based retrieval.
957
-
958
- Args:
959
- query (str): The user's input query.
960
- conversation_history (Optional[List[Tuple[str, str]]]): List of past (user, assistant) exchanges.
961
- quality_checker (Optional['ResponseQualityChecker']): Quality checker instance.
962
- top_k (int): Number of top responses to retrieve.
963
-
964
- Returns:
965
- Tuple[str, List[Tuple[str, float]], Dict[str, Any]]: (Response, Candidates, Quality Metrics)
966
  """
967
- # Retrieve responses using FAISS
968
- responses = self.retrieve_responses_faiss(query, top_k)
969
-
970
- # If no quality checker provided, return the top response
971
- if quality_checker is None:
972
- return responses[0][0] if responses else "I'm sorry, I don't have an answer for that.", responses, {}
973
-
974
- # Check quality
975
- quality_metrics = quality_checker.check_response_quality(query, responses)
976
-
977
- if quality_metrics.get('is_confident', False):
978
- return responses[0][0], responses, quality_metrics
979
- else:
980
- uncertainty_response = (
981
- "I apologize, but I don't feel confident providing an answer to that "
982
- "question at the moment. Could you please rephrase or ask something else?"
983
- )
984
- return uncertainty_response, responses, quality_metrics
985
-
986
- # TODO: consider removal
987
- # def prepare_dataset(self, dialogues: List[dict], neg_samples_per_pos: int = 1, debug_samples: int = None) -> Tuple[tf.Tensor, tf.Tensor, tf.Tensor]:
988
- # """Prepares the dataset for training."""
989
- # logger.info("Preparing dataset...")
990
-
991
- # # Extract (query, positive, negative) triples
992
- # queries, positives, negatives = [], [], []
993
-
994
- # for dialogue in dialogues:
995
- # turns = dialogue.get('turns', [])
996
- # for i in range(len(turns) - 1):
997
- # current_turn = turns[i]
998
- # next_turn = turns[i+1]
999
-
1000
- # if (current_turn.get('speaker') == 'user' and
1001
- # next_turn.get('speaker') == 'assistant' and
1002
- # 'text' in current_turn and
1003
- # 'text' in next_turn):
1004
-
1005
- # query = current_turn['text'].strip()
1006
- # positive = next_turn['text'].strip()
1007
-
1008
- # # Generate hard negative samples
1009
- # hard_negatives = self.hard_negative_sampling(positive, n_samples=neg_samples_per_pos)
1010
- # for negative in hard_negatives:
1011
- # negatives.append(negative)
1012
- # queries.append(query)
1013
- # positives.append(positive)
1014
-
1015
- # logger.info(f"Prepared {len(queries)} training examples.")
1016
-
1017
- # # Tokenize and pad sequences
1018
- # encoded_queries = self.tokenizer(
1019
- # queries,
1020
- # padding='max_length',
1021
- # truncation=True,
1022
- # max_length=self.config.max_sequence_length,
1023
- # return_tensors='tf'
1024
- # )
1025
- # encoded_positives = self.tokenizer(
1026
- # positives,
1027
- # padding='max_length',
1028
- # truncation=True,
1029
- # max_length=self.config.max_sequence_length,
1030
- # return_tensors='tf'
1031
- # )
1032
- # encoded_negatives = self.tokenizer(
1033
- # negatives,
1034
- # padding='max_length',
1035
- # truncation=True,
1036
- # max_length=self.config.max_sequence_length,
1037
- # return_tensors='tf'
1038
- # )
1039
-
1040
- # q_tensor = encoded_queries['input_ids']
1041
- # p_tensor = encoded_positives['input_ids']
1042
- # n_tensor = encoded_negatives['input_ids']
1043
-
1044
- # logger.info(f"Tokenized and padded sequences.")
1045
-
1046
- # return q_tensor, p_tensor, n_tensor
1047
-
1048
-
1049
- # # TODO: consider removal
1050
- # def hard_negative_sampling(self, positive_response, n_samples=1):
1051
- # """Select hard negatives based on cosine similarity."""
1052
- # try:
1053
- # # Ensure we don't request more negatives than available
1054
- # max_neg_samples = len(self.response_pool) - 1 # Exclude the positive response
1055
- # n_samples = min(n_samples, max_neg_samples)
1056
-
1057
- # if n_samples <= 0:
1058
- # logger.error("Not enough responses to sample negatives.")
1059
- # return []
1060
-
1061
- # # Encode the positive response using the chatbot's encode_responses method
1062
- # pos_emb = self.encode_responses([positive_response]).numpy()
1063
- # faiss.normalize_L2(pos_emb)
1064
- # #logger.info(f"Normalized positive embedding for response: {positive_response}")
1065
-
1066
- # # Search for the top n_samples + 1 most similar responses (including the positive itself)
1067
- # D, I = self.index.search(pos_emb, n_samples + 1)
1068
- # #logger.info(f"FAISS search results: {I}")
1069
-
1070
- # # Exclude the positive response itself (assuming it's indexed)
1071
- # negatives = []
1072
- # for i in range(n_samples):
1073
- # idx = I[0][i + 1] # Skip the first one as it's the positive
1074
- # if idx < len(self.response_pool):
1075
- # negative_response = self.response_pool[idx]
1076
- # negatives.append(negative_response)
1077
- # logger.info(f"Selected negative: {negative_response}")
1078
- # else:
1079
- # logger.warning(f"Index {idx} out of range for response_pool with size {len(self.response_pool)}.")
1080
 
1081
- # return negatives
1082
- # except Exception as e:
1083
- # logger.error(f"An error occurred during hard negative sampling: {e}")
1084
- # return []
1085
-
1086
- # def train(
1087
- # self,
1088
- # q_pad: tf.Tensor,
1089
- # p_pad: tf.Tensor,
1090
- # n_pad: tf.Tensor,
1091
- # epochs: int,
1092
- # batch_size: int,
1093
- # validation_split: float,
1094
- # checkpoint_dir: str,
1095
- # callbacks: Optional[List[tf.keras.callbacks.Callback]] = None
1096
- # ):
1097
- # """
1098
- # Train the chatbot model.
1099
-
1100
- # Args:
1101
- # q_pad (tf.Tensor): Padded query input_ids.
1102
- # p_pad (tf.Tensor): Padded positive response input_ids.
1103
- # n_pad (tf.Tensor): Padded negative response input_ids.
1104
- # epochs (int): Number of training epochs.
1105
- # batch_size (int): Training batch size.
1106
- # validation_split (float): Fraction of data to use for validation.
1107
- # checkpoint_dir (str): Directory to save model checkpoints.
1108
- # callbacks (list, optional): List of Keras callbacks.
1109
- # """
1110
- # dataset_size = tf.shape(q_pad)[0].numpy()
1111
- # val_size = int(dataset_size * validation_split)
1112
- # train_size = dataset_size - val_size
1113
-
1114
- # logger.info(f"Total samples: {dataset_size}")
1115
- # logger.info(f"Training samples: {train_size}")
1116
- # logger.info(f"Validation samples: {val_size}")
1117
-
1118
- # # Calculate steps_per_epoch
1119
- # steps_per_epoch = train_size // batch_size
1120
- # if train_size % batch_size != 0:
1121
- # steps_per_epoch += 1
1122
- # total_steps = steps_per_epoch * epochs
1123
-
1124
- # logger.info(f"Total training steps: {total_steps}")
1125
-
1126
- # # Initialize learning rate schedule with adjusted warmup_steps
1127
- # lr_schedule = self._get_lr_schedule(
1128
- # total_steps=total_steps,
1129
- # peak_lr=self.config.learning_rate,
1130
- # warmup_steps=self.config.warmup_steps
1131
- # )
1132
-
1133
- # # callbacks = []
1134
- # # if checkpoint_dir:
1135
- # # checkpoint_dir = Path(checkpoint_dir)
1136
- # # checkpoint_dir.mkdir(parents=True, exist_ok=True)
1137
-
1138
- # # # Setup checkpoint callback with correct file format
1139
- # # checkpoint_template = str(checkpoint_dir / "model_epoch_{epoch:04d}.weights.h5")
1140
- # # checkpoint_callback = tf.keras.callbacks.ModelCheckpoint(
1141
- # # checkpoint_template,
1142
- # # save_weights_only=True,
1143
- # # save_best_only=True,
1144
- # # monitor='val_loss',
1145
- # # mode='min',
1146
- # # verbose=1
1147
- # # )
1148
- # # callbacks.append(checkpoint_callback)
1149
-
1150
- # # # Early stopping callback
1151
- # # early_stopping = tf.keras.callbacks.EarlyStopping(
1152
- # # monitor='val_loss',
1153
- # # patience=5,
1154
- # # restore_best_weights=True,
1155
- # # verbose=1
1156
- # # )
1157
- # # callbacks.append(early_stopping)
1158
 
1159
- # # # TensorBoard callback
1160
- # # tensorboard_callback = tf.keras.callbacks.TensorBoard(log_dir='./logs')
1161
- # # callbacks.append(tensorboard_callback)
1162
-
1163
- # # Update optimizer with the new learning rate schedule
1164
- # self.optimizer = tf.keras.optimizers.Adam(learning_rate=lr_schedule)
1165
-
1166
- # # Split the data
1167
- # train_q = q_pad[:train_size]
1168
- # train_p = p_pad[:train_size]
1169
- # train_n = n_pad[:train_size]
1170
-
1171
- # val_q = q_pad[train_size:]
1172
- # val_p = p_pad[train_size:]
1173
- # val_n = n_pad[train_size:]
1174
 
1175
- # # Create TensorFlow datasets
1176
- # train_dataset = tf.data.Dataset.from_tensor_slices((train_q, train_p, train_n))
1177
- # train_dataset = train_dataset.shuffle(buffer_size=1000).batch(batch_size)
1178
-
1179
- # val_dataset = tf.data.Dataset.from_tensor_slices((val_q, val_p, val_n))
1180
- # val_dataset = val_dataset.batch(batch_size)
1181
-
1182
- # # Log dataset sizes
1183
- # logger.info(f"Training dataset batches: {len(list(train_dataset))}")
1184
- # logger.info(f"Validation dataset batches: {len(list(val_dataset))}")
1185
-
1186
- # # Create checkpoint manager
1187
- # checkpoint = tf.train.Checkpoint(optimizer=self.optimizer, model=self.encoder)
1188
- # manager = tf.train.CheckpointManager(checkpoint, checkpoint_dir, max_to_keep=3)
1189
-
1190
- # for epoch in range(1, epochs + 1):
1191
- # logger.info(f"Epoch {epoch}/{epochs}")
1192
- # epoch_loss_avg = tf.keras.metrics.Mean()
1193
 
1194
- # # Training loop
1195
- # for q_batch, p_batch, n_batch in train_dataset:
1196
- # batch_loss = self._train_step(q_batch, p_batch, n_batch)
1197
- # epoch_loss_avg(batch_loss)
 
 
1198
 
1199
- # # Validation loop
1200
- # val_loss_avg = tf.keras.metrics.Mean()
1201
- # try:
1202
- # for q_val, p_val, n_val in val_dataset:
1203
- # # Encode queries, positives, and negatives without training
1204
- # q_enc = self.encoder(q_val, training=False)
1205
- # p_enc = self.encoder(p_val, training=False)
1206
- # n_enc = self.encoder(n_val, training=False)
1207
-
1208
- # # Compute cosine similarities
1209
- # pos_sim = tf.reduce_sum(tf.multiply(q_enc, p_enc), axis=1)
1210
- # neg_sim = tf.reduce_sum(tf.multiply(q_enc, n_enc), axis=1)
1211
-
1212
- # # Ensure similarities are float32
1213
- # pos_sim = tf.cast(pos_sim, tf.float32)
1214
- # neg_sim = tf.cast(neg_sim, tf.float32)
1215
-
1216
- # # Compute loss with margin
1217
- # margin = tf.cast(self.config.margin, tf.float32)
1218
- # loss = tf.maximum(0.0, margin - pos_sim + neg_sim)
1219
-
1220
- # val_loss_avg(tf.reduce_mean(loss))
1221
-
1222
- # # Optional: Log individual batch validation loss
1223
- # logger.debug(f"Batch Validation Loss: {tf.reduce_mean(loss).numpy():.6f}")
1224
-
1225
- # train_loss = epoch_loss_avg.result().numpy()
1226
- # val_loss = val_loss_avg.result().numpy()
1227
- # logger.info(f"Epoch {epoch}: Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}")
1228
 
1229
- # # Save checkpoint
1230
- # manager.save()
1231
 
1232
- # # Update history
1233
- # self.history['train_loss'].append(train_loss)
1234
- # self.history['val_loss'].append(val_loss)
1235
-
1236
- # # Invoke callbacks if any
1237
- # if callbacks:
1238
- # for callback in callbacks:
1239
- # callback.on_epoch_end(epoch, logs={'loss': train_loss, 'val_loss': val_loss})
1240
-
1241
- # except tf.errors.OutOfRangeError:
1242
- # logger.warning("Validation dataset is exhausted before expected.")
1243
- # self.history['val_loss'].append(val_loss_avg.result().numpy())
1244
-
1245
- # logger.info("Training completed.")
1246
-
1247
- # @tf.function
1248
- # def _train_step(self, q_batch, p_batch, n_batch):
1249
- # """
1250
- # Performs a single training step with query, positive, and negative batches.
1251
-
1252
- # Args:
1253
- # q_batch (tf.Tensor): Batch of query input_ids.
1254
- # p_batch (tf.Tensor): Batch of positive response input_ids.
1255
- # n_batch (tf.Tensor): Batch of negative response input_ids.
1256
-
1257
- # Returns:
1258
- # tf.Tensor: Mean loss for the batch.
1259
- # """
1260
- # with tf.GradientTape() as tape:
1261
- # # Encode queries, positives, and negatives using the shared encoder
1262
- # q_enc = self.encoder(q_batch, training=True) # Shape: (batch_size, embedding_dim)
1263
- # p_enc = self.encoder(p_batch, training=True) # Shape: (batch_size, embedding_dim)
1264
- # n_enc = self.encoder(n_batch, training=True) # Shape: (batch_size, embedding_dim)
1265
-
1266
- # # Compute cosine similarities
1267
- # pos_sim = tf.reduce_sum(tf.multiply(q_enc, p_enc), axis=1) # Shape: (batch_size,)
1268
- # neg_sim = tf.reduce_sum(tf.multiply(q_enc, n_enc), axis=1) # Shape: (batch_size,)
1269
-
1270
- # # Ensure similarities are float32
1271
- # pos_sim = tf.cast(pos_sim, tf.float32)
1272
- # neg_sim = tf.cast(neg_sim, tf.float32)
1273
-
1274
- # # Compute loss with margin
1275
- # margin = tf.cast(self.config.margin, tf.float32)
1276
- # loss = tf.maximum(0.0, margin - pos_sim + neg_sim)
1277
-
1278
- # # Compute gradients and update encoder weights
1279
- # gradients = tape.gradient(loss, self.encoder.pretrained.trainable_variables)
1280
-
1281
- # # Filter out None gradients (if any)
1282
- # grads_and_vars = [
1283
- # (g, v) for g, v in zip(gradients, self.encoder.pretrained.trainable_variables)
1284
- # if g is not None
1285
- # ]
1286
-
1287
- # if grads_and_vars:
1288
- # self.optimizer.apply_gradients(grads_and_vars)
1289
-
1290
- # # Return mean loss
1291
- # return tf.reduce_mean(loss)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
  import tensorflow as tf
3
  import numpy as np
4
  from typing import List, Tuple, Dict, Optional, Union, Any
5
+ import math
6
  from dataclasses import dataclass
 
7
  import json
8
  from tqdm import tqdm
9
  from pathlib import Path
10
+ import datetime
11
  import faiss
12
  from response_quality_checker import ResponseQualityChecker
13
+ from cross_encoder_reranker import CrossEncoderReranker
14
+ from conversation_summarizer import DeviceAwareModel, Summarizer
15
+ from logger_config import config_logger
16
+ logger = config_logger(__name__)
 
 
 
 
 
 
17
 
18
  @dataclass
19
  class ChatbotConfig:
20
  """Configuration for the RetrievalChatbot."""
21
  vocab_size: int = 30526 # DistilBERT vocab size
22
+ max_context_token_limit: int = 512
23
+ embedding_dim: int = 512 # Match DistilBERT's dimension
24
  encoder_units: int = 256
25
  num_attention_heads: int = 8
26
  dropout_rate: float = 0.2
 
65
  # Freeze pretrained weights if specified
66
  self.pretrained.distilbert.embeddings.trainable = False
67
  for i, layer_module in enumerate(self.pretrained.distilbert.transformer.layer):
68
+ if i < 1: # freeze first layer
69
  layer_module.trainable = False
70
  else:
71
  layer_module.trainable = True
72
 
73
  # Pooling layer (Global Average Pooling)
74
  self.pooler = tf.keras.layers.GlobalAveragePooling1D()
75
+
76
+ # Projection layer
77
+ self.projection = tf.keras.layers.Dense(
78
+ config.embedding_dim,
79
+ activation='tanh',
80
+ name="projection"
81
+ )
82
 
83
  # Dropout and normalization
84
  self.dropout = tf.keras.layers.Dropout(config.dropout_rate)
 
92
  pretrained_outputs = self.pretrained(inputs, training=training)
93
  x = pretrained_outputs.last_hidden_state # Shape: [batch_size, seq_len, embedding_dim]
94
 
95
+ # Apply pooling, projection, dropout, and normalization
96
+ x = self.pooler(x) # Shape: [batch_size, 768]
97
+ x = self.projection(x) # Shape: [batch_size, 512]
98
+ x = self.dropout(x, training=training) # Apply dropout
99
+ x = self.normalize(x) # Shape: [batch_size, 512]
 
 
 
100
 
101
  return x
102
 
 
109
  "name": self.name
110
  })
111
  return config
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
112
 
113
+ class RetrievalChatbot(DeviceAwareModel):
114
  """Retrieval-based chatbot using pretrained embeddings and FAISS for similarity search."""
115
+ def __init__(self, config: ChatbotConfig, dialogues: List[dict] = [], device: str = None, strategy=None, reranker: Optional[CrossEncoderReranker] = None, summarizer: Optional[Summarizer] = None):
116
  self.config = config
117
+ self.strategy = strategy
118
+ self.setup_device(device)
119
+
120
+ if reranker is None:
121
+ logger.info("Creating default CrossEncoderReranker...")
122
+ reranker = CrossEncoderReranker(model_name="cross-encoder/ms-marco-MiniLM-L-12-v2")
123
+ self.reranker = reranker
124
+
125
+ if summarizer is None:
126
+ logger.info("Creating default Summarizer...")
127
+ summarizer = Summarizer(device=self.device)
128
+ self.summarizer = summarizer
129
+
130
+ # Configure XLA optimization if on GPU/TPU
131
+ if self.device in ["GPU", "TPU"]:
132
+ tf.config.optimizer.set_jit(True)
133
+ logger.info(f"XLA compilation enabled for {self.device}")
134
+
135
+ # Configure mixed precision for GPU/TPU
136
+ if self.device != "CPU":
137
+ policy = tf.keras.mixed_precision.Policy('mixed_float16')
138
+ tf.keras.mixed_precision.set_global_policy(policy)
139
+ logger.info("Mixed precision training enabled (float16)")
140
 
141
  # Special tokens
142
  self.special_tokens = {
 
152
  {'additional_special_tokens': list(self.special_tokens.values())}
153
  )
154
 
155
+ # Build encoders within device strategy scope
156
+ if self.strategy:
157
+ with self.strategy.scope():
158
+ self._build_models()
159
+ else:
160
+ self._build_models()
161
 
162
  # Initialize FAISS index
163
  self._initialize_faiss()
 
188
  self.encoder.pretrained.resize_token_embeddings(new_vocab_size)
189
  logger.info(f"Token embeddings resized to: {new_vocab_size}")
190
 
191
+ # Debug embeddings attributes
192
  logger.info("Inspecting embeddings attributes:")
193
  for attr in dir(self.encoder.pretrained.distilbert.embeddings):
194
  if not attr.startswith('_'):
195
  logger.info(f" {attr}")
196
 
197
+ # Try different ways to get embedding dimension
198
+ try:
199
+ # First try: from config
200
+ embedding_dim = self.encoder.pretrained.config.dim
201
+ logger.info("Got embedding dim from config")
202
+ except AttributeError:
203
+ try:
204
+ # Second try: from word embeddings
205
+ embedding_dim = self.encoder.pretrained.distilbert.embeddings.word_embeddings.embedding_dim
206
+ logger.info("Got embedding dim from word embeddings")
207
+ except AttributeError:
208
+ try:
209
+ # Third try: from embeddings module
210
+ embedding_dim = self.encoder.pretrained.distilbert.embeddings.embedding_dim
211
+ logger.info("Got embedding dim from embeddings module")
212
+ except AttributeError:
213
+ # Fallback to config value
214
+ embedding_dim = self.config.embedding_dim
215
+ logger.info("Using config embedding dim")
216
+
217
+ vocab_size = len(self.tokenizer)
218
+
219
  logger.info(f"Encoder Embedding Dimension: {embedding_dim}")
220
  logger.info(f"Encoder Embedding Vocabulary Size: {vocab_size}")
221
+ if vocab_size >= embedding_dim:
222
+ logger.info("Encoder model built and embeddings resized successfully.")
223
+ else:
224
+ logger.error("Vocabulary size is less than embedding dimension.")
225
+ raise ValueError("Vocabulary size is less than embedding dimension.")
 
 
 
 
 
 
 
 
 
226
 
227
  def _initialize_faiss(self):
228
  """Initialize FAISS index based on available resources."""
 
244
  self.index = faiss.IndexFlatIP(self.config.embedding_dim)
245
  logger.info("FAISS index initialized.")
246
 
247
+ def verify_faiss_index(self):
248
  """Verify that FAISS index matches the response pool."""
249
+ indexed_size = self.index.ntotal
250
+ pool_size = len(self.response_pool)
251
  logger.info(f"FAISS index size: {indexed_size}")
252
  logger.info(f"Response pool size: {pool_size}")
253
  if indexed_size != pool_size:
 
273
  logger.info(f"Found {len(unique_responses)} unique responses.")
274
 
275
  # Encode responses
276
+ logger.info("Encoding unique responses")
277
  response_embeddings = self.encode_responses(unique_responses)
278
  response_embeddings = response_embeddings.numpy()
279
 
280
  # Ensure float32
281
  if response_embeddings.dtype != np.float32:
 
282
  response_embeddings = response_embeddings.astype('float32')
283
 
284
  # Ensure the array is contiguous in memory
 
317
  Returns:
318
  tf.Tensor: Tensor of shape (N, emb_dim) with all response embeddings.
319
  """
320
+ # Accumulate embeddings in a list and concatenate at the end
 
 
321
  all_embeddings = []
 
 
 
 
322
 
323
  # Process the responses in chunks of 'batch_size'
324
  for start_idx in range(0, len(responses), batch_size):
 
330
  batch_texts,
331
  padding='max_length',
332
  truncation=True,
333
+ max_length=self.config.max_context_token_limit,
334
  return_tensors='tf',
335
  )
336
 
 
345
  # Collect
346
  all_embeddings.append(embeddings_batch)
347
 
 
 
 
 
 
348
  # Concatenate all batch embeddings along axis=0
349
  if len(all_embeddings) == 1:
350
  # Only one batch
 
353
  # Multiple batches, concatenate
354
  final_embeddings = tf.concat(all_embeddings, axis=0)
355
 
 
 
 
 
356
  return final_embeddings
357
 
358
  def encode_query(self, query: str, context: Optional[List[Tuple[str, str]]] = None) -> tf.Tensor:
 
373
  [query],
374
  padding='max_length',
375
  truncation=True,
376
+ max_length=self.config.max_context_token_limit,
377
  return_tensors='tf'
378
  )
379
  input_ids = encodings['input_ids']
 
381
  # Verify token IDs
382
  max_id = tf.reduce_max(input_ids).numpy()
383
  new_vocab_size = len(self.tokenizer)
 
384
 
385
  if max_id >= new_vocab_size:
386
  logger.error(f"Token ID {max_id} exceeds the vocabulary size {new_vocab_size}.")
 
388
 
389
  # Get embeddings from the shared encoder
390
  return self.encoder(input_ids, training=False)
391
+
392
+ def retrieve_responses_cross_encoder(
393
+ self,
394
+ query: str,
395
+ top_k: int,
396
+ reranker: Optional[CrossEncoderReranker] = None,
397
+ summarizer: Optional[Summarizer] = None,
398
+ summarize_threshold: int = 512 # Summarize over 512 tokens
399
+ ) -> List[Tuple[str, float]]:
400
+ """
401
+ Retrieve top-k from FAISS, then re-rank them with a cross-encoder.
402
+ Optionally summarize the user query if it's too long.
403
+ """
404
+ if reranker is None:
405
+ reranker = self.reranker
406
+ if summarizer is None:
407
+ summarizer = self.summarizer
408
+
409
+ # Optional summarization
410
+ if summarizer and len(query.split()) > summarize_threshold:
411
+ logger.info(f"Query is long. Summarizing before cross-encoder. Original length: {len(query.split())}")
412
+ query = summarizer.summarize_text(query)
413
+ logger.info(f"Summarized query: {query}")
414
+
415
+ # 2) Dense retrieval
416
+ dense_topk = self.retrieve_responses_faiss(query, top_k=top_k) # [(resp, dense_score), ...]
417
+
418
+ if not dense_topk:
419
+ return []
420
+
421
+ # 3) Cross-encoder rerank
422
+ candidate_texts = [pair[0] for pair in dense_topk]
423
+ cross_scores = reranker.rerank(query, candidate_texts, max_length=256)
424
+
425
+ # Combine
426
+ combined = [(text, score) for (text, _), score in zip(dense_topk, cross_scores)]
427
+ # Sort descending by cross-encoder score
428
+ combined.sort(key=lambda x: x[1], reverse=True)
429
+
430
+ return combined
431
 
432
  def retrieve_responses_faiss(self, query: str, top_k: int = 5) -> List[Tuple[str, float]]:
433
  """Retrieve top-k responses using FAISS."""
 
485
  load_dir / "shared_encoder",
486
  config=config
487
  )
 
 
 
 
488
 
489
  # Load tokenizer
490
  chatbot.tokenizer = AutoTokenizer.from_pretrained(load_dir / "tokenizer")
 
523
  def prepare_dataset(
524
  self,
525
  dialogues: List[dict],
526
+ neg_samples: int = 1,
527
  debug_samples: int = None
528
  ) -> Tuple[tf.Tensor, tf.Tensor]:
529
  """
530
+ Prepares dataset for multiple-negatives ranking,
531
+ but also appends 'hard negative' pairs for each query.
532
+
533
+ We'll generate:
534
+ - (query, positive) as usual
535
+ - (query, negative) for each query, using FAISS top-1 approx. negative.
536
+ Then, in-batch training sees them as 'two different positives'
537
+ for the same query, forcing the model to discriminate them.
538
  """
539
+
540
+ logger.info("Preparing in-batch dataset with hard negatives...")
541
 
542
  queries, positives = [], []
543
 
544
+ # Assemble (q, p)
545
  for dialogue in dialogues:
546
  turns = dialogue.get('turns', [])
547
  for i in range(len(turns) - 1):
548
  current_turn = turns[i]
549
  next_turn = turns[i+1]
550
 
551
+ if (current_turn.get('speaker') == 'user'
552
+ and next_turn.get('speaker') == 'assistant'
553
+ and 'text' in current_turn
554
+ and 'text' in next_turn):
555
 
556
+ query_text = current_turn['text'].strip()
557
+ pos_text = next_turn['text'].strip()
558
 
559
+ queries.append(query_text)
560
+ positives.append(pos_text)
561
 
562
+ # Debug slicing
563
  if debug_samples is not None:
564
  queries = queries[:debug_samples]
565
  positives = positives[:debug_samples]
566
  logger.info(f"Debug mode: limited to {debug_samples} pairs.")
567
 
568
+ logger.info(f"Prepared {len(queries)} (query, positive) pairs initially.")
569
+
570
+ # Find a hard negative from FAISS for each (q, p)
571
+ # Create a second 'positive' row => (q, negative). In-batch, it's seen as a different 'positive' row, but is a hard negative.
572
+ augmented_queries = []
573
+ augmented_positives = []
574
+
575
+ for q_text, p_text in zip(queries, positives):
576
+ neg_texts = self._find_hard_negative(q_text, p_text, top_k=5, neg_samples=neg_samples)
577
+ for neg_text in neg_texts:
578
+ augmented_queries.append(q_text)
579
+ augmented_positives.append(neg_text)
580
+
581
+ logger.info(f"Found hard negatives for {len(augmented_queries)} queries.")
582
 
583
+ # Combine them into a single big list -> Original pairs: (q, p) & Hard neg pairs: (q, n)
584
+ final_queries = queries + augmented_queries
585
+ final_positives = positives + augmented_positives
586
+ logger.info(f"Total dataset size after adding hard neg: {len(final_queries)}")
587
+
588
+ # Tokenize
589
  encoded_queries = self.tokenizer(
590
+ final_queries,
591
  padding='max_length',
592
  truncation=True,
593
+ max_length=self.config.max_context_token_limit,
594
  return_tensors='tf'
595
  )
 
596
  encoded_positives = self.tokenizer(
597
+ final_positives,
598
  padding='max_length',
599
  truncation=True,
600
+ max_length=self.config.max_context_token_limit,
601
  return_tensors='tf'
602
  )
603
 
604
  q_tensor = encoded_queries['input_ids']
605
  p_tensor = encoded_positives['input_ids']
606
 
607
+ logger.info("Tokenized and padded sequences for in-batch training + hard negatives.")
608
  return q_tensor, p_tensor
609
 
610
+ def _find_hard_negative(
611
+ self,
612
+ query_text: str,
613
+ positive_text: str,
614
+ top_k: int = 5,
615
+ neg_samples: int = 1
616
+ ) -> List[str]:
617
+ """
618
+ Return up to `neg_samples` unique negatives from top_k FAISS results,
619
+ excluding the known positive_text.
620
+ """
621
+ # Encode the query to get the embedding
622
+ query_emb = self.encode_query(query_text)
623
+ q_emb_np = query_emb.numpy().astype('float32')
624
+
625
+ # Normalize for cosine similarity
626
+ faiss.normalize_L2(q_emb_np)
627
+
628
+ # Search in FAISS
629
+ distances, indices = self.index.search(q_emb_np, top_k)
630
+
631
+ # Exclude the actual positive from these results
632
+ hard_negatives = []
633
+ for idx in indices[0]:
634
+ if idx < len(self.response_pool):
635
+ candidate = self.response_pool[idx].strip()
636
+ if candidate != positive_text.strip():
637
+ hard_negatives.append(candidate)
638
+ if len(hard_negatives) == neg_samples:
639
+ break
640
+
641
+ return hard_negatives
642
+
643
  def train(
644
  self,
645
  q_pad: tf.Tensor,
646
  p_pad: tf.Tensor,
647
+ epochs: int = 20,
648
+ batch_size: int = 16,
649
+ validation_split: float = 0.2,
650
+ checkpoint_dir: str = "checkpoints/",
651
  use_lr_schedule: bool = True,
652
  peak_lr: float = 2e-5,
653
  warmup_steps_ratio: float = 0.1,
654
  early_stopping_patience: int = 3,
655
+ min_delta: float = 1e-4,
656
+ accum_steps: int = 2 # Gradient accumulation steps
657
  ):
658
  dataset_size = tf.shape(q_pad)[0].numpy()
659
  val_size = int(dataset_size * validation_split)
 
689
  val_q = q_pad[train_size:]
690
  val_p = p_pad[train_size:]
691
 
692
+ train_dataset = (tf.data.Dataset.from_tensor_slices((train_q, train_p))
693
+ .shuffle(4096)
694
+ .batch(batch_size)
695
+ .prefetch(tf.data.AUTOTUNE))
696
 
697
+ val_dataset = (tf.data.Dataset.from_tensor_slices((val_q, val_p))
698
+ .batch(batch_size)
699
+ .prefetch(tf.data.AUTOTUNE))
700
 
701
  # 3) Checkpoint + manager
702
  checkpoint = tf.train.Checkpoint(optimizer=self.optimizer, model=self.encoder)
703
  manager = tf.train.CheckpointManager(checkpoint, checkpoint_dir, max_to_keep=3)
704
 
705
  # 4) TensorBoard setup
 
 
 
 
706
  log_dir = Path(checkpoint_dir) / "tensorboard_logs"
707
  log_dir.mkdir(parents=True, exist_ok=True)
708
 
 
722
  logger.info("Beginning training loop...")
723
  global_step = 0
724
 
725
+ # Prepare zero-initialized accumulators for your trainable variables
726
+ # We'll accumulate gradients across mini-batches, then apply them every accum_steps.
727
+ train_vars = self.encoder.pretrained.trainable_variables
728
+ accum_grads = [tf.zeros_like(var, dtype=tf.float32) for var in train_vars]
729
+
730
  from tqdm import tqdm
731
  for epoch in range(1, epochs + 1):
732
  logger.info(f"\n=== Epoch {epoch}/{epochs} ===")
733
  epoch_loss_avg = tf.keras.metrics.Mean()
734
 
735
+ step_in_epoch = 0
736
  with tqdm(total=steps_per_epoch, desc=f"Training Epoch {epoch}") as pbar:
737
  for (q_batch, p_batch) in train_dataset:
738
+ step_in_epoch += 1
739
  global_step += 1
740
 
741
+ with tf.GradientTape() as tape:
742
+ q_enc = self.encoder(q_batch, training=True)
743
+ p_enc = self.encoder(p_batch, training=True)
744
+
745
+ sim_matrix = tf.matmul(q_enc, p_enc, transpose_b=True)
746
+ bsz = tf.shape(q_enc)[0]
747
+ labels = tf.range(bsz, dtype=tf.int32)
748
+ loss_value = tf.nn.sparse_softmax_cross_entropy_with_logits(
749
+ labels=labels, logits=sim_matrix
750
+ )
751
+ loss_value = tf.reduce_mean(loss_value)
752
+
753
+ gradients = tape.gradient(loss_value, train_vars)
754
+
755
+ # -- Accumulate gradients --
756
+ for i, grad in enumerate(gradients):
757
+ if grad is not None:
758
+ accum_grads[i] += tf.cast(grad, tf.float32)
759
+
760
+ epoch_loss_avg(loss_value)
761
+
762
+ # -- Apply gradients every 'accum_steps' mini-batches --
763
+ if (step_in_epoch % accum_steps) == 0:
764
+ # Scale by 1/accum_steps so that each accumulation cycle
765
+ # is effectively the same as one “normal” update
766
+ for i in range(len(accum_grads)):
767
+ accum_grads[i] /= accum_steps
768
+
769
+ self.optimizer.apply_gradients(
770
+ [(accum_grads[i], train_vars[i]) for i in range(len(accum_grads))]
771
+ )
772
+ # Reset the accumulator
773
+ accum_grads = [tf.zeros_like(var, dtype=tf.float32) for var in train_vars]
774
+
775
+ # Logging / tqdm updates
776
  if use_lr_schedule:
777
+ # measure current LR
778
  lr = self.optimizer.learning_rate
779
  if isinstance(lr, tf.keras.optimizers.schedules.LearningRateSchedule):
 
780
  current_step = tf.cast(self.optimizer.iterations, tf.float32)
 
781
  current_lr = lr(current_step)
782
  else:
 
783
  current_lr = lr
 
784
  current_lr_value = float(current_lr.numpy())
785
  else:
 
786
  current_lr_value = float(self.optimizer.learning_rate.numpy())
787
 
 
788
  pbar.update(1)
789
  pbar.set_postfix({
790
+ "loss": f"{loss_value.numpy():.4f}",
791
  "lr": f"{current_lr_value:.2e}"
792
  })
793
 
794
+ # TensorBoard logging omitted for brevity...
795
+
796
+ # -- Handle leftover partial accumulation at epoch end --
797
+ leftover = (step_in_epoch % accum_steps)
798
+ if leftover != 0:
799
+ logger.info(f"Applying leftover accum_grads for partial batch group (size={leftover}).")
800
+ # If you want each leftover batch to contribute proportionally:
801
+ # multiply by leftover/accum_steps (this ensures leftover
802
+ # steps have the same "average" effect as a full accumulation cycle)
803
+ for i in range(len(accum_grads)):
804
+ accum_grads[i] *= float(leftover) / float(accum_steps)
805
+
806
+ self.optimizer.apply_gradients(
807
+ [(accum_grads[i], train_vars[i]) for i in range(len(accum_grads))]
808
+ )
809
+ accum_grads = [tf.zeros_like(var, dtype=tf.float32) for var in train_vars]
810
 
811
  # Validation
812
  val_loss_avg = tf.keras.metrics.Mean()
 
853
 
854
  logger.info("In-batch training completed!")
855
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
856
  def _get_lr_schedule(
857
  self,
858
  total_steps: int,
 
898
  decay_factor = (step - self.warmup_steps) / decay_steps
899
  decay_factor = tf.minimum(tf.maximum(0.0, decay_factor), 1.0) # Clip to [0,1]
900
 
901
+ cosine_decay = 0.5 * (1.0 + tf.cos(tf.constant(math.pi) * decay_factor))
902
  decay_lr = self.min_lr + (self.peak_lr - self.min_lr) * cosine_decay
903
 
904
  # Choose between warmup and decay
 
924
  normalized_emb1 = emb1 / np.linalg.norm(emb1, axis=1, keepdims=True)
925
  normalized_emb2 = emb2 / np.linalg.norm(emb2, axis=1, keepdims=True)
926
  return np.dot(normalized_emb1, normalized_emb2.T)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
927
 
928
  def chat(
929
  self,
930
  query: str,
931
  conversation_history: Optional[List[Tuple[str, str]]] = None,
932
  quality_checker: Optional['ResponseQualityChecker'] = None,
933
+ top_k: int = 5,
934
  ) -> Tuple[str, List[Tuple[str, float]], Dict[str, Any]]:
935
  """
936
+ Example chat method that always uses cross-encoder re-ranking
937
+ if self.reranker is available.
 
 
 
 
 
 
 
 
938
  """
939
+ @self.run_on_device
940
+ def get_response(self_arg, query_arg): # Add parameters that match decorator's expectations
941
+ # 1) Build conversation context string
942
+ conversation_str = self_arg._build_conversation_context(query_arg, conversation_history)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
943
 
944
+ # 2) Retrieve + cross-encoder re-rank
945
+ results = self_arg.retrieve_responses_cross_encoder(
946
+ query=conversation_str,
947
+ top_k=top_k,
948
+ reranker=self_arg.reranker,
949
+ summarizer=self_arg.summarizer,
950
+ summarize_threshold=512
951
+ )
952
+
953
+ # 3) Handle empty or confidence
954
+ if not results:
955
+ return (
956
+ "I'm sorry, but I couldn't find a relevant response.",
957
+ [],
958
+ {}
959
+ )
960
+
961
+ if quality_checker:
962
+ metrics = quality_checker.check_response_quality(query_arg, results)
963
+ if not metrics.get('is_confident', False):
964
+ return (
965
+ "I need more information to provide a good answer. Could you please clarify?",
966
+ results,
967
+ metrics
968
+ )
969
+ return results[0][0], results, metrics
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
970
 
971
+ return results[0][0], results, {}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
972
 
973
+ return get_response(self, query)
974
+
975
+ def _build_conversation_context(
976
+ self,
977
+ query: str,
978
+ conversation_history: Optional[List[Tuple[str, str]]]
979
+ ) -> str:
980
+ """Build conversation context with better memory management."""
981
+ if not conversation_history:
982
+ return f"{self.special_tokens['user']} {query}"
 
 
 
 
 
 
 
 
983
 
984
+ conversation_parts = []
985
+ for user_txt, assistant_txt in conversation_history:
986
+ conversation_parts.extend([
987
+ f"{self.special_tokens['user']} {user_txt}",
988
+ f"{self.special_tokens['assistant']} {assistant_txt}"
989
+ ])
990
 
991
+ conversation_parts.append(f"{self.special_tokens['user']} {query}")
992
+ return "\n".join(conversation_parts)
993
+
994
+ # def prepare_dataset(
995
+ # self,
996
+ # dialogues: List[dict],
997
+ # debug_samples: int = None
998
+ # ) -> Tuple[tf.Tensor, tf.Tensor]:
999
+ # """
1000
+ # Prepares dataset for in-batch negatives:
1001
+ # Only returns (query, positive) pairs.
1002
+ # """
1003
+ # logger.info("Preparing in-batch dataset...")
1004
+
1005
+ # queries, positives = [], []
1006
+
1007
+ # for dialogue in dialogues:
1008
+ # turns = dialogue.get('turns', [])
1009
+ # for i in range(len(turns) - 1):
1010
+ # current_turn = turns[i]
1011
+ # next_turn = turns[i+1]
1012
+
1013
+ # if (current_turn.get('speaker') == 'user' and
1014
+ # next_turn.get('speaker') == 'assistant' and
1015
+ # 'text' in current_turn and
1016
+ # 'text' in next_turn):
 
 
 
1017
 
1018
+ # query = current_turn['text'].strip()
1019
+ # positive = next_turn['text'].strip()
1020
 
1021
+ # queries.append(query)
1022
+ # positives.append(positive)
1023
+
1024
+ # # Optional debug slicing
1025
+ # if debug_samples is not None:
1026
+ # queries = queries[:debug_samples]
1027
+ # positives = positives[:debug_samples]
1028
+ # logger.info(f"Debug mode: limited to {debug_samples} pairs.")
1029
+
1030
+ # logger.info(f"Prepared {len(queries)} (query, positive) pairs.")
1031
+
1032
+ # # Tokenize queries
1033
+ # encoded_queries = self.tokenizer(
1034
+ # queries,
1035
+ # padding='max_length',
1036
+ # truncation=True,
1037
+ # max_length=self.config.max_sequence_length,
1038
+ # return_tensors='tf'
1039
+ # )
1040
+ # # Tokenize positives
1041
+ # encoded_positives = self.tokenizer(
1042
+ # positives,
1043
+ # padding='max_length',
1044
+ # truncation=True,
1045
+ # max_length=self.config.max_sequence_length,
1046
+ # return_tensors='tf'
1047
+ # )
1048
+
1049
+ # q_tensor = encoded_queries['input_ids']
1050
+ # p_tensor = encoded_positives['input_ids']
1051
+
1052
+ # logger.info("Tokenized and padded sequences for in-batch training.")
1053
+ # return q_tensor, p_tensor
1054
+
1055
+ # def train(
1056
+ # self,
1057
+ # q_pad: tf.Tensor,
1058
+ # p_pad: tf.Tensor,
1059
+ # epochs: int = 20,
1060
+ # batch_size: int = 16,
1061
+ # validation_split: float = 0.2,
1062
+ # checkpoint_dir: str = "checkpoints/",
1063
+ # use_lr_schedule: bool = True,
1064
+ # peak_lr: float = 2e-5,
1065
+ # warmup_steps_ratio: float = 0.1,
1066
+ # early_stopping_patience: int = 3,
1067
+ # min_delta: float = 1e-4
1068
+ # ):
1069
+ # dataset_size = tf.shape(q_pad)[0].numpy()
1070
+ # val_size = int(dataset_size * validation_split)
1071
+ # train_size = dataset_size - val_size
1072
+
1073
+ # logger.info(f"Total samples: {dataset_size}")
1074
+ # logger.info(f"Training samples: {train_size}")
1075
+ # logger.info(f"Validation samples: {val_size}")
1076
+
1077
+ # steps_per_epoch = train_size // batch_size
1078
+ # if train_size % batch_size != 0:
1079
+ # steps_per_epoch += 1
1080
+ # total_steps = steps_per_epoch * epochs
1081
+ # logger.info(f"Total training steps (approx): {total_steps}")
1082
+
1083
+ # # 1) Set up LR schedule or fixed LR
1084
+ # if use_lr_schedule:
1085
+ # warmup_steps = int(total_steps * warmup_steps_ratio)
1086
+ # lr_schedule = self._get_lr_schedule(
1087
+ # total_steps=total_steps,
1088
+ # peak_lr=peak_lr,
1089
+ # warmup_steps=warmup_steps
1090
+ # )
1091
+ # self.optimizer = tf.keras.optimizers.Adam(learning_rate=lr_schedule)
1092
+ # logger.info("Using custom learning rate schedule.")
1093
+ # else:
1094
+ # self.optimizer = tf.keras.optimizers.Adam(learning_rate=peak_lr)
1095
+ # logger.info("Using fixed learning rate.")
1096
+
1097
+ # # 2) Prepare data splits
1098
+ # train_q = q_pad[:train_size]
1099
+ # train_p = p_pad[:train_size]
1100
+ # val_q = q_pad[train_size:]
1101
+ # val_p = p_pad[train_size:]
1102
+
1103
+ # train_dataset = tf.data.Dataset.from_tensor_slices((train_q, train_p))
1104
+ # train_dataset = train_dataset.shuffle(buffer_size=4096).batch(batch_size)
1105
+
1106
+ # val_dataset = tf.data.Dataset.from_tensor_slices((val_q, val_p))
1107
+ # val_dataset = val_dataset.batch(batch_size)
1108
+
1109
+ # # 3) Checkpoint + manager
1110
+ # checkpoint = tf.train.Checkpoint(optimizer=self.optimizer, model=self.encoder)
1111
+ # manager = tf.train.CheckpointManager(checkpoint, checkpoint_dir, max_to_keep=3)
1112
+
1113
+ # # 4) TensorBoard setup
1114
+ # log_dir = Path(checkpoint_dir) / "tensorboard_logs"
1115
+ # log_dir.mkdir(parents=True, exist_ok=True)
1116
+
1117
+ # current_time = datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
1118
+ # train_log_dir = str(log_dir / f"train_{current_time}")
1119
+ # val_log_dir = str(log_dir / f"val_{current_time}")
1120
+
1121
+ # train_summary_writer = tf.summary.create_file_writer(train_log_dir)
1122
+ # val_summary_writer = tf.summary.create_file_writer(val_log_dir)
1123
+
1124
+ # logger.info(f"TensorBoard logs will be saved in {log_dir}")
1125
+
1126
+ # # 5) Early stopping
1127
+ # best_val_loss = float("inf")
1128
+ # epochs_no_improve = 0
1129
+
1130
+ # logger.info("Beginning training loop...")
1131
+ # global_step = 0
1132
+
1133
+ # from tqdm import tqdm
1134
+ # for epoch in range(1, epochs + 1):
1135
+ # logger.info(f"\n=== Epoch {epoch}/{epochs} ===")
1136
+ # epoch_loss_avg = tf.keras.metrics.Mean()
1137
+
1138
+ # # Training loop
1139
+ # with tqdm(total=steps_per_epoch, desc=f"Training Epoch {epoch}") as pbar:
1140
+ # for (q_batch, p_batch) in train_dataset:
1141
+ # global_step += 1
1142
+
1143
+ # # Train step
1144
+ # batch_loss = self._train_step(q_batch, p_batch)
1145
+ # epoch_loss_avg(batch_loss)
1146
+
1147
+ # # Get current LR
1148
+ # if use_lr_schedule:
1149
+ # lr = self.optimizer.learning_rate
1150
+ # if isinstance(lr, tf.keras.optimizers.schedules.LearningRateSchedule):
1151
+ # # Get the current step
1152
+ # current_step = tf.cast(self.optimizer.iterations, tf.float32)
1153
+ # # Compute the current learning rate
1154
+ # current_lr = lr(current_step)
1155
+ # else:
1156
+ # # If learning_rate is not a schedule, use it directly
1157
+ # current_lr = lr
1158
+ # # Convert to float for logging
1159
+ # current_lr_value = float(current_lr.numpy())
1160
+ # else:
1161
+ # # If using fixed learning rate
1162
+ # current_lr_value = float(self.optimizer.learning_rate.numpy())
1163
+
1164
+ # # Update tqdm
1165
+ # pbar.update(1)
1166
+ # pbar.set_postfix({
1167
+ # "loss": f"{batch_loss.numpy():.4f}",
1168
+ # "lr": f"{current_lr_value:.2e}"
1169
+ # })
1170
+
1171
+ # # TensorBoard: log train metrics per step
1172
+ # with train_summary_writer.as_default():
1173
+ # tf.summary.scalar("loss", batch_loss, step=global_step)
1174
+ # tf.summary.scalar("learning_rate", current_lr_value, step=global_step)
1175
+
1176
+ # # Validation
1177
+ # val_loss_avg = tf.keras.metrics.Mean()
1178
+ # for q_val, p_val in val_dataset:
1179
+ # q_enc = self.encoder(q_val, training=False)
1180
+ # p_enc = self.encoder(p_val, training=False)
1181
+ # sim_matrix = tf.matmul(q_enc, p_enc, transpose_b=True)
1182
+ # bs_val = tf.shape(q_enc)[0]
1183
+ # labels_val = tf.range(bs_val, dtype=tf.int32)
1184
+ # loss_val = tf.nn.sparse_softmax_cross_entropy_with_logits(
1185
+ # labels=labels_val,
1186
+ # logits=sim_matrix
1187
+ # )
1188
+ # val_loss_avg(tf.reduce_mean(loss_val))
1189
+
1190
+ # train_loss = epoch_loss_avg.result().numpy()
1191
+ # val_loss = val_loss_avg.result().numpy()
1192
+
1193
+ # logger.info(f"Epoch {epoch} Complete: Train Loss={train_loss:.4f}, Val Loss={val_loss:.4f}")
1194
+
1195
+ # # TensorBoard: validation loss
1196
+ # with val_summary_writer.as_default():
1197
+ # tf.summary.scalar("val_loss", val_loss, step=epoch)
1198
+
1199
+ # # Save checkpoint
1200
+ # manager.save()
1201
+
1202
+ # # Update history
1203
+ # self.history['train_loss'].append(train_loss)
1204
+ # self.history['val_loss'].append(val_loss)
1205
+ # self.history.setdefault('learning_rate', []).append(float(current_lr_value))
1206
+
1207
+ # # Early stopping
1208
+ # if val_loss < best_val_loss - min_delta:
1209
+ # best_val_loss = val_loss
1210
+ # epochs_no_improve = 0
1211
+ # logger.info(f"Validation loss improved to {val_loss:.4f}. Reset patience.")
1212
+ # else:
1213
+ # epochs_no_improve += 1
1214
+ # logger.info(f"No improvement this epoch. Patience: {epochs_no_improve}/{early_stopping_patience}")
1215
+ # if epochs_no_improve >= early_stopping_patience:
1216
+ # logger.info("Early stopping triggered.")
1217
+ # break
1218
+
1219
+ # logger.info("In-batch training completed!")
1220
+
1221
+ # @tf.function
1222
+ # def _train_step(self, q_batch, p_batch):
1223
+ # """
1224
+ # Single training step using in-batch negatives.
1225
+ # q_batch: (batch_size, seq_len) int32 input_ids for queries
1226
+ # p_batch: (batch_size, seq_len) int32 input_ids for positives
1227
+ # """
1228
+ # with tf.GradientTape() as tape:
1229
+ # # Encode queries and positives
1230
+ # q_enc = self.encoder(q_batch, training=True) # [B, emb_dim]
1231
+ # p_enc = self.encoder(p_batch, training=True) # [B, emb_dim]
1232
+
1233
+ # # Compute similarity matrix: (B, B) = q_enc * p_enc^T
1234
+ # # If embeddings are L2-normalized, this is cosine similarity
1235
+ # sim_matrix = tf.matmul(q_enc, p_enc, transpose_b=True) # [B, B]
1236
+
1237
+ # # Labels are just the diagonal indices
1238
+ # batch_size = tf.shape(q_enc)[0]
1239
+ # labels = tf.range(batch_size, dtype=tf.int32) # [0..B-1]
1240
+
1241
+ # # Softmax cross-entropy
1242
+ # loss = tf.nn.sparse_softmax_cross_entropy_with_logits(
1243
+ # labels=labels,
1244
+ # logits=sim_matrix
1245
+ # )
1246
+ # loss = tf.reduce_mean(loss)
1247
+
1248
+ # # Compute gradients for the pretrained DistilBERT variables only
1249
+ # train_vars = self.encoder.pretrained.trainable_variables
1250
+ # gradients = tape.gradient(loss, train_vars)
1251
+
1252
+ # # Remove any None grads (in case some layers are frozen)
1253
+ # grads_and_vars = [(g, v) for g, v in zip(gradients, train_vars) if g is not None]
1254
+ # if grads_and_vars:
1255
+ # self.optimizer.apply_gradients(grads_and_vars)
1256
+
1257
+ # return loss
chatbot_validator.py ADDED
@@ -0,0 +1,207 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict, List, Tuple, Any, Optional
2
+ import numpy as np
3
+ from logger_config import config_logger
4
+
5
+ logger = config_logger(__name__)
6
+
7
+ class ChatbotValidator:
8
+ """Handles automated validation and performance analysis for the chatbot."""
9
+
10
+ def __init__(self, chatbot, quality_checker):
11
+ """
12
+ Initialize the validator.
13
+
14
+ Args:
15
+ chatbot: RetrievalChatbot instance
16
+ quality_checker: ResponseQualityChecker instance
17
+ """
18
+ self.chatbot = chatbot
19
+ self.quality_checker = quality_checker
20
+
21
+ # Domain-specific test queries aligned with Taskmaster-1 and Schema-Guided
22
+ self.domain_queries = {
23
+ 'restaurant': [
24
+ "I'd like to make a reservation for dinner tonight.",
25
+ "Can you book a table for 4 people at an Italian place?",
26
+ "Do you have any availability for tomorrow at 7pm?",
27
+ "I need to change my dinner reservation time.",
28
+ "What's the wait time for a table right now?"
29
+ ],
30
+ 'movie_tickets': [
31
+ "I want to buy tickets for the new Marvel movie.",
32
+ "Are there any showings of Avatar after 6pm?",
33
+ "Can I get 3 tickets for the 8pm show?",
34
+ "What movies are playing this weekend?",
35
+ "Do you have any matinee showtimes available?"
36
+ ],
37
+ 'rideshare': [
38
+ "I need a ride from the airport to downtown.",
39
+ "How much would it cost to get to the mall?",
40
+ "Can you book a car for tomorrow morning?",
41
+ "Is there a driver available now?",
42
+ "What's the estimated arrival time?"
43
+ ],
44
+ 'services': [
45
+ "I need to schedule an oil change for my car.",
46
+ "When can I bring my car in for maintenance?",
47
+ "Do you have any openings for auto repair today?",
48
+ "How long will the service take?",
49
+ "Can I get an estimate for brake repair?"
50
+ ],
51
+ 'events': [
52
+ "I need tickets to the concert this weekend.",
53
+ "What events are happening near me?",
54
+ "Can I book seats for the basketball game?",
55
+ "Are there any comedy shows tonight?",
56
+ "How much are tickets to the theater?"
57
+ ]
58
+ }
59
+
60
+ def run_validation(
61
+ self,
62
+ num_examples: int = 10,
63
+ top_k: int = 10,
64
+ domains: Optional[List[str]] = None
65
+ ) -> Dict[str, Any]:
66
+ """
67
+ Run comprehensive validation across specified domains.
68
+
69
+ Args:
70
+ num_examples: Number of test queries per domain
71
+ top_k: Number of responses to retrieve for each query
72
+ domains: Optional list of specific domains to test
73
+
74
+ Returns:
75
+ Dict containing detailed validation metrics and domain-specific performance
76
+ """
77
+ logger.info("\n=== Running Enhanced Automatic Validation ===")
78
+
79
+ # Select domains to test
80
+ test_domains = domains if domains else list(self.domain_queries.keys())
81
+ metrics_history = []
82
+ domain_metrics = {}
83
+
84
+ # Run validation for each domain
85
+ for domain in test_domains:
86
+ domain_metrics[domain] = []
87
+ queries = self.domain_queries[domain][:num_examples]
88
+
89
+ logger.info(f"\n=== Testing {domain.title()} Domain ===")
90
+
91
+ for i, query in enumerate(queries, 1):
92
+ logger.info(f"\nTest Case {i}:")
93
+ logger.info(f"Query: {query}")
94
+
95
+ # Get responses with increased top_k
96
+ responses = self.chatbot.retrieve_responses_cross_encoder(query, top_k=top_k)
97
+
98
+ # Enhanced quality checking with context
99
+ quality_metrics = self.quality_checker.check_response_quality(query, responses)
100
+
101
+ # Add domain info
102
+ quality_metrics['domain'] = domain
103
+ metrics_history.append(quality_metrics)
104
+ domain_metrics[domain].append(quality_metrics)
105
+
106
+ # Detailed logging
107
+ self._log_validation_results(query, responses, quality_metrics, i)
108
+
109
+ # Calculate and log overall metrics
110
+ aggregate_metrics = self._calculate_aggregate_metrics(metrics_history)
111
+ domain_analysis = self._analyze_domain_performance(domain_metrics)
112
+ confidence_analysis = self._analyze_confidence_distribution(metrics_history)
113
+
114
+ aggregate_metrics.update({
115
+ 'domain_performance': domain_analysis,
116
+ 'confidence_analysis': confidence_analysis
117
+ })
118
+
119
+ self._log_validation_summary(aggregate_metrics)
120
+ return aggregate_metrics
121
+
122
+ def _calculate_aggregate_metrics(self, metrics_history: List[Dict]) -> Dict[str, float]:
123
+ """Calculate comprehensive aggregate metrics."""
124
+ metrics = {
125
+ 'num_queries_tested': len(metrics_history),
126
+ 'avg_top_response_score': np.mean([m.get('top_score', 0) for m in metrics_history]),
127
+ 'avg_diversity': np.mean([m.get('response_diversity', 0) for m in metrics_history]),
128
+ 'avg_relevance': np.mean([m.get('query_response_relevance', 0) for m in metrics_history]),
129
+ 'avg_length_score': np.mean([m.get('response_length_score', 0) for m in metrics_history]),
130
+ 'avg_score_gap': np.mean([m.get('top_3_score_gap', 0) for m in metrics_history]),
131
+ 'confidence_rate': np.mean([m.get('is_confident', False) for m in metrics_history]),
132
+
133
+ # Additional statistical metrics
134
+ 'median_top_score': np.median([m.get('top_score', 0) for m in metrics_history]),
135
+ 'score_std': np.std([m.get('top_score', 0) for m in metrics_history]),
136
+ 'min_score': np.min([m.get('top_score', 0) for m in metrics_history]),
137
+ 'max_score': np.max([m.get('top_score', 0) for m in metrics_history])
138
+ }
139
+ return metrics
140
+
141
+ def _analyze_domain_performance(self, domain_metrics: Dict[str, List[Dict]]) -> Dict[str, Dict]:
142
+ """Analyze performance by domain."""
143
+ domain_analysis = {}
144
+
145
+ for domain, metrics in domain_metrics.items():
146
+ domain_analysis[domain] = {
147
+ 'confidence_rate': np.mean([m.get('is_confident', False) for m in metrics]),
148
+ 'avg_relevance': np.mean([m.get('query_response_relevance', 0) for m in metrics]),
149
+ 'avg_diversity': np.mean([m.get('response_diversity', 0) for m in metrics]),
150
+ 'avg_top_score': np.mean([m.get('top_score', 0) for m in metrics]),
151
+ 'num_samples': len(metrics)
152
+ }
153
+
154
+ return domain_analysis
155
+
156
+ def _analyze_confidence_distribution(self, metrics_history: List[Dict]) -> Dict[str, float]:
157
+ """Analyze the distribution of confidence scores."""
158
+ scores = [m.get('top_score', 0) for m in metrics_history]
159
+
160
+ return {
161
+ 'percentile_25': np.percentile(scores, 25),
162
+ 'percentile_50': np.percentile(scores, 50),
163
+ 'percentile_75': np.percentile(scores, 75),
164
+ 'percentile_90': np.percentile(scores, 90)
165
+ }
166
+
167
+ def _log_validation_results(
168
+ self,
169
+ query: str,
170
+ responses: List[Tuple[str, float]],
171
+ metrics: Dict[str, Any],
172
+ case_num: int
173
+ ):
174
+ """Log detailed validation results."""
175
+ logger.info(f"\nTest Case {case_num}:")
176
+ logger.info(f"Query: {query}")
177
+ logger.info(f"Domain: {metrics.get('domain', 'Unknown')}")
178
+ logger.info(f"Confidence: {'Yes' if metrics.get('is_confident', False) else 'No'}")
179
+ logger.info("\nQuality Metrics:")
180
+ for metric, value in metrics.items():
181
+ if isinstance(value, (int, float)):
182
+ logger.info(f" {metric}: {value:.4f}")
183
+
184
+ logger.info("\nTop Responses:")
185
+ for i, (response, score) in enumerate(responses[:3], 1):
186
+ logger.info(f"{i}. Score: {score:.4f}. Response: {response}")
187
+ if i == 1 and not metrics.get('is_confident', False):
188
+ logger.info(" [Low Confidence]")
189
+
190
+ def _log_validation_summary(self, metrics: Dict[str, Any]):
191
+ """Log comprehensive validation summary."""
192
+ logger.info("\n=== Validation Summary ===")
193
+
194
+ logger.info("\nOverall Metrics:")
195
+ for metric, value in metrics.items():
196
+ if isinstance(value, (int, float)):
197
+ logger.info(f"{metric}: {value:.4f}")
198
+
199
+ logger.info("\nDomain Performance:")
200
+ for domain, domain_metrics in metrics['domain_performance'].items():
201
+ logger.info(f"\n{domain.title()}:")
202
+ for metric, value in domain_metrics.items():
203
+ logger.info(f" {metric}: {value:.4f}")
204
+
205
+ logger.info("\nConfidence Distribution:")
206
+ for percentile, value in metrics['confidence_analysis'].items():
207
+ logger.info(f"{percentile}: {value:.4f}")
conversation_summarizer.py ADDED
@@ -0,0 +1,147 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import tensorflow as tf
2
+ from typing import List, Dict
3
+ from transformers import TFAutoModelForSeq2SeqLM, AutoTokenizer
4
+ import logging
5
+ from dataclasses import dataclass
6
+
7
+ logger = logging.getLogger(__name__)
8
+
9
+ @dataclass
10
+ class ChatConfig:
11
+ max_sequence_length: int = 512
12
+ default_top_k: int = 5
13
+ chunk_size: int = 512
14
+ chunk_overlap: int = 256
15
+ min_confidence_score: float = 0.7
16
+
17
+ class DeviceAwareModel:
18
+ """Mixin to handle device placement and mixed precision training."""
19
+
20
+ def setup_device(self, device: str = None):
21
+ if device is None:
22
+ device = 'GPU' if tf.config.list_physical_devices('GPU') else 'CPU'
23
+
24
+ self.device = device.upper()
25
+ self.strategy = None
26
+
27
+ if self.device == 'GPU':
28
+ # Enable mixed precision for better performance
29
+ policy = tf.keras.mixed_precision.Policy('mixed_float16')
30
+ tf.keras.mixed_precision.set_global_policy(policy)
31
+
32
+ # Setup distribution strategy for multi-GPU if available
33
+ gpus = tf.config.list_physical_devices('GPU')
34
+ if len(gpus) > 1:
35
+ self.strategy = tf.distribute.MirroredStrategy()
36
+
37
+ return self.device
38
+
39
+ def run_on_device(self, func):
40
+ """Decorator to ensure ops run on the correct device."""
41
+ def wrapper(*args, **kwargs):
42
+ with tf.device(f'/{self.device}:0'):
43
+ return func(*args, **kwargs)
44
+ return wrapper
45
+
46
+ class Summarizer(DeviceAwareModel):
47
+ """
48
+ Enhanced T5-based summarizer with better chunking and device management.
49
+ Handles long conversations by intelligent chunking and progressive summarization.
50
+ """
51
+
52
+ def __init__(self, model_name="t5-small", max_summary_length=128, device=None, max_summary_rounds=2):
53
+ self.setup_device(device)
54
+
55
+ # Initialize model within strategy scope if using distribution
56
+ if self.strategy:
57
+ with self.strategy.scope():
58
+ self._setup_model(model_name)
59
+ else:
60
+ self._setup_model(model_name)
61
+
62
+ self.max_summary_length = max_summary_length
63
+ self.max_summary_rounds = max_summary_rounds
64
+
65
+ def _setup_model(self, model_name):
66
+ self.tokenizer = AutoTokenizer.from_pretrained(model_name)
67
+ self.model = TFAutoModelForSeq2SeqLM.from_pretrained(model_name)
68
+
69
+ # Optimize model for inference
70
+ self.model.predict = tf.function(
71
+ self.model.predict,
72
+ input_signature=[
73
+ {
74
+ 'input_ids': tf.TensorSpec(shape=[None, None], dtype=tf.int32),
75
+ 'attention_mask': tf.TensorSpec(shape=[None, None], dtype=tf.int32)
76
+ }
77
+ ]
78
+ )
79
+
80
+ @tf.function
81
+ def _generate_summary(self, inputs):
82
+ return self.model.generate(
83
+ inputs,
84
+ max_length=self.max_summary_length,
85
+ num_beams=4,
86
+ length_penalty=2.0,
87
+ early_stopping=True,
88
+ no_repeat_ngram_size=3
89
+ )
90
+
91
+ def chunk_text(self, text: str, chunk_size: int = 512, overlap: int = 256) -> List[str]:
92
+ """Split text into overlapping chunks for better context preservation."""
93
+ tokens = self.tokenizer.encode(text)
94
+ chunks = []
95
+
96
+ for i in range(0, len(tokens), chunk_size - overlap):
97
+ chunk = tokens[i:i + chunk_size]
98
+ chunks.append(self.tokenizer.decode(chunk, skip_special_tokens=True))
99
+
100
+ return chunks
101
+
102
+ def summarize_text(
103
+ self,
104
+ text: str,
105
+ progressive: bool = True,
106
+ round_idx: int = 0
107
+ ) -> str:
108
+ """
109
+ Summarize text with optional progressive summarization
110
+ and limit the maximum number of re-summarization rounds.
111
+ """
112
+ @self.run_on_device
113
+ def _summarize_chunk(chunk: str) -> str:
114
+ input_text = "summarize: " + chunk
115
+ inputs = self.tokenizer(
116
+ input_text,
117
+ return_tensors="tf",
118
+ padding=True,
119
+ truncation=True
120
+ )
121
+ summary_ids = self._generate_summary(inputs)
122
+ return self.tokenizer.decode(summary_ids[0], skip_special_tokens=True)
123
+
124
+ # If we've hit our max allowed summarization rounds, just do a single pass
125
+ if round_idx >= self.max_summary_rounds:
126
+ return _summarize_chunk(text)
127
+
128
+ # If text is longer than threshold and progressive summarization is on
129
+ if len(text.split()) > 512 and progressive:
130
+ chunks = self.chunk_text(text)
131
+ chunk_summaries = [_summarize_chunk(chunk) for chunk in chunks]
132
+
133
+ # Combine chunk-level summaries
134
+ combined_summary = " ".join(chunk_summaries)
135
+
136
+ # If still too long, do another summarization pass but increment round_idx
137
+ if len(combined_summary.split()) > 512:
138
+ return self.summarize_text(
139
+ combined_summary,
140
+ progressive=True,
141
+ round_idx=round_idx + 1
142
+ )
143
+
144
+ return combined_summary
145
+ else:
146
+ # If text is not too long, just summarize once and return
147
+ return _summarize_chunk(text)
cross_encoder_reranker.py ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import AutoTokenizer, TFAutoModelForSequenceClassification
2
+ import tensorflow as tf
3
+ from typing import List, Tuple
4
+
5
+ from logger_config import config_logger
6
+ logger = config_logger(__name__)
7
+
8
+ class CrossEncoderReranker:
9
+ """
10
+ Cross-Encoder Re-Ranker: Takes (query, candidate) pairs,
11
+ outputs a single relevance score (one logit).
12
+ """
13
+ def __init__(self, model_name: str = "cross-encoder/ms-marco-MiniLM-L-12-v2"):
14
+ self.tokenizer = AutoTokenizer.from_pretrained(model_name)
15
+ self.model = TFAutoModelForSequenceClassification.from_pretrained(model_name)
16
+ # Model outputs shape [batch_size, 1] -> Interpret the logit as relevance score.
17
+
18
+ def rerank(
19
+ self,
20
+ query: str,
21
+ candidates: List[str],
22
+ max_length: int = 256
23
+ ) -> List[float]:
24
+ """
25
+ Returns a list of re_scores, one for each candidate, indicating
26
+ how relevant the candidate is to the query.
27
+ """
28
+ # Build (query, candidate) pairs
29
+ pair_texts = [(query, candidate) for candidate in candidates]
30
+
31
+ # Tokenize the entire batch
32
+ encodings = self.tokenizer(
33
+ pair_texts,
34
+ padding=True,
35
+ truncation=True,
36
+ max_length=max_length,
37
+ return_tensors="tf"
38
+ )
39
+
40
+ # Forward pass -> logits shape [batch_size, 1]
41
+ outputs = self.model(
42
+ input_ids=encodings["input_ids"],
43
+ attention_mask=encodings["attention_mask"],
44
+ token_type_ids=encodings.get("token_type_ids")
45
+ )
46
+
47
+ logits = outputs.logits
48
+ # Flatten to shape [batch_size]
49
+ scores = tf.reshape(logits, [-1]).numpy()
50
+
51
+ return scores.tolist()
dialogue_augmenter.py CHANGED
@@ -7,7 +7,6 @@ from pipeline_config import PipelineConfig
7
  from quality_metrics import QualityMetrics
8
  from paraphraser import Paraphraser
9
  import nlpaug.augmenter.word as naw
10
- from concurrent.futures import ThreadPoolExecutor
11
  from functools import lru_cache
12
  from sklearn.metrics.pairwise import cosine_similarity
13
 
 
7
  from quality_metrics import QualityMetrics
8
  from paraphraser import Paraphraser
9
  import nlpaug.augmenter.word as naw
 
10
  from functools import lru_cache
11
  from sklearn.metrics.pairwise import cosine_similarity
12
 
environment_setup.py ADDED
@@ -0,0 +1,207 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict, Optional, Tuple
2
+ from pathlib import Path
3
+ import tensorflow as tf
4
+ import os
5
+ import subprocess
6
+ from datetime import datetime
7
+ from logger_config import config_logger
8
+
9
+ logger = config_logger(__name__)
10
+
11
+ class EnvironmentSetup:
12
+ def __init__(self):
13
+ self.device_type, self.strategy = self.setup_devices()
14
+ self.cache_dir = None
15
+
16
+ def initialize(self, cache_dir: Optional[Path] = None):
17
+ self.cache_dir = self.setup_model_cache(cache_dir)
18
+ self.training_dirs = self.setup_training_directories()
19
+
20
+ @staticmethod
21
+ def setup_model_cache(cache_dir: Optional[Path] = None) -> Path:
22
+ """Setup and manage model cache directory."""
23
+ if cache_dir is None:
24
+ cache_dir = Path.home() / '.chatbot_cache'
25
+
26
+ cache_dir.mkdir(parents=True, exist_ok=True)
27
+
28
+ # Set environment variables for various libraries
29
+ os.environ['TRANSFORMERS_CACHE'] = str(cache_dir / 'transformers')
30
+ os.environ['TORCH_HOME'] = str(cache_dir / 'torch')
31
+ os.environ['HF_HOME'] = str(cache_dir / 'huggingface')
32
+
33
+ logger.info(f"Using cache directory: {cache_dir}")
34
+ return cache_dir
35
+
36
+ @staticmethod
37
+ def setup_training_directories(base_dir: str = "chatbot_training") -> Dict[str, Path]:
38
+ """Setup directory structure for training artifacts."""
39
+ base_dir = Path(base_dir)
40
+ timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
41
+ train_dir = base_dir / f"training_run_{timestamp}"
42
+
43
+ directories = {
44
+ 'base': train_dir,
45
+ 'checkpoints': train_dir / 'checkpoints',
46
+ 'plots': train_dir / 'plots',
47
+ 'logs': train_dir / 'logs'
48
+ }
49
+
50
+ # Create directories
51
+ for dir_path in directories.values():
52
+ dir_path.mkdir(parents=True, exist_ok=True)
53
+
54
+ return directories
55
+
56
+ @staticmethod
57
+ def is_colab() -> bool:
58
+ """Check if code is running in Google Colab."""
59
+ try:
60
+ # Handle both import and attribute checks
61
+ import google.colab # type: ignore
62
+ import IPython # type: ignore
63
+ return True
64
+ except (ImportError, AttributeError):
65
+ return False
66
+
67
+ def setup_colab_tpu(self) -> Optional[tf.distribute.Strategy]:
68
+ """Setup TPU in Colab environment if available."""
69
+ if not self.is_colab():
70
+ return None
71
+
72
+ try:
73
+ import requests
74
+ import os
75
+
76
+ # Check TPU availability
77
+ if 'COLAB_TPU_ADDR' not in os.environ:
78
+ return None
79
+
80
+ # TPU address should be set
81
+ tpu_address = 'grpc://' + os.environ['COLAB_TPU_ADDR']
82
+ resolver = tf.distribute.cluster_resolver.TPUClusterResolver(tpu=tpu_address)
83
+ tf.config.experimental_connect_to_cluster(resolver)
84
+ tf.tpu.experimental.initialize_tpu_system(resolver)
85
+ strategy = tf.distribute.TPUStrategy(resolver)
86
+
87
+ return strategy
88
+ except Exception as e:
89
+ logger.warning(f"Failed to initialize Colab TPU: {e}")
90
+ return None
91
+
92
+ def setup_devices(self) -> Tuple[str, tf.distribute.Strategy]:
93
+ """Configure available compute devices with Colab-specific optimizations."""
94
+ logger.info("Checking available compute devices...")
95
+
96
+ # Colab-specific setup
97
+ if self.is_colab():
98
+ logger.info("Running in Google Colab environment")
99
+
100
+ # Try TPU first in Colab
101
+ tpu_strategy = self.setup_colab_tpu()
102
+ if tpu_strategy is not None:
103
+ logger.info("Colab TPU detected and initialized")
104
+ return "TPU", tpu_strategy
105
+
106
+ # Colab GPU setup
107
+ gpus = tf.config.list_physical_devices('GPU')
108
+ if gpus:
109
+ try:
110
+ # Colab-specific GPU memory management
111
+ for gpu in gpus:
112
+ tf.config.experimental.set_memory_growth(gpu, True)
113
+
114
+ # Get GPU info using subprocess
115
+ try:
116
+ gpu_name = subprocess.check_output(
117
+ ['nvidia-smi', '--query-gpu=gpu_name', '--format=csv,noheader'],
118
+ stderr=subprocess.DEVNULL
119
+ ).decode('utf-8').strip()
120
+ logger.info(f"Colab GPU detected: {gpu_name}")
121
+
122
+ except (subprocess.SubprocessError, FileNotFoundError):
123
+ logger.warning("Could not detect specific GPU model")
124
+
125
+ # Enable XLA
126
+ tf.config.optimizer.set_jit(True)
127
+ logger.info("XLA compilation enabled for Colab GPU")
128
+
129
+ # Set mixed precision policy
130
+ policy = tf.keras.mixed_precision.Policy('mixed_float16')
131
+ tf.keras.mixed_precision.set_global_policy(policy)
132
+ logger.info("Mixed precision training enabled (float16)")
133
+
134
+ strategy = tf.distribute.OneDeviceStrategy("/GPU:0")
135
+ return "GPU", strategy
136
+
137
+ except Exception as e:
138
+ logger.error(f"Error configuring Colab GPU: {str(e)}")
139
+
140
+ # Non-Colab setup (same as before)
141
+ else:
142
+ # Check for TPU
143
+ try:
144
+ resolver = tf.distribute.cluster_resolver.TPUClusterResolver()
145
+ tf.config.experimental_connect_to_cluster(resolver)
146
+ tf.tpu.experimental.initialize_tpu_system(resolver)
147
+ strategy = tf.distribute.TPUStrategy(resolver)
148
+ logger.info("TPU detected and initialized")
149
+ return "TPU", strategy
150
+ except ValueError:
151
+ logger.info("No TPU detected. Checking for GPUs...")
152
+
153
+ # Check for GPUs
154
+ gpus = tf.config.list_physical_devices('GPU')
155
+ if gpus:
156
+ try:
157
+ for gpu in gpus:
158
+ tf.config.experimental.set_memory_growth(gpu, True)
159
+
160
+ if len(gpus) > 1:
161
+ strategy = tf.distribute.MirroredStrategy()
162
+ logger.info(f"Multi-GPU strategy set up with {len(gpus)} GPUs")
163
+ else:
164
+ strategy = tf.distribute.OneDeviceStrategy("/GPU:0")
165
+ logger.info("Single GPU strategy set up")
166
+
167
+ return "GPU", strategy
168
+
169
+ except Exception as e:
170
+ logger.error(f"Error configuring GPU: {str(e)}")
171
+
172
+ # CPU fallback
173
+ strategy = tf.distribute.OneDeviceStrategy("/CPU:0")
174
+ logger.info("Using CPU strategy")
175
+ return "CPU", strategy
176
+
177
+ def optimize_batch_size(self, base_batch_size: int = 16) -> int:
178
+ """Apply Colab-specific optimizations for training."""
179
+ if not self.is_colab():
180
+ return base_batch_size
181
+
182
+ # Colab-specific batch size optimization
183
+ if self.device_type == "GPU":
184
+ try:
185
+ gpu_name = subprocess.check_output(
186
+ ['nvidia-smi', '--query-gpu=gpu_name', '--format=csv,noheader'],
187
+ stderr=subprocess.DEVNULL
188
+ ).decode('utf-8').strip()
189
+
190
+ if "T4" in gpu_name:
191
+ # T4 optimizations
192
+ logger.info("Optimizing for Colab T4 GPU")
193
+ base_batch_size = min(base_batch_size * 2, 32) # T4 can handle larger batches
194
+ elif "V100" in gpu_name:
195
+ # V100 optimizations
196
+ logger.info("Optimizing for Colab V100 GPU")
197
+ base_batch_size = min(base_batch_size * 3, 48) # V100 can handle even larger batches
198
+ except (subprocess.SubprocessError, FileNotFoundError):
199
+ logger.warning("Could not detect specific GPU model, using default settings")
200
+
201
+ elif self.device_type == "TPU":
202
+ # TPU optimizations
203
+ base_batch_size = min(base_batch_size * 4, 64) # TPUs can handle very large batches
204
+ logger.info("Optimizing for Colab TPU")
205
+
206
+ logger.info(f"Optimized batch size for Colab: {base_batch_size}")
207
+ return base_batch_size
logger_config.py ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+
3
+ def config_logger(name):
4
+ logging.basicConfig(
5
+ level=logging.DEBUG,
6
+ format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
7
+ )
8
+ logger = logging.getLogger(name)
9
+ logger.setLevel(logging.DEBUG)
10
+ return logger
requirements.txt CHANGED
@@ -1,6 +1,12 @@
 
 
 
 
1
  nlpaug>=1.1.0 # Data augmentation for NLP
2
  nltk>=3.6.0 # Natural language toolkit
3
  numpy>=1.19.0 # General numerical computation
 
 
4
  scikit-learn>=1.0.0 # Machine learning tools
5
  sacremoses>=0.0.53 # Required for some HuggingFace models
6
  sentencepiece>=0.1.99 # Required for HuggingFace transformers
@@ -11,4 +17,10 @@ tokenizers>=0.13.0 # Required for HuggingFace transformers
11
  torch>=2.0.0 # PyTorch, for deep learning
12
  tqdm>=4.64.0 # Progress bar
13
  transformers>=4.30.0 # Hugging Face Transformers library
14
- faiss-cpu>=1.7.0 # Required for Facebook AI Similarity Search
 
 
 
 
 
 
 
1
+ faiss-cpu>=1.7.0 # Required for Facebook AI Similarity Search
2
+ ipython>=8.0.0 # For interactive Python
3
+ loguru>=0.7.0 # Enhanced logging (optional but recommended)
4
+ matplotlib>=3.5.0 # For validation plotting
5
  nlpaug>=1.1.0 # Data augmentation for NLP
6
  nltk>=3.6.0 # Natural language toolkit
7
  numpy>=1.19.0 # General numerical computation
8
+ pandas>=1.5.0 # For data handling
9
+ pyyaml>=6.0.0 # For config management
10
  scikit-learn>=1.0.0 # Machine learning tools
11
  sacremoses>=0.0.53 # Required for some HuggingFace models
12
  sentencepiece>=0.1.99 # Required for HuggingFace transformers
 
17
  torch>=2.0.0 # PyTorch, for deep learning
18
  tqdm>=4.64.0 # Progress bar
19
  transformers>=4.30.0 # Hugging Face Transformers library
20
+ typing-extensions>=4.0.0 # For better type hints
21
+
22
+ # Dev dependencies
23
+ black>=22.0.0 # For code formatting
24
+ isort>=5.10.0 # For import sorting
25
+ mypy>=1.0.0 # For type checking
26
+ pytest>=7.0.0 # For testing
response_quality_checker.py CHANGED
@@ -1,164 +1,170 @@
1
  import numpy as np
2
- from typing import List, Tuple, Dict, Any
3
  from sklearn.metrics.pairwise import cosine_similarity
4
- from chatbot4 import RetrievalChatbot
 
 
 
 
 
5
 
6
  class ResponseQualityChecker:
7
- """Handles quality checking and confidence scoring for chatbot responses."""
8
 
9
  def __init__(
10
  self,
11
- chatbot: RetrievalChatbot,
12
- confidence_threshold: float = 0.5,
13
- diversity_threshold: float = 0.1,
14
- min_response_length: int = 3,
15
- max_similarity_ratio: float = 0.9
16
  ):
17
  self.confidence_threshold = confidence_threshold
18
  self.diversity_threshold = diversity_threshold
19
  self.min_response_length = min_response_length
20
- self.max_similarity_ratio = max_similarity_ratio
21
  self.chatbot = chatbot
22
-
 
 
 
 
 
 
 
23
  def check_response_quality(
24
  self,
25
  query: str,
26
  responses: List[Tuple[str, float]]
27
  ) -> Dict[str, Any]:
28
  """
29
- Evaluate the quality of the responses based on various metrics.
30
- """
31
- # Calculate diversity based on the responses themselves
32
- diversity = self.calculate_diversity(responses)
33
 
34
- # Calculate relevance based on some criteria
35
- relevance = self.calculate_relevance(query, responses)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
36
 
37
- # Calculate length scores for each response
38
- length_scores = [self._calculate_length_score(response) for response, _ in responses]
39
- avg_length_score = np.mean(length_scores) if length_scores else 0.0
40
 
41
- # Extract similarity scores
42
- similarity_scores = [score for _, score in responses]
 
 
 
 
 
43
 
44
- # Calculate score gap
45
- score_gap = self._calculate_score_gap(similarity_scores, top_n=3)
 
46
 
47
- # Aggregate metrics
48
- metrics = {
49
- 'top_score': similarity_scores[0] if similarity_scores else 0.0,
50
- 'response_diversity': diversity,
51
- 'query_response_relevance': relevance,
52
- 'response_length_score': avg_length_score,
53
- 'top_3_score_gap': score_gap
54
- }
55
 
56
- # Determine overall confidence
57
- is_confident = self._determine_confidence(metrics)
58
-
59
- return {
60
- 'diversity': diversity,
61
- 'relevance': relevance,
62
- 'is_confident': is_confident,
63
- 'top_score': metrics['top_score'],
64
- 'response_diversity': metrics['response_diversity'],
65
- 'query_response_relevance': metrics['query_response_relevance'],
66
- 'response_length_score': metrics['response_length_score'],
67
- 'top_3_score_gap': metrics['top_3_score_gap']
68
- }
69
-
70
  def calculate_diversity(self, responses: List[Tuple[str, float]]) -> float:
71
- """
72
- Calculate diversity as the average pairwise similarity between responses.
73
- Lower similarity indicates higher diversity.
74
- """
75
  if not responses:
76
  return 0.0
77
 
78
- # Encode responses
79
  embeddings = [self.encode_text(response) for response, _ in responses]
80
  if len(embeddings) < 2:
81
- return 1.0 # Maximum diversity
82
 
83
- # Compute pairwise cosine similarity
84
  similarity_matrix = cosine_similarity(embeddings)
 
85
 
86
- # Exclude diagonal
87
- sum_similarities = np.sum(similarity_matrix) - len(responses)
 
 
 
 
 
 
 
 
88
  num_pairs = len(responses) * (len(responses) - 1)
89
  avg_similarity = sum_similarities / num_pairs if num_pairs > 0 else 0.0
90
- diversity_score = 1 - avg_similarity # Higher value indicates more diversity
91
- return diversity_score
92
-
93
- def calculate_relevance(self, query: str, responses: List[Tuple[str, float]]) -> float:
94
- """
95
- Calculate relevance as the average similarity between the query and each response.
96
- """
97
- if not responses:
98
- return 0.0
99
 
100
- # Encode query
101
- query_embedding = self.encode_query(query)
102
-
103
- # Encode responses
104
- response_embeddings = [self.encode_text(response) for response, _ in responses]
 
 
 
 
 
105
 
106
- # Compute cosine similarity
107
- similarities = cosine_similarity([query_embedding], response_embeddings)[0]
 
 
 
 
108
 
109
- avg_relevance = np.mean(similarities) if similarities.size > 0 else 0.0
110
- return avg_relevance
111
-
112
  def _calculate_length_score(self, response: str) -> float:
113
- """Score based on response length appropriateness."""
114
- length = len(response.split())
115
- if length < self.min_response_length:
116
- return length / self.min_response_length
 
 
 
117
  return 1.0
118
-
119
  def _calculate_score_gap(self, scores: List[float], top_n: int = 3) -> float:
120
- """
121
- Calculate the average gap between the top N scores.
122
-
123
- Args:
124
- scores (List[float]): List of similarity scores.
125
- top_n (int): Number of top scores to consider.
126
-
127
- Returns:
128
- float: Average score gap.
129
- """
130
  if len(scores) < top_n + 1:
131
  return 0.0
132
- gaps = [scores[i] - scores[i + 1] for i in range(top_n)]
133
- avg_gap = np.mean(gaps)
134
- return avg_gap
135
-
136
- def _determine_confidence(self, metrics: Dict[str, float]) -> bool:
137
- """
138
- Determine if we're confident enough in the response.
139
-
140
- Returns:
141
- bool: True if we should use this response, False if we should abstain
142
- """
143
- conditions = [
144
- metrics['top_score'] >= self.confidence_threshold,
145
- metrics['response_diversity'] >= self.diversity_threshold,
146
- metrics['response_length_score'] >= 0.8,
147
- metrics['query_response_relevance'] >= 0.3, # was 0.5
148
- metrics['top_3_score_gap'] >= 0.05 # was 0.1
149
- ]
150
- return all(conditions)
151
-
152
  def encode_text(self, text: str) -> np.ndarray:
153
- # 1) Turn text into a list if your encode_responses() expects a list.
154
- # 2) Then call the method from the chatbot to get the embedding.
155
- embedding_tensor = self.chatbot.encode_responses([text]) # returns tf.Tensor of shape (1, emb_dim)
156
- embedding = embedding_tensor.numpy()[0].astype('float32') # shape: (emb_dim,)
157
- embedding = embedding / np.linalg.norm(embedding) if np.linalg.norm(embedding) > 0 else embedding
158
- return embedding
159
-
160
  def encode_query(self, query: str) -> np.ndarray:
161
- embedding_tensor = self.chatbot.encode_query(query) # returns tf.Tensor of shape (1, emb_dim)
162
- embedding = embedding_tensor.numpy()[0].astype('float32') # shape: (emb_dim,)
163
- embedding = embedding / np.linalg.norm(embedding) if np.linalg.norm(embedding) > 0 else embedding
164
- return embedding
 
 
 
 
 
 
1
  import numpy as np
2
+ from typing import List, Tuple, Dict, Any, TYPE_CHECKING
3
  from sklearn.metrics.pairwise import cosine_similarity
4
+
5
+ from logger_config import config_logger
6
+ logger = config_logger(__name__)
7
+
8
+ if TYPE_CHECKING:
9
+ from chatbot_model import RetrievalChatbot
10
 
11
  class ResponseQualityChecker:
12
+ """Enhanced quality checking with dynamic thresholds."""
13
 
14
  def __init__(
15
  self,
16
+ chatbot: 'RetrievalChatbot',
17
+ confidence_threshold: float = 0.6,
18
+ diversity_threshold: float = 0.15,
19
+ min_response_length: int = 5,
20
+ similarity_cap: float = 0.85 # Renamed from max_similarity_ratio and used in diversity calc
21
  ):
22
  self.confidence_threshold = confidence_threshold
23
  self.diversity_threshold = diversity_threshold
24
  self.min_response_length = min_response_length
25
+ self.similarity_cap = similarity_cap
26
  self.chatbot = chatbot
27
+
28
+ # Dynamic thresholds based on response patterns
29
+ self.thresholds = {
30
+ 'relevance': 0.35,
31
+ 'length_score': 0.85,
32
+ 'score_gap': 0.07
33
+ }
34
+
35
  def check_response_quality(
36
  self,
37
  query: str,
38
  responses: List[Tuple[str, float]]
39
  ) -> Dict[str, Any]:
40
  """
41
+ Evaluate the quality of responses based on various metrics.
 
 
 
42
 
43
+ Args:
44
+ query: The user's query
45
+ responses: List of (response_text, score) tuples
46
+
47
+ Returns:
48
+ Dict containing quality metrics and confidence assessment
49
+ """
50
+ if not responses:
51
+ return {
52
+ 'response_diversity': 0.0,
53
+ 'query_response_relevance': 0.0,
54
+ 'is_confident': False,
55
+ 'top_score': 0.0,
56
+ 'response_length_score': 0.0,
57
+ 'top_3_score_gap': 0.0
58
+ }
59
+
60
+ # Calculate core metrics
61
+ metrics = {
62
+ 'response_diversity': self.calculate_diversity(responses),
63
+ 'query_response_relevance': self.calculate_relevance(query, responses),
64
+ 'response_length_score': np.mean([
65
+ self._calculate_length_score(response) for response, _ in responses
66
+ ]),
67
+ 'top_score': responses[0][1],
68
+ 'top_3_score_gap': self._calculate_score_gap([score for _, score in responses], top_n=3)
69
+ }
70
 
71
+ # Determine confidence using thresholds
72
+ metrics['is_confident'] = self._determine_confidence(metrics)
 
73
 
74
+ logger.info(f"Quality metrics: {metrics}")
75
+ return metrics
76
+
77
+ def calculate_relevance(self, query: str, responses: List[Tuple[str, float]]) -> float:
78
+ """Calculate relevance as weighted similarity between query and responses."""
79
+ if not responses:
80
+ return 0.0
81
 
82
+ # Get embeddings
83
+ query_embedding = self.encode_query(query)
84
+ response_embeddings = [self.encode_text(response) for response, _ in responses]
85
 
86
+ # Compute similarities with decreasing weights for later responses
87
+ similarities = cosine_similarity([query_embedding], response_embeddings)[0]
88
+ weights = np.array([1.0 / (i + 1) for i in range(len(similarities))])
 
 
 
 
 
89
 
90
+ return np.average(similarities, weights=weights)
91
+
 
 
 
 
 
 
 
 
 
 
 
 
92
  def calculate_diversity(self, responses: List[Tuple[str, float]]) -> float:
93
+ """Calculate diversity with length normalization and similarity capping."""
 
 
 
94
  if not responses:
95
  return 0.0
96
 
 
97
  embeddings = [self.encode_text(response) for response, _ in responses]
98
  if len(embeddings) < 2:
99
+ return 1.0
100
 
101
+ # Calculate similarities and apply cap
102
  similarity_matrix = cosine_similarity(embeddings)
103
+ similarity_matrix = np.minimum(similarity_matrix, self.similarity_cap)
104
 
105
+ # Apply length normalization
106
+ lengths = [len(resp[0].split()) for resp in responses]
107
+ length_ratios = np.array([min(a, b) / max(a, b) for a in lengths for b in lengths])
108
+ length_ratios = length_ratios.reshape(len(responses), len(responses))
109
+
110
+ # Combine factors with weights
111
+ adjusted_similarity = (similarity_matrix * 0.7 + length_ratios * 0.3)
112
+
113
+ # Calculate final score
114
+ sum_similarities = np.sum(adjusted_similarity) - len(responses)
115
  num_pairs = len(responses) * (len(responses) - 1)
116
  avg_similarity = sum_similarities / num_pairs if num_pairs > 0 else 0.0
 
 
 
 
 
 
 
 
 
117
 
118
+ return 1 - avg_similarity
119
+
120
+ def _determine_confidence(self, metrics: Dict[str, float]) -> bool:
121
+ """Determine confidence using primary and secondary conditions."""
122
+ # Primary conditions (must all be met)
123
+ primary_conditions = [
124
+ metrics['top_score'] >= self.confidence_threshold,
125
+ metrics['response_diversity'] >= self.diversity_threshold,
126
+ metrics['response_length_score'] >= self.thresholds['length_score']
127
+ ]
128
 
129
+ # Secondary conditions (majority must be met)
130
+ secondary_conditions = [
131
+ metrics['query_response_relevance'] >= self.thresholds['relevance'],
132
+ metrics['top_3_score_gap'] >= self.thresholds['score_gap'],
133
+ metrics['top_score'] >= (self.confidence_threshold * 1.1) # Extra confidence boost
134
+ ]
135
 
136
+ return all(primary_conditions) and sum(secondary_conditions) >= 2
137
+
 
138
  def _calculate_length_score(self, response: str) -> float:
139
+ """Calculate length score with penalty for very short or long responses."""
140
+ words = len(response.split())
141
+
142
+ if words < self.min_response_length:
143
+ return words / self.min_response_length
144
+ elif words > 50: # Penalty for very long responses
145
+ return min(1.0, 50 / words)
146
  return 1.0
147
+
148
  def _calculate_score_gap(self, scores: List[float], top_n: int = 3) -> float:
149
+ """Calculate average gap between top N scores."""
 
 
 
 
 
 
 
 
 
150
  if len(scores) < top_n + 1:
151
  return 0.0
152
+ gaps = [scores[i] - scores[i + 1] for i in range(min(len(scores) - 1, top_n))]
153
+ return np.mean(gaps)
154
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
155
  def encode_text(self, text: str) -> np.ndarray:
156
+ """Encode response text to embedding."""
157
+ embedding_tensor = self.chatbot.encode_responses([text])
158
+ embedding = embedding_tensor.numpy()[0].astype('float32')
159
+ return self._normalize_embedding(embedding)
160
+
 
 
161
  def encode_query(self, query: str) -> np.ndarray:
162
+ """Encode query text to embedding."""
163
+ embedding_tensor = self.chatbot.encode_query(query)
164
+ embedding = embedding_tensor.numpy()[0].astype('float32')
165
+ return self._normalize_embedding(embedding)
166
+
167
+ def _normalize_embedding(self, embedding: np.ndarray) -> np.ndarray:
168
+ """Normalize embedding vector."""
169
+ norm = np.linalg.norm(embedding)
170
+ return embedding / norm if norm > 0 else embedding
run_model.py DELETED
@@ -1,162 +0,0 @@
1
- import json
2
- import glob
3
- import os
4
- from chatbot import RetrievalChatbot
5
- import tensorflow as tf
6
- from sklearn.model_selection import train_test_split
7
- import matplotlib.pyplot as plt
8
-
9
- def load_training_data(data_directory: str) -> list:
10
- """Load and combine dialogue data from multiple JSON files."""
11
- all_dialogues = []
12
-
13
- # Get all json files matching the pattern
14
- pattern = os.path.join(data_directory, "batch_*.json")
15
- json_files = sorted(glob.glob(pattern))
16
-
17
- print(f"Found {len(json_files)} batch files")
18
-
19
- for file_path in json_files:
20
- try:
21
- with open(file_path, 'r', encoding='utf-8') as f:
22
- batch_dialogues = json.load(f)
23
- all_dialogues.extend(batch_dialogues)
24
- print(f"Loaded {len(batch_dialogues)} dialogues from {os.path.basename(file_path)}")
25
- except Exception as e:
26
- print(f"Error loading {file_path}: {str(e)}")
27
-
28
- print(f"Total dialogues loaded: {len(all_dialogues)}")
29
- return all_dialogues
30
-
31
- def plot_training_history(train_losses, val_losses):
32
- # Plot training and validation loss
33
- plt.figure()
34
- plt.plot(train_losses, label='Train Loss')
35
- plt.plot(val_losses, label='Val Loss')
36
- plt.xlabel('Epoch')
37
- plt.ylabel('Triplet Loss')
38
- plt.legend()
39
- plt.show()
40
-
41
- dialogues = load_training_data('processed_outputs')
42
-
43
- # Initialize the chatbot
44
- chatbot = RetrievalChatbot(
45
- vocab_size=10000,
46
- max_sequence_length=80,
47
- embedding_dim=256,
48
- lstm_units=256,
49
- num_attention_heads=8,
50
- margin=0.3
51
- )
52
-
53
- # Prepare the dataset for triplet training
54
- q_pad, p_pad, n_pad = chatbot.prepare_dataset(dialogues, neg_samples_per_pos=3)
55
-
56
- # Train with triplet loss
57
- train_losses, val_losses = chatbot.train_with_triplet_loss(
58
- q_pad, p_pad, n_pad,
59
- epochs=1,
60
- batch_size=32,
61
- validation_split=0.2
62
- )
63
-
64
- plot_training_history(train_losses, val_losses)
65
-
66
- # After training, test prediction
67
- response_candidates = [turn['text'] for d in dialogues for turn in d['turns'] if turn['speaker'] == 'assistant']
68
-
69
- # Test retrieval
70
- test_query = "I'd like a recommendation for a Korean restaurant in NYC."
71
- top_responses = chatbot.retrieve_top_n(test_query, response_candidates, top_n=5)
72
- print("Top responses:")
73
- for resp, score in top_responses:
74
- print(f"Score: {score:.4f} - {resp}")
75
-
76
- # Single-turn validation:
77
- test_queries = [
78
- "I want to book a Korean restaurant in NYC.",
79
- "Can I get two tickets for 'What Men Want'?",
80
- "What's the best time to watch the movie today?"
81
- ]
82
- for query in test_queries:
83
- top_responses = chatbot.retrieve_top_n(query, response_candidates, top_n=3)
84
- print(f"\nQuery: {query}")
85
- for resp, score in top_responses:
86
- print(f"Score: {score:.4f} - {resp}")
87
-
88
- # Multi-turn conversation:
89
- multi_turn_history = []
90
-
91
- def update_context(multi_turn_history, query, response, max_context_turns=3):
92
- multi_turn_history.append((query, response))
93
- if len(multi_turn_history) > max_context_turns:
94
- multi_turn_history.pop(0)
95
-
96
- def get_context_enhanced_query(multi_turn_history, query):
97
- if not multi_turn_history:
98
- return query
99
- context = " ".join([f"User: {q} Assistant: {r}" for q, r in multi_turn_history])
100
- return f"{context} User: {query}"
101
-
102
- conversation_queries = [
103
- "I'd like to watch a movie tonight.",
104
- "Is there a showing of 'What Men Want'?",
105
- "What time is the last show?",
106
- "Can I get two tickets?"
107
- ]
108
-
109
- for query in conversation_queries:
110
- context_query = get_context_enhanced_query(multi_turn_history, query)
111
- top_responses = chatbot.retrieve_top_n(context_query, response_candidates, top_n=3)
112
- best_response = top_responses[0][0]
113
- print(f"\nUser: {query}\nAssistant: {best_response}")
114
- update_context(multi_turn_history, query, best_response)
115
-
116
-
117
-
118
-
119
-
120
-
121
-
122
-
123
- #queries, responses, labels = chatbot.prepare_dataset(dialogues, neg_samples_per_pos=3)
124
-
125
- #train_dialogues, val_dialogues = train_test_split(dialogues, test_size=0.2, random_state=20)
126
- #query_train, query_val, response_train, response_val, labels_train, labels_val = train_test_split(queries, responses, labels, test_size=0.2, random_state=20)
127
-
128
- # chatbot.model.compile(
129
- # optimizer=tf.keras.optimizers.Adam(learning_rate=0.001, clipnorm=1.0),
130
- # loss='binary_crossentropy',
131
- # metrics=['accuracy']
132
- # )
133
-
134
- # # Train the model with early stopping to prevent overfitting
135
- # callbacks = [
136
- # tf.keras.callbacks.EarlyStopping(
137
- # monitor='val_loss',
138
- # patience=3,
139
- # restore_best_weights=True
140
- # ),
141
- # tf.keras.callbacks.ReduceLROnPlateau(
142
- # monitor='val_loss',
143
- # factor=0.5,
144
- # patience=2,
145
- # min_lr=1e-6,
146
- # verbose=1
147
- # ),
148
- # tf.keras.callbacks.ModelCheckpoint(
149
- # 'chatbot_model.keras',
150
- # monitor='val_loss',
151
- # save_best_only=True
152
- # )
153
- # ]
154
-
155
- # history = chatbot.model.fit(
156
- # {'query_input': query_train, 'response_input': response_train},
157
- # labels_train,
158
- # validation_data=({'query_input': query_val, 'response_input': response_val}, labels_val),
159
- # epochs=5,
160
- # batch_size=32,
161
- # callbacks=callbacks
162
- # )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
run_model2.py DELETED
@@ -1,340 +0,0 @@
1
- from chatbot2 import RetrievalChatbot, ChatbotConfig
2
- import os
3
- import json
4
- import glob
5
- import matplotlib.pyplot as plt
6
- import logging
7
- from pathlib import Path
8
- from typing import List, Dict, Optional, Any, Tuple
9
- import numpy as np
10
- from datetime import datetime
11
- from response_quality_checker import ResponseQualityChecker
12
-
13
- # Configure logging
14
- logging.basicConfig(
15
- level=logging.INFO,
16
- format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
17
- )
18
- logger = logging.getLogger(__name__)
19
-
20
- def load_training_data(data_directory: str, debug_samples: Optional[int] = None) -> list:
21
- """
22
- Load and combine dialogue data from multiple JSON files.
23
-
24
- Args:
25
- data_directory: Directory containing the dialogue files
26
- debug_samples: If set, only load this many dialogues for debugging
27
- """
28
- all_dialogues = []
29
- data_directory = Path(data_directory)
30
-
31
- # Get all json files matching the pattern
32
- pattern = "batch_*.json"
33
- json_files = sorted(data_directory.glob(pattern))
34
-
35
- logger.info(f"Found {len(json_files)} batch files")
36
-
37
- if debug_samples:
38
- logger.info(f"Debug mode: Will load up to {debug_samples} dialogues")
39
-
40
- for file_path in json_files:
41
- try:
42
- with open(file_path, 'r', encoding='utf-8') as f:
43
- batch_dialogues = json.load(f)
44
-
45
- # If in debug mode, only take what we need from this batch
46
- if debug_samples is not None:
47
- remaining_samples = debug_samples - len(all_dialogues)
48
- if remaining_samples <= 0:
49
- break
50
- batch_dialogues = batch_dialogues[:remaining_samples]
51
-
52
- all_dialogues.extend(batch_dialogues)
53
- logger.info(f"Loaded {len(batch_dialogues)} dialogues from {file_path.name}")
54
-
55
- # If we've reached our debug sample limit, stop loading
56
- if debug_samples is not None and len(all_dialogues) >= debug_samples:
57
- logger.info(f"Debug mode: Reached {debug_samples} samples, stopping load")
58
- break
59
-
60
- except Exception as e:
61
- logger.error(f"Error loading {file_path}: {str(e)}")
62
-
63
- total_loaded = len(all_dialogues)
64
- if debug_samples:
65
- logger.info(f"Debug mode: Loaded {total_loaded}/{debug_samples} requested dialogues")
66
- else:
67
- logger.info(f"Total dialogues loaded: {total_loaded}")
68
-
69
- return all_dialogues
70
-
71
- def plot_training_history(history: Dict[str, List[float]], save_dir: Path = None):
72
- """Plot and optionally save training history."""
73
- # Create figure with two subplots
74
- fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(10, 12))
75
-
76
- # Plot losses
77
- ax1.plot(history['train_loss'], label='Train Loss')
78
- ax1.plot(history['val_loss'], label='Validation Loss')
79
- ax1.set_xlabel('Epoch')
80
- ax1.set_ylabel('Triplet Loss')
81
- ax1.set_title('Training and Validation Loss')
82
- ax1.legend()
83
- ax1.grid(True)
84
-
85
- # Plot learning rate if available
86
- if 'learning_rate' in history:
87
- ax2.plot(history['learning_rate'], label='Learning Rate')
88
- ax2.set_xlabel('Step')
89
- ax2.set_ylabel('Learning Rate')
90
- ax2.set_title('Learning Rate Schedule')
91
- ax2.legend()
92
- ax2.grid(True)
93
-
94
- plt.tight_layout()
95
-
96
- # Save if directory provided
97
- if save_dir:
98
- save_dir = Path(save_dir)
99
- save_dir.mkdir(parents=True, exist_ok=True)
100
- timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
101
- plt.savefig(save_dir / f'training_history_{timestamp}.png')
102
-
103
- plt.show()
104
-
105
- def setup_training_directories(base_dir: str = "chatbot_training") -> Dict[str, Path]:
106
- """Setup directory structure for training artifacts."""
107
- base_dir = Path(base_dir)
108
- timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
109
- train_dir = base_dir / f"training_run_{timestamp}"
110
-
111
- directories = {
112
- 'base': train_dir,
113
- 'checkpoints': train_dir / 'checkpoints',
114
- 'plots': train_dir / 'plots',
115
- 'logs': train_dir / 'logs'
116
- }
117
-
118
- # Create directories
119
- for dir_path in directories.values():
120
- dir_path.mkdir(parents=True, exist_ok=True)
121
-
122
- return directories
123
-
124
- def run_automatic_validation(
125
- chatbot,
126
- response_pool: List[str],
127
- quality_checker: ResponseQualityChecker,
128
- num_examples: int = 5
129
- ) -> Dict[str, Any]:
130
- """
131
- Run automatic validation with quality metrics.
132
- """
133
- logger.info("\n=== Running Automatic Validation ===")
134
-
135
- test_queries = [
136
- "Hello, how are you today?",
137
- "What's the weather like?",
138
- "Can you help me with a problem?",
139
- "Tell me a joke",
140
- "What time is it?",
141
- "I need help with my homework",
142
- "Where's a good place to eat?",
143
- "What movies are playing?",
144
- "How do I reset my password?",
145
- "Can you recommend a book?"
146
- ]
147
-
148
- test_queries = test_queries[:num_examples]
149
- metrics_history = []
150
-
151
- for i, query in enumerate(test_queries, 1):
152
- logger.info(f"\nTest Case {i}:")
153
- logger.info(f"Query: {query}")
154
-
155
- # Get responses and scores
156
- responses = chatbot.retrieve_responses(
157
- query,
158
- response_pool,
159
- context=None,
160
- top_k=5
161
- )
162
-
163
- # Check quality
164
- quality_metrics = quality_checker.check_response_quality(
165
- query, responses, response_pool
166
- )
167
- metrics_history.append(quality_metrics)
168
-
169
- # Log results
170
- logger.info(f"Quality Metrics: {quality_metrics}")
171
- logger.info("Top responses:")
172
- for j, (response, score) in enumerate(responses[:3], 1):
173
- logger.info(f"{j}. Score: {score:.4f}")
174
- logger.info(f" Response: {response}")
175
- if j == 1 and not quality_metrics['is_confident']:
176
- logger.info(" [Low Confidence - Would abstain from answering]")
177
-
178
- # Calculate aggregate metrics
179
- aggregate_metrics = {
180
- 'num_queries_tested': len(test_queries),
181
- 'avg_top_response_score': np.mean([m['top_score'] for m in metrics_history]),
182
- 'avg_diversity': np.mean([m['response_diversity'] for m in metrics_history]),
183
- 'avg_relevance': np.mean([m['query_response_relevance'] for m in metrics_history]),
184
- 'confidence_rate': np.mean([m['is_confident'] for m in metrics_history]),
185
- }
186
-
187
- logger.info("\n=== Validation Summary ===")
188
- for metric, value in aggregate_metrics.items():
189
- logger.info(f"{metric}: {value:.4f}")
190
-
191
- return aggregate_metrics
192
-
193
- def chat_with_quality_check(
194
- chatbot,
195
- query: str,
196
- response_pool: List[str],
197
- conversation_history: List[Tuple[str, str]],
198
- quality_checker: ResponseQualityChecker
199
- ) -> Tuple[Optional[str], List[Tuple[str, float]], Dict[str, Any]]:
200
- """
201
- Enhanced chat function with quality checking.
202
- """
203
- # Get responses and scores
204
- responses = chatbot.retrieve_responses(
205
- query,
206
- response_pool,
207
- conversation_history
208
- )
209
-
210
- # Check quality
211
- quality_metrics = quality_checker.check_response_quality(
212
- query, responses, response_pool
213
- )
214
-
215
- if quality_metrics['is_confident']:
216
- return responses[0][0], responses, quality_metrics
217
- else:
218
- uncertainty_response = (
219
- "I apologize, but I don't feel confident providing an answer to that "
220
- "question at the moment. Could you please rephrase or ask something else?"
221
- )
222
- return uncertainty_response, responses, quality_metrics
223
-
224
- def get_total_steps(dialogues: List[Dict[str, Any]], batch_size: int, epochs: int) -> int:
225
- """
226
- Calculate total training steps based on dialogues and batch size.
227
- Assume 80% of data for training due to validation split
228
- """
229
- estimated_train_samples = int(len(dialogues) * 0.8)
230
- steps_per_epoch = estimated_train_samples // batch_size
231
- return steps_per_epoch * epochs
232
-
233
- def main():
234
- DEBUG_SAMPLES = 350
235
- BATCH_SIZE = 32
236
- EPOCHS = 5 if DEBUG_SAMPLES else 10
237
-
238
- # Setup training directories
239
- dirs = setup_training_directories()
240
-
241
- # Load training data
242
- dialogues = load_training_data('processed_outputs', debug_samples=DEBUG_SAMPLES)
243
- total_steps = get_total_steps(dialogues, BATCH_SIZE, EPOCHS)
244
-
245
- # Initialize configuration
246
- config = ChatbotConfig(
247
- embedding_dim=32, # TODO: 256
248
- encoder_units=32, # TODO: 256
249
- num_attention_heads=2, # TODO: 8
250
- warmup_steps=int(total_steps * 0.1), # 10% of total steps for warmup
251
- )
252
-
253
- # Save configuration
254
- with open(dirs['base'] / 'config.json', 'w') as f:
255
- json.dump(config.to_dict(), f, indent=2)
256
-
257
- # Initialize chatbot
258
- chatbot = RetrievalChatbot(config)
259
-
260
- # Prepare dataset
261
- logger.info("Preparing dataset...")
262
-
263
- # Prepare and train with debug samples
264
- q_pad, p_pad, n_pad = chatbot.prepare_dataset(
265
- dialogues,
266
- neg_samples_per_pos=3,
267
- debug_samples=DEBUG_SAMPLES
268
- )
269
-
270
- # Train model
271
- logger.info("Starting training...")
272
- chatbot.train(
273
- q_pad, p_pad, n_pad,
274
- epochs=EPOCHS,
275
- batch_size=BATCH_SIZE,
276
- checkpoint_dir=dirs['checkpoints']
277
- )
278
-
279
- # Plot and save training history
280
- plot_training_history(chatbot.history, save_dir=dirs['plots'])
281
-
282
- # Save final model
283
- chatbot.save_models(dirs['base'] / 'final_model')
284
-
285
- # Prepare response pool for chat
286
- response_pool = [
287
- turn['text'] for d in dialogues
288
- for turn in d['turns'] if turn['speaker'] == 'assistant'
289
- ]
290
-
291
- # Initialize quality checker with appropriate thresholds
292
- quality_checker = ResponseQualityChecker(
293
- confidence_threshold=0.6 if not DEBUG_SAMPLES else 0.4, # Lower threshold for debug
294
- diversity_threshold=0.2,
295
- min_response_length=10,
296
- max_similarity_ratio=0.9
297
- )
298
-
299
- # Run automatic validation
300
- validation_metrics = run_automatic_validation(
301
- chatbot,
302
- response_pool,
303
- quality_checker,
304
- num_examples=5 if DEBUG_SAMPLES else 10
305
- )
306
-
307
- # Log validation metrics
308
- logger.info(f"Validation Metrics: {validation_metrics}")
309
-
310
- # Now continue with interactive chat
311
- logger.info("\nStarting interactive chat session...")
312
- conversation_history = []
313
-
314
- while True:
315
- query = input("\nYou: ")
316
- if query.lower() in ['quit', 'exit', 'bye']:
317
- break
318
-
319
- try:
320
- response, candidates = chatbot.chat(
321
- query,
322
- response_pool,
323
- conversation_history
324
- )
325
- print(f"\nAssistant: {response}")
326
-
327
- # Print top alternative responses
328
- print("\nAlternative responses:")
329
- for resp, score in candidates[1:4]:
330
- print(f"Score: {score:.4f} - {resp}")
331
-
332
- # Update history
333
- conversation_history.append((query, response))
334
-
335
- except Exception as e:
336
- logger.error(f"Error during chat: {str(e)}")
337
- print("Sorry, I encountered an error. Please try again.")
338
-
339
- if __name__ == "__main__":
340
- main()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
run_model3.py DELETED
@@ -1,434 +0,0 @@
1
- from chatbot3 import RetrievalChatbot, ChatbotConfig
2
- import os
3
- import json
4
- import glob
5
- import tensorflow as tf
6
- import matplotlib.pyplot as plt
7
- import logging
8
- from pathlib import Path
9
- from typing import List, Dict, Optional, Any, Tuple
10
- import numpy as np
11
- from datetime import datetime
12
- from response_quality_checker import ResponseQualityChecker
13
- import torch
14
- from transformers import TFAutoModel, AutoTokenizer
15
-
16
-
17
- policy = tf.keras.mixed_precision.Policy('mixed_float16')
18
- tf.keras.mixed_precision.set_global_policy(policy)
19
-
20
- # Configure logging
21
- logging.basicConfig(
22
- level=logging.INFO,
23
- format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
24
- )
25
- logger = logging.getLogger(__name__)
26
-
27
- def setup_model_cache(cache_dir: Optional[Path] = None) -> Path:
28
- """Setup and manage model cache directory."""
29
- if cache_dir is None:
30
- cache_dir = Path.home() / '.chatbot_cache'
31
-
32
- cache_dir.mkdir(parents=True, exist_ok=True)
33
-
34
- # Set environment variables for various libraries
35
- os.environ['TRANSFORMERS_CACHE'] = str(cache_dir / 'transformers')
36
- os.environ['TORCH_HOME'] = str(cache_dir / 'torch')
37
- os.environ['HF_HOME'] = str(cache_dir / 'huggingface')
38
-
39
- logger.info(f"Using cache directory: {cache_dir}")
40
- return cache_dir
41
-
42
- def setup_gpu():
43
- """Configure GPU settings for optimal performance."""
44
- logger.info("Checking GPU availability...")
45
-
46
- gpus = tf.config.list_physical_devices('GPU')
47
- if gpus:
48
- try:
49
- # Allow memory growth to prevent taking all GPU memory at once
50
- for gpu in gpus:
51
- tf.config.experimental.set_memory_growth(gpu, True)
52
- logger.info(f"Found {len(gpus)} GPU(s). Memory growth enabled.")
53
-
54
- # Log GPU info
55
- for gpu in gpus:
56
- logger.info(f"GPU Device: {gpu}")
57
- return True
58
- except Exception as e:
59
- logger.error(f"Error configuring GPU: {str(e)}")
60
- return False
61
-
62
- else:
63
- logger.info("No GPU found. Using CPU.")
64
- return False
65
-
66
- def preload_models(config: ChatbotConfig, cache_dir: Path):
67
- """Preload and cache models."""
68
- logger.info("Preloading models...")
69
-
70
- # Cache DistilBERT
71
- model_name = config.pretrained_model
72
- cache_path = cache_dir / 'transformers' / model_name
73
-
74
- if not cache_path.exists():
75
- logger.info(f"Downloading and caching {model_name}...")
76
- tokenizer = AutoTokenizer.from_pretrained(model_name)
77
- model = TFAutoModel.from_pretrained(model_name)
78
-
79
- # Save to cache
80
- tokenizer.save_pretrained(cache_path)
81
- model.save_pretrained(cache_path)
82
- else:
83
- logger.info(f"Using cached model from {cache_path}")
84
-
85
- return cache_path
86
-
87
- def load_training_data(data_directory: str, debug_samples: Optional[int] = None) -> list:
88
- """
89
- Load and combine dialogue data from multiple JSON files.
90
-
91
- Args:
92
- data_directory: Directory containing the dialogue files
93
- debug_samples: If set, only load this many dialogues for debugging
94
- """
95
- all_dialogues = []
96
- data_directory = Path(data_directory)
97
-
98
- # Get all json files matching the pattern
99
- pattern = "batch_*.json"
100
- json_files = sorted(data_directory.glob(pattern))
101
-
102
- logger.info(f"Found {len(json_files)} batch files")
103
-
104
- if debug_samples:
105
- logger.info(f"Debug mode: Will load up to {debug_samples} dialogues")
106
-
107
- for file_path in json_files:
108
- try:
109
- with open(file_path, 'r', encoding='utf-8') as f:
110
- batch_dialogues = json.load(f)
111
-
112
- # If in debug mode, only take what we need from this batch
113
- if debug_samples is not None:
114
- remaining_samples = debug_samples - len(all_dialogues)
115
- if remaining_samples <= 0:
116
- break
117
- batch_dialogues = batch_dialogues[:remaining_samples]
118
-
119
- all_dialogues.extend(batch_dialogues)
120
- logger.info(f"Loaded {len(batch_dialogues)} dialogues from {file_path.name}")
121
-
122
- # If we've reached our debug sample limit, stop loading
123
- if debug_samples is not None and len(all_dialogues) >= debug_samples:
124
- logger.info(f"Debug mode: Reached {debug_samples} samples, stopping load")
125
- break
126
-
127
- except Exception as e:
128
- logger.error(f"Error loading {file_path}: {str(e)}")
129
-
130
- total_loaded = len(all_dialogues)
131
- if debug_samples:
132
- logger.info(f"Debug mode: Loaded {total_loaded}/{debug_samples} requested dialogues")
133
- else:
134
- logger.info(f"Total dialogues loaded: {total_loaded}")
135
-
136
- return all_dialogues
137
-
138
- def plot_training_history(history: Dict[str, List[float]], save_dir: Path = None):
139
- """Plot and optionally save training history."""
140
- # Create figure with two subplots
141
- fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(10, 12))
142
-
143
- # Plot losses
144
- ax1.plot(history['train_loss'], label='Train Loss')
145
- ax1.plot(history['val_loss'], label='Validation Loss')
146
- ax1.set_xlabel('Epoch')
147
- ax1.set_ylabel('Triplet Loss')
148
- ax1.set_title('Training and Validation Loss')
149
- ax1.legend()
150
- ax1.grid(True)
151
-
152
- # Plot learning rate if available
153
- if 'learning_rate' in history:
154
- ax2.plot(history['learning_rate'], label='Learning Rate')
155
- ax2.set_xlabel('Step')
156
- ax2.set_ylabel('Learning Rate')
157
- ax2.set_title('Learning Rate Schedule')
158
- ax2.legend()
159
- ax2.grid(True)
160
-
161
- plt.tight_layout()
162
-
163
- # Save if directory provided
164
- if save_dir:
165
- save_dir = Path(save_dir)
166
- save_dir.mkdir(parents=True, exist_ok=True)
167
- timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
168
- plt.savefig(save_dir / f'training_history_{timestamp}.png')
169
-
170
- plt.show()
171
-
172
- def setup_training_directories(base_dir: str = "chatbot_training") -> Dict[str, Path]:
173
- """Setup directory structure for training artifacts."""
174
- base_dir = Path(base_dir)
175
- timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
176
- train_dir = base_dir / f"training_run_{timestamp}"
177
-
178
- directories = {
179
- 'base': train_dir,
180
- 'checkpoints': train_dir / 'checkpoints',
181
- 'plots': train_dir / 'plots',
182
- 'logs': train_dir / 'logs'
183
- }
184
-
185
- # Create directories
186
- for dir_path in directories.values():
187
- dir_path.mkdir(parents=True, exist_ok=True)
188
-
189
- return directories
190
-
191
- def run_automatic_validation(
192
- chatbot,
193
- response_pool: List[str],
194
- quality_checker: ResponseQualityChecker,
195
- num_examples: int = 5
196
- ) -> Dict[str, Any]:
197
- """
198
- Run automatic validation with quality metrics.
199
- """
200
- logger.info("\n=== Running Automatic Validation ===")
201
-
202
- test_queries = [
203
- "Hello, how are you today?",
204
- "What's the weather like?",
205
- "Can you help me with a problem?",
206
- "Tell me a joke",
207
- "What time is it?",
208
- "I need help with my homework",
209
- "Where's a good place to eat?",
210
- "What movies are playing?",
211
- "How do I reset my password?",
212
- "Can you recommend a book?"
213
- ]
214
-
215
- test_queries = test_queries[:num_examples]
216
- metrics_history = []
217
-
218
- for i, query in enumerate(test_queries, 1):
219
- logger.info(f"\nTest Case {i}:")
220
- logger.info(f"Query: {query}")
221
-
222
- # Get responses and scores
223
- responses = chatbot.retrieve_responses(
224
- query,
225
- response_pool,
226
- context=None,
227
- top_k=5
228
- )
229
-
230
- # Check quality
231
- quality_metrics = quality_checker.check_response_quality(
232
- query, responses, response_pool
233
- )
234
- metrics_history.append(quality_metrics)
235
-
236
- # Log results
237
- logger.info(f"Quality Metrics: {quality_metrics}")
238
- logger.info("Top responses:")
239
- for j, (response, score) in enumerate(responses[:3], 1):
240
- logger.info(f"{j}. Score: {score:.4f}")
241
- logger.info(f" Response: {response}")
242
- if j == 1 and not quality_metrics['is_confident']:
243
- logger.info(" [Low Confidence - Would abstain from answering]")
244
-
245
- # Calculate aggregate metrics
246
- aggregate_metrics = {
247
- 'num_queries_tested': len(test_queries),
248
- 'avg_top_response_score': np.mean([m['top_score'] for m in metrics_history]),
249
- 'avg_diversity': np.mean([m['response_diversity'] for m in metrics_history]),
250
- 'avg_relevance': np.mean([m['query_response_relevance'] for m in metrics_history]),
251
- 'confidence_rate': np.mean([m['is_confident'] for m in metrics_history]),
252
- }
253
-
254
- logger.info("\n=== Validation Summary ===")
255
- for metric, value in aggregate_metrics.items():
256
- logger.info(f"{metric}: {value:.4f}")
257
-
258
- return aggregate_metrics
259
-
260
- def chat_with_quality_check(
261
- chatbot,
262
- query: str,
263
- response_pool: List[str],
264
- conversation_history: List[Tuple[str, str]],
265
- quality_checker: ResponseQualityChecker
266
- ) -> Tuple[Optional[str], List[Tuple[str, float]], Dict[str, Any]]:
267
- """
268
- Enhanced chat function with quality checking.
269
- """
270
- # Get responses and scores
271
- responses = chatbot.retrieve_responses(
272
- query,
273
- response_pool,
274
- conversation_history
275
- )
276
-
277
- # Check quality
278
- quality_metrics = quality_checker.check_response_quality(
279
- query, responses, response_pool
280
- )
281
-
282
- if quality_metrics['is_confident']:
283
- return responses[0][0], responses, quality_metrics
284
- else:
285
- uncertainty_response = (
286
- "I apologize, but I don't feel confident providing an answer to that "
287
- "question at the moment. Could you please rephrase or ask something else?"
288
- )
289
- return uncertainty_response, responses, quality_metrics
290
-
291
- def get_total_steps(dialogues: List[Dict[str, Any]], batch_size: int, epochs: int) -> int:
292
- """
293
- Calculate total training steps based on dialogues and batch size.
294
- Assume 80% of data for training due to validation split
295
- """
296
- estimated_train_samples = int(len(dialogues) * 0.8)
297
- steps_per_epoch = estimated_train_samples // batch_size
298
- return steps_per_epoch * epochs
299
-
300
- def main():
301
- # Set up GPU
302
- is_gpu = setup_gpu()
303
-
304
- DEBUG_SAMPLES = 350
305
- BATCH_SIZE = 64 if is_gpu else 32
306
- EPOCHS = 5 if DEBUG_SAMPLES else 10
307
-
308
- # Set device
309
- device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
310
- logger.info(f"Using device: {device}")
311
-
312
- # Set up caching
313
- cache_dir = setup_model_cache()
314
-
315
- # Set up training directories
316
- dirs = setup_training_directories()
317
-
318
- # Load training data
319
- dialogues = load_training_data('processed_outputs', debug_samples=DEBUG_SAMPLES)
320
- total_steps = get_total_steps(dialogues, BATCH_SIZE, EPOCHS)
321
-
322
- # Initialize configuration
323
- config = ChatbotConfig(
324
- embedding_dim=768, # Match DistilBERT's dimension
325
- encoder_units=256,
326
- num_attention_heads=8,
327
- warmup_steps=int(total_steps * 0.1),
328
- learning_rate=0.0003,
329
- margin=0.5,
330
- pretrained_model='distilbert-base-uncased'
331
- )
332
-
333
- # Preload models
334
- preload_models(config, cache_dir)
335
-
336
- # Save configuration
337
- with open(dirs['base'] / 'config.json', 'w') as f:
338
- json.dump(config.to_dict(), f, indent=2)
339
-
340
- # Initialize chatbot
341
- chatbot = RetrievalChatbot(config)
342
-
343
- # Prepare dataset
344
- logger.info("Preparing dataset...")
345
-
346
- # Prepare and train with debug samples
347
- q_pad, p_pad, n_pad = chatbot.prepare_dataset(
348
- dialogues,
349
- neg_samples_per_pos=3,
350
- debug_samples=DEBUG_SAMPLES
351
- )
352
-
353
- tensorboard_callback = tf.keras.callbacks.TensorBoard(log_dir='./logs')
354
-
355
- # Train model
356
- logger.info("Starting training...")
357
- chatbot.train(
358
- q_pad, p_pad, n_pad,
359
- epochs=EPOCHS,
360
- batch_size=BATCH_SIZE,
361
- validation_split=0.2,
362
- checkpoint_dir=dirs['checkpoints'],
363
- callbacks=[tensorboard_callback]
364
- )
365
-
366
- # Plot and save training history
367
- plot_training_history(chatbot.history, save_dir=dirs['plots'])
368
-
369
- # Save final model
370
- chatbot.save_models(dirs['base'] / 'final_model')
371
-
372
- # Prepare response pool for chat
373
- response_pool = [
374
- turn['text'] for d in dialogues
375
- for turn in d['turns'] if turn['speaker'] == 'assistant'
376
- ]
377
-
378
- # Initialize quality checker with appropriate thresholds
379
- quality_checker = ResponseQualityChecker(
380
- confidence_threshold=0.6 if not DEBUG_SAMPLES else 0.4, # Lower threshold for debug
381
- diversity_threshold=0.2,
382
- min_response_length=10,
383
- max_similarity_ratio=0.9
384
- )
385
-
386
- # Run automatic validation
387
- validation_metrics = run_automatic_validation(
388
- chatbot,
389
- response_pool,
390
- quality_checker,
391
- num_examples=5 if DEBUG_SAMPLES else 10
392
- )
393
-
394
- # Log validation metrics
395
- logger.info(f"Validation Metrics: {validation_metrics}")
396
-
397
- # Now continue with interactive chat
398
- logger.info("\nStarting interactive chat session...")
399
- conversation_history = []
400
-
401
- while True:
402
- query = input("\nYou: ")
403
- if query.lower() in ['quit', 'exit', 'bye']:
404
- break
405
-
406
- try:
407
- response, candidates, quality_metrics = chat_with_quality_check(
408
- chatbot,
409
- query,
410
- response_pool,
411
- conversation_history,
412
- quality_checker
413
- )
414
- print(f"\nAssistant: {response}")
415
-
416
- # Print top alternative responses if confident
417
- if quality_metrics['is_confident']:
418
- print("\nAlternative responses:")
419
- for resp, score in candidates[1:4]:
420
- print(f"Score: {score:.4f} - {resp}")
421
-
422
- # Update history only for confident responses
423
- conversation_history.append((query, response))
424
- else:
425
- print("\nQuality metrics indicated low confidence:")
426
- print(f"Confidence score: {quality_metrics['top_score']:.4f}")
427
- print(f"Response relevance: {quality_metrics['query_response_relevance']:.4f}")
428
-
429
- except Exception as e:
430
- logger.error(f"Error during chat: {str(e)}")
431
- print("Sorry, I encountered an error. Please try again.")
432
-
433
- if __name__ == "__main__":
434
- main()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
run_model4.py DELETED
@@ -1,237 +0,0 @@
1
- from chatbot4 import RetrievalChatbot, ChatbotConfig
2
- import os
3
- import tensorflow as tf
4
- import matplotlib.pyplot as plt
5
- import logging
6
- from pathlib import Path
7
- from typing import List, Dict, Optional
8
- from datetime import datetime
9
- from response_quality_checker import ResponseQualityChecker
10
-
11
- # Configure logging
12
- logging.basicConfig(
13
- level=logging.INFO,
14
- format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
15
- )
16
- logger = logging.getLogger(__name__)
17
-
18
- def setup_model_cache(cache_dir: Optional[Path] = None) -> Path:
19
- """Setup and manage model cache directory."""
20
- if cache_dir is None:
21
- cache_dir = Path.home() / '.chatbot_cache'
22
-
23
- cache_dir.mkdir(parents=True, exist_ok=True)
24
-
25
- # Set environment variables for various libraries
26
- os.environ['TRANSFORMERS_CACHE'] = str(cache_dir / 'transformers')
27
- os.environ['TORCH_HOME'] = str(cache_dir / 'torch')
28
- os.environ['HF_HOME'] = str(cache_dir / 'huggingface')
29
-
30
- logger.info(f"Using cache directory: {cache_dir}")
31
- return cache_dir
32
-
33
- def setup_gpu():
34
- """Configure GPU settings for optimal performance."""
35
- logger.info("Checking GPU availability...")
36
-
37
- gpus = tf.config.list_physical_devices('GPU')
38
- if gpus:
39
- try:
40
- # Allow memory growth to prevent taking all GPU memory at once
41
- for gpu in gpus:
42
- tf.config.experimental.set_memory_growth(gpu, True)
43
- logger.info(f"Found {len(gpus)} GPU(s). Memory growth enabled.")
44
-
45
- # Log GPU info
46
- for gpu in gpus:
47
- logger.info(f"GPU Device: {gpu}")
48
- return True
49
- except Exception as e:
50
- logger.error(f"Error configuring GPU: {str(e)}")
51
- return False
52
-
53
- else:
54
- logger.info("No GPU found. Using CPU.")
55
- return False
56
-
57
- # def preload_models(config: ChatbotConfig, cache_dir: Path):
58
- # """Preload and cache models."""
59
- # logger.info("Preloading models...")
60
-
61
- # # Cache DistilBERT
62
- # model_name = config.pretrained_model
63
- # cache_path = cache_dir / 'transformers' / model_name
64
-
65
- # if not cache_path.exists():
66
- # logger.info(f"Downloading and caching {model_name}...")
67
- # tokenizer = AutoTokenizer.from_pretrained(model_name)
68
- # model = TFAutoModel.from_pretrained(model_name)
69
-
70
- # # Save to cache
71
- # tokenizer.save_pretrained(cache_path)
72
- # model.save_pretrained(cache_path)
73
- # else:
74
- # logger.info(f"Using cached model from {cache_path}")
75
-
76
- # return cache_path
77
-
78
- def plot_training_history(history: Dict[str, List[float]], save_dir: Path = None):
79
- """Plot and optionally save training history."""
80
- # Create figure with two subplots
81
- fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(10, 12))
82
-
83
- # Plot losses
84
- ax1.plot(history['train_loss'], label='Train Loss')
85
- ax1.plot(history['val_loss'], label='Validation Loss')
86
- ax1.set_xlabel('Epoch')
87
- ax1.set_ylabel('Triplet Loss')
88
- ax1.set_title('Training and Validation Loss')
89
- ax1.legend()
90
- ax1.grid(True)
91
-
92
- # Plot learning rate if available
93
- if 'learning_rate' in history:
94
- ax2.plot(history['learning_rate'], label='Learning Rate')
95
- ax2.set_xlabel('Step')
96
- ax2.set_ylabel('Learning Rate')
97
- ax2.set_title('Learning Rate Schedule')
98
- ax2.legend()
99
- ax2.grid(True)
100
-
101
- plt.tight_layout()
102
-
103
- # Save if directory provided
104
- if save_dir:
105
- save_dir = Path(save_dir)
106
- save_dir.mkdir(parents=True, exist_ok=True)
107
- timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
108
- plt.savefig(save_dir / f'training_history_{timestamp}.png')
109
-
110
- plt.show()
111
-
112
- def setup_training_directories(base_dir: str = "chatbot_training") -> Dict[str, Path]:
113
- """Setup directory structure for training artifacts."""
114
- base_dir = Path(base_dir)
115
- timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
116
- train_dir = base_dir / f"training_run_{timestamp}"
117
-
118
- directories = {
119
- 'base': train_dir,
120
- 'checkpoints': train_dir / 'checkpoints',
121
- 'plots': train_dir / 'plots',
122
- 'logs': train_dir / 'logs'
123
- }
124
-
125
- # Create directories
126
- for dir_path in directories.values():
127
- dir_path.mkdir(parents=True, exist_ok=True)
128
-
129
- return directories
130
-
131
- def main():
132
- # Set up GPU
133
- is_gpu = setup_gpu()
134
-
135
- DEBUG_SAMPLES = 2000
136
- BATCH_SIZE = 128 if is_gpu else 64
137
- EPOCHS = 5 if DEBUG_SAMPLES else 10
138
-
139
- # Set up caching
140
- cache_dir = setup_model_cache()
141
-
142
- # Set up training directories
143
- dirs = setup_training_directories()
144
-
145
- # Initialize configuration
146
- config = ChatbotConfig(
147
- embedding_dim=768, # Match DistilBERT's dimension
148
- max_sequence_length=512,
149
- freeze_embeddings=False
150
- )
151
-
152
- # Preload models
153
- #preload_models(config, cache_dir)
154
-
155
- # Save configuration
156
- # with open(dirs['base'] / 'config.json', 'w') as f:
157
- # json.dump(config.to_dict(), f, indent=4)
158
-
159
- # Load training data
160
- dialogues = RetrievalChatbot.load_training_data(data_path='processed_outputs/batch_group_0010.json', debug_samples=DEBUG_SAMPLES)
161
-
162
- # Initialize chatbot
163
- chatbot = RetrievalChatbot(config, dialogues)
164
-
165
- # Check trainable variables
166
- chatbot.check_trainable_variables()
167
-
168
- # Verify FAISS
169
- chatbot.verify_faiss_index()
170
-
171
- # Prepare dataset
172
- logger.info("Preparing dataset...")
173
-
174
- # Prepare and train with debug samples
175
- q_tensor, p_tensor = chatbot.prepare_dataset(dialogues)
176
-
177
- quality_checker = ResponseQualityChecker(chatbot=chatbot)
178
-
179
- # Train model
180
- logger.info("Starting training...")
181
-
182
- tf.config.optimizer.set_jit(True) # XLA
183
- policy = tf.keras.mixed_precision.Policy('mixed_float16')
184
- tf.keras.mixed_precision.set_global_policy(policy)
185
-
186
- chatbot.train(
187
- q_pad=q_tensor,
188
- p_pad=p_tensor,
189
- epochs=EPOCHS,
190
- batch_size=BATCH_SIZE,
191
- validation_split=0.2,
192
- checkpoint_dir="checkpoints/",
193
- use_lr_schedule=True, # Enable custom schedule
194
- peak_lr=2e-5, # Peak learning rate
195
- warmup_steps_ratio=0.1, # 10% warmup
196
- early_stopping_patience=3 # Stop if no improvement for 3 epochs
197
- )
198
-
199
- # Plot and save training history
200
- #plot_training_history(chatbot.history, save_dir=dirs['plots'])
201
-
202
- # Save final model
203
- chatbot.save_models(dirs['base'] / 'final_model')
204
-
205
- # Run automatic validation
206
- validation_metrics = chatbot.run_automatic_validation(quality_checker, num_examples=5)
207
-
208
- # Log validation metrics
209
- logger.info(f"Validation Metrics: {validation_metrics}")
210
-
211
- # Now continue with interactive chat
212
- logger.info("\nStarting interactive chat session...")
213
- conversation_history = []
214
-
215
- while True:
216
- user_input = input("You: ")
217
- if user_input.lower() in ['quit', 'exit', 'bye']:
218
- print("Assistant: Goodbye!")
219
- break
220
-
221
- response, candidates, metrics = chatbot.chat(
222
- query=user_input,
223
- conversation_history=None, # Pass conversation history if available
224
- quality_checker=quality_checker,
225
- top_k=5
226
- )
227
-
228
- print(f"Assistant: {response}")
229
-
230
- # Optionally, display alternative responses
231
- if metrics.get('is_confident', False):
232
- print("\nAlternative responses:")
233
- for resp, score in candidates[1:4]:
234
- print(f"Score: {score:.4f} - {resp}")
235
-
236
- if __name__ == "__main__":
237
- main()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
run_model_train.py ADDED
@@ -0,0 +1,96 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from chatbot_model import RetrievalChatbot, ChatbotConfig
2
+ from environment_setup import EnvironmentSetup
3
+ from response_quality_checker import ResponseQualityChecker
4
+ from chatbot_validator import ChatbotValidator
5
+ from training_plotter import TrainingPlotter
6
+
7
+
8
+ # Configure logging
9
+ from logger_config import config_logger
10
+ logger = config_logger(__name__)
11
+
12
+ def run_interactive_chat(chatbot, quality_checker):
13
+ """Separate function for interactive chat loop"""
14
+ while True:
15
+ user_input = input("You: ")
16
+ if user_input.lower() in ['quit', 'exit', 'bye']:
17
+ print("Assistant: Goodbye!")
18
+ break
19
+
20
+ response, candidates, metrics = chatbot.chat(
21
+ query=user_input,
22
+ conversation_history=None,
23
+ quality_checker=quality_checker,
24
+ top_k=5
25
+ )
26
+
27
+ print(f"Assistant: {response}")
28
+
29
+ if metrics.get('is_confident', False):
30
+ print("\nAlternative responses:")
31
+ for resp, score in candidates[1:4]:
32
+ print(f"Score: {score:.4f} - {resp}")
33
+
34
+ def main():
35
+ # Initialize environment
36
+ env = EnvironmentSetup()
37
+ env.initialize()
38
+
39
+ DEBUG_SAMPLES = 5
40
+ EPOCHS = 1 if DEBUG_SAMPLES else 20
41
+ TRAINING_DATA_PATH = 'processed_outputs/batch_group_0010.json'
42
+
43
+ # Optimize batch size for Colab
44
+ batch_size = env.optimize_batch_size(base_batch_size=16)
45
+
46
+ # Initialize configuration
47
+ config = ChatbotConfig(
48
+ embedding_dim=512, # 768, # Match DistilBERT's dimension
49
+ max_context_token_limit=512,
50
+ freeze_embeddings=False,
51
+ )
52
+
53
+ # Load training data
54
+ dialogues = RetrievalChatbot.load_training_data(data_path=TRAINING_DATA_PATH, debug_samples=DEBUG_SAMPLES)
55
+
56
+ # Initialize chatbot and verify FAISS index
57
+ with env.strategy.scope():
58
+ chatbot = RetrievalChatbot(config, dialogues)
59
+ chatbot.verify_faiss_index()
60
+
61
+ # Prepare dataset
62
+ logger.info("Preparing dataset...")
63
+ q_tensor, p_tensor = chatbot.prepare_dataset(dialogues)
64
+ quality_checker = ResponseQualityChecker(chatbot=chatbot)
65
+
66
+ # Train model
67
+ logger.info("Starting training...")
68
+ chatbot.train(
69
+ q_pad=q_tensor,
70
+ p_pad=p_tensor,
71
+ epochs=EPOCHS,
72
+ batch_size=batch_size,
73
+ validation_split=0.2,
74
+ )
75
+
76
+ # Save final model
77
+ model_save_path = env.training_dirs['base'] / 'final_model'
78
+ chatbot.save_models(model_save_path)
79
+
80
+ # Run automatic validation
81
+ quality_checker = ResponseQualityChecker(chatbot=chatbot)
82
+ validator = ChatbotValidator(chatbot, quality_checker)
83
+ validation_metrics = validator.run_validation(num_examples=5)
84
+ logger.info(f"Validation Metrics: {validation_metrics}")
85
+
86
+ # Plot and save training history
87
+ plotter = TrainingPlotter(save_dir=env.training_dirs['plots'])
88
+ plotter.plot_training_history(chatbot.history)
89
+ plotter.plot_validation_metrics(validation_metrics)
90
+
91
+ # Run interactive chat
92
+ logger.info("\nStarting interactive chat session...")
93
+ run_interactive_chat(chatbot, quality_checker)
94
+
95
+ if __name__ == "__main__":
96
+ main()
setup.py CHANGED
@@ -25,6 +25,27 @@ def setup_spacy_models(models=['en_core_web_sm', 'en_core_web_md']):
25
  print(f"Error downloading spaCy model: {model}")
26
  print(e)
27
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
28
  def setup_models():
29
  """
30
  Download other required models.
@@ -40,7 +61,7 @@ def setup_models():
40
  DistilBertModel
41
  )
42
 
43
- # Download DistilBERT for chatbot
44
  tokenizer = DistilBertTokenizer.from_pretrained('distilbert-base-uncased')
45
  model = DistilBertModel.from_pretrained('distilbert-base-uncased')
46
 
@@ -116,11 +137,11 @@ def setup_faiss():
116
  print(e)
117
 
118
  setup(
119
- name="text-data-augmenter",
120
- version="0.1.0",
121
  author="Joe Armani",
122
  author_email="[email protected]",
123
- description="A tool for generating high-quality dialogue variations",
124
  long_description=long_description,
125
  long_description_content_type="text/markdown",
126
  packages=find_packages(),
@@ -132,24 +153,39 @@ setup(
132
  "Programming Language :: Python :: 3",
133
  "Programming Language :: Python :: 3.8",
134
  "Programming Language :: Python :: 3.9",
 
135
  "Topic :: Scientific/Engineering :: Artificial Intelligence",
136
  "Topic :: Text Processing :: Linguistic",
137
  ],
138
  python_requires=">=3.8",
139
  install_requires=requirements,
 
 
 
 
 
 
 
 
 
 
 
140
  entry_points={
141
  "console_scripts": [
142
  "dialogue-augment=dialogue_augmenter.main:main",
 
143
  ],
144
  },
145
  include_package_data=True,
146
  package_data={
 
147
  "dialogue_augmenter": ["data/*.json", "config/*.yaml"],
148
  },
149
  )
150
 
151
  if __name__ == '__main__':
152
  setup_spacy_models()
 
153
  setup_models()
154
  setup_nltk()
155
  setup_faiss()
 
25
  print(f"Error downloading spaCy model: {model}")
26
  print(e)
27
 
28
+ def setup_gpu_dependencies():
29
+ """Setup GPU-specific dependencies."""
30
+ cuda_available = False
31
+
32
+ # Check CUDA availability
33
+ try:
34
+ import torch
35
+ cuda_available = torch.cuda.is_available()
36
+ except ImportError:
37
+ pass
38
+
39
+ if cuda_available:
40
+ try:
41
+ subprocess.check_call([sys.executable, "-m", "pip", "install", "faiss-gpu>=1.7.0"])
42
+ print("Successfully installed faiss-gpu")
43
+ except subprocess.CalledProcessError:
44
+ print("Failed to install faiss-gpu. Falling back to faiss-cpu")
45
+ subprocess.check_call([sys.executable, "-m", "pip", "install", "faiss-cpu>=1.7.0"])
46
+ else:
47
+ subprocess.check_call([sys.executable, "-m", "pip", "install", "faiss-cpu>=1.7.0"])
48
+
49
  def setup_models():
50
  """
51
  Download other required models.
 
61
  DistilBertModel
62
  )
63
 
64
+ # Cache the models
65
  tokenizer = DistilBertTokenizer.from_pretrained('distilbert-base-uncased')
66
  model = DistilBertModel.from_pretrained('distilbert-base-uncased')
67
 
 
137
  print(e)
138
 
139
  setup(
140
+ name="retrieval-chatbot",
141
+ version="0.2.0",
142
  author="Joe Armani",
143
  author_email="[email protected]",
144
+ description="A retrieval-based chatbot with enhanced validation",
145
  long_description=long_description,
146
  long_description_content_type="text/markdown",
147
  packages=find_packages(),
 
153
  "Programming Language :: Python :: 3",
154
  "Programming Language :: Python :: 3.8",
155
  "Programming Language :: Python :: 3.9",
156
+ "Programming Language :: Python :: 3.10",
157
  "Topic :: Scientific/Engineering :: Artificial Intelligence",
158
  "Topic :: Text Processing :: Linguistic",
159
  ],
160
  python_requires=">=3.8",
161
  install_requires=requirements,
162
+ extras_require={
163
+ 'dev': [
164
+ 'pytest>=7.0.0',
165
+ 'black>=22.0.0',
166
+ 'isort>=5.10.0',
167
+ 'mypy>=1.0.0',
168
+ ],
169
+ 'gpu': [
170
+ 'faiss-gpu>=1.7.0',
171
+ ],
172
+ },
173
  entry_points={
174
  "console_scripts": [
175
  "dialogue-augment=dialogue_augmenter.main:main",
176
+ "run-chatbot=chatbot.main:main",
177
  ],
178
  },
179
  include_package_data=True,
180
  package_data={
181
+ "chatbot": ["config/*.yaml"],
182
  "dialogue_augmenter": ["data/*.json", "config/*.yaml"],
183
  },
184
  )
185
 
186
  if __name__ == '__main__':
187
  setup_spacy_models()
188
+ setup_gpu_dependencies()
189
  setup_models()
190
  setup_nltk()
191
  setup_faiss()
training_plotter.py ADDED
@@ -0,0 +1,127 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pathlib import Path
2
+ from typing import Dict, List, Optional
3
+ import matplotlib.pyplot as plt
4
+ from datetime import datetime
5
+ import logging
6
+
7
+ logger = logging.getLogger(__name__)
8
+
9
+ class TrainingPlotter:
10
+ def __init__(self, save_dir: Optional[Path] = None):
11
+ self.save_dir = save_dir
12
+ if save_dir:
13
+ self.save_dir.mkdir(parents=True, exist_ok=True)
14
+
15
+ def plot_training_history(self, history: Dict[str, List[float]], title: str = "Training History"):
16
+ """Plot and optionally save training metrics history.
17
+
18
+ Args:
19
+ history: Dictionary containing training metrics
20
+ title: Title for the plot
21
+ """
22
+ # Silence matplotlib debug messages
23
+ logger.setLevel(logging.WARNING)
24
+
25
+ # Create figure with subplots
26
+ fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(10, 12))
27
+
28
+ # Plot losses
29
+ ax1.plot(history['train_loss'], label='Train Loss')
30
+ ax1.plot(history['val_loss'], label='Validation Loss')
31
+ ax1.set_xlabel('Epoch')
32
+ ax1.set_ylabel('Loss')
33
+ ax1.set_title('Training and Validation Loss')
34
+ ax1.legend()
35
+ ax1.grid(True)
36
+
37
+ # Plot learning rate if available
38
+ if 'learning_rate' in history:
39
+ ax2.plot(history['learning_rate'], label='Learning Rate')
40
+ ax2.set_xlabel('Step')
41
+ ax2.set_ylabel('Learning Rate')
42
+ ax2.set_title('Learning Rate Schedule')
43
+ ax2.legend()
44
+ ax2.grid(True)
45
+
46
+ plt.suptitle(title)
47
+ plt.tight_layout()
48
+
49
+ # Reset the logger level
50
+ logger.setLevel(logging.INFO)
51
+
52
+ # Save if directory provided
53
+ if self.save_dir:
54
+ timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
55
+ save_path = self.save_dir / f'training_history_{timestamp}.png'
56
+ plt.savefig(save_path)
57
+ logger.info(f"Saved training history plot to {save_path}")
58
+
59
+ plt.show()
60
+
61
+ def plot_validation_metrics(self, metrics: Dict[str, float]):
62
+ """Plot validation metrics as a bar chart.
63
+
64
+ Args:
65
+ metrics: Dictionary of validation metrics. Can handle nested dictionaries.
66
+ """
67
+ # Silence matplotlib debug messages
68
+ logger.setLevel(logging.WARNING)
69
+
70
+ # Flatten nested metrics dictionary
71
+ flat_metrics = {}
72
+ for key, value in metrics.items():
73
+ # Skip num_queries_tested
74
+ if key == 'num_queries_tested':
75
+ continue
76
+
77
+ if isinstance(value, dict):
78
+ # If value is a dictionary, flatten it with key prefix
79
+ for subkey, subvalue in value.items():
80
+ if isinstance(subvalue, (int, float)): # Only include numeric values
81
+ flat_metrics[f"{key}_{subkey}"] = subvalue
82
+ elif isinstance(value, (int, float)): # Only include numeric values
83
+ flat_metrics[key] = value
84
+
85
+ if not flat_metrics:
86
+ logger.warning("No numeric metrics to plot")
87
+ return
88
+
89
+ plt.figure(figsize=(12, 6))
90
+
91
+ # Extract metrics and values
92
+ metric_names = list(flat_metrics.keys())
93
+ values = list(flat_metrics.values())
94
+
95
+ # Create bar chart
96
+ bars = plt.bar(range(len(metric_names)), values)
97
+
98
+ # Customize the plot
99
+ plt.title('Validation Metrics')
100
+ plt.xticks(range(len(metric_names)), metric_names, rotation=45, ha='right')
101
+ plt.ylabel('Value')
102
+
103
+ # Add value labels on top of bars
104
+ for bar in bars:
105
+ height = bar.get_height()
106
+ plt.text(bar.get_x() + bar.get_width()/2., height,
107
+ f'{height:.3f}',
108
+ ha='center', va='bottom')
109
+
110
+ # Set y-axis limits to focus on metrics between 0 and 1
111
+ plt.ylim(0, 1.1) # Slight padding above 1 for label visibility
112
+
113
+ # Adjust layout to prevent label cutoff
114
+ plt.tight_layout()
115
+
116
+ # Reset the logger level
117
+ logger.setLevel(logging.INFO)
118
+
119
+ # Save if directory provided
120
+ if self.save_dir:
121
+ timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
122
+ save_path = self.save_dir / f'validation_metrics_{timestamp}.png'
123
+ plt.savefig(save_path)
124
+ logger.info(f"Saved validation metrics plot to {save_path}")
125
+
126
+ plt.show()
127
+