JoeArmani commited on
Commit
64e7c31
·
1 Parent(s): a763857

sentence transformer

Browse files
.gitignore CHANGED
@@ -183,4 +183,6 @@ training_data/*
183
  augmented_dialogues.json
184
 
185
  raw_datasets/*
 
 
186
 
 
183
  augmented_dialogues.json
184
 
185
  raw_datasets/*
186
+ st/*
187
+
188
 
chatbot_config.py ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+ from pathlib import Path
3
+ from typing import Dict
4
+
5
+ @dataclass
6
+ class ChatbotConfig:
7
+ """RetrievalChatbot Config"""
8
+ max_context_token_limit: int = 512
9
+ embedding_dim: int = 384 # Match Sentence Transformer dimension
10
+ learning_rate: float = 0.0005
11
+ min_text_length: int = 3
12
+ max_context_turns: int = 20
13
+ pretrained_model: str = 'sentence-transformers/all-MiniLM-L6-v2'
14
+ cross_encoder_model: str = 'cross-encoder/ms-marco-MiniLM-L-12-v2'
15
+ summarizer_model: str = 't5-small'
16
+ embedding_batch_size: int = 64
17
+ search_batch_size: int = 64
18
+ max_batch_size: int = 64
19
+ max_retries: int = 3
20
+
21
+ def to_dict(self) -> Dict:
22
+ """Convert config to dictionary."""
23
+ return {k: (str(v) if isinstance(v, Path) else v)
24
+ for k, v in self.__dict__.items()}
25
+
26
+ @classmethod
27
+ def from_dict(cls, config_dict: Dict) -> 'ChatbotConfig':
28
+ """Create config from dictionary."""
29
+ return cls(**{k: v for k, v in config_dict.items()
30
+ if k in cls.__dataclass_fields__})
chatbot_model.py CHANGED
@@ -1,6 +1,6 @@
1
  import os
2
  import numpy as np
3
- from transformers import TFAutoModel, AutoTokenizer
4
  import tensorflow as tf
5
  from typing import List, Tuple, Dict, Optional, Union, Any
6
  import math
@@ -11,125 +11,24 @@ import datetime
11
  import faiss
12
  import gc
13
  import re
14
- from tf_data_pipeline import TFDataPipeline
15
  from response_quality_checker import ResponseQualityChecker
16
  from cross_encoder_reranker import CrossEncoderReranker
17
  from conversation_summarizer import DeviceAwareModel, Summarizer
 
 
18
  import absl.logging
19
  from logger_config import config_logger
20
  from tqdm.auto import tqdm
21
 
22
  absl.logging.set_verbosity(absl.logging.WARNING)
23
  logger = config_logger(__name__)
24
-
25
- @dataclass
26
- class ChatbotConfig:
27
- """RetrievalChatbot Config"""
28
- max_context_token_limit: int = 512
29
- embedding_dim: int = 768
30
- encoder_units: int = 256
31
- num_attention_heads: int = 8
32
- dropout_rate: float = 0.2
33
- l2_reg_weight: float = 0.001
34
- learning_rate: float = 0.0005
35
- min_text_length: int = 3
36
- max_context_turns: int = 20
37
- warmup_steps: int = 200
38
- pretrained_model: str = 'distilbert-base-uncased'
39
- cross_encoder_model: str = 'cross-encoder/ms-marco-MiniLM-L-12-v2'
40
- summarizer_model: str = 't5-small'
41
- dtype: str = 'float32'
42
- freeze_embeddings: bool = False
43
- embedding_batch_size: int = 64
44
- search_batch_size: int = 64
45
- max_batch_size: int = 64
46
- max_retries: int = 3
47
-
48
- def to_dict(self) -> Dict:
49
- """Convert config to dictionary."""
50
- return {k: (str(v) if isinstance(v, Path) else v)
51
- for k, v in self.__dict__.items()}
52
-
53
- @classmethod
54
- def from_dict(cls, config_dict: Dict) -> 'ChatbotConfig':
55
- """Create config from dictionary."""
56
- return cls(**{k: v for k, v in config_dict.items()
57
- if k in cls.__dataclass_fields__})
58
-
59
- class EncoderModel(tf.keras.Model):
60
- """Dual encoder model with pretrained DistilBERT embeddings."""
61
- def __init__(
62
- self,
63
- config: ChatbotConfig,
64
- name: str = "encoder",
65
- **kwargs
66
- ):
67
- super().__init__(name=name, **kwargs)
68
- self.config = config
69
-
70
- # Load pretrained model and freeze layers based on config
71
- self.pretrained = TFAutoModel.from_pretrained(config.pretrained_model)
72
- self._freeze_layers()
73
-
74
- # Add Global Average Pooling, Projection, Dropout, and Normalization layers
75
- self.pooler = tf.keras.layers.GlobalAveragePooling1D()
76
- self.projection = tf.keras.layers.Dense(
77
- config.embedding_dim,
78
- activation='tanh',
79
- name="projection",
80
- dtype=tf.float32
81
- )
82
- self.dropout = tf.keras.layers.Dropout(config.dropout_rate)
83
- self.normalize = tf.keras.layers.Lambda(
84
- lambda x: tf.nn.l2_normalize(x, axis=1),
85
- name="l2_normalize"
86
- )
87
-
88
- def _freeze_layers(self):
89
- """Freeze n layers of the pretrained model"""
90
- if self.config.freeze_embeddings:
91
- self.pretrained.trainable = False
92
- logger.info("All pretrained layers frozen.")
93
- else:
94
- # Freeze only the first 'n' transformer layers
95
- for i, layer in enumerate(self.pretrained.layers):
96
- if isinstance(layer, tf.keras.layers.Layer):
97
- if hasattr(layer, 'trainable'):
98
- if i < 1:
99
- layer.trainable = False
100
- logger.info(f"Layer {i} frozen.")
101
- else:
102
- layer.trainable = True
103
- logger.info(f"Layer {i} trainable.")
104
-
105
- def call(self, inputs: tf.Tensor, training: bool = False) -> tf.Tensor:
106
- """Forward pass."""
107
- # Get pretrained embeddings
108
- pretrained_outputs = self.pretrained(inputs, training=training)
109
- x = pretrained_outputs.last_hidden_state # Shape: [batch_size, seq_len, embedding_dim]
110
-
111
- # Apply pooling, projection, dropout, and normalization
112
- x = self.pooler(x) # Shape: [batch_size, 768]
113
- x = self.projection(x) # Shape: [batch_size, 768]
114
- x = self.dropout(x, training=training)
115
- x = self.normalize(x) # Shape: [batch_size, 768]
116
-
117
- return x
118
-
119
- def get_config(self) -> dict:
120
- """Return the model config"""
121
- config = super().get_config()
122
- config.update({
123
- "config": self.config.to_dict(),
124
- "name": self.name
125
- })
126
- return config
127
 
128
  class RetrievalChatbot(DeviceAwareModel):
129
  """
130
  Retrieval-based learning chatbot model.
131
  Uses trained embeddings and FAISS for similarity search.
132
  """
 
133
  def __init__(
134
  self,
135
  config: ChatbotConfig,
@@ -139,6 +38,7 @@ class RetrievalChatbot(DeviceAwareModel):
139
  summarizer: Optional[Summarizer] = None,
140
  mode: str = 'training'
141
  ):
 
142
  super().__init__()
143
  self.config = config
144
  self.strategy = strategy
@@ -146,13 +46,14 @@ class RetrievalChatbot(DeviceAwareModel):
146
  self.mode = mode.lower()
147
 
148
  # Initialize reranker, summarizer, tokenizer, and encoder
149
- self.reranker = reranker or self._initialize_reranker()
150
- self.tokenizer = self._initialize_tokenizer()
151
  self.encoder = self._initialize_encoder()
 
 
152
  self.summarizer = summarizer or self._initialize_summarizer()
153
 
154
  # Initialize data pipeline
155
  logger.info("Initializing TFDataPipeline.")
 
156
  self.data_pipeline = TFDataPipeline(
157
  config=self.config,
158
  tokenizer=self.tokenizer,
@@ -177,7 +78,6 @@ class RetrievalChatbot(DeviceAwareModel):
177
  "train_metrics": {},
178
  "val_metrics": {}
179
  }
180
-
181
 
182
  def _setup_default_device(self) -> str:
183
  """Set up default device if none is provided."""
@@ -200,34 +100,11 @@ class RetrievalChatbot(DeviceAwareModel):
200
  device=self.device,
201
  max_summary_rounds=2
202
  )
203
-
204
- def _initialize_tokenizer(self) -> AutoTokenizer:
205
- """Initialize the tokenizer and add special tokens."""
206
- logger.info("Initializing tokenizer and adding special tokens...")
207
- tokenizer = AutoTokenizer.from_pretrained(self.config.pretrained_model)
208
- special_tokens = {
209
- "user": "<USER>",
210
- "assistant": "<ASSISTANT>",
211
- "context": "<CONTEXT>",
212
- "sep": "<SEP>"
213
- }
214
- tokenizer.add_special_tokens(
215
- {'additional_special_tokens': list(special_tokens.values())}
216
- )
217
- return tokenizer
218
 
219
- def _initialize_encoder(self) -> EncoderModel:
220
- """Initialize the EncoderModel and resize token embeddings."""
221
- logger.info("Initializing encoder model...")
222
- encoder = EncoderModel(
223
- self.config,
224
- name="shared_encoder",
225
- )
226
-
227
- new_vocab_size = len(self.tokenizer)
228
- encoder.pretrained.resize_token_embeddings(new_vocab_size)
229
- logger.info(f"Token embeddings resized to: {new_vocab_size}")
230
-
231
  return encoder
232
 
233
  def _load_faiss_index_and_responses(self) -> None:
@@ -254,43 +131,35 @@ class RetrievalChatbot(DeviceAwareModel):
254
  except Exception as e:
255
  logger.error(f"Failed to load FAISS index and response pool: {e}")
256
  raise
257
-
258
  @classmethod
259
  def load_model(cls, load_dir: Union[str, Path], mode: str = 'training') -> 'RetrievalChatbot':
260
- """
261
- Load saved models and configuration.
262
- """
263
  load_dir = Path(load_dir)
264
 
265
  # Load config
266
- with open(load_dir / "config.json", "r") as f:
267
- config = ChatbotConfig.from_dict(json.load(f))
 
 
 
 
 
268
 
269
  # Initialize chatbot
270
  chatbot = cls(config, mode=mode)
271
 
272
- # Load DistilBERT
273
- chatbot.encoder.pretrained = TFAutoModel.from_pretrained(load_dir / "shared_encoder", config=config)
274
-
275
- dummy_input = tf.zeros((1, config.max_context_token_limit), dtype=tf.int32)
276
- _ = chatbot.encoder(dummy_input, training=False)
277
-
278
- # Load tokenizer
279
- chatbot.tokenizer = AutoTokenizer.from_pretrained(load_dir / "tokenizer")
280
- logger.info(f"Models and tokenizer loaded from {load_dir}")
281
-
282
- # Load the custom weights
283
- custom_weights_path = load_dir / "encoder_custom_weights.weights.h5"
284
- if custom_weights_path.exists():
285
- chatbot.encoder.load_weights(str(custom_weights_path))
286
- logger.info("Loaded custom encoder weights for projection/dropout/etc.")
287
  else:
288
- logger.warning(f"No custom encoder weights found at {custom_weights_path}. The top-level projection layer won't have learned parameters.")
289
-
290
- # Handle 'inference' mode: load FAISS, etc.
291
- if mode == 'inference':
292
- cls._prepare_model_for_inference(chatbot, load_dir)
293
-
294
  return chatbot
295
 
296
  @classmethod
@@ -324,21 +193,19 @@ class RetrievalChatbot(DeviceAwareModel):
324
  except Exception as e:
325
  logger.error(f"Error loading inference components: {e}")
326
  raise
327
-
328
  def save_models(self, save_dir: Union[str, Path]):
329
- """Save model and config"""
330
  save_dir = Path(save_dir)
331
  save_dir.mkdir(parents=True, exist_ok=True)
332
 
333
  # Save config
334
  with open(save_dir / "config.json", "w") as f:
335
  json.dump(self.config.to_dict(), f, indent=2)
336
-
337
- # Save the HF DistilBERT submodule, custom top-level layers, and tokenizer
338
- self.encoder.pretrained.save_pretrained(save_dir / "shared_encoder")
339
- self.encoder.save_weights(save_dir / "encoder_custom_weights.weights.h5")
340
- self.tokenizer.save_pretrained(save_dir / "tokenizer")
341
- logger.info(f"Models and tokenizer saved to {save_dir}.")
342
 
343
  def retrieve_responses(
344
  self,
@@ -346,59 +213,73 @@ class RetrievalChatbot(DeviceAwareModel):
346
  top_k: int = 10,
347
  reranker: Optional[CrossEncoderReranker] = None,
348
  summarizer: Optional[Summarizer] = None,
349
- summarize_threshold: int = 512
 
350
  ) -> List[Tuple[str, float]]:
351
  """
352
  Retrieve top-k responses using FAISS and cross-encoder re-ranking.
 
353
  Args:
354
  query: The user's input text.
355
- top_k: Number of FAISS results to return
356
- reranker: CrossEncoderReranker for refined scoring
357
- summarizer: Summarizer for long queries
358
- summarize_threshold: Summarize if conversation tokens > threshold.
 
 
359
  Returns:
360
  List of (response_text, final_score).
361
  """
362
  def sigmoid(x: float) -> float:
363
  return 1 / (1 + np.exp(-x))
364
 
365
- # Query summarization
366
  if summarizer and len(query.split()) > summarize_threshold:
367
- logger.info(f"Query is long ({len(query.split())} words). Summarizing.")
368
  query = summarizer.summarize_text(query)
369
- logger.info(f"Summarized Query: {query}")
370
-
 
371
  detected_domain = self.detect_domain_from_query(query)
372
 
373
- # Retrieve initial candidates from FAISS
374
- initial_k = min(top_k * 10, len(self.data_pipeline.response_pool))
375
- faiss_candidates = self.faiss_search(query, domain=detected_domain, top_k=initial_k)
376
 
377
  if not faiss_candidates:
 
378
  return []
379
 
380
- texts = [item[0] for item in faiss_candidates]
 
 
 
381
 
382
- if not reranker:
383
  reranker = CrossEncoderReranker(model_name=self.config.cross_encoder_model)
 
 
384
 
385
- # Re-rank the texts (candidates) from FAISS search using the cross-encoder
386
- ce_logits = reranker.rerank(query, texts, max_length=256)
387
-
388
- # Combine scores from FAISS and cross-encoder
389
  final_candidates = []
390
- for (resp_text, faiss_score), logit in zip(faiss_candidates, ce_logits):
391
- ce_prob = sigmoid(logit) # now in range [0...1]
392
- faiss_norm = (faiss_score + 1)/2.0 # now in range [0...1]
393
- combined_score = 0.85 * ce_prob + 0.15 * faiss_norm
 
 
 
 
 
 
 
394
  length_adjusted_score = self.length_adjust_score(resp_text, combined_score)
395
 
396
  final_candidates.append((resp_text, length_adjusted_score))
397
 
398
- # Sort descending by combined score
399
  final_candidates.sort(key=lambda x: x[1], reverse=True)
400
-
401
- # Return top_k
402
  return final_candidates[:top_k]
403
 
404
  def extract_keywords(self, query: str) -> List[str]:
@@ -636,21 +517,45 @@ class RetrievalChatbot(DeviceAwareModel):
636
  conversation_history: Optional[List[Tuple[str, str]]]
637
  ) -> str:
638
  """
639
- Build conversation context string from conversation history.
 
640
  """
 
 
 
641
  if not conversation_history:
642
- return f"{self.tokenizer.additional_special_tokens[self.tokenizer.additional_special_tokens.index('<USER>')]} {query}"
643
-
644
  conversation_parts = []
645
  for user_txt, assistant_txt in conversation_history:
646
- conversation_parts.extend([
647
- f"{self.tokenizer.additional_special_tokens[self.tokenizer.additional_special_tokens.index('<USER>')]} {user_txt}",
648
- f"{self.tokenizer.additional_special_tokens[self.tokenizer.additional_special_tokens.index('<ASSISTANT>')]} {assistant_txt}"
649
- ])
650
-
651
- conversation_parts.append(f"{self.tokenizer.additional_special_tokens[self.tokenizer.additional_special_tokens.index('<USER>')]} {query}")
652
  return "\n".join(conversation_parts)
653
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
654
  def train_model(
655
  self,
656
  tfrecord_file_path: str,
 
1
  import os
2
  import numpy as np
3
+ from sentence_transformers import SentenceTransformer
4
  import tensorflow as tf
5
  from typing import List, Tuple, Dict, Optional, Union, Any
6
  import math
 
11
  import faiss
12
  import gc
13
  import re
 
14
  from response_quality_checker import ResponseQualityChecker
15
  from cross_encoder_reranker import CrossEncoderReranker
16
  from conversation_summarizer import DeviceAwareModel, Summarizer
17
+ from chatbot_config import ChatbotConfig
18
+ from tf_data_pipeline import TFDataPipeline
19
  import absl.logging
20
  from logger_config import config_logger
21
  from tqdm.auto import tqdm
22
 
23
  absl.logging.set_verbosity(absl.logging.WARNING)
24
  logger = config_logger(__name__)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
25
 
26
  class RetrievalChatbot(DeviceAwareModel):
27
  """
28
  Retrieval-based learning chatbot model.
29
  Uses trained embeddings and FAISS for similarity search.
30
  """
31
+
32
  def __init__(
33
  self,
34
  config: ChatbotConfig,
 
38
  summarizer: Optional[Summarizer] = None,
39
  mode: str = 'training'
40
  ):
41
+
42
  super().__init__()
43
  self.config = config
44
  self.strategy = strategy
 
46
  self.mode = mode.lower()
47
 
48
  # Initialize reranker, summarizer, tokenizer, and encoder
 
 
49
  self.encoder = self._initialize_encoder()
50
+ self.tokenizer = self.encoder.tokenizer
51
+ self.reranker = reranker or self._initialize_reranker()
52
  self.summarizer = summarizer or self._initialize_summarizer()
53
 
54
  # Initialize data pipeline
55
  logger.info("Initializing TFDataPipeline.")
56
+
57
  self.data_pipeline = TFDataPipeline(
58
  config=self.config,
59
  tokenizer=self.tokenizer,
 
78
  "train_metrics": {},
79
  "val_metrics": {}
80
  }
 
81
 
82
  def _setup_default_device(self) -> str:
83
  """Set up default device if none is provided."""
 
100
  device=self.device,
101
  max_summary_rounds=2
102
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
103
 
104
+ def _initialize_encoder(self) -> SentenceTransformer:
105
+ """Initialize the Sentence Transformer model."""
106
+ logger.info("Initializing SentenceTransformer encoder model...")
107
+ encoder = SentenceTransformer(self.config.pretrained_model)
 
 
 
 
 
 
 
 
108
  return encoder
109
 
110
  def _load_faiss_index_and_responses(self) -> None:
 
131
  except Exception as e:
132
  logger.error(f"Failed to load FAISS index and response pool: {e}")
133
  raise
134
+
135
  @classmethod
136
  def load_model(cls, load_dir: Union[str, Path], mode: str = 'training') -> 'RetrievalChatbot':
137
+ """Load chatbot model and configuration."""
 
 
138
  load_dir = Path(load_dir)
139
 
140
  # Load config
141
+ config_path = load_dir / "config.json"
142
+ if config_path.exists():
143
+ with open(config_path, "r") as f:
144
+ config = ChatbotConfig.from_dict(json.load(f))
145
+ logger.info("Loaded ChatbotConfig from config.json.")
146
+ else:
147
+ raise FileNotFoundError(f"Config file not found at {config_path}. Please ensure it exists.")
148
 
149
  # Initialize chatbot
150
  chatbot = cls(config, mode=mode)
151
 
152
+ # Load Sentence Transformer
153
+ model_path = load_dir / "sentence_transformer"
154
+ if model_path.exists():
155
+ # Load locally saved model
156
+ chatbot.encoder = SentenceTransformer(str(model_path))
157
+ logger.info("Loaded SentenceTransformer model from local path successfully.")
 
 
 
 
 
 
 
 
 
158
  else:
159
+ # Load from pre-trained model hub
160
+ chatbot.encoder = SentenceTransformer(config.pretrained_model)
161
+ logger.info(f"Loaded SentenceTransformer model '{config.pretrained_model}' from the hub successfully.")
162
+
 
 
163
  return chatbot
164
 
165
  @classmethod
 
193
  except Exception as e:
194
  logger.error(f"Error loading inference components: {e}")
195
  raise
196
+
197
  def save_models(self, save_dir: Union[str, Path]):
198
+ """Save SentenceTransformer model and config."""
199
  save_dir = Path(save_dir)
200
  save_dir.mkdir(parents=True, exist_ok=True)
201
 
202
  # Save config
203
  with open(save_dir / "config.json", "w") as f:
204
  json.dump(self.config.to_dict(), f, indent=2)
205
+
206
+ # Save Sentence Transformer
207
+ self.encoder.save(save_dir / "sentence_transformer")
208
+ logger.info(f"Model and config saved to {save_dir}.")
 
 
209
 
210
  def retrieve_responses(
211
  self,
 
213
  top_k: int = 10,
214
  reranker: Optional[CrossEncoderReranker] = None,
215
  summarizer: Optional[Summarizer] = None,
216
+ summarize_threshold: int = 512,
217
+ boost_factor: float = 1.15
218
  ) -> List[Tuple[str, float]]:
219
  """
220
  Retrieve top-k responses using FAISS and cross-encoder re-ranking.
221
+
222
  Args:
223
  query: The user's input text.
224
+ top_k: Number of responses to return.
225
+ reranker: Optional reranker for refined scoring.
226
+ summarizer: Optional summarizer for long queries.
227
+ summarize_threshold: Threshold to summarize long queries.
228
+ boost_factor: Factor to boost scores for keyword matches.
229
+
230
  Returns:
231
  List of (response_text, final_score).
232
  """
233
  def sigmoid(x: float) -> float:
234
  return 1 / (1 + np.exp(-x))
235
 
236
+ # Summarize long queries
237
  if summarizer and len(query.split()) > summarize_threshold:
238
+ logger.info(f"Query is long ({len(query.split())} words). Summarizing...")
239
  query = summarizer.summarize_text(query)
240
+ logger.info(f"Summarized query: {query}")
241
+
242
+ # Detect domain for query
243
  detected_domain = self.detect_domain_from_query(query)
244
 
245
+ # Step 1: Retrieve candidates from FAISS
246
+ logger.info("Retrieving initial candidates from FAISS...")
247
+ faiss_candidates = self.data_pipeline.retrieve_responses(query, top_k=top_k * 10)
248
 
249
  if not faiss_candidates:
250
+ logger.warning("No candidates retrieved from FAISS.")
251
  return []
252
 
253
+ # Step 2: Re-rank candidates using Cross-Encoder
254
+ logger.info("Re-ranking candidates using Cross-Encoder...")
255
+ texts = [item[0] for item in faiss_candidates] # Extract response texts
256
+ faiss_scores = [item[1] for item in faiss_candidates]
257
 
258
+ if reranker is None:
259
  reranker = CrossEncoderReranker(model_name=self.config.cross_encoder_model)
260
+
261
+ ce_logits = reranker.rerank(query, texts, max_length=256) # Re-rank responses
262
 
263
+ # Combine FAISS and Cross-Encoder scores
 
 
 
264
  final_candidates = []
265
+ for resp_text, faiss_score, logit in zip(texts, faiss_scores, ce_logits):
266
+ ce_prob = sigmoid(logit) # Cross-encoder score in range [0, 1]
267
+ faiss_norm = (faiss_score + 1) / 2 # Normalize FAISS score to range [0, 1]
268
+ combined_score = 0.75 * ce_prob + 0.25 * faiss_norm
269
+
270
+ # Boost score based on keyword match
271
+ query_keywords = self.extract_keywords(query)
272
+ if query_keywords and any(kw in resp_text.lower() for kw in query_keywords):
273
+ combined_score *= boost_factor
274
+
275
+ # Adjust score based on length
276
  length_adjusted_score = self.length_adjust_score(resp_text, combined_score)
277
 
278
  final_candidates.append((resp_text, length_adjusted_score))
279
 
280
+ # Step 3: Sort and return top-k results
281
  final_candidates.sort(key=lambda x: x[1], reverse=True)
282
+ logger.info(f"Returning top-{top_k} re-ranked responses.")
 
283
  return final_candidates[:top_k]
284
 
285
  def extract_keywords(self, query: str) -> List[str]:
 
517
  conversation_history: Optional[List[Tuple[str, str]]]
518
  ) -> str:
519
  """
520
+ Build conversation context string from conversation history,
521
+ using literal <USER> and <ASSISTANT> tokens (no tokenizer special index).
522
  """
523
+ USER_TOKEN = "<USER>"
524
+ ASSISTANT_TOKEN = "<ASSISTANT>"
525
+
526
  if not conversation_history:
527
+ return f"{USER_TOKEN} {query}"
528
+
529
  conversation_parts = []
530
  for user_txt, assistant_txt in conversation_history:
531
+ # Insert literal tokens
532
+ conversation_parts.append(f"{USER_TOKEN} {user_txt}")
533
+ conversation_parts.append(f"{ASSISTANT_TOKEN} {assistant_txt}")
534
+
535
+ conversation_parts.append(f"{USER_TOKEN} {query}")
 
536
  return "\n".join(conversation_parts)
537
 
538
+ # def _build_conversation_context(
539
+ # self,
540
+ # query: str,
541
+ # conversation_history: Optional[List[Tuple[str, str]]]
542
+ # ) -> str:
543
+ # """
544
+ # Build conversation context string from conversation history.
545
+ # """
546
+ # if not conversation_history:
547
+ # return f"{self.tokenizer.additional_special_tokens[self.tokenizer.additional_special_tokens.index('<USER>')]} {query}"
548
+
549
+ # conversation_parts = []
550
+ # for user_txt, assistant_txt in conversation_history:
551
+ # conversation_parts.extend([
552
+ # f"{self.tokenizer.additional_special_tokens[self.tokenizer.additional_special_tokens.index('<USER>')]} {user_txt}",
553
+ # f"{self.tokenizer.additional_special_tokens[self.tokenizer.additional_special_tokens.index('<ASSISTANT>')]} {assistant_txt}"
554
+ # ])
555
+
556
+ # conversation_parts.append(f"{self.tokenizer.additional_special_tokens[self.tokenizer.additional_special_tokens.index('<USER>')]} {query}")
557
+ # return "\n".join(conversation_parts)
558
+
559
  def train_model(
560
  self,
561
  tfrecord_file_path: str,
chatbot_validator.py CHANGED
@@ -13,7 +13,7 @@ class ChatbotValidator:
13
  This testing module executes domain-specific queries, obtains chatbot responses, and evaluates them with a quality checker.
14
  """
15
 
16
- def __init__(self, chatbot, quality_checker):
17
  """
18
  Initialize the validator.
19
  Args:
@@ -22,6 +22,7 @@ class ChatbotValidator:
22
  """
23
  self.chatbot = chatbot
24
  self.quality_checker = quality_checker
 
25
 
26
  # Domain-specific test queries (aligns with Taskmaster-1 dataset)
27
  self.domain_queries = {
@@ -85,9 +86,6 @@ class ChatbotValidator:
85
  metrics_history = []
86
  domain_metrics = {}
87
 
88
- # Init the cross-encoder reranker to pass to the chatbot
89
- reranker = CrossEncoderReranker(model_name=self.chatbot.config.cross_encoder_model)
90
-
91
  # Prepare random selection if needed
92
  rng = random.Random(seed)
93
 
@@ -113,7 +111,7 @@ class ChatbotValidator:
113
  logger.info(f"TEST CASE {i}: QUERY: {query}")
114
 
115
  # Retrieve top_k responses, then evaluate with quality checker
116
- responses = self.chatbot.retrieve_responses(query, top_k=top_k, reranker=reranker)
117
  quality_metrics = self.quality_checker.check_response_quality(query, responses)
118
 
119
  # Aggregate metrics and log
 
13
  This testing module executes domain-specific queries, obtains chatbot responses, and evaluates them with a quality checker.
14
  """
15
 
16
+ def __init__(self, chatbot, quality_checker, cross_encoder_model='cross-encoder/ms-marco-MiniLM-L-12-v2'):
17
  """
18
  Initialize the validator.
19
  Args:
 
22
  """
23
  self.chatbot = chatbot
24
  self.quality_checker = quality_checker
25
+ self.reranker = CrossEncoderReranker(model_name=cross_encoder_model)
26
 
27
  # Domain-specific test queries (aligns with Taskmaster-1 dataset)
28
  self.domain_queries = {
 
86
  metrics_history = []
87
  domain_metrics = {}
88
 
 
 
 
89
  # Prepare random selection if needed
90
  rng = random.Random(seed)
91
 
 
111
  logger.info(f"TEST CASE {i}: QUERY: {query}")
112
 
113
  # Retrieve top_k responses, then evaluate with quality checker
114
+ responses = self.chatbot.retrieve_responses(query, top_k=top_k, reranker=self.reranker)
115
  quality_metrics = self.quality_checker.check_response_quality(query, responses)
116
 
117
  # Aggregate metrics and log
prepare_data.py CHANGED
@@ -1,14 +1,12 @@
1
  import os
2
- import sys
3
- import faiss
4
  import json
5
  import pickle
6
- import tensorflow as tf
7
- from transformers import AutoTokenizer, TFAutoModel
8
  from tqdm.auto import tqdm
9
  from pathlib import Path
10
- from chatbot_model import ChatbotConfig, EncoderModel
11
  from tf_data_pipeline import TFDataPipeline
 
12
  from logger_config import config_logger
13
 
14
  logger = config_logger(__name__)
@@ -23,15 +21,10 @@ def main():
23
  FAISS_INDICES_DIR = os.path.join(MODELS_DIR, 'faiss_indices')
24
  TF_RECORD_DIR = 'training_data'
25
  FAISS_INDEX_PRODUCTION_PATH = os.path.join(FAISS_INDICES_DIR, 'faiss_index_production.index')
26
- JSON_TRAINING_DATA_PATH = os.path.join(PROCESSED_DATA_DIR, 'taskmaster_dialogues.json')
27
  CACHE_FILE = os.path.join(CACHE_DIR, 'query_embeddings_cache.pkl')
28
  TF_RECORD_PATH = os.path.join(TF_RECORD_DIR, 'training_data_3.tfrecord')
29
 
30
- # Decide whether to load the **custom** model or base DistilBERT (Base used for first iteration).
31
- # True for custom, False for base DistilBERT.
32
- LOAD_CUSTOM_MODEL = True
33
- NUM_NEG_SAMPLES = 10
34
-
35
  # Ensure output directories exist
36
  os.makedirs(MODELS_DIR, exist_ok=True)
37
  os.makedirs(PROCESSED_DATA_DIR, exist_ok=True)
@@ -40,7 +33,7 @@ def main():
40
  os.makedirs(FAISS_INDICES_DIR, exist_ok=True)
41
  os.makedirs(TF_RECORD_DIR, exist_ok=True)
42
 
43
- # Init config
44
  config_json = Path(MODELS_DIR) / "config.json"
45
  if config_json.exists():
46
  with open(config_json, "r", encoding="utf-8") as f:
@@ -50,187 +43,77 @@ def main():
50
  else:
51
  config = ChatbotConfig()
52
  logger.warning("No config.json found. Using default ChatbotConfig.")
 
 
 
 
 
 
 
53
 
54
- # Ensure negative samples are set
55
- config.neg_samples = NUM_NEG_SAMPLES
56
-
57
- # Load or init tokenizer
58
- try:
59
- if Path(TOKENIZER_DIR).exists() and list(Path(TOKENIZER_DIR).iterdir()):
60
- logger.info(f"Loading tokenizer from {TOKENIZER_DIR}")
61
- tokenizer = AutoTokenizer.from_pretrained(TOKENIZER_DIR)
62
- else:
63
- logger.info(f"Loading base tokenizer for {config.pretrained_model}")
64
- tokenizer = AutoTokenizer.from_pretrained(config.pretrained_model)
65
-
66
- Path(TOKENIZER_DIR).mkdir(parents=True, exist_ok=True)
67
- tokenizer.save_pretrained(TOKENIZER_DIR)
68
- logger.info(f"New tokenizer saved to {TOKENIZER_DIR}")
69
- except Exception as e:
70
- logger.error(f"Failed to load or create tokenizer: {e}")
71
- sys.exit(1)
72
-
73
- # Init the encoder
74
- try:
75
- encoder = EncoderModel(config=config)
76
- logger.info("EncoderModel initialized successfully.")
77
-
78
- if LOAD_CUSTOM_MODEL:
79
- # Load the DistilBERT submodule from 'shared_encoder'
80
- shared_encoder_path = Path(MODELS_DIR) / "shared_encoder"
81
- if shared_encoder_path.exists():
82
- logger.info(f"Loading DistilBERT submodule from {shared_encoder_path}")
83
- encoder.pretrained = TFAutoModel.from_pretrained(shared_encoder_path)
84
- else:
85
- logger.warning(f"No shared_encoder found at {shared_encoder_path}, using base DistilBERT instead.")
86
-
87
- # Load custom .weights.h5 (projection, dropout, etc.)
88
- custom_weights_path = Path(MODELS_DIR) / "encoder_custom_weights.weights.h5"
89
- if custom_weights_path.exists():
90
- logger.info(f"Loading custom top-level weights from {custom_weights_path}")
91
-
92
- # Dummy forward pass forces model build to ensure all layers are built
93
- dummy_input = tf.zeros((1, config.max_context_token_limit), dtype=tf.int32)
94
- _ = encoder(dummy_input, training=False)
95
-
96
- encoder.load_weights(str(custom_weights_path))
97
- logger.info("Custom encoder weights loaded successfully.")
98
- else:
99
- logger.warning(f"Custom weights file not found at {custom_weights_path}. Using only submodule weights.")
100
- else:
101
- # Base DistilBERT with special tokens
102
- logger.info("Using the base DistilBERT without loading custom weights.")
103
-
104
- # Resize token embeddings in case we added special tokens (EncoderModel class)
105
- encoder.pretrained.resize_token_embeddings(len(tokenizer))
106
- logger.info(f"Token embeddings resized to: {len(tokenizer)}")
107
-
108
- except Exception as e:
109
- logger.error(f"Failed to initialize EncoderModel: {e}")
110
- sys.exit(1)
111
 
112
- # Load JSON dialogues
113
- try:
114
- if not Path(JSON_TRAINING_DATA_PATH).exists():
115
- logger.warning(f"No dialogues found at {JSON_TRAINING_DATA_PATH}, skipping.")
116
- dialogues = []
117
- else:
118
- dialogues = TFDataPipeline.load_json_training_data(JSON_TRAINING_DATA_PATH, debug_samples=None)
119
- logger.info(f"Loaded {len(dialogues)} dialogues from {JSON_TRAINING_DATA_PATH}.")
120
- except Exception as e:
121
- logger.error(f"Failed to load dialogues: {e}")
122
- sys.exit(1)
123
 
124
- # Load or init query_embeddings_cache. NOTE: recompute after each training. This was a bug source.
125
  query_embeddings_cache = {}
126
  if os.path.exists(CACHE_FILE):
127
- try:
128
- with open(CACHE_FILE, 'rb') as f:
129
- query_embeddings_cache = pickle.load(f)
130
- logger.info(f"Loaded {len(query_embeddings_cache)} query embeddings from {CACHE_FILE}.")
131
- except Exception as e:
132
- logger.warning(f"Failed to load query embeddings cache: {e}")
133
  else:
134
  logger.info("No existing query embeddings cache found. Starting fresh.")
135
 
136
- # Initialize TFDataPipeline
137
- try:
138
- # Load or init FAISS index
139
- if Path(FAISS_INDEX_PRODUCTION_PATH).exists():
140
- logger.info(f"Loading existing FAISS index from {FAISS_INDEX_PRODUCTION_PATH}...")
141
- faiss_index = faiss.read_index(FAISS_INDEX_PRODUCTION_PATH)
142
- logger.info("FAISS index loaded successfully.")
143
- else:
144
- logger.info("No existing FAISS index found. Initializing a new index.")
145
- dimension = config.embedding_dim # Ensure this matches your encoder's output
146
- faiss_index = faiss.IndexFlatIP(dimension) # Using Inner Product for cosine similarity
147
- logger.info(f"Initialized new FAISS index with dimension {dimension}.")
148
-
149
- # Init TFDataPipeline with the FAISS index
150
- data_pipeline = TFDataPipeline(
151
- config=config,
152
- tokenizer=tokenizer,
153
- encoder=encoder,
154
- index_file_path=FAISS_INDEX_PRODUCTION_PATH,
155
- response_pool=[],
156
- max_length=config.max_context_token_limit,
157
- neg_samples=config.neg_samples,
158
- query_embeddings_cache=query_embeddings_cache,
159
- index_type='IndexFlatIP',
160
- nlist=100, # Not used for IndexFlatIP. Retained for future use of IndexIVFFlat
161
- max_retries=config.max_retries
162
- )
163
- logger.info("TFDataPipeline initialized successfully.")
164
- except Exception as e:
165
- logger.error(f"Failed to initialize TFDataPipeline: {e}")
166
- sys.exit(1)
167
-
168
- # Collect response pool from dialogues
169
- try:
170
- if dialogues:
171
- response_pool = data_pipeline.collect_responses_with_domain(dialogues)
172
- data_pipeline.response_pool = response_pool
173
- logger.info(f"Collected {len(response_pool)} unique assistant responses from dialogues.")
174
- else:
175
- logger.warning("No dialogues loaded. response_pool remains empty.")
176
- except Exception as e:
177
- logger.error(f"Failed to collect responses: {e}")
178
- sys.exit(1)
179
-
180
- # Build FAISS index with response embeddings
181
- try:
182
- if data_pipeline.response_pool:
183
- data_pipeline.build_text_to_domain_map()
184
- logger.info("Computing and adding response embeddings to FAISS index using TFDataPipeline...")
185
- data_pipeline.compute_and_index_response_embeddings()
186
- logger.info("Response embeddings computed and added to FAISS index.")
187
-
188
- # Save the FAISS index
189
- data_pipeline.save_faiss_index(FAISS_INDEX_PRODUCTION_PATH)
190
-
191
- # Also save response pool JSON
192
- response_pool_path = FAISS_INDEX_PRODUCTION_PATH.replace('.index', '_responses.json')
193
- with open(response_pool_path, 'w', encoding='utf-8') as f:
194
- json.dump(data_pipeline.response_pool, f, indent=2)
195
- logger.info(f"Response pool saved to {response_pool_path}.")
196
- else:
197
- logger.warning("No responses to embed. Skipping FAISS indexing.")
198
-
199
- except Exception as e:
200
- logger.error(f"Failed to compute or add response embeddings: {e}")
201
- sys.exit(1)
202
-
203
- # Prepare training data as TFRecords (TensforFlow Record format)
204
- try:
205
- if dialogues:
206
- logger.info("Starting data preparation and saving as TFRecord...")
207
- data_pipeline.prepare_and_save_data(dialogues, TF_RECORD_PATH)
208
- logger.info(f"Data saved as TFRecord at {TF_RECORD_PATH}.")
209
- else:
210
- logger.warning("No dialogues to build TFRecord from. Skipping TFRecord creation.")
211
- except Exception as e:
212
- logger.error(f"Failed during data preparation and saving: {e}")
213
- sys.exit(1)
214
 
215
  # Save query embeddings cache
216
- try:
217
- with open(CACHE_FILE, 'wb') as f:
218
- pickle.dump(data_pipeline.query_embeddings_cache, f)
219
- logger.info(f"Saved {len(data_pipeline.query_embeddings_cache)} query embeddings to {CACHE_FILE}.")
220
- except Exception as e:
221
- logger.error(f"Failed to save query embeddings cache: {e}")
222
- sys.exit(1)
223
-
224
- # Save Tokenizer
225
- try:
226
- tokenizer.save_pretrained(TOKENIZER_DIR)
227
- logger.info(f"Tokenizer saved to {TOKENIZER_DIR}.")
228
- except Exception as e:
229
- logger.error(f"Failed to save tokenizer: {e}")
230
- sys.exit(1)
231
-
232
- logger.info("Data preparation pipeline completed successfully.")
233
 
 
234
 
235
  if __name__ == "__main__":
236
  main()
 
1
  import os
 
 
2
  import json
3
  import pickle
4
+ import faiss
 
5
  from tqdm.auto import tqdm
6
  from pathlib import Path
7
+ from sentence_transformers import SentenceTransformer
8
  from tf_data_pipeline import TFDataPipeline
9
+ from chatbot_config import ChatbotConfig
10
  from logger_config import config_logger
11
 
12
  logger = config_logger(__name__)
 
21
  FAISS_INDICES_DIR = os.path.join(MODELS_DIR, 'faiss_indices')
22
  TF_RECORD_DIR = 'training_data'
23
  FAISS_INDEX_PRODUCTION_PATH = os.path.join(FAISS_INDICES_DIR, 'faiss_index_production.index')
24
+ JSON_TRAINING_DATA_PATH = os.path.join(PROCESSED_DATA_DIR, 'taskmaster_only.json')
25
  CACHE_FILE = os.path.join(CACHE_DIR, 'query_embeddings_cache.pkl')
26
  TF_RECORD_PATH = os.path.join(TF_RECORD_DIR, 'training_data_3.tfrecord')
27
 
 
 
 
 
 
28
  # Ensure output directories exist
29
  os.makedirs(MODELS_DIR, exist_ok=True)
30
  os.makedirs(PROCESSED_DATA_DIR, exist_ok=True)
 
33
  os.makedirs(FAISS_INDICES_DIR, exist_ok=True)
34
  os.makedirs(TF_RECORD_DIR, exist_ok=True)
35
 
36
+ # Load ChatbotConfig
37
  config_json = Path(MODELS_DIR) / "config.json"
38
  if config_json.exists():
39
  with open(config_json, "r", encoding="utf-8") as f:
 
43
  else:
44
  config = ChatbotConfig()
45
  logger.warning("No config.json found. Using default ChatbotConfig.")
46
+ try:
47
+ with open(config_json, "w", encoding="utf-8") as f:
48
+ json.dump(config.to_dict(), f, indent=2)
49
+ logger.info(f"Default ChatbotConfig saved to {config_json}")
50
+ except Exception as e:
51
+ logger.error(f"Failed to save default ChatbotConfig: {e}")
52
+ raise
53
 
54
+ # Init SentenceTransformer
55
+ encoder = SentenceTransformer(config.pretrained_model)
56
+ logger.info(f"Initialized SentenceTransformer model: {config.pretrained_model}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
57
 
58
+ # Load dialogues
59
+ if Path(JSON_TRAINING_DATA_PATH).exists():
60
+ dialogues = TFDataPipeline.load_json_training_data(JSON_TRAINING_DATA_PATH)
61
+ logger.info(f"Loaded {len(dialogues)} dialogues.")
62
+ else:
63
+ logger.warning(f"No dialogues found at {JSON_TRAINING_DATA_PATH}.")
64
+ dialogues = []
 
 
 
 
65
 
66
+ # Load or init query embeddings cache
67
  query_embeddings_cache = {}
68
  if os.path.exists(CACHE_FILE):
69
+ with open(CACHE_FILE, 'rb') as f:
70
+ query_embeddings_cache = pickle.load(f)
71
+ logger.info(f"Loaded query embeddings cache with {len(query_embeddings_cache)} entries.")
 
 
 
72
  else:
73
  logger.info("No existing query embeddings cache found. Starting fresh.")
74
 
75
+ # Init FAISS index
76
+ dimension = encoder.get_sentence_embedding_dimension()
77
+ if Path(FAISS_INDEX_PRODUCTION_PATH).exists():
78
+ faiss_index = faiss.read_index(FAISS_INDEX_PRODUCTION_PATH)
79
+ logger.info(f"Loaded FAISS index from {FAISS_INDEX_PRODUCTION_PATH}.")
80
+ else:
81
+ faiss_index = faiss.IndexFlatIP(dimension)
82
+ logger.info(f"Initialized new FAISS index with dimension {dimension}.")
83
+
84
+ # Init TFDataPipeline
85
+ data_pipeline = TFDataPipeline(
86
+ config=config,
87
+ tokenizer=encoder.tokenizer,
88
+ encoder=encoder,
89
+ response_pool=[],
90
+ query_embeddings_cache=query_embeddings_cache,
91
+ index_type='IndexFlatIP',
92
+ faiss_index_file_path=FAISS_INDEX_PRODUCTION_PATH
93
+ )
94
+
95
+ # Collect and embed responses
96
+ if dialogues:
97
+ response_pool = data_pipeline.collect_responses_with_domain(dialogues)
98
+ data_pipeline.response_pool = response_pool
99
+
100
+ # Save the response pool
101
+ response_pool_path = FAISS_INDEX_PRODUCTION_PATH.replace('.index', '_responses.json')
102
+ with open(response_pool_path, 'w', encoding='utf-8') as f:
103
+ json.dump(response_pool, f, indent=2)
104
+ logger.info(f"Response pool saved to {response_pool_path}.")
105
+ data_pipeline.compute_and_index_response_embeddings()
106
+ data_pipeline.save_faiss_index(FAISS_INDEX_PRODUCTION_PATH)
107
+ logger.info(f"FAISS index saved at {FAISS_INDEX_PRODUCTION_PATH}.")
108
+ else:
109
+ logger.warning("No responses to embed. Skipping FAISS indexing.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
110
 
111
  # Save query embeddings cache
112
+ with open(CACHE_FILE, 'wb') as f:
113
+ pickle.dump(query_embeddings_cache, f)
114
+ logger.info(f"Query embeddings cache saved at {CACHE_FILE}.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
115
 
116
+ logger.info("Pipeline completed successfully.")
117
 
118
  if __name__ == "__main__":
119
  main()
run_chatbot_chat.py CHANGED
@@ -1,6 +1,7 @@
1
  import os
2
  import json
3
- from chatbot_model import ChatbotConfig, RetrievalChatbot
 
4
  from response_quality_checker import ResponseQualityChecker
5
  from environment_setup import EnvironmentSetup
6
  from logger_config import config_logger
 
1
  import os
2
  import json
3
+ from chatbot_model import RetrievalChatbot
4
+ from chatbot_config import ChatbotConfig
5
  from response_quality_checker import ResponseQualityChecker
6
  from environment_setup import EnvironmentSetup
7
  from logger_config import config_logger
run_chatbot_validation.py CHANGED
@@ -1,24 +1,27 @@
1
  import os
2
  import json
3
- from chatbot_model import ChatbotConfig, RetrievalChatbot
 
 
4
  from response_quality_checker import ResponseQualityChecker
5
  from chatbot_validator import ChatbotValidator
6
  from plotter import Plotter
7
  from environment_setup import EnvironmentSetup
8
  from logger_config import config_logger
 
9
 
10
  logger = config_logger(__name__)
11
-
12
  def run_chatbot_validation():
13
  # Initialize environment
14
  env = EnvironmentSetup()
15
  env.initialize()
16
-
17
  MODEL_DIR = "models"
18
  FAISS_INDICES_DIR = os.path.join(MODEL_DIR, "faiss_indices")
19
  FAISS_INDEX_PRODUCTION_PATH = os.path.join(FAISS_INDICES_DIR, "faiss_index_production.index")
20
  FAISS_INDEX_TEST_PATH = os.path.join(FAISS_INDICES_DIR, "faiss_index_test.index")
21
-
22
  # Toggle 'production' or 'test' env
23
  ENVIRONMENT = "production"
24
  if ENVIRONMENT == "test":
@@ -27,7 +30,7 @@ def run_chatbot_validation():
27
  else:
28
  FAISS_INDEX_PATH = FAISS_INDEX_PRODUCTION_PATH
29
  RESPONSE_POOL_PATH = FAISS_INDEX_PRODUCTION_PATH.replace(".index", "_responses.json")
30
-
31
  # Load the config
32
  config_path = os.path.join(MODEL_DIR, "config.json")
33
  if os.path.exists(config_path):
@@ -38,55 +41,62 @@ def run_chatbot_validation():
38
  else:
39
  config = ChatbotConfig()
40
  logger.warning("No config.json found. Using default ChatbotConfig.")
41
-
42
- # Load RetrievalChatbot in 'inference' mode
43
  try:
44
- chatbot = RetrievalChatbot.load_model(load_dir=MODEL_DIR, mode="inference")
45
- logger.info("RetrievalChatbot loaded in 'inference' mode successfully.")
 
46
  except Exception as e:
47
- logger.error(f"Failed to load RetrievalChatbot: {e}")
48
- return
49
-
50
- # Confirm FAISS index & response pool exist
51
- if not os.path.exists(FAISS_INDEX_PATH) or not os.path.exists(RESPONSE_POOL_PATH):
52
- logger.error("FAISS index or response pool file is missing.")
53
  return
54
-
55
  # Load FAISS index and response pool
56
  try:
57
- chatbot.data_pipeline.load_faiss_index(FAISS_INDEX_PATH)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
58
  logger.info(f"FAISS index loaded from {FAISS_INDEX_PATH}.")
59
- logger.info(f"FAISS dimensions: {chatbot.data_pipeline.index.d}")
60
- logger.info(f"FAISS index type: {type(chatbot.data_pipeline.index)}")
61
- logger.info(f"FAISS index total vectors: {chatbot.data_pipeline.index.ntotal}")
62
- logger.info(f"FAISS is_trained: {chatbot.data_pipeline.index.is_trained}")
63
-
64
  with open(RESPONSE_POOL_PATH, "r", encoding="utf-8") as f:
65
- chatbot.data_pipeline.response_pool = json.load(f)
66
  logger.info(f"Response pool loaded from {RESPONSE_POOL_PATH}.")
67
- logger.info(f"\nTotal responses in pool: {len(chatbot.data_pipeline.response_pool)}")
68
-
69
  # Validate dimension consistency
70
- chatbot.data_pipeline.validate_faiss_index()
71
  logger.info("FAISS index and response pool validated successfully.")
72
-
73
  except Exception as e:
74
  logger.error(f"Failed to load or validate FAISS index: {e}")
75
  return
76
-
77
  # Init QualityChecker and Validator
78
- quality_checker = ResponseQualityChecker(data_pipeline=chatbot.data_pipeline)
79
- validator = ChatbotValidator(chatbot=chatbot, quality_checker=quality_checker)
80
- logger.info("ResponseQualityChecker and ChatbotValidator initialized.")
81
-
82
- # Run validation
83
  try:
 
 
 
 
 
 
84
  validation_metrics = validator.run_validation(num_examples=5)
85
  logger.info(f"Validation Metrics: {validation_metrics}")
86
  except Exception as e:
87
  logger.error(f"Validation process failed: {e}")
88
  return
89
-
90
  # Plot metrics
91
  try:
92
  plotter = Plotter(save_dir=env.training_dirs["plots"])
@@ -94,10 +104,22 @@ def run_chatbot_validation():
94
  logger.info("Validation metrics plotted successfully.")
95
  except Exception as e:
96
  logger.error(f"Failed to plot validation metrics: {e}")
97
-
98
  # Run interactive chat loop
99
- logger.info("\nStarting interactive chat session...")
100
- chatbot.run_interactive_chat(quality_checker, show_alternatives=True)
 
 
 
 
 
 
 
 
 
 
 
 
101
 
102
  if __name__ == "__main__":
103
- run_chatbot_validation()
 
1
  import os
2
  import json
3
+ from sentence_transformers import SentenceTransformer
4
+ from chatbot_config import ChatbotConfig
5
+ from chatbot_model import RetrievalChatbot
6
  from response_quality_checker import ResponseQualityChecker
7
  from chatbot_validator import ChatbotValidator
8
  from plotter import Plotter
9
  from environment_setup import EnvironmentSetup
10
  from logger_config import config_logger
11
+ from tf_data_pipeline import TFDataPipeline
12
 
13
  logger = config_logger(__name__)
14
+
15
  def run_chatbot_validation():
16
  # Initialize environment
17
  env = EnvironmentSetup()
18
  env.initialize()
19
+
20
  MODEL_DIR = "models"
21
  FAISS_INDICES_DIR = os.path.join(MODEL_DIR, "faiss_indices")
22
  FAISS_INDEX_PRODUCTION_PATH = os.path.join(FAISS_INDICES_DIR, "faiss_index_production.index")
23
  FAISS_INDEX_TEST_PATH = os.path.join(FAISS_INDICES_DIR, "faiss_index_test.index")
24
+
25
  # Toggle 'production' or 'test' env
26
  ENVIRONMENT = "production"
27
  if ENVIRONMENT == "test":
 
30
  else:
31
  FAISS_INDEX_PATH = FAISS_INDEX_PRODUCTION_PATH
32
  RESPONSE_POOL_PATH = FAISS_INDEX_PRODUCTION_PATH.replace(".index", "_responses.json")
33
+
34
  # Load the config
35
  config_path = os.path.join(MODEL_DIR, "config.json")
36
  if os.path.exists(config_path):
 
41
  else:
42
  config = ChatbotConfig()
43
  logger.warning("No config.json found. Using default ChatbotConfig.")
44
+
45
+ # Init SentenceTransformer
46
  try:
47
+ model_name = "sentence-transformers/all-MiniLM-L6-v2" # Replace with your chosen model
48
+ encoder = SentenceTransformer(model_name)
49
+ logger.info(f"Loaded SentenceTransformer model: {model_name}")
50
  except Exception as e:
51
+ logger.error(f"Failed to load SentenceTransformer: {e}")
 
 
 
 
 
52
  return
53
+
54
  # Load FAISS index and response pool
55
  try:
56
+ # Initialize TFDataPipeline
57
+ data_pipeline = TFDataPipeline(
58
+ config=config,
59
+ tokenizer=encoder.tokenizer,
60
+ encoder=encoder,
61
+ response_pool=[],
62
+ query_embeddings_cache={},
63
+ index_type='IndexFlatIP',
64
+ faiss_index_file_path=FAISS_INDEX_PATH
65
+ )
66
+
67
+ if not os.path.exists(FAISS_INDEX_PATH) or not os.path.exists(RESPONSE_POOL_PATH):
68
+ logger.error("FAISS index or response pool file is missing.")
69
+ return
70
+
71
+ data_pipeline.load_faiss_index(FAISS_INDEX_PATH)
72
  logger.info(f"FAISS index loaded from {FAISS_INDEX_PATH}.")
73
+
 
 
 
 
74
  with open(RESPONSE_POOL_PATH, "r", encoding="utf-8") as f:
75
+ data_pipeline.response_pool = json.load(f)
76
  logger.info(f"Response pool loaded from {RESPONSE_POOL_PATH}.")
77
+ logger.info(f"Total responses in pool: {len(data_pipeline.response_pool)}")
78
+
79
  # Validate dimension consistency
80
+ data_pipeline.validate_faiss_index()
81
  logger.info("FAISS index and response pool validated successfully.")
 
82
  except Exception as e:
83
  logger.error(f"Failed to load or validate FAISS index: {e}")
84
  return
85
+
86
  # Init QualityChecker and Validator
 
 
 
 
 
87
  try:
88
+ chatbot = RetrievalChatbot.load_model(load_dir=MODEL_DIR, mode="inference")
89
+ quality_checker = ResponseQualityChecker(data_pipeline=data_pipeline)
90
+ validator = ChatbotValidator(chatbot=chatbot, quality_checker=quality_checker)
91
+ logger.info("ResponseQualityChecker and ChatbotValidator initialized.")
92
+
93
+ # Run validation
94
  validation_metrics = validator.run_validation(num_examples=5)
95
  logger.info(f"Validation Metrics: {validation_metrics}")
96
  except Exception as e:
97
  logger.error(f"Validation process failed: {e}")
98
  return
99
+
100
  # Plot metrics
101
  try:
102
  plotter = Plotter(save_dir=env.training_dirs["plots"])
 
104
  logger.info("Validation metrics plotted successfully.")
105
  except Exception as e:
106
  logger.error(f"Failed to plot validation metrics: {e}")
107
+
108
  # Run interactive chat loop
109
+ try:
110
+ logger.info("\nStarting interactive chat session...")
111
+ while True:
112
+ user_input = input("You: ")
113
+ if user_input.lower() in ["exit", "quit"]:
114
+ logger.info("Exiting chat session.")
115
+ break
116
+
117
+ responses = data_pipeline.retrieve_responses(user_input, top_k=3)
118
+ print("Top Responses:")
119
+ for i, (response, score) in enumerate(responses, start=1):
120
+ print(f"{i}. {response} (Score: {score:.4f})")
121
+ except KeyboardInterrupt:
122
+ logger.info("Interactive chat session interrupted by user.")
123
 
124
  if __name__ == "__main__":
125
+ run_chatbot_validation()
run_taskmaster_processor.py CHANGED
@@ -5,7 +5,7 @@ from taskmaster_processor import TaskmasterProcessor, RawDataProcessingConfig
5
 
6
  def main():
7
  # Setup config and processor
8
- base_dir = "datasets/taskmaster"
9
  config = RawDataProcessingConfig(
10
  debug=True,
11
  max_length=512,
 
5
 
6
  def main():
7
  # Setup config and processor
8
+ base_dir = "raw_datasets/taskmaster"
9
  config = RawDataProcessingConfig(
10
  debug=True,
11
  max_length=512,
taskmaster_processor.py CHANGED
@@ -4,6 +4,9 @@ import json
4
  from pathlib import Path
5
  from typing import List, Dict, Optional, Any
6
  from dataclasses import dataclass, field
 
 
 
7
 
8
  @dataclass
9
  class TaskmasterDialogue:
@@ -28,7 +31,7 @@ class RawDataProcessingConfig:
28
  self,
29
  debug: bool = True,
30
  max_length: int = 512,
31
- min_turns: int = 2,
32
  min_user_words: int = 3
33
  ):
34
  self.debug = debug
@@ -68,7 +71,7 @@ class TaskmasterProcessor:
68
  with open(ontology_path, 'r', encoding='utf-8') as f:
69
  ontology = json.load(f)
70
  if self.config.debug:
71
- print(f"[TaskmasterProcessor] Loaded ontology with {len(ontology.keys())} top-level keys (unused).")
72
 
73
  dialogues: List[TaskmasterDialogue] = []
74
 
@@ -106,7 +109,7 @@ class TaskmasterProcessor:
106
  break
107
 
108
  if self.config.debug:
109
- print(f"[TaskmasterProcessor] Loaded {len(dialogues)} total dialogues from Taskmaster-1.")
110
  return dialogues
111
 
112
  def _extract_domain(self, scenario: str, turns: List[Dict[str, str]]) -> str:
@@ -130,43 +133,15 @@ class TaskmasterProcessor:
130
 
131
  for domain, pattern in domain_patterns.items():
132
  if re.search(pattern, combined_text):
133
- # Optional: print if debug
134
  if self.config.debug:
135
- print(f"Matched domain: {domain} in scenario/turns")
136
  return domain
137
 
138
  if self.config.debug:
139
- print("No domain match, returning 'other'")
140
  return 'other'
141
 
142
- def _process_utterances(self, utterances: List[Dict[str, Any]]) -> List[Dict[str, str]]:
143
- """
144
- Convert "utterances" to a cleaned List -> (speaker, text).
145
- Skip lines that are numeric, too short, or empty.
146
- """
147
- cleaned_turns = []
148
- for utt in utterances:
149
- speaker = 'assistant' if utt.get('speaker') == 'ASSISTANT' else 'user'
150
- raw_text = utt.get('text', '').strip()
151
-
152
- # Text cleaning
153
- text = self._clean_text(raw_text)
154
-
155
- # Skip blank or numeric lines (e.g. "4 3 13")
156
- if not text or self._is_numeric_line(text):
157
- continue
158
-
159
- # Skip too short (no training benefit from 1-word user turns). E.g. "ok","yes", etc.
160
- if len(text.split()) < 3:
161
- continue
162
-
163
- # Add to cleaned turns
164
- cleaned_turns.append({
165
- 'speaker': speaker,
166
- 'text': text
167
- })
168
- return cleaned_turns
169
-
170
  def _clean_text(self, text: str) -> str:
171
  """
172
  Simple text normalization
@@ -193,13 +168,20 @@ class TaskmasterProcessor:
193
  "turns": [ {"speaker": "user", "text": "..."}, ... ]
194
  }
195
  """
 
 
 
 
196
  results = []
 
197
  for dlg in dialogues:
198
  if not dlg.validate():
 
199
  continue
200
 
201
  # Skip if too few turns
202
  if len(dlg.turns) < self.config.min_turns:
 
203
  continue
204
 
205
  # Skip if any user turn is too short
@@ -208,6 +190,7 @@ class TaskmasterProcessor:
208
  if turn['speaker'] == 'user':
209
  words_count = len(turn['text'].split())
210
  if words_count < self.config.min_user_words:
 
211
  keep = False
212
  break
213
 
@@ -217,10 +200,59 @@ class TaskmasterProcessor:
217
  pipeline_dlg = {
218
  'dialogue_id': dlg.conversation_id,
219
  'domain': dlg.domain,
220
- 'turns': dlg.turns # already cleaned
221
  }
222
  results.append(pipeline_dlg)
223
 
224
  if self.config.debug:
225
- print(f"[TaskmasterProcessor] Filtered down to {len(results)} dialogues after cleaning.")
 
 
 
 
 
 
 
226
  return results
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4
  from pathlib import Path
5
  from typing import List, Dict, Optional, Any
6
  from dataclasses import dataclass, field
7
+ from logger_config import config_logger
8
+
9
+ logger = config_logger(__name__)
10
 
11
  @dataclass
12
  class TaskmasterDialogue:
 
31
  self,
32
  debug: bool = True,
33
  max_length: int = 512,
34
+ min_turns: int = 4,
35
  min_user_words: int = 3
36
  ):
37
  self.debug = debug
 
71
  with open(ontology_path, 'r', encoding='utf-8') as f:
72
  ontology = json.load(f)
73
  if self.config.debug:
74
+ logger.info(f"[TaskmasterProcessor] Loaded ontology with {len(ontology.keys())} top-level keys (unused).")
75
 
76
  dialogues: List[TaskmasterDialogue] = []
77
 
 
109
  break
110
 
111
  if self.config.debug:
112
+ logger.info(f"[TaskmasterProcessor] Loaded {len(dialogues)} total dialogues from Taskmaster-1.")
113
  return dialogues
114
 
115
  def _extract_domain(self, scenario: str, turns: List[Dict[str, str]]) -> str:
 
133
 
134
  for domain, pattern in domain_patterns.items():
135
  if re.search(pattern, combined_text):
136
+ # Optional: logger.info if debug
137
  if self.config.debug:
138
+ logger.info(f"Matched domain: {domain} in scenario/turns")
139
  return domain
140
 
141
  if self.config.debug:
142
+ logger.info("No domain match, returning 'other'")
143
  return 'other'
144
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
145
  def _clean_text(self, text: str) -> str:
146
  """
147
  Simple text normalization
 
168
  "turns": [ {"speaker": "user", "text": "..."}, ... ]
169
  }
170
  """
171
+ total = len(dialogues)
172
+ invalid = 0
173
+ too_few_turns = 0
174
+ short_user_turns = 0
175
  results = []
176
+
177
  for dlg in dialogues:
178
  if not dlg.validate():
179
+ invalid += 1
180
  continue
181
 
182
  # Skip if too few turns
183
  if len(dlg.turns) < self.config.min_turns:
184
+ too_few_turns += 1
185
  continue
186
 
187
  # Skip if any user turn is too short
 
190
  if turn['speaker'] == 'user':
191
  words_count = len(turn['text'].split())
192
  if words_count < self.config.min_user_words:
193
+ short_user_turns += 1
194
  keep = False
195
  break
196
 
 
200
  pipeline_dlg = {
201
  'dialogue_id': dlg.conversation_id,
202
  'domain': dlg.domain,
203
+ 'turns': dlg.turns
204
  }
205
  results.append(pipeline_dlg)
206
 
207
  if self.config.debug:
208
+ logger.info(f"\nFiltering Statistics:")
209
+ logger.info(f"Total dialogues: {total}")
210
+ logger.info(f"Invalid dialogues: {invalid}")
211
+ logger.info(f"Too few turns: {too_few_turns}")
212
+ logger.info(f"Short user turns: {short_user_turns}")
213
+ logger.info(f"Remaining dialogues: {len(results)}")
214
+ logger.info(f"Filtering rate: {((total - len(results)) / total) * 100:.1f}%\n")
215
+
216
  return results
217
+
218
+ def _process_utterances(self, utterances: List[Dict[str, Any]]) -> List[Dict[str, str]]:
219
+ """Added logging to track utterance filtering"""
220
+ total = len(utterances)
221
+ empty = 0
222
+ numeric = 0
223
+ too_short = 0
224
+ cleaned_turns = []
225
+
226
+ for utt in utterances:
227
+ speaker = 'assistant' if utt.get('speaker') == 'ASSISTANT' else 'user'
228
+ raw_text = utt.get('text', '').strip()
229
+
230
+ text = self._clean_text(raw_text)
231
+
232
+ if not text:
233
+ empty += 1
234
+ continue
235
+
236
+ if self._is_numeric_line(text):
237
+ numeric += 1
238
+ continue
239
+
240
+ if len(text.split()) < 3:
241
+ too_short += 1
242
+ continue
243
+
244
+ cleaned_turns.append({
245
+ 'speaker': speaker,
246
+ 'text': text
247
+ })
248
+
249
+ if self.config.debug and total > 0:
250
+ logger.info(f"\nUtterance Cleaning Statistics (Dialogue {utterances[0].get('conversation_id', 'unknown')}):")
251
+ logger.info(f"Total utterances: {total}")
252
+ logger.info(f"Empty/blank: {empty}")
253
+ logger.info(f"Numeric only: {numeric}")
254
+ logger.info(f"Too short (<3 words): {too_short}")
255
+ logger.info(f"Remaining turns: {len(cleaned_turns)}")
256
+ logger.info(f"Filtering rate: {((total - len(cleaned_turns)) / total) * 100:.1f}%\n")
257
+
258
+ return cleaned_turns
tf_data_pipeline.py CHANGED
@@ -11,6 +11,8 @@ import json
11
  from pathlib import Path
12
  from typing import Union, Optional, Dict, List, Tuple, Generator
13
  from transformers import AutoTokenizer
 
 
14
  from typing import List, Tuple, Generator
15
  from transformers import AutoTokenizer
16
  import random
@@ -21,26 +23,30 @@ logger = config_logger(__name__)
21
  class TFDataPipeline:
22
  def __init__(
23
  self,
24
- config,
25
- tokenizer,
26
- encoder,
27
  response_pool: List[str],
28
  query_embeddings_cache: dict,
 
29
  max_length: int = 512,
30
  neg_samples: int = 10,
31
  index_type: str = 'IndexFlatIP',
32
  faiss_index_file_path: str = 'models/faiss_indices/faiss_index_production.index',
 
33
  nlist: int = 100,
34
  max_retries: int = 3
35
  ):
36
  self.config = config
37
  self.tokenizer = tokenizer
38
  self.encoder = encoder
 
39
  self.faiss_index_file_path = faiss_index_file_path
40
  self.response_pool = response_pool
41
  self.max_length = max_length
42
  self.neg_samples = neg_samples
43
  self.query_embeddings_cache = query_embeddings_cache # In-memory cache for embeddings
 
44
  self.index_type = index_type
45
  self.nlist = nlist
46
  self.embedding_batch_size = 16 if len(response_pool) < 100 else 64
@@ -59,9 +65,8 @@ class TFDataPipeline:
59
  self.validate_faiss_index()
60
  logger.info("FAISS index loaded and validated successfully.")
61
  else:
62
- dimension = self.encoder.config.embedding_dim
63
- self.index = faiss.IndexFlatIP(dimension)
64
- logger.info(f"Initialized FAISS IndexFlatIP with dimension {dimension}.")
65
 
66
  if not self.index.is_trained:
67
  # Train the index if it's not trained. IndexFlatIP doesn't need training, but others do (Future switch to IndexIVFFlat)
@@ -98,7 +103,7 @@ class TFDataPipeline:
98
 
99
  def validate_faiss_index(self):
100
  """Validates FAISS index dimensionality."""
101
- expected_dim = self.encoder.config.embedding_dim
102
  if self.index.d != expected_dim:
103
  logger.error(f"FAISS index dimension {self.index.d} does not match encoder embedding dimension {expected_dim}.")
104
  raise ValueError("FAISS index dimensionality mismatch.")
@@ -186,44 +191,49 @@ class TFDataPipeline:
186
  pairs.append((query, positive))
187
 
188
  return pairs
189
-
190
  def compute_and_index_response_embeddings(self):
191
  """
192
- Compute embeddings for the response pool and add them to the FAISS index.
193
- self.response_pool: List[Dict[str, str]] with keys "domain" and "text".
194
  """
195
- logger.info("Computing embeddings for the response pool...")
 
 
196
 
197
- # Extract the assistant text
198
  texts = [resp["text"] for resp in self.response_pool]
199
  logger.debug(f"Total texts to embed: {len(texts)}")
200
 
201
- batch_size = getattr(self, 'embedding_batch_size', 64)
202
  embeddings = []
 
203
 
 
204
  with tqdm(total=len(texts), desc="Computing Embeddings", unit="response") as pbar:
205
  for i in range(0, len(texts), batch_size):
206
- batch_texts = texts[i:i+batch_size]
207
- encodings = self.tokenizer(
 
 
208
  batch_texts,
209
- padding=True,
210
- truncation=True,
211
- max_length=self.max_length,
212
- return_tensors='tf'
213
  )
214
- batch_embeds = self.encoder(encodings['input_ids'], training=False).numpy()
215
 
216
- embeddings.append(batch_embeds)
217
  pbar.update(len(batch_texts))
218
 
219
- # Combine embeddings and add to FAISS
220
  all_embeddings = np.vstack(embeddings).astype(np.float32)
221
  logger.info(f"Adding {len(all_embeddings)} response embeddings to FAISS index...")
 
 
222
  self.index.add(all_embeddings)
223
 
224
  # Store in memory
225
  self.response_embeddings = all_embeddings
226
- logger.info(f"FAISS index now has {self.index.ntotal} vectors.")
227
 
228
  def _find_hard_negatives(self, queries: List[str], positives: List[str], batch_size: int = 128) -> List[List[str]]:
229
  """
@@ -385,106 +395,41 @@ class TFDataPipeline:
385
  self._text_domain_map[stripped_text] = domain
386
 
387
  logger.info(f"Built text -> domain map with {len(self._text_domain_map)} unique text entries.")
388
-
389
- def encode_query(
390
- self,
391
- query: str,
392
- context: Optional[List[Tuple[str, str]]] = None
393
- ) -> np.ndarray:
394
- """
395
- Encode a user query (and optional conversation context) into an embedding vector.
396
-
397
- Args:
398
- query: The user query.
399
- context: Optional conversation history as a list of (user_text, assistant_text).
400
- Returns:
401
- np.ndarray of shape [embedding_dim], typically L2-normalized already.
402
- """
403
- # Prepare context: concat user/assistant pairs
404
- if context:
405
- # Take the last N turns
406
- relevant_history = context[-self.config.max_context_turns:]
407
- context_str_parts = []
408
- for (u_text, a_text) in relevant_history:
409
- context_str_parts.append(
410
- f"{self.tokenizer.additional_special_tokens[self.tokenizer.additional_special_tokens.index('<USER>')]} {u_text} "
411
- f"{self.tokenizer.additional_special_tokens[self.tokenizer.additional_special_tokens.index('<ASSISTANT>')]} {a_text}"
412
- )
413
- context_str = " ".join(context_str_parts)
414
-
415
- # Append the new query
416
- full_query = (
417
- f"{context_str} "
418
- f"{self.tokenizer.additional_special_tokens[self.tokenizer.additional_special_tokens.index('<USER>')]} {query}"
419
- )
420
- else:
421
- # Single user turn
422
- full_query = (
423
- f"{self.tokenizer.additional_special_tokens[self.tokenizer.additional_special_tokens.index('<USER>')]} {query}"
424
- )
425
 
426
- # Tokenize
427
- encodings = self.tokenizer(
428
- [full_query],
429
- padding='max_length',
430
- truncation=True,
431
- max_length=self.max_length,
432
- return_tensors='np' # to keep it compatible with FAISS
433
- )
434
- input_ids = encodings['input_ids']
435
-
436
- # Debug out-of-vocab IDs
437
- max_id = np.max(input_ids)
438
- vocab_size = len(self.tokenizer)
439
- if max_id >= vocab_size:
440
- logger.error(f"Token ID {max_id} exceeds tokenizer vocab size {vocab_size}.")
441
- raise ValueError("Token ID exceeds vocabulary size.")
442
-
443
- # Get embeddings from the model. These are already L2-normalized by the model's final layer.
444
- embeddings = self.encoder(input_ids, training=False).numpy()
445
-
446
- return embeddings[0]
447
 
448
  def encode_responses(
449
- self,
450
- responses: List[str],
451
  context: Optional[List[Tuple[str, str]]] = None
452
  ) -> np.ndarray:
453
  """
454
- Encode multiple response texts into embedding vectors.
455
- Args:
456
- responses: List of assistant responses.
457
- context: Optional conversation context (last N turns).
458
- Returns:
459
- np.ndarray of shape [num_responses, embedding_dim].
460
  """
461
- # Incorporate context into response encoding. Note: Undecided on benefit of this
 
 
462
  if context:
463
  relevant_history = context[-self.config.max_context_turns:]
464
  prepared = []
465
  for resp in responses:
466
  context_str_parts = []
 
467
  for (u_text, a_text) in relevant_history:
468
  context_str_parts.append(
469
- f"{self.tokenizer.additional_special_tokens[self.tokenizer.additional_special_tokens.index('<USER>')]} {u_text} "
470
- f"{self.tokenizer.additional_special_tokens[self.tokenizer.additional_special_tokens.index('<ASSISTANT>')]} {a_text}"
471
  )
472
  context_str = " ".join(context_str_parts)
473
-
474
- # Treat resp as an assistant turn
475
- full_resp = (
476
- f"{context_str} "
477
- f"{self.tokenizer.additional_special_tokens[self.tokenizer.additional_special_tokens.index('<ASSISTANT>')]} {resp}"
478
- )
479
  prepared.append(full_resp)
480
  else:
481
  # Single response from the assistant
482
- prepared = [
483
- f"{self.tokenizer.additional_special_tokens[self.tokenizer.additional_special_tokens.index('<ASSISTANT>')]} {r}"
484
- for r in responses
485
- ]
486
 
487
- # Tokenize
488
  encodings = self.tokenizer(
489
  prepared,
490
  padding='max_length',
@@ -493,19 +438,42 @@ class TFDataPipeline:
493
  return_tensors='np'
494
  )
495
  input_ids = encodings['input_ids']
496
-
497
  # Debug for out-of-vocab
498
  max_id = np.max(input_ids)
499
  vocab_size = len(self.tokenizer)
500
  if max_id >= vocab_size:
501
- logger.error(f"Token ID {max_id} exceeds tokenizer vocab size {vocab_size}.")
502
  raise ValueError("Token ID exceeds vocabulary size.")
503
 
504
- # Get embeddings from the model. These are already L2-normalized by the model's final layer.
505
- embeddings = self.encoder(input_ids, training=False).numpy()
506
 
507
  return embeddings.astype('float32')
508
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
509
  def prepare_and_save_data(self, dialogues: List[dict], tf_record_path: str, batch_size: int = 32):
510
  """
511
  Batch-Process dialogues and save to TFRecord file.
 
11
  from pathlib import Path
12
  from typing import Union, Optional, Dict, List, Tuple, Generator
13
  from transformers import AutoTokenizer
14
+ from sentence_transformers import SentenceTransformer
15
+ from chatbot_config import ChatbotConfig
16
  from typing import List, Tuple, Generator
17
  from transformers import AutoTokenizer
18
  import random
 
23
  class TFDataPipeline:
24
  def __init__(
25
  self,
26
+ config: ChatbotConfig,
27
+ tokenizer: AutoTokenizer,
28
+ encoder: SentenceTransformer,
29
  response_pool: List[str],
30
  query_embeddings_cache: dict,
31
+ model_name: str = 'sentence-transformers/all-MiniLM-L6-v2',
32
  max_length: int = 512,
33
  neg_samples: int = 10,
34
  index_type: str = 'IndexFlatIP',
35
  faiss_index_file_path: str = 'models/faiss_indices/faiss_index_production.index',
36
+ dimension: int = 384,
37
  nlist: int = 100,
38
  max_retries: int = 3
39
  ):
40
  self.config = config
41
  self.tokenizer = tokenizer
42
  self.encoder = encoder
43
+ self.model = SentenceTransformer(model_name)
44
  self.faiss_index_file_path = faiss_index_file_path
45
  self.response_pool = response_pool
46
  self.max_length = max_length
47
  self.neg_samples = neg_samples
48
  self.query_embeddings_cache = query_embeddings_cache # In-memory cache for embeddings
49
+ self.dimension = config.embedding_dim
50
  self.index_type = index_type
51
  self.nlist = nlist
52
  self.embedding_batch_size = 16 if len(response_pool) < 100 else 64
 
65
  self.validate_faiss_index()
66
  logger.info("FAISS index loaded and validated successfully.")
67
  else:
68
+ self.index = faiss.IndexFlatIP(self.dimension)
69
+ logger.info(f"Initialized FAISS IndexFlatIP with dimension {self.dimension}.")
 
70
 
71
  if not self.index.is_trained:
72
  # Train the index if it's not trained. IndexFlatIP doesn't need training, but others do (Future switch to IndexIVFFlat)
 
103
 
104
  def validate_faiss_index(self):
105
  """Validates FAISS index dimensionality."""
106
+ expected_dim = self.dimension
107
  if self.index.d != expected_dim:
108
  logger.error(f"FAISS index dimension {self.index.d} does not match encoder embedding dimension {expected_dim}.")
109
  raise ValueError("FAISS index dimensionality mismatch.")
 
191
  pairs.append((query, positive))
192
 
193
  return pairs
194
+
195
  def compute_and_index_response_embeddings(self):
196
  """
197
+ Compute embeddings for the response pool using SentenceTransformer
198
+ and add them to the FAISS index.
199
  """
200
+ if not self.response_pool:
201
+ logger.warning("Response pool is empty. No embeddings to compute.")
202
+ return
203
 
204
+ logger.info("Computing embeddings for the response pool...")
205
  texts = [resp["text"] for resp in self.response_pool]
206
  logger.debug(f"Total texts to embed: {len(texts)}")
207
 
 
208
  embeddings = []
209
+ batch_size = self.embedding_batch_size
210
 
211
+ # Use SentenceTransformer to compute embeddings in batches
212
  with tqdm(total=len(texts), desc="Computing Embeddings", unit="response") as pbar:
213
  for i in range(0, len(texts), batch_size):
214
+ batch_texts = texts[i:i + batch_size]
215
+
216
+ # Compute embeddings
217
+ batch_embeddings = self.encoder.encode(
218
  batch_texts,
219
+ batch_size=batch_size,
220
+ convert_to_numpy=True,
221
+ normalize_embeddings=True # Normalizes for cosine similarity
 
222
  )
 
223
 
224
+ embeddings.append(batch_embeddings)
225
  pbar.update(len(batch_texts))
226
 
227
+ # Combine all embeddings
228
  all_embeddings = np.vstack(embeddings).astype(np.float32)
229
  logger.info(f"Adding {len(all_embeddings)} response embeddings to FAISS index...")
230
+
231
+ # Add to FAISS index
232
  self.index.add(all_embeddings)
233
 
234
  # Store in memory
235
  self.response_embeddings = all_embeddings
236
+ logger.info(f"FAISS index now contains {self.index.ntotal} vectors.")
237
 
238
  def _find_hard_negatives(self, queries: List[str], positives: List[str], batch_size: int = 128) -> List[List[str]]:
239
  """
 
395
  self._text_domain_map[stripped_text] = domain
396
 
397
  logger.info(f"Built text -> domain map with {len(self._text_domain_map)} unique text entries.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
398
 
399
+ def encode_query(self, query: str) -> np.ndarray:
400
+ """Generate embedding for a query string."""
401
+ return self.encoder.encode(query, convert_to_numpy=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
402
 
403
  def encode_responses(
404
+ self,
405
+ responses: List[str],
406
  context: Optional[List[Tuple[str, str]]] = None
407
  ) -> np.ndarray:
408
  """
409
+ Encode multiple response texts into embeddings, injecting <ASSISTANT> literally.
 
 
 
 
 
410
  """
411
+ USER_TOKEN = "<USER>"
412
+ ASSISTANT_TOKEN = "<ASSISTANT>"
413
+
414
  if context:
415
  relevant_history = context[-self.config.max_context_turns:]
416
  prepared = []
417
  for resp in responses:
418
  context_str_parts = []
419
+ # Build all user->assistant text
420
  for (u_text, a_text) in relevant_history:
421
  context_str_parts.append(
422
+ f"{USER_TOKEN} {u_text} {ASSISTANT_TOKEN} {a_text}"
 
423
  )
424
  context_str = " ".join(context_str_parts)
425
+ # Treat resp as an assistant turn:
426
+ full_resp = f"{context_str} {ASSISTANT_TOKEN} {resp}"
 
 
 
 
427
  prepared.append(full_resp)
428
  else:
429
  # Single response from the assistant
430
+ prepared = [f"{ASSISTANT_TOKEN} {r}" for r in responses]
 
 
 
431
 
432
+ # Pass the prepared strings to the SentenceTransformer tokenizer:
433
  encodings = self.tokenizer(
434
  prepared,
435
  padding='max_length',
 
438
  return_tensors='np'
439
  )
440
  input_ids = encodings['input_ids']
441
+
442
  # Debug for out-of-vocab
443
  max_id = np.max(input_ids)
444
  vocab_size = len(self.tokenizer)
445
  if max_id >= vocab_size:
446
+ logger.error(f"Token ID {max_id} >= tokenizer vocab size {vocab_size}")
447
  raise ValueError("Token ID exceeds vocabulary size.")
448
 
449
+ # Get embeddings from SentenceTransformer
450
+ embeddings = self.encoder.encode(prepared, convert_to_numpy=True)
451
 
452
  return embeddings.astype('float32')
453
 
454
+ def retrieve_responses(self, query: str, top_k: int = 10) -> List[Tuple[str, float]]:
455
+ """
456
+ Retrieve top-k responses for a query using FAISS.
457
+
458
+ Args:
459
+ query: User's query text.
460
+ top_k: Number of responses to return.
461
+
462
+ Returns:
463
+ List of tuples (response text, similarity score).
464
+ """
465
+ query_embedding = self.encode_query(query).reshape(1, -1).astype("float32")
466
+ distances, indices = self.index.search(query_embedding, top_k)
467
+
468
+ results = []
469
+ for idx, dist in zip(indices[0], distances[0]):
470
+ if idx < 0:
471
+ continue
472
+ response = self.response_pool[idx]
473
+ results.append((response["text"], dist))
474
+
475
+ return results
476
+
477
  def prepare_and_save_data(self, dialogues: List[dict], tf_record_path: str, batch_size: int = 32):
478
  """
479
  Batch-Process dialogues and save to TFRecord file.