JoeArmani
commited on
Commit
·
64e7c31
1
Parent(s):
a763857
sentence transformer
Browse files- .gitignore +2 -0
- chatbot_config.py +30 -0
- chatbot_model.py +109 -204
- chatbot_validator.py +3 -5
- prepare_data.py +65 -182
- run_chatbot_chat.py +2 -1
- run_chatbot_validation.py +60 -38
- run_taskmaster_processor.py +1 -1
- taskmaster_processor.py +68 -36
- tf_data_pipeline.py +75 -107
.gitignore
CHANGED
@@ -183,4 +183,6 @@ training_data/*
|
|
183 |
augmented_dialogues.json
|
184 |
|
185 |
raw_datasets/*
|
|
|
|
|
186 |
|
|
|
183 |
augmented_dialogues.json
|
184 |
|
185 |
raw_datasets/*
|
186 |
+
st/*
|
187 |
+
|
188 |
|
chatbot_config.py
ADDED
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from dataclasses import dataclass
|
2 |
+
from pathlib import Path
|
3 |
+
from typing import Dict
|
4 |
+
|
5 |
+
@dataclass
|
6 |
+
class ChatbotConfig:
|
7 |
+
"""RetrievalChatbot Config"""
|
8 |
+
max_context_token_limit: int = 512
|
9 |
+
embedding_dim: int = 384 # Match Sentence Transformer dimension
|
10 |
+
learning_rate: float = 0.0005
|
11 |
+
min_text_length: int = 3
|
12 |
+
max_context_turns: int = 20
|
13 |
+
pretrained_model: str = 'sentence-transformers/all-MiniLM-L6-v2'
|
14 |
+
cross_encoder_model: str = 'cross-encoder/ms-marco-MiniLM-L-12-v2'
|
15 |
+
summarizer_model: str = 't5-small'
|
16 |
+
embedding_batch_size: int = 64
|
17 |
+
search_batch_size: int = 64
|
18 |
+
max_batch_size: int = 64
|
19 |
+
max_retries: int = 3
|
20 |
+
|
21 |
+
def to_dict(self) -> Dict:
|
22 |
+
"""Convert config to dictionary."""
|
23 |
+
return {k: (str(v) if isinstance(v, Path) else v)
|
24 |
+
for k, v in self.__dict__.items()}
|
25 |
+
|
26 |
+
@classmethod
|
27 |
+
def from_dict(cls, config_dict: Dict) -> 'ChatbotConfig':
|
28 |
+
"""Create config from dictionary."""
|
29 |
+
return cls(**{k: v for k, v in config_dict.items()
|
30 |
+
if k in cls.__dataclass_fields__})
|
chatbot_model.py
CHANGED
@@ -1,6 +1,6 @@
|
|
1 |
import os
|
2 |
import numpy as np
|
3 |
-
from
|
4 |
import tensorflow as tf
|
5 |
from typing import List, Tuple, Dict, Optional, Union, Any
|
6 |
import math
|
@@ -11,125 +11,24 @@ import datetime
|
|
11 |
import faiss
|
12 |
import gc
|
13 |
import re
|
14 |
-
from tf_data_pipeline import TFDataPipeline
|
15 |
from response_quality_checker import ResponseQualityChecker
|
16 |
from cross_encoder_reranker import CrossEncoderReranker
|
17 |
from conversation_summarizer import DeviceAwareModel, Summarizer
|
|
|
|
|
18 |
import absl.logging
|
19 |
from logger_config import config_logger
|
20 |
from tqdm.auto import tqdm
|
21 |
|
22 |
absl.logging.set_verbosity(absl.logging.WARNING)
|
23 |
logger = config_logger(__name__)
|
24 |
-
|
25 |
-
@dataclass
|
26 |
-
class ChatbotConfig:
|
27 |
-
"""RetrievalChatbot Config"""
|
28 |
-
max_context_token_limit: int = 512
|
29 |
-
embedding_dim: int = 768
|
30 |
-
encoder_units: int = 256
|
31 |
-
num_attention_heads: int = 8
|
32 |
-
dropout_rate: float = 0.2
|
33 |
-
l2_reg_weight: float = 0.001
|
34 |
-
learning_rate: float = 0.0005
|
35 |
-
min_text_length: int = 3
|
36 |
-
max_context_turns: int = 20
|
37 |
-
warmup_steps: int = 200
|
38 |
-
pretrained_model: str = 'distilbert-base-uncased'
|
39 |
-
cross_encoder_model: str = 'cross-encoder/ms-marco-MiniLM-L-12-v2'
|
40 |
-
summarizer_model: str = 't5-small'
|
41 |
-
dtype: str = 'float32'
|
42 |
-
freeze_embeddings: bool = False
|
43 |
-
embedding_batch_size: int = 64
|
44 |
-
search_batch_size: int = 64
|
45 |
-
max_batch_size: int = 64
|
46 |
-
max_retries: int = 3
|
47 |
-
|
48 |
-
def to_dict(self) -> Dict:
|
49 |
-
"""Convert config to dictionary."""
|
50 |
-
return {k: (str(v) if isinstance(v, Path) else v)
|
51 |
-
for k, v in self.__dict__.items()}
|
52 |
-
|
53 |
-
@classmethod
|
54 |
-
def from_dict(cls, config_dict: Dict) -> 'ChatbotConfig':
|
55 |
-
"""Create config from dictionary."""
|
56 |
-
return cls(**{k: v for k, v in config_dict.items()
|
57 |
-
if k in cls.__dataclass_fields__})
|
58 |
-
|
59 |
-
class EncoderModel(tf.keras.Model):
|
60 |
-
"""Dual encoder model with pretrained DistilBERT embeddings."""
|
61 |
-
def __init__(
|
62 |
-
self,
|
63 |
-
config: ChatbotConfig,
|
64 |
-
name: str = "encoder",
|
65 |
-
**kwargs
|
66 |
-
):
|
67 |
-
super().__init__(name=name, **kwargs)
|
68 |
-
self.config = config
|
69 |
-
|
70 |
-
# Load pretrained model and freeze layers based on config
|
71 |
-
self.pretrained = TFAutoModel.from_pretrained(config.pretrained_model)
|
72 |
-
self._freeze_layers()
|
73 |
-
|
74 |
-
# Add Global Average Pooling, Projection, Dropout, and Normalization layers
|
75 |
-
self.pooler = tf.keras.layers.GlobalAveragePooling1D()
|
76 |
-
self.projection = tf.keras.layers.Dense(
|
77 |
-
config.embedding_dim,
|
78 |
-
activation='tanh',
|
79 |
-
name="projection",
|
80 |
-
dtype=tf.float32
|
81 |
-
)
|
82 |
-
self.dropout = tf.keras.layers.Dropout(config.dropout_rate)
|
83 |
-
self.normalize = tf.keras.layers.Lambda(
|
84 |
-
lambda x: tf.nn.l2_normalize(x, axis=1),
|
85 |
-
name="l2_normalize"
|
86 |
-
)
|
87 |
-
|
88 |
-
def _freeze_layers(self):
|
89 |
-
"""Freeze n layers of the pretrained model"""
|
90 |
-
if self.config.freeze_embeddings:
|
91 |
-
self.pretrained.trainable = False
|
92 |
-
logger.info("All pretrained layers frozen.")
|
93 |
-
else:
|
94 |
-
# Freeze only the first 'n' transformer layers
|
95 |
-
for i, layer in enumerate(self.pretrained.layers):
|
96 |
-
if isinstance(layer, tf.keras.layers.Layer):
|
97 |
-
if hasattr(layer, 'trainable'):
|
98 |
-
if i < 1:
|
99 |
-
layer.trainable = False
|
100 |
-
logger.info(f"Layer {i} frozen.")
|
101 |
-
else:
|
102 |
-
layer.trainable = True
|
103 |
-
logger.info(f"Layer {i} trainable.")
|
104 |
-
|
105 |
-
def call(self, inputs: tf.Tensor, training: bool = False) -> tf.Tensor:
|
106 |
-
"""Forward pass."""
|
107 |
-
# Get pretrained embeddings
|
108 |
-
pretrained_outputs = self.pretrained(inputs, training=training)
|
109 |
-
x = pretrained_outputs.last_hidden_state # Shape: [batch_size, seq_len, embedding_dim]
|
110 |
-
|
111 |
-
# Apply pooling, projection, dropout, and normalization
|
112 |
-
x = self.pooler(x) # Shape: [batch_size, 768]
|
113 |
-
x = self.projection(x) # Shape: [batch_size, 768]
|
114 |
-
x = self.dropout(x, training=training)
|
115 |
-
x = self.normalize(x) # Shape: [batch_size, 768]
|
116 |
-
|
117 |
-
return x
|
118 |
-
|
119 |
-
def get_config(self) -> dict:
|
120 |
-
"""Return the model config"""
|
121 |
-
config = super().get_config()
|
122 |
-
config.update({
|
123 |
-
"config": self.config.to_dict(),
|
124 |
-
"name": self.name
|
125 |
-
})
|
126 |
-
return config
|
127 |
|
128 |
class RetrievalChatbot(DeviceAwareModel):
|
129 |
"""
|
130 |
Retrieval-based learning chatbot model.
|
131 |
Uses trained embeddings and FAISS for similarity search.
|
132 |
"""
|
|
|
133 |
def __init__(
|
134 |
self,
|
135 |
config: ChatbotConfig,
|
@@ -139,6 +38,7 @@ class RetrievalChatbot(DeviceAwareModel):
|
|
139 |
summarizer: Optional[Summarizer] = None,
|
140 |
mode: str = 'training'
|
141 |
):
|
|
|
142 |
super().__init__()
|
143 |
self.config = config
|
144 |
self.strategy = strategy
|
@@ -146,13 +46,14 @@ class RetrievalChatbot(DeviceAwareModel):
|
|
146 |
self.mode = mode.lower()
|
147 |
|
148 |
# Initialize reranker, summarizer, tokenizer, and encoder
|
149 |
-
self.reranker = reranker or self._initialize_reranker()
|
150 |
-
self.tokenizer = self._initialize_tokenizer()
|
151 |
self.encoder = self._initialize_encoder()
|
|
|
|
|
152 |
self.summarizer = summarizer or self._initialize_summarizer()
|
153 |
|
154 |
# Initialize data pipeline
|
155 |
logger.info("Initializing TFDataPipeline.")
|
|
|
156 |
self.data_pipeline = TFDataPipeline(
|
157 |
config=self.config,
|
158 |
tokenizer=self.tokenizer,
|
@@ -177,7 +78,6 @@ class RetrievalChatbot(DeviceAwareModel):
|
|
177 |
"train_metrics": {},
|
178 |
"val_metrics": {}
|
179 |
}
|
180 |
-
|
181 |
|
182 |
def _setup_default_device(self) -> str:
|
183 |
"""Set up default device if none is provided."""
|
@@ -200,34 +100,11 @@ class RetrievalChatbot(DeviceAwareModel):
|
|
200 |
device=self.device,
|
201 |
max_summary_rounds=2
|
202 |
)
|
203 |
-
|
204 |
-
def _initialize_tokenizer(self) -> AutoTokenizer:
|
205 |
-
"""Initialize the tokenizer and add special tokens."""
|
206 |
-
logger.info("Initializing tokenizer and adding special tokens...")
|
207 |
-
tokenizer = AutoTokenizer.from_pretrained(self.config.pretrained_model)
|
208 |
-
special_tokens = {
|
209 |
-
"user": "<USER>",
|
210 |
-
"assistant": "<ASSISTANT>",
|
211 |
-
"context": "<CONTEXT>",
|
212 |
-
"sep": "<SEP>"
|
213 |
-
}
|
214 |
-
tokenizer.add_special_tokens(
|
215 |
-
{'additional_special_tokens': list(special_tokens.values())}
|
216 |
-
)
|
217 |
-
return tokenizer
|
218 |
|
219 |
-
def _initialize_encoder(self) ->
|
220 |
-
"""Initialize the
|
221 |
-
logger.info("Initializing encoder model...")
|
222 |
-
encoder =
|
223 |
-
self.config,
|
224 |
-
name="shared_encoder",
|
225 |
-
)
|
226 |
-
|
227 |
-
new_vocab_size = len(self.tokenizer)
|
228 |
-
encoder.pretrained.resize_token_embeddings(new_vocab_size)
|
229 |
-
logger.info(f"Token embeddings resized to: {new_vocab_size}")
|
230 |
-
|
231 |
return encoder
|
232 |
|
233 |
def _load_faiss_index_and_responses(self) -> None:
|
@@ -254,43 +131,35 @@ class RetrievalChatbot(DeviceAwareModel):
|
|
254 |
except Exception as e:
|
255 |
logger.error(f"Failed to load FAISS index and response pool: {e}")
|
256 |
raise
|
257 |
-
|
258 |
@classmethod
|
259 |
def load_model(cls, load_dir: Union[str, Path], mode: str = 'training') -> 'RetrievalChatbot':
|
260 |
-
"""
|
261 |
-
Load saved models and configuration.
|
262 |
-
"""
|
263 |
load_dir = Path(load_dir)
|
264 |
|
265 |
# Load config
|
266 |
-
|
267 |
-
|
|
|
|
|
|
|
|
|
|
|
268 |
|
269 |
# Initialize chatbot
|
270 |
chatbot = cls(config, mode=mode)
|
271 |
|
272 |
-
# Load
|
273 |
-
|
274 |
-
|
275 |
-
|
276 |
-
|
277 |
-
|
278 |
-
# Load tokenizer
|
279 |
-
chatbot.tokenizer = AutoTokenizer.from_pretrained(load_dir / "tokenizer")
|
280 |
-
logger.info(f"Models and tokenizer loaded from {load_dir}")
|
281 |
-
|
282 |
-
# Load the custom weights
|
283 |
-
custom_weights_path = load_dir / "encoder_custom_weights.weights.h5"
|
284 |
-
if custom_weights_path.exists():
|
285 |
-
chatbot.encoder.load_weights(str(custom_weights_path))
|
286 |
-
logger.info("Loaded custom encoder weights for projection/dropout/etc.")
|
287 |
else:
|
288 |
-
|
289 |
-
|
290 |
-
|
291 |
-
|
292 |
-
cls._prepare_model_for_inference(chatbot, load_dir)
|
293 |
-
|
294 |
return chatbot
|
295 |
|
296 |
@classmethod
|
@@ -324,21 +193,19 @@ class RetrievalChatbot(DeviceAwareModel):
|
|
324 |
except Exception as e:
|
325 |
logger.error(f"Error loading inference components: {e}")
|
326 |
raise
|
327 |
-
|
328 |
def save_models(self, save_dir: Union[str, Path]):
|
329 |
-
"""Save model and config"""
|
330 |
save_dir = Path(save_dir)
|
331 |
save_dir.mkdir(parents=True, exist_ok=True)
|
332 |
|
333 |
# Save config
|
334 |
with open(save_dir / "config.json", "w") as f:
|
335 |
json.dump(self.config.to_dict(), f, indent=2)
|
336 |
-
|
337 |
-
# Save
|
338 |
-
self.encoder.
|
339 |
-
|
340 |
-
self.tokenizer.save_pretrained(save_dir / "tokenizer")
|
341 |
-
logger.info(f"Models and tokenizer saved to {save_dir}.")
|
342 |
|
343 |
def retrieve_responses(
|
344 |
self,
|
@@ -346,59 +213,73 @@ class RetrievalChatbot(DeviceAwareModel):
|
|
346 |
top_k: int = 10,
|
347 |
reranker: Optional[CrossEncoderReranker] = None,
|
348 |
summarizer: Optional[Summarizer] = None,
|
349 |
-
summarize_threshold: int = 512
|
|
|
350 |
) -> List[Tuple[str, float]]:
|
351 |
"""
|
352 |
Retrieve top-k responses using FAISS and cross-encoder re-ranking.
|
|
|
353 |
Args:
|
354 |
query: The user's input text.
|
355 |
-
top_k: Number of
|
356 |
-
reranker:
|
357 |
-
summarizer:
|
358 |
-
summarize_threshold:
|
|
|
|
|
359 |
Returns:
|
360 |
List of (response_text, final_score).
|
361 |
"""
|
362 |
def sigmoid(x: float) -> float:
|
363 |
return 1 / (1 + np.exp(-x))
|
364 |
|
365 |
-
#
|
366 |
if summarizer and len(query.split()) > summarize_threshold:
|
367 |
-
logger.info(f"Query is long ({len(query.split())} words). Summarizing
|
368 |
query = summarizer.summarize_text(query)
|
369 |
-
logger.info(f"Summarized
|
370 |
-
|
|
|
371 |
detected_domain = self.detect_domain_from_query(query)
|
372 |
|
373 |
-
# Retrieve
|
374 |
-
|
375 |
-
faiss_candidates = self.
|
376 |
|
377 |
if not faiss_candidates:
|
|
|
378 |
return []
|
379 |
|
380 |
-
|
|
|
|
|
|
|
381 |
|
382 |
-
if
|
383 |
reranker = CrossEncoderReranker(model_name=self.config.cross_encoder_model)
|
|
|
|
|
384 |
|
385 |
-
#
|
386 |
-
ce_logits = reranker.rerank(query, texts, max_length=256)
|
387 |
-
|
388 |
-
# Combine scores from FAISS and cross-encoder
|
389 |
final_candidates = []
|
390 |
-
for
|
391 |
-
ce_prob = sigmoid(logit) #
|
392 |
-
faiss_norm = (faiss_score + 1)/2
|
393 |
-
combined_score = 0.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
394 |
length_adjusted_score = self.length_adjust_score(resp_text, combined_score)
|
395 |
|
396 |
final_candidates.append((resp_text, length_adjusted_score))
|
397 |
|
398 |
-
# Sort
|
399 |
final_candidates.sort(key=lambda x: x[1], reverse=True)
|
400 |
-
|
401 |
-
# Return top_k
|
402 |
return final_candidates[:top_k]
|
403 |
|
404 |
def extract_keywords(self, query: str) -> List[str]:
|
@@ -636,21 +517,45 @@ class RetrievalChatbot(DeviceAwareModel):
|
|
636 |
conversation_history: Optional[List[Tuple[str, str]]]
|
637 |
) -> str:
|
638 |
"""
|
639 |
-
Build conversation context string from conversation history
|
|
|
640 |
"""
|
|
|
|
|
|
|
641 |
if not conversation_history:
|
642 |
-
return f"{
|
643 |
-
|
644 |
conversation_parts = []
|
645 |
for user_txt, assistant_txt in conversation_history:
|
646 |
-
|
647 |
-
|
648 |
-
|
649 |
-
|
650 |
-
|
651 |
-
conversation_parts.append(f"{self.tokenizer.additional_special_tokens[self.tokenizer.additional_special_tokens.index('<USER>')]} {query}")
|
652 |
return "\n".join(conversation_parts)
|
653 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
654 |
def train_model(
|
655 |
self,
|
656 |
tfrecord_file_path: str,
|
|
|
1 |
import os
|
2 |
import numpy as np
|
3 |
+
from sentence_transformers import SentenceTransformer
|
4 |
import tensorflow as tf
|
5 |
from typing import List, Tuple, Dict, Optional, Union, Any
|
6 |
import math
|
|
|
11 |
import faiss
|
12 |
import gc
|
13 |
import re
|
|
|
14 |
from response_quality_checker import ResponseQualityChecker
|
15 |
from cross_encoder_reranker import CrossEncoderReranker
|
16 |
from conversation_summarizer import DeviceAwareModel, Summarizer
|
17 |
+
from chatbot_config import ChatbotConfig
|
18 |
+
from tf_data_pipeline import TFDataPipeline
|
19 |
import absl.logging
|
20 |
from logger_config import config_logger
|
21 |
from tqdm.auto import tqdm
|
22 |
|
23 |
absl.logging.set_verbosity(absl.logging.WARNING)
|
24 |
logger = config_logger(__name__)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
25 |
|
26 |
class RetrievalChatbot(DeviceAwareModel):
|
27 |
"""
|
28 |
Retrieval-based learning chatbot model.
|
29 |
Uses trained embeddings and FAISS for similarity search.
|
30 |
"""
|
31 |
+
|
32 |
def __init__(
|
33 |
self,
|
34 |
config: ChatbotConfig,
|
|
|
38 |
summarizer: Optional[Summarizer] = None,
|
39 |
mode: str = 'training'
|
40 |
):
|
41 |
+
|
42 |
super().__init__()
|
43 |
self.config = config
|
44 |
self.strategy = strategy
|
|
|
46 |
self.mode = mode.lower()
|
47 |
|
48 |
# Initialize reranker, summarizer, tokenizer, and encoder
|
|
|
|
|
49 |
self.encoder = self._initialize_encoder()
|
50 |
+
self.tokenizer = self.encoder.tokenizer
|
51 |
+
self.reranker = reranker or self._initialize_reranker()
|
52 |
self.summarizer = summarizer or self._initialize_summarizer()
|
53 |
|
54 |
# Initialize data pipeline
|
55 |
logger.info("Initializing TFDataPipeline.")
|
56 |
+
|
57 |
self.data_pipeline = TFDataPipeline(
|
58 |
config=self.config,
|
59 |
tokenizer=self.tokenizer,
|
|
|
78 |
"train_metrics": {},
|
79 |
"val_metrics": {}
|
80 |
}
|
|
|
81 |
|
82 |
def _setup_default_device(self) -> str:
|
83 |
"""Set up default device if none is provided."""
|
|
|
100 |
device=self.device,
|
101 |
max_summary_rounds=2
|
102 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
103 |
|
104 |
+
def _initialize_encoder(self) -> SentenceTransformer:
|
105 |
+
"""Initialize the Sentence Transformer model."""
|
106 |
+
logger.info("Initializing SentenceTransformer encoder model...")
|
107 |
+
encoder = SentenceTransformer(self.config.pretrained_model)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
108 |
return encoder
|
109 |
|
110 |
def _load_faiss_index_and_responses(self) -> None:
|
|
|
131 |
except Exception as e:
|
132 |
logger.error(f"Failed to load FAISS index and response pool: {e}")
|
133 |
raise
|
134 |
+
|
135 |
@classmethod
|
136 |
def load_model(cls, load_dir: Union[str, Path], mode: str = 'training') -> 'RetrievalChatbot':
|
137 |
+
"""Load chatbot model and configuration."""
|
|
|
|
|
138 |
load_dir = Path(load_dir)
|
139 |
|
140 |
# Load config
|
141 |
+
config_path = load_dir / "config.json"
|
142 |
+
if config_path.exists():
|
143 |
+
with open(config_path, "r") as f:
|
144 |
+
config = ChatbotConfig.from_dict(json.load(f))
|
145 |
+
logger.info("Loaded ChatbotConfig from config.json.")
|
146 |
+
else:
|
147 |
+
raise FileNotFoundError(f"Config file not found at {config_path}. Please ensure it exists.")
|
148 |
|
149 |
# Initialize chatbot
|
150 |
chatbot = cls(config, mode=mode)
|
151 |
|
152 |
+
# Load Sentence Transformer
|
153 |
+
model_path = load_dir / "sentence_transformer"
|
154 |
+
if model_path.exists():
|
155 |
+
# Load locally saved model
|
156 |
+
chatbot.encoder = SentenceTransformer(str(model_path))
|
157 |
+
logger.info("Loaded SentenceTransformer model from local path successfully.")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
158 |
else:
|
159 |
+
# Load from pre-trained model hub
|
160 |
+
chatbot.encoder = SentenceTransformer(config.pretrained_model)
|
161 |
+
logger.info(f"Loaded SentenceTransformer model '{config.pretrained_model}' from the hub successfully.")
|
162 |
+
|
|
|
|
|
163 |
return chatbot
|
164 |
|
165 |
@classmethod
|
|
|
193 |
except Exception as e:
|
194 |
logger.error(f"Error loading inference components: {e}")
|
195 |
raise
|
196 |
+
|
197 |
def save_models(self, save_dir: Union[str, Path]):
|
198 |
+
"""Save SentenceTransformer model and config."""
|
199 |
save_dir = Path(save_dir)
|
200 |
save_dir.mkdir(parents=True, exist_ok=True)
|
201 |
|
202 |
# Save config
|
203 |
with open(save_dir / "config.json", "w") as f:
|
204 |
json.dump(self.config.to_dict(), f, indent=2)
|
205 |
+
|
206 |
+
# Save Sentence Transformer
|
207 |
+
self.encoder.save(save_dir / "sentence_transformer")
|
208 |
+
logger.info(f"Model and config saved to {save_dir}.")
|
|
|
|
|
209 |
|
210 |
def retrieve_responses(
|
211 |
self,
|
|
|
213 |
top_k: int = 10,
|
214 |
reranker: Optional[CrossEncoderReranker] = None,
|
215 |
summarizer: Optional[Summarizer] = None,
|
216 |
+
summarize_threshold: int = 512,
|
217 |
+
boost_factor: float = 1.15
|
218 |
) -> List[Tuple[str, float]]:
|
219 |
"""
|
220 |
Retrieve top-k responses using FAISS and cross-encoder re-ranking.
|
221 |
+
|
222 |
Args:
|
223 |
query: The user's input text.
|
224 |
+
top_k: Number of responses to return.
|
225 |
+
reranker: Optional reranker for refined scoring.
|
226 |
+
summarizer: Optional summarizer for long queries.
|
227 |
+
summarize_threshold: Threshold to summarize long queries.
|
228 |
+
boost_factor: Factor to boost scores for keyword matches.
|
229 |
+
|
230 |
Returns:
|
231 |
List of (response_text, final_score).
|
232 |
"""
|
233 |
def sigmoid(x: float) -> float:
|
234 |
return 1 / (1 + np.exp(-x))
|
235 |
|
236 |
+
# Summarize long queries
|
237 |
if summarizer and len(query.split()) > summarize_threshold:
|
238 |
+
logger.info(f"Query is long ({len(query.split())} words). Summarizing...")
|
239 |
query = summarizer.summarize_text(query)
|
240 |
+
logger.info(f"Summarized query: {query}")
|
241 |
+
|
242 |
+
# Detect domain for query
|
243 |
detected_domain = self.detect_domain_from_query(query)
|
244 |
|
245 |
+
# Step 1: Retrieve candidates from FAISS
|
246 |
+
logger.info("Retrieving initial candidates from FAISS...")
|
247 |
+
faiss_candidates = self.data_pipeline.retrieve_responses(query, top_k=top_k * 10)
|
248 |
|
249 |
if not faiss_candidates:
|
250 |
+
logger.warning("No candidates retrieved from FAISS.")
|
251 |
return []
|
252 |
|
253 |
+
# Step 2: Re-rank candidates using Cross-Encoder
|
254 |
+
logger.info("Re-ranking candidates using Cross-Encoder...")
|
255 |
+
texts = [item[0] for item in faiss_candidates] # Extract response texts
|
256 |
+
faiss_scores = [item[1] for item in faiss_candidates]
|
257 |
|
258 |
+
if reranker is None:
|
259 |
reranker = CrossEncoderReranker(model_name=self.config.cross_encoder_model)
|
260 |
+
|
261 |
+
ce_logits = reranker.rerank(query, texts, max_length=256) # Re-rank responses
|
262 |
|
263 |
+
# Combine FAISS and Cross-Encoder scores
|
|
|
|
|
|
|
264 |
final_candidates = []
|
265 |
+
for resp_text, faiss_score, logit in zip(texts, faiss_scores, ce_logits):
|
266 |
+
ce_prob = sigmoid(logit) # Cross-encoder score in range [0, 1]
|
267 |
+
faiss_norm = (faiss_score + 1) / 2 # Normalize FAISS score to range [0, 1]
|
268 |
+
combined_score = 0.75 * ce_prob + 0.25 * faiss_norm
|
269 |
+
|
270 |
+
# Boost score based on keyword match
|
271 |
+
query_keywords = self.extract_keywords(query)
|
272 |
+
if query_keywords and any(kw in resp_text.lower() for kw in query_keywords):
|
273 |
+
combined_score *= boost_factor
|
274 |
+
|
275 |
+
# Adjust score based on length
|
276 |
length_adjusted_score = self.length_adjust_score(resp_text, combined_score)
|
277 |
|
278 |
final_candidates.append((resp_text, length_adjusted_score))
|
279 |
|
280 |
+
# Step 3: Sort and return top-k results
|
281 |
final_candidates.sort(key=lambda x: x[1], reverse=True)
|
282 |
+
logger.info(f"Returning top-{top_k} re-ranked responses.")
|
|
|
283 |
return final_candidates[:top_k]
|
284 |
|
285 |
def extract_keywords(self, query: str) -> List[str]:
|
|
|
517 |
conversation_history: Optional[List[Tuple[str, str]]]
|
518 |
) -> str:
|
519 |
"""
|
520 |
+
Build conversation context string from conversation history,
|
521 |
+
using literal <USER> and <ASSISTANT> tokens (no tokenizer special index).
|
522 |
"""
|
523 |
+
USER_TOKEN = "<USER>"
|
524 |
+
ASSISTANT_TOKEN = "<ASSISTANT>"
|
525 |
+
|
526 |
if not conversation_history:
|
527 |
+
return f"{USER_TOKEN} {query}"
|
528 |
+
|
529 |
conversation_parts = []
|
530 |
for user_txt, assistant_txt in conversation_history:
|
531 |
+
# Insert literal tokens
|
532 |
+
conversation_parts.append(f"{USER_TOKEN} {user_txt}")
|
533 |
+
conversation_parts.append(f"{ASSISTANT_TOKEN} {assistant_txt}")
|
534 |
+
|
535 |
+
conversation_parts.append(f"{USER_TOKEN} {query}")
|
|
|
536 |
return "\n".join(conversation_parts)
|
537 |
|
538 |
+
# def _build_conversation_context(
|
539 |
+
# self,
|
540 |
+
# query: str,
|
541 |
+
# conversation_history: Optional[List[Tuple[str, str]]]
|
542 |
+
# ) -> str:
|
543 |
+
# """
|
544 |
+
# Build conversation context string from conversation history.
|
545 |
+
# """
|
546 |
+
# if not conversation_history:
|
547 |
+
# return f"{self.tokenizer.additional_special_tokens[self.tokenizer.additional_special_tokens.index('<USER>')]} {query}"
|
548 |
+
|
549 |
+
# conversation_parts = []
|
550 |
+
# for user_txt, assistant_txt in conversation_history:
|
551 |
+
# conversation_parts.extend([
|
552 |
+
# f"{self.tokenizer.additional_special_tokens[self.tokenizer.additional_special_tokens.index('<USER>')]} {user_txt}",
|
553 |
+
# f"{self.tokenizer.additional_special_tokens[self.tokenizer.additional_special_tokens.index('<ASSISTANT>')]} {assistant_txt}"
|
554 |
+
# ])
|
555 |
+
|
556 |
+
# conversation_parts.append(f"{self.tokenizer.additional_special_tokens[self.tokenizer.additional_special_tokens.index('<USER>')]} {query}")
|
557 |
+
# return "\n".join(conversation_parts)
|
558 |
+
|
559 |
def train_model(
|
560 |
self,
|
561 |
tfrecord_file_path: str,
|
chatbot_validator.py
CHANGED
@@ -13,7 +13,7 @@ class ChatbotValidator:
|
|
13 |
This testing module executes domain-specific queries, obtains chatbot responses, and evaluates them with a quality checker.
|
14 |
"""
|
15 |
|
16 |
-
def __init__(self, chatbot, quality_checker):
|
17 |
"""
|
18 |
Initialize the validator.
|
19 |
Args:
|
@@ -22,6 +22,7 @@ class ChatbotValidator:
|
|
22 |
"""
|
23 |
self.chatbot = chatbot
|
24 |
self.quality_checker = quality_checker
|
|
|
25 |
|
26 |
# Domain-specific test queries (aligns with Taskmaster-1 dataset)
|
27 |
self.domain_queries = {
|
@@ -85,9 +86,6 @@ class ChatbotValidator:
|
|
85 |
metrics_history = []
|
86 |
domain_metrics = {}
|
87 |
|
88 |
-
# Init the cross-encoder reranker to pass to the chatbot
|
89 |
-
reranker = CrossEncoderReranker(model_name=self.chatbot.config.cross_encoder_model)
|
90 |
-
|
91 |
# Prepare random selection if needed
|
92 |
rng = random.Random(seed)
|
93 |
|
@@ -113,7 +111,7 @@ class ChatbotValidator:
|
|
113 |
logger.info(f"TEST CASE {i}: QUERY: {query}")
|
114 |
|
115 |
# Retrieve top_k responses, then evaluate with quality checker
|
116 |
-
responses = self.chatbot.retrieve_responses(query, top_k=top_k, reranker=reranker)
|
117 |
quality_metrics = self.quality_checker.check_response_quality(query, responses)
|
118 |
|
119 |
# Aggregate metrics and log
|
|
|
13 |
This testing module executes domain-specific queries, obtains chatbot responses, and evaluates them with a quality checker.
|
14 |
"""
|
15 |
|
16 |
+
def __init__(self, chatbot, quality_checker, cross_encoder_model='cross-encoder/ms-marco-MiniLM-L-12-v2'):
|
17 |
"""
|
18 |
Initialize the validator.
|
19 |
Args:
|
|
|
22 |
"""
|
23 |
self.chatbot = chatbot
|
24 |
self.quality_checker = quality_checker
|
25 |
+
self.reranker = CrossEncoderReranker(model_name=cross_encoder_model)
|
26 |
|
27 |
# Domain-specific test queries (aligns with Taskmaster-1 dataset)
|
28 |
self.domain_queries = {
|
|
|
86 |
metrics_history = []
|
87 |
domain_metrics = {}
|
88 |
|
|
|
|
|
|
|
89 |
# Prepare random selection if needed
|
90 |
rng = random.Random(seed)
|
91 |
|
|
|
111 |
logger.info(f"TEST CASE {i}: QUERY: {query}")
|
112 |
|
113 |
# Retrieve top_k responses, then evaluate with quality checker
|
114 |
+
responses = self.chatbot.retrieve_responses(query, top_k=top_k, reranker=self.reranker)
|
115 |
quality_metrics = self.quality_checker.check_response_quality(query, responses)
|
116 |
|
117 |
# Aggregate metrics and log
|
prepare_data.py
CHANGED
@@ -1,14 +1,12 @@
|
|
1 |
import os
|
2 |
-
import sys
|
3 |
-
import faiss
|
4 |
import json
|
5 |
import pickle
|
6 |
-
import
|
7 |
-
from transformers import AutoTokenizer, TFAutoModel
|
8 |
from tqdm.auto import tqdm
|
9 |
from pathlib import Path
|
10 |
-
from
|
11 |
from tf_data_pipeline import TFDataPipeline
|
|
|
12 |
from logger_config import config_logger
|
13 |
|
14 |
logger = config_logger(__name__)
|
@@ -23,15 +21,10 @@ def main():
|
|
23 |
FAISS_INDICES_DIR = os.path.join(MODELS_DIR, 'faiss_indices')
|
24 |
TF_RECORD_DIR = 'training_data'
|
25 |
FAISS_INDEX_PRODUCTION_PATH = os.path.join(FAISS_INDICES_DIR, 'faiss_index_production.index')
|
26 |
-
JSON_TRAINING_DATA_PATH = os.path.join(PROCESSED_DATA_DIR, '
|
27 |
CACHE_FILE = os.path.join(CACHE_DIR, 'query_embeddings_cache.pkl')
|
28 |
TF_RECORD_PATH = os.path.join(TF_RECORD_DIR, 'training_data_3.tfrecord')
|
29 |
|
30 |
-
# Decide whether to load the **custom** model or base DistilBERT (Base used for first iteration).
|
31 |
-
# True for custom, False for base DistilBERT.
|
32 |
-
LOAD_CUSTOM_MODEL = True
|
33 |
-
NUM_NEG_SAMPLES = 10
|
34 |
-
|
35 |
# Ensure output directories exist
|
36 |
os.makedirs(MODELS_DIR, exist_ok=True)
|
37 |
os.makedirs(PROCESSED_DATA_DIR, exist_ok=True)
|
@@ -40,7 +33,7 @@ def main():
|
|
40 |
os.makedirs(FAISS_INDICES_DIR, exist_ok=True)
|
41 |
os.makedirs(TF_RECORD_DIR, exist_ok=True)
|
42 |
|
43 |
-
#
|
44 |
config_json = Path(MODELS_DIR) / "config.json"
|
45 |
if config_json.exists():
|
46 |
with open(config_json, "r", encoding="utf-8") as f:
|
@@ -50,187 +43,77 @@ def main():
|
|
50 |
else:
|
51 |
config = ChatbotConfig()
|
52 |
logger.warning("No config.json found. Using default ChatbotConfig.")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
53 |
|
54 |
-
#
|
55 |
-
|
56 |
-
|
57 |
-
# Load or init tokenizer
|
58 |
-
try:
|
59 |
-
if Path(TOKENIZER_DIR).exists() and list(Path(TOKENIZER_DIR).iterdir()):
|
60 |
-
logger.info(f"Loading tokenizer from {TOKENIZER_DIR}")
|
61 |
-
tokenizer = AutoTokenizer.from_pretrained(TOKENIZER_DIR)
|
62 |
-
else:
|
63 |
-
logger.info(f"Loading base tokenizer for {config.pretrained_model}")
|
64 |
-
tokenizer = AutoTokenizer.from_pretrained(config.pretrained_model)
|
65 |
-
|
66 |
-
Path(TOKENIZER_DIR).mkdir(parents=True, exist_ok=True)
|
67 |
-
tokenizer.save_pretrained(TOKENIZER_DIR)
|
68 |
-
logger.info(f"New tokenizer saved to {TOKENIZER_DIR}")
|
69 |
-
except Exception as e:
|
70 |
-
logger.error(f"Failed to load or create tokenizer: {e}")
|
71 |
-
sys.exit(1)
|
72 |
-
|
73 |
-
# Init the encoder
|
74 |
-
try:
|
75 |
-
encoder = EncoderModel(config=config)
|
76 |
-
logger.info("EncoderModel initialized successfully.")
|
77 |
-
|
78 |
-
if LOAD_CUSTOM_MODEL:
|
79 |
-
# Load the DistilBERT submodule from 'shared_encoder'
|
80 |
-
shared_encoder_path = Path(MODELS_DIR) / "shared_encoder"
|
81 |
-
if shared_encoder_path.exists():
|
82 |
-
logger.info(f"Loading DistilBERT submodule from {shared_encoder_path}")
|
83 |
-
encoder.pretrained = TFAutoModel.from_pretrained(shared_encoder_path)
|
84 |
-
else:
|
85 |
-
logger.warning(f"No shared_encoder found at {shared_encoder_path}, using base DistilBERT instead.")
|
86 |
-
|
87 |
-
# Load custom .weights.h5 (projection, dropout, etc.)
|
88 |
-
custom_weights_path = Path(MODELS_DIR) / "encoder_custom_weights.weights.h5"
|
89 |
-
if custom_weights_path.exists():
|
90 |
-
logger.info(f"Loading custom top-level weights from {custom_weights_path}")
|
91 |
-
|
92 |
-
# Dummy forward pass forces model build to ensure all layers are built
|
93 |
-
dummy_input = tf.zeros((1, config.max_context_token_limit), dtype=tf.int32)
|
94 |
-
_ = encoder(dummy_input, training=False)
|
95 |
-
|
96 |
-
encoder.load_weights(str(custom_weights_path))
|
97 |
-
logger.info("Custom encoder weights loaded successfully.")
|
98 |
-
else:
|
99 |
-
logger.warning(f"Custom weights file not found at {custom_weights_path}. Using only submodule weights.")
|
100 |
-
else:
|
101 |
-
# Base DistilBERT with special tokens
|
102 |
-
logger.info("Using the base DistilBERT without loading custom weights.")
|
103 |
-
|
104 |
-
# Resize token embeddings in case we added special tokens (EncoderModel class)
|
105 |
-
encoder.pretrained.resize_token_embeddings(len(tokenizer))
|
106 |
-
logger.info(f"Token embeddings resized to: {len(tokenizer)}")
|
107 |
-
|
108 |
-
except Exception as e:
|
109 |
-
logger.error(f"Failed to initialize EncoderModel: {e}")
|
110 |
-
sys.exit(1)
|
111 |
|
112 |
-
# Load
|
113 |
-
|
114 |
-
|
115 |
-
|
116 |
-
|
117 |
-
|
118 |
-
|
119 |
-
logger.info(f"Loaded {len(dialogues)} dialogues from {JSON_TRAINING_DATA_PATH}.")
|
120 |
-
except Exception as e:
|
121 |
-
logger.error(f"Failed to load dialogues: {e}")
|
122 |
-
sys.exit(1)
|
123 |
|
124 |
-
# Load or init
|
125 |
query_embeddings_cache = {}
|
126 |
if os.path.exists(CACHE_FILE):
|
127 |
-
|
128 |
-
|
129 |
-
|
130 |
-
logger.info(f"Loaded {len(query_embeddings_cache)} query embeddings from {CACHE_FILE}.")
|
131 |
-
except Exception as e:
|
132 |
-
logger.warning(f"Failed to load query embeddings cache: {e}")
|
133 |
else:
|
134 |
logger.info("No existing query embeddings cache found. Starting fresh.")
|
135 |
|
136 |
-
#
|
137 |
-
|
138 |
-
|
139 |
-
|
140 |
-
|
141 |
-
|
142 |
-
|
143 |
-
|
144 |
-
|
145 |
-
|
146 |
-
|
147 |
-
|
148 |
-
|
149 |
-
|
150 |
-
|
151 |
-
|
152 |
-
|
153 |
-
|
154 |
-
|
155 |
-
|
156 |
-
|
157 |
-
|
158 |
-
|
159 |
-
|
160 |
-
|
161 |
-
|
162 |
-
)
|
163 |
-
|
164 |
-
|
165 |
-
logger.
|
166 |
-
|
167 |
-
|
168 |
-
|
169 |
-
|
170 |
-
|
171 |
-
response_pool = data_pipeline.collect_responses_with_domain(dialogues)
|
172 |
-
data_pipeline.response_pool = response_pool
|
173 |
-
logger.info(f"Collected {len(response_pool)} unique assistant responses from dialogues.")
|
174 |
-
else:
|
175 |
-
logger.warning("No dialogues loaded. response_pool remains empty.")
|
176 |
-
except Exception as e:
|
177 |
-
logger.error(f"Failed to collect responses: {e}")
|
178 |
-
sys.exit(1)
|
179 |
-
|
180 |
-
# Build FAISS index with response embeddings
|
181 |
-
try:
|
182 |
-
if data_pipeline.response_pool:
|
183 |
-
data_pipeline.build_text_to_domain_map()
|
184 |
-
logger.info("Computing and adding response embeddings to FAISS index using TFDataPipeline...")
|
185 |
-
data_pipeline.compute_and_index_response_embeddings()
|
186 |
-
logger.info("Response embeddings computed and added to FAISS index.")
|
187 |
-
|
188 |
-
# Save the FAISS index
|
189 |
-
data_pipeline.save_faiss_index(FAISS_INDEX_PRODUCTION_PATH)
|
190 |
-
|
191 |
-
# Also save response pool JSON
|
192 |
-
response_pool_path = FAISS_INDEX_PRODUCTION_PATH.replace('.index', '_responses.json')
|
193 |
-
with open(response_pool_path, 'w', encoding='utf-8') as f:
|
194 |
-
json.dump(data_pipeline.response_pool, f, indent=2)
|
195 |
-
logger.info(f"Response pool saved to {response_pool_path}.")
|
196 |
-
else:
|
197 |
-
logger.warning("No responses to embed. Skipping FAISS indexing.")
|
198 |
-
|
199 |
-
except Exception as e:
|
200 |
-
logger.error(f"Failed to compute or add response embeddings: {e}")
|
201 |
-
sys.exit(1)
|
202 |
-
|
203 |
-
# Prepare training data as TFRecords (TensforFlow Record format)
|
204 |
-
try:
|
205 |
-
if dialogues:
|
206 |
-
logger.info("Starting data preparation and saving as TFRecord...")
|
207 |
-
data_pipeline.prepare_and_save_data(dialogues, TF_RECORD_PATH)
|
208 |
-
logger.info(f"Data saved as TFRecord at {TF_RECORD_PATH}.")
|
209 |
-
else:
|
210 |
-
logger.warning("No dialogues to build TFRecord from. Skipping TFRecord creation.")
|
211 |
-
except Exception as e:
|
212 |
-
logger.error(f"Failed during data preparation and saving: {e}")
|
213 |
-
sys.exit(1)
|
214 |
|
215 |
# Save query embeddings cache
|
216 |
-
|
217 |
-
|
218 |
-
|
219 |
-
logger.info(f"Saved {len(data_pipeline.query_embeddings_cache)} query embeddings to {CACHE_FILE}.")
|
220 |
-
except Exception as e:
|
221 |
-
logger.error(f"Failed to save query embeddings cache: {e}")
|
222 |
-
sys.exit(1)
|
223 |
-
|
224 |
-
# Save Tokenizer
|
225 |
-
try:
|
226 |
-
tokenizer.save_pretrained(TOKENIZER_DIR)
|
227 |
-
logger.info(f"Tokenizer saved to {TOKENIZER_DIR}.")
|
228 |
-
except Exception as e:
|
229 |
-
logger.error(f"Failed to save tokenizer: {e}")
|
230 |
-
sys.exit(1)
|
231 |
-
|
232 |
-
logger.info("Data preparation pipeline completed successfully.")
|
233 |
|
|
|
234 |
|
235 |
if __name__ == "__main__":
|
236 |
main()
|
|
|
1 |
import os
|
|
|
|
|
2 |
import json
|
3 |
import pickle
|
4 |
+
import faiss
|
|
|
5 |
from tqdm.auto import tqdm
|
6 |
from pathlib import Path
|
7 |
+
from sentence_transformers import SentenceTransformer
|
8 |
from tf_data_pipeline import TFDataPipeline
|
9 |
+
from chatbot_config import ChatbotConfig
|
10 |
from logger_config import config_logger
|
11 |
|
12 |
logger = config_logger(__name__)
|
|
|
21 |
FAISS_INDICES_DIR = os.path.join(MODELS_DIR, 'faiss_indices')
|
22 |
TF_RECORD_DIR = 'training_data'
|
23 |
FAISS_INDEX_PRODUCTION_PATH = os.path.join(FAISS_INDICES_DIR, 'faiss_index_production.index')
|
24 |
+
JSON_TRAINING_DATA_PATH = os.path.join(PROCESSED_DATA_DIR, 'taskmaster_only.json')
|
25 |
CACHE_FILE = os.path.join(CACHE_DIR, 'query_embeddings_cache.pkl')
|
26 |
TF_RECORD_PATH = os.path.join(TF_RECORD_DIR, 'training_data_3.tfrecord')
|
27 |
|
|
|
|
|
|
|
|
|
|
|
28 |
# Ensure output directories exist
|
29 |
os.makedirs(MODELS_DIR, exist_ok=True)
|
30 |
os.makedirs(PROCESSED_DATA_DIR, exist_ok=True)
|
|
|
33 |
os.makedirs(FAISS_INDICES_DIR, exist_ok=True)
|
34 |
os.makedirs(TF_RECORD_DIR, exist_ok=True)
|
35 |
|
36 |
+
# Load ChatbotConfig
|
37 |
config_json = Path(MODELS_DIR) / "config.json"
|
38 |
if config_json.exists():
|
39 |
with open(config_json, "r", encoding="utf-8") as f:
|
|
|
43 |
else:
|
44 |
config = ChatbotConfig()
|
45 |
logger.warning("No config.json found. Using default ChatbotConfig.")
|
46 |
+
try:
|
47 |
+
with open(config_json, "w", encoding="utf-8") as f:
|
48 |
+
json.dump(config.to_dict(), f, indent=2)
|
49 |
+
logger.info(f"Default ChatbotConfig saved to {config_json}")
|
50 |
+
except Exception as e:
|
51 |
+
logger.error(f"Failed to save default ChatbotConfig: {e}")
|
52 |
+
raise
|
53 |
|
54 |
+
# Init SentenceTransformer
|
55 |
+
encoder = SentenceTransformer(config.pretrained_model)
|
56 |
+
logger.info(f"Initialized SentenceTransformer model: {config.pretrained_model}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
57 |
|
58 |
+
# Load dialogues
|
59 |
+
if Path(JSON_TRAINING_DATA_PATH).exists():
|
60 |
+
dialogues = TFDataPipeline.load_json_training_data(JSON_TRAINING_DATA_PATH)
|
61 |
+
logger.info(f"Loaded {len(dialogues)} dialogues.")
|
62 |
+
else:
|
63 |
+
logger.warning(f"No dialogues found at {JSON_TRAINING_DATA_PATH}.")
|
64 |
+
dialogues = []
|
|
|
|
|
|
|
|
|
65 |
|
66 |
+
# Load or init query embeddings cache
|
67 |
query_embeddings_cache = {}
|
68 |
if os.path.exists(CACHE_FILE):
|
69 |
+
with open(CACHE_FILE, 'rb') as f:
|
70 |
+
query_embeddings_cache = pickle.load(f)
|
71 |
+
logger.info(f"Loaded query embeddings cache with {len(query_embeddings_cache)} entries.")
|
|
|
|
|
|
|
72 |
else:
|
73 |
logger.info("No existing query embeddings cache found. Starting fresh.")
|
74 |
|
75 |
+
# Init FAISS index
|
76 |
+
dimension = encoder.get_sentence_embedding_dimension()
|
77 |
+
if Path(FAISS_INDEX_PRODUCTION_PATH).exists():
|
78 |
+
faiss_index = faiss.read_index(FAISS_INDEX_PRODUCTION_PATH)
|
79 |
+
logger.info(f"Loaded FAISS index from {FAISS_INDEX_PRODUCTION_PATH}.")
|
80 |
+
else:
|
81 |
+
faiss_index = faiss.IndexFlatIP(dimension)
|
82 |
+
logger.info(f"Initialized new FAISS index with dimension {dimension}.")
|
83 |
+
|
84 |
+
# Init TFDataPipeline
|
85 |
+
data_pipeline = TFDataPipeline(
|
86 |
+
config=config,
|
87 |
+
tokenizer=encoder.tokenizer,
|
88 |
+
encoder=encoder,
|
89 |
+
response_pool=[],
|
90 |
+
query_embeddings_cache=query_embeddings_cache,
|
91 |
+
index_type='IndexFlatIP',
|
92 |
+
faiss_index_file_path=FAISS_INDEX_PRODUCTION_PATH
|
93 |
+
)
|
94 |
+
|
95 |
+
# Collect and embed responses
|
96 |
+
if dialogues:
|
97 |
+
response_pool = data_pipeline.collect_responses_with_domain(dialogues)
|
98 |
+
data_pipeline.response_pool = response_pool
|
99 |
+
|
100 |
+
# Save the response pool
|
101 |
+
response_pool_path = FAISS_INDEX_PRODUCTION_PATH.replace('.index', '_responses.json')
|
102 |
+
with open(response_pool_path, 'w', encoding='utf-8') as f:
|
103 |
+
json.dump(response_pool, f, indent=2)
|
104 |
+
logger.info(f"Response pool saved to {response_pool_path}.")
|
105 |
+
data_pipeline.compute_and_index_response_embeddings()
|
106 |
+
data_pipeline.save_faiss_index(FAISS_INDEX_PRODUCTION_PATH)
|
107 |
+
logger.info(f"FAISS index saved at {FAISS_INDEX_PRODUCTION_PATH}.")
|
108 |
+
else:
|
109 |
+
logger.warning("No responses to embed. Skipping FAISS indexing.")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
110 |
|
111 |
# Save query embeddings cache
|
112 |
+
with open(CACHE_FILE, 'wb') as f:
|
113 |
+
pickle.dump(query_embeddings_cache, f)
|
114 |
+
logger.info(f"Query embeddings cache saved at {CACHE_FILE}.")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
115 |
|
116 |
+
logger.info("Pipeline completed successfully.")
|
117 |
|
118 |
if __name__ == "__main__":
|
119 |
main()
|
run_chatbot_chat.py
CHANGED
@@ -1,6 +1,7 @@
|
|
1 |
import os
|
2 |
import json
|
3 |
-
from chatbot_model import
|
|
|
4 |
from response_quality_checker import ResponseQualityChecker
|
5 |
from environment_setup import EnvironmentSetup
|
6 |
from logger_config import config_logger
|
|
|
1 |
import os
|
2 |
import json
|
3 |
+
from chatbot_model import RetrievalChatbot
|
4 |
+
from chatbot_config import ChatbotConfig
|
5 |
from response_quality_checker import ResponseQualityChecker
|
6 |
from environment_setup import EnvironmentSetup
|
7 |
from logger_config import config_logger
|
run_chatbot_validation.py
CHANGED
@@ -1,24 +1,27 @@
|
|
1 |
import os
|
2 |
import json
|
3 |
-
from
|
|
|
|
|
4 |
from response_quality_checker import ResponseQualityChecker
|
5 |
from chatbot_validator import ChatbotValidator
|
6 |
from plotter import Plotter
|
7 |
from environment_setup import EnvironmentSetup
|
8 |
from logger_config import config_logger
|
|
|
9 |
|
10 |
logger = config_logger(__name__)
|
11 |
-
|
12 |
def run_chatbot_validation():
|
13 |
# Initialize environment
|
14 |
env = EnvironmentSetup()
|
15 |
env.initialize()
|
16 |
-
|
17 |
MODEL_DIR = "models"
|
18 |
FAISS_INDICES_DIR = os.path.join(MODEL_DIR, "faiss_indices")
|
19 |
FAISS_INDEX_PRODUCTION_PATH = os.path.join(FAISS_INDICES_DIR, "faiss_index_production.index")
|
20 |
FAISS_INDEX_TEST_PATH = os.path.join(FAISS_INDICES_DIR, "faiss_index_test.index")
|
21 |
-
|
22 |
# Toggle 'production' or 'test' env
|
23 |
ENVIRONMENT = "production"
|
24 |
if ENVIRONMENT == "test":
|
@@ -27,7 +30,7 @@ def run_chatbot_validation():
|
|
27 |
else:
|
28 |
FAISS_INDEX_PATH = FAISS_INDEX_PRODUCTION_PATH
|
29 |
RESPONSE_POOL_PATH = FAISS_INDEX_PRODUCTION_PATH.replace(".index", "_responses.json")
|
30 |
-
|
31 |
# Load the config
|
32 |
config_path = os.path.join(MODEL_DIR, "config.json")
|
33 |
if os.path.exists(config_path):
|
@@ -38,55 +41,62 @@ def run_chatbot_validation():
|
|
38 |
else:
|
39 |
config = ChatbotConfig()
|
40 |
logger.warning("No config.json found. Using default ChatbotConfig.")
|
41 |
-
|
42 |
-
#
|
43 |
try:
|
44 |
-
|
45 |
-
|
|
|
46 |
except Exception as e:
|
47 |
-
logger.error(f"Failed to load
|
48 |
-
return
|
49 |
-
|
50 |
-
# Confirm FAISS index & response pool exist
|
51 |
-
if not os.path.exists(FAISS_INDEX_PATH) or not os.path.exists(RESPONSE_POOL_PATH):
|
52 |
-
logger.error("FAISS index or response pool file is missing.")
|
53 |
return
|
54 |
-
|
55 |
# Load FAISS index and response pool
|
56 |
try:
|
57 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
58 |
logger.info(f"FAISS index loaded from {FAISS_INDEX_PATH}.")
|
59 |
-
|
60 |
-
logger.info(f"FAISS index type: {type(chatbot.data_pipeline.index)}")
|
61 |
-
logger.info(f"FAISS index total vectors: {chatbot.data_pipeline.index.ntotal}")
|
62 |
-
logger.info(f"FAISS is_trained: {chatbot.data_pipeline.index.is_trained}")
|
63 |
-
|
64 |
with open(RESPONSE_POOL_PATH, "r", encoding="utf-8") as f:
|
65 |
-
|
66 |
logger.info(f"Response pool loaded from {RESPONSE_POOL_PATH}.")
|
67 |
-
logger.info(f"
|
68 |
-
|
69 |
# Validate dimension consistency
|
70 |
-
|
71 |
logger.info("FAISS index and response pool validated successfully.")
|
72 |
-
|
73 |
except Exception as e:
|
74 |
logger.error(f"Failed to load or validate FAISS index: {e}")
|
75 |
return
|
76 |
-
|
77 |
# Init QualityChecker and Validator
|
78 |
-
quality_checker = ResponseQualityChecker(data_pipeline=chatbot.data_pipeline)
|
79 |
-
validator = ChatbotValidator(chatbot=chatbot, quality_checker=quality_checker)
|
80 |
-
logger.info("ResponseQualityChecker and ChatbotValidator initialized.")
|
81 |
-
|
82 |
-
# Run validation
|
83 |
try:
|
|
|
|
|
|
|
|
|
|
|
|
|
84 |
validation_metrics = validator.run_validation(num_examples=5)
|
85 |
logger.info(f"Validation Metrics: {validation_metrics}")
|
86 |
except Exception as e:
|
87 |
logger.error(f"Validation process failed: {e}")
|
88 |
return
|
89 |
-
|
90 |
# Plot metrics
|
91 |
try:
|
92 |
plotter = Plotter(save_dir=env.training_dirs["plots"])
|
@@ -94,10 +104,22 @@ def run_chatbot_validation():
|
|
94 |
logger.info("Validation metrics plotted successfully.")
|
95 |
except Exception as e:
|
96 |
logger.error(f"Failed to plot validation metrics: {e}")
|
97 |
-
|
98 |
# Run interactive chat loop
|
99 |
-
|
100 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
101 |
|
102 |
if __name__ == "__main__":
|
103 |
-
run_chatbot_validation()
|
|
|
1 |
import os
|
2 |
import json
|
3 |
+
from sentence_transformers import SentenceTransformer
|
4 |
+
from chatbot_config import ChatbotConfig
|
5 |
+
from chatbot_model import RetrievalChatbot
|
6 |
from response_quality_checker import ResponseQualityChecker
|
7 |
from chatbot_validator import ChatbotValidator
|
8 |
from plotter import Plotter
|
9 |
from environment_setup import EnvironmentSetup
|
10 |
from logger_config import config_logger
|
11 |
+
from tf_data_pipeline import TFDataPipeline
|
12 |
|
13 |
logger = config_logger(__name__)
|
14 |
+
|
15 |
def run_chatbot_validation():
|
16 |
# Initialize environment
|
17 |
env = EnvironmentSetup()
|
18 |
env.initialize()
|
19 |
+
|
20 |
MODEL_DIR = "models"
|
21 |
FAISS_INDICES_DIR = os.path.join(MODEL_DIR, "faiss_indices")
|
22 |
FAISS_INDEX_PRODUCTION_PATH = os.path.join(FAISS_INDICES_DIR, "faiss_index_production.index")
|
23 |
FAISS_INDEX_TEST_PATH = os.path.join(FAISS_INDICES_DIR, "faiss_index_test.index")
|
24 |
+
|
25 |
# Toggle 'production' or 'test' env
|
26 |
ENVIRONMENT = "production"
|
27 |
if ENVIRONMENT == "test":
|
|
|
30 |
else:
|
31 |
FAISS_INDEX_PATH = FAISS_INDEX_PRODUCTION_PATH
|
32 |
RESPONSE_POOL_PATH = FAISS_INDEX_PRODUCTION_PATH.replace(".index", "_responses.json")
|
33 |
+
|
34 |
# Load the config
|
35 |
config_path = os.path.join(MODEL_DIR, "config.json")
|
36 |
if os.path.exists(config_path):
|
|
|
41 |
else:
|
42 |
config = ChatbotConfig()
|
43 |
logger.warning("No config.json found. Using default ChatbotConfig.")
|
44 |
+
|
45 |
+
# Init SentenceTransformer
|
46 |
try:
|
47 |
+
model_name = "sentence-transformers/all-MiniLM-L6-v2" # Replace with your chosen model
|
48 |
+
encoder = SentenceTransformer(model_name)
|
49 |
+
logger.info(f"Loaded SentenceTransformer model: {model_name}")
|
50 |
except Exception as e:
|
51 |
+
logger.error(f"Failed to load SentenceTransformer: {e}")
|
|
|
|
|
|
|
|
|
|
|
52 |
return
|
53 |
+
|
54 |
# Load FAISS index and response pool
|
55 |
try:
|
56 |
+
# Initialize TFDataPipeline
|
57 |
+
data_pipeline = TFDataPipeline(
|
58 |
+
config=config,
|
59 |
+
tokenizer=encoder.tokenizer,
|
60 |
+
encoder=encoder,
|
61 |
+
response_pool=[],
|
62 |
+
query_embeddings_cache={},
|
63 |
+
index_type='IndexFlatIP',
|
64 |
+
faiss_index_file_path=FAISS_INDEX_PATH
|
65 |
+
)
|
66 |
+
|
67 |
+
if not os.path.exists(FAISS_INDEX_PATH) or not os.path.exists(RESPONSE_POOL_PATH):
|
68 |
+
logger.error("FAISS index or response pool file is missing.")
|
69 |
+
return
|
70 |
+
|
71 |
+
data_pipeline.load_faiss_index(FAISS_INDEX_PATH)
|
72 |
logger.info(f"FAISS index loaded from {FAISS_INDEX_PATH}.")
|
73 |
+
|
|
|
|
|
|
|
|
|
74 |
with open(RESPONSE_POOL_PATH, "r", encoding="utf-8") as f:
|
75 |
+
data_pipeline.response_pool = json.load(f)
|
76 |
logger.info(f"Response pool loaded from {RESPONSE_POOL_PATH}.")
|
77 |
+
logger.info(f"Total responses in pool: {len(data_pipeline.response_pool)}")
|
78 |
+
|
79 |
# Validate dimension consistency
|
80 |
+
data_pipeline.validate_faiss_index()
|
81 |
logger.info("FAISS index and response pool validated successfully.")
|
|
|
82 |
except Exception as e:
|
83 |
logger.error(f"Failed to load or validate FAISS index: {e}")
|
84 |
return
|
85 |
+
|
86 |
# Init QualityChecker and Validator
|
|
|
|
|
|
|
|
|
|
|
87 |
try:
|
88 |
+
chatbot = RetrievalChatbot.load_model(load_dir=MODEL_DIR, mode="inference")
|
89 |
+
quality_checker = ResponseQualityChecker(data_pipeline=data_pipeline)
|
90 |
+
validator = ChatbotValidator(chatbot=chatbot, quality_checker=quality_checker)
|
91 |
+
logger.info("ResponseQualityChecker and ChatbotValidator initialized.")
|
92 |
+
|
93 |
+
# Run validation
|
94 |
validation_metrics = validator.run_validation(num_examples=5)
|
95 |
logger.info(f"Validation Metrics: {validation_metrics}")
|
96 |
except Exception as e:
|
97 |
logger.error(f"Validation process failed: {e}")
|
98 |
return
|
99 |
+
|
100 |
# Plot metrics
|
101 |
try:
|
102 |
plotter = Plotter(save_dir=env.training_dirs["plots"])
|
|
|
104 |
logger.info("Validation metrics plotted successfully.")
|
105 |
except Exception as e:
|
106 |
logger.error(f"Failed to plot validation metrics: {e}")
|
107 |
+
|
108 |
# Run interactive chat loop
|
109 |
+
try:
|
110 |
+
logger.info("\nStarting interactive chat session...")
|
111 |
+
while True:
|
112 |
+
user_input = input("You: ")
|
113 |
+
if user_input.lower() in ["exit", "quit"]:
|
114 |
+
logger.info("Exiting chat session.")
|
115 |
+
break
|
116 |
+
|
117 |
+
responses = data_pipeline.retrieve_responses(user_input, top_k=3)
|
118 |
+
print("Top Responses:")
|
119 |
+
for i, (response, score) in enumerate(responses, start=1):
|
120 |
+
print(f"{i}. {response} (Score: {score:.4f})")
|
121 |
+
except KeyboardInterrupt:
|
122 |
+
logger.info("Interactive chat session interrupted by user.")
|
123 |
|
124 |
if __name__ == "__main__":
|
125 |
+
run_chatbot_validation()
|
run_taskmaster_processor.py
CHANGED
@@ -5,7 +5,7 @@ from taskmaster_processor import TaskmasterProcessor, RawDataProcessingConfig
|
|
5 |
|
6 |
def main():
|
7 |
# Setup config and processor
|
8 |
-
base_dir = "
|
9 |
config = RawDataProcessingConfig(
|
10 |
debug=True,
|
11 |
max_length=512,
|
|
|
5 |
|
6 |
def main():
|
7 |
# Setup config and processor
|
8 |
+
base_dir = "raw_datasets/taskmaster"
|
9 |
config = RawDataProcessingConfig(
|
10 |
debug=True,
|
11 |
max_length=512,
|
taskmaster_processor.py
CHANGED
@@ -4,6 +4,9 @@ import json
|
|
4 |
from pathlib import Path
|
5 |
from typing import List, Dict, Optional, Any
|
6 |
from dataclasses import dataclass, field
|
|
|
|
|
|
|
7 |
|
8 |
@dataclass
|
9 |
class TaskmasterDialogue:
|
@@ -28,7 +31,7 @@ class RawDataProcessingConfig:
|
|
28 |
self,
|
29 |
debug: bool = True,
|
30 |
max_length: int = 512,
|
31 |
-
min_turns: int =
|
32 |
min_user_words: int = 3
|
33 |
):
|
34 |
self.debug = debug
|
@@ -68,7 +71,7 @@ class TaskmasterProcessor:
|
|
68 |
with open(ontology_path, 'r', encoding='utf-8') as f:
|
69 |
ontology = json.load(f)
|
70 |
if self.config.debug:
|
71 |
-
|
72 |
|
73 |
dialogues: List[TaskmasterDialogue] = []
|
74 |
|
@@ -106,7 +109,7 @@ class TaskmasterProcessor:
|
|
106 |
break
|
107 |
|
108 |
if self.config.debug:
|
109 |
-
|
110 |
return dialogues
|
111 |
|
112 |
def _extract_domain(self, scenario: str, turns: List[Dict[str, str]]) -> str:
|
@@ -130,43 +133,15 @@ class TaskmasterProcessor:
|
|
130 |
|
131 |
for domain, pattern in domain_patterns.items():
|
132 |
if re.search(pattern, combined_text):
|
133 |
-
# Optional:
|
134 |
if self.config.debug:
|
135 |
-
|
136 |
return domain
|
137 |
|
138 |
if self.config.debug:
|
139 |
-
|
140 |
return 'other'
|
141 |
|
142 |
-
def _process_utterances(self, utterances: List[Dict[str, Any]]) -> List[Dict[str, str]]:
|
143 |
-
"""
|
144 |
-
Convert "utterances" to a cleaned List -> (speaker, text).
|
145 |
-
Skip lines that are numeric, too short, or empty.
|
146 |
-
"""
|
147 |
-
cleaned_turns = []
|
148 |
-
for utt in utterances:
|
149 |
-
speaker = 'assistant' if utt.get('speaker') == 'ASSISTANT' else 'user'
|
150 |
-
raw_text = utt.get('text', '').strip()
|
151 |
-
|
152 |
-
# Text cleaning
|
153 |
-
text = self._clean_text(raw_text)
|
154 |
-
|
155 |
-
# Skip blank or numeric lines (e.g. "4 3 13")
|
156 |
-
if not text or self._is_numeric_line(text):
|
157 |
-
continue
|
158 |
-
|
159 |
-
# Skip too short (no training benefit from 1-word user turns). E.g. "ok","yes", etc.
|
160 |
-
if len(text.split()) < 3:
|
161 |
-
continue
|
162 |
-
|
163 |
-
# Add to cleaned turns
|
164 |
-
cleaned_turns.append({
|
165 |
-
'speaker': speaker,
|
166 |
-
'text': text
|
167 |
-
})
|
168 |
-
return cleaned_turns
|
169 |
-
|
170 |
def _clean_text(self, text: str) -> str:
|
171 |
"""
|
172 |
Simple text normalization
|
@@ -193,13 +168,20 @@ class TaskmasterProcessor:
|
|
193 |
"turns": [ {"speaker": "user", "text": "..."}, ... ]
|
194 |
}
|
195 |
"""
|
|
|
|
|
|
|
|
|
196 |
results = []
|
|
|
197 |
for dlg in dialogues:
|
198 |
if not dlg.validate():
|
|
|
199 |
continue
|
200 |
|
201 |
# Skip if too few turns
|
202 |
if len(dlg.turns) < self.config.min_turns:
|
|
|
203 |
continue
|
204 |
|
205 |
# Skip if any user turn is too short
|
@@ -208,6 +190,7 @@ class TaskmasterProcessor:
|
|
208 |
if turn['speaker'] == 'user':
|
209 |
words_count = len(turn['text'].split())
|
210 |
if words_count < self.config.min_user_words:
|
|
|
211 |
keep = False
|
212 |
break
|
213 |
|
@@ -217,10 +200,59 @@ class TaskmasterProcessor:
|
|
217 |
pipeline_dlg = {
|
218 |
'dialogue_id': dlg.conversation_id,
|
219 |
'domain': dlg.domain,
|
220 |
-
'turns': dlg.turns
|
221 |
}
|
222 |
results.append(pipeline_dlg)
|
223 |
|
224 |
if self.config.debug:
|
225 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
226 |
return results
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
4 |
from pathlib import Path
|
5 |
from typing import List, Dict, Optional, Any
|
6 |
from dataclasses import dataclass, field
|
7 |
+
from logger_config import config_logger
|
8 |
+
|
9 |
+
logger = config_logger(__name__)
|
10 |
|
11 |
@dataclass
|
12 |
class TaskmasterDialogue:
|
|
|
31 |
self,
|
32 |
debug: bool = True,
|
33 |
max_length: int = 512,
|
34 |
+
min_turns: int = 4,
|
35 |
min_user_words: int = 3
|
36 |
):
|
37 |
self.debug = debug
|
|
|
71 |
with open(ontology_path, 'r', encoding='utf-8') as f:
|
72 |
ontology = json.load(f)
|
73 |
if self.config.debug:
|
74 |
+
logger.info(f"[TaskmasterProcessor] Loaded ontology with {len(ontology.keys())} top-level keys (unused).")
|
75 |
|
76 |
dialogues: List[TaskmasterDialogue] = []
|
77 |
|
|
|
109 |
break
|
110 |
|
111 |
if self.config.debug:
|
112 |
+
logger.info(f"[TaskmasterProcessor] Loaded {len(dialogues)} total dialogues from Taskmaster-1.")
|
113 |
return dialogues
|
114 |
|
115 |
def _extract_domain(self, scenario: str, turns: List[Dict[str, str]]) -> str:
|
|
|
133 |
|
134 |
for domain, pattern in domain_patterns.items():
|
135 |
if re.search(pattern, combined_text):
|
136 |
+
# Optional: logger.info if debug
|
137 |
if self.config.debug:
|
138 |
+
logger.info(f"Matched domain: {domain} in scenario/turns")
|
139 |
return domain
|
140 |
|
141 |
if self.config.debug:
|
142 |
+
logger.info("No domain match, returning 'other'")
|
143 |
return 'other'
|
144 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
145 |
def _clean_text(self, text: str) -> str:
|
146 |
"""
|
147 |
Simple text normalization
|
|
|
168 |
"turns": [ {"speaker": "user", "text": "..."}, ... ]
|
169 |
}
|
170 |
"""
|
171 |
+
total = len(dialogues)
|
172 |
+
invalid = 0
|
173 |
+
too_few_turns = 0
|
174 |
+
short_user_turns = 0
|
175 |
results = []
|
176 |
+
|
177 |
for dlg in dialogues:
|
178 |
if not dlg.validate():
|
179 |
+
invalid += 1
|
180 |
continue
|
181 |
|
182 |
# Skip if too few turns
|
183 |
if len(dlg.turns) < self.config.min_turns:
|
184 |
+
too_few_turns += 1
|
185 |
continue
|
186 |
|
187 |
# Skip if any user turn is too short
|
|
|
190 |
if turn['speaker'] == 'user':
|
191 |
words_count = len(turn['text'].split())
|
192 |
if words_count < self.config.min_user_words:
|
193 |
+
short_user_turns += 1
|
194 |
keep = False
|
195 |
break
|
196 |
|
|
|
200 |
pipeline_dlg = {
|
201 |
'dialogue_id': dlg.conversation_id,
|
202 |
'domain': dlg.domain,
|
203 |
+
'turns': dlg.turns
|
204 |
}
|
205 |
results.append(pipeline_dlg)
|
206 |
|
207 |
if self.config.debug:
|
208 |
+
logger.info(f"\nFiltering Statistics:")
|
209 |
+
logger.info(f"Total dialogues: {total}")
|
210 |
+
logger.info(f"Invalid dialogues: {invalid}")
|
211 |
+
logger.info(f"Too few turns: {too_few_turns}")
|
212 |
+
logger.info(f"Short user turns: {short_user_turns}")
|
213 |
+
logger.info(f"Remaining dialogues: {len(results)}")
|
214 |
+
logger.info(f"Filtering rate: {((total - len(results)) / total) * 100:.1f}%\n")
|
215 |
+
|
216 |
return results
|
217 |
+
|
218 |
+
def _process_utterances(self, utterances: List[Dict[str, Any]]) -> List[Dict[str, str]]:
|
219 |
+
"""Added logging to track utterance filtering"""
|
220 |
+
total = len(utterances)
|
221 |
+
empty = 0
|
222 |
+
numeric = 0
|
223 |
+
too_short = 0
|
224 |
+
cleaned_turns = []
|
225 |
+
|
226 |
+
for utt in utterances:
|
227 |
+
speaker = 'assistant' if utt.get('speaker') == 'ASSISTANT' else 'user'
|
228 |
+
raw_text = utt.get('text', '').strip()
|
229 |
+
|
230 |
+
text = self._clean_text(raw_text)
|
231 |
+
|
232 |
+
if not text:
|
233 |
+
empty += 1
|
234 |
+
continue
|
235 |
+
|
236 |
+
if self._is_numeric_line(text):
|
237 |
+
numeric += 1
|
238 |
+
continue
|
239 |
+
|
240 |
+
if len(text.split()) < 3:
|
241 |
+
too_short += 1
|
242 |
+
continue
|
243 |
+
|
244 |
+
cleaned_turns.append({
|
245 |
+
'speaker': speaker,
|
246 |
+
'text': text
|
247 |
+
})
|
248 |
+
|
249 |
+
if self.config.debug and total > 0:
|
250 |
+
logger.info(f"\nUtterance Cleaning Statistics (Dialogue {utterances[0].get('conversation_id', 'unknown')}):")
|
251 |
+
logger.info(f"Total utterances: {total}")
|
252 |
+
logger.info(f"Empty/blank: {empty}")
|
253 |
+
logger.info(f"Numeric only: {numeric}")
|
254 |
+
logger.info(f"Too short (<3 words): {too_short}")
|
255 |
+
logger.info(f"Remaining turns: {len(cleaned_turns)}")
|
256 |
+
logger.info(f"Filtering rate: {((total - len(cleaned_turns)) / total) * 100:.1f}%\n")
|
257 |
+
|
258 |
+
return cleaned_turns
|
tf_data_pipeline.py
CHANGED
@@ -11,6 +11,8 @@ import json
|
|
11 |
from pathlib import Path
|
12 |
from typing import Union, Optional, Dict, List, Tuple, Generator
|
13 |
from transformers import AutoTokenizer
|
|
|
|
|
14 |
from typing import List, Tuple, Generator
|
15 |
from transformers import AutoTokenizer
|
16 |
import random
|
@@ -21,26 +23,30 @@ logger = config_logger(__name__)
|
|
21 |
class TFDataPipeline:
|
22 |
def __init__(
|
23 |
self,
|
24 |
-
config,
|
25 |
-
tokenizer,
|
26 |
-
encoder,
|
27 |
response_pool: List[str],
|
28 |
query_embeddings_cache: dict,
|
|
|
29 |
max_length: int = 512,
|
30 |
neg_samples: int = 10,
|
31 |
index_type: str = 'IndexFlatIP',
|
32 |
faiss_index_file_path: str = 'models/faiss_indices/faiss_index_production.index',
|
|
|
33 |
nlist: int = 100,
|
34 |
max_retries: int = 3
|
35 |
):
|
36 |
self.config = config
|
37 |
self.tokenizer = tokenizer
|
38 |
self.encoder = encoder
|
|
|
39 |
self.faiss_index_file_path = faiss_index_file_path
|
40 |
self.response_pool = response_pool
|
41 |
self.max_length = max_length
|
42 |
self.neg_samples = neg_samples
|
43 |
self.query_embeddings_cache = query_embeddings_cache # In-memory cache for embeddings
|
|
|
44 |
self.index_type = index_type
|
45 |
self.nlist = nlist
|
46 |
self.embedding_batch_size = 16 if len(response_pool) < 100 else 64
|
@@ -59,9 +65,8 @@ class TFDataPipeline:
|
|
59 |
self.validate_faiss_index()
|
60 |
logger.info("FAISS index loaded and validated successfully.")
|
61 |
else:
|
62 |
-
|
63 |
-
|
64 |
-
logger.info(f"Initialized FAISS IndexFlatIP with dimension {dimension}.")
|
65 |
|
66 |
if not self.index.is_trained:
|
67 |
# Train the index if it's not trained. IndexFlatIP doesn't need training, but others do (Future switch to IndexIVFFlat)
|
@@ -98,7 +103,7 @@ class TFDataPipeline:
|
|
98 |
|
99 |
def validate_faiss_index(self):
|
100 |
"""Validates FAISS index dimensionality."""
|
101 |
-
expected_dim = self.
|
102 |
if self.index.d != expected_dim:
|
103 |
logger.error(f"FAISS index dimension {self.index.d} does not match encoder embedding dimension {expected_dim}.")
|
104 |
raise ValueError("FAISS index dimensionality mismatch.")
|
@@ -186,44 +191,49 @@ class TFDataPipeline:
|
|
186 |
pairs.append((query, positive))
|
187 |
|
188 |
return pairs
|
189 |
-
|
190 |
def compute_and_index_response_embeddings(self):
|
191 |
"""
|
192 |
-
Compute embeddings for the response pool
|
193 |
-
|
194 |
"""
|
195 |
-
|
|
|
|
|
196 |
|
197 |
-
|
198 |
texts = [resp["text"] for resp in self.response_pool]
|
199 |
logger.debug(f"Total texts to embed: {len(texts)}")
|
200 |
|
201 |
-
batch_size = getattr(self, 'embedding_batch_size', 64)
|
202 |
embeddings = []
|
|
|
203 |
|
|
|
204 |
with tqdm(total=len(texts), desc="Computing Embeddings", unit="response") as pbar:
|
205 |
for i in range(0, len(texts), batch_size):
|
206 |
-
batch_texts = texts[i:i+batch_size]
|
207 |
-
|
|
|
|
|
208 |
batch_texts,
|
209 |
-
|
210 |
-
|
211 |
-
|
212 |
-
return_tensors='tf'
|
213 |
)
|
214 |
-
batch_embeds = self.encoder(encodings['input_ids'], training=False).numpy()
|
215 |
|
216 |
-
embeddings.append(
|
217 |
pbar.update(len(batch_texts))
|
218 |
|
219 |
-
# Combine embeddings
|
220 |
all_embeddings = np.vstack(embeddings).astype(np.float32)
|
221 |
logger.info(f"Adding {len(all_embeddings)} response embeddings to FAISS index...")
|
|
|
|
|
222 |
self.index.add(all_embeddings)
|
223 |
|
224 |
# Store in memory
|
225 |
self.response_embeddings = all_embeddings
|
226 |
-
logger.info(f"FAISS index now
|
227 |
|
228 |
def _find_hard_negatives(self, queries: List[str], positives: List[str], batch_size: int = 128) -> List[List[str]]:
|
229 |
"""
|
@@ -385,106 +395,41 @@ class TFDataPipeline:
|
|
385 |
self._text_domain_map[stripped_text] = domain
|
386 |
|
387 |
logger.info(f"Built text -> domain map with {len(self._text_domain_map)} unique text entries.")
|
388 |
-
|
389 |
-
def encode_query(
|
390 |
-
self,
|
391 |
-
query: str,
|
392 |
-
context: Optional[List[Tuple[str, str]]] = None
|
393 |
-
) -> np.ndarray:
|
394 |
-
"""
|
395 |
-
Encode a user query (and optional conversation context) into an embedding vector.
|
396 |
-
|
397 |
-
Args:
|
398 |
-
query: The user query.
|
399 |
-
context: Optional conversation history as a list of (user_text, assistant_text).
|
400 |
-
Returns:
|
401 |
-
np.ndarray of shape [embedding_dim], typically L2-normalized already.
|
402 |
-
"""
|
403 |
-
# Prepare context: concat user/assistant pairs
|
404 |
-
if context:
|
405 |
-
# Take the last N turns
|
406 |
-
relevant_history = context[-self.config.max_context_turns:]
|
407 |
-
context_str_parts = []
|
408 |
-
for (u_text, a_text) in relevant_history:
|
409 |
-
context_str_parts.append(
|
410 |
-
f"{self.tokenizer.additional_special_tokens[self.tokenizer.additional_special_tokens.index('<USER>')]} {u_text} "
|
411 |
-
f"{self.tokenizer.additional_special_tokens[self.tokenizer.additional_special_tokens.index('<ASSISTANT>')]} {a_text}"
|
412 |
-
)
|
413 |
-
context_str = " ".join(context_str_parts)
|
414 |
-
|
415 |
-
# Append the new query
|
416 |
-
full_query = (
|
417 |
-
f"{context_str} "
|
418 |
-
f"{self.tokenizer.additional_special_tokens[self.tokenizer.additional_special_tokens.index('<USER>')]} {query}"
|
419 |
-
)
|
420 |
-
else:
|
421 |
-
# Single user turn
|
422 |
-
full_query = (
|
423 |
-
f"{self.tokenizer.additional_special_tokens[self.tokenizer.additional_special_tokens.index('<USER>')]} {query}"
|
424 |
-
)
|
425 |
|
426 |
-
|
427 |
-
|
428 |
-
|
429 |
-
padding='max_length',
|
430 |
-
truncation=True,
|
431 |
-
max_length=self.max_length,
|
432 |
-
return_tensors='np' # to keep it compatible with FAISS
|
433 |
-
)
|
434 |
-
input_ids = encodings['input_ids']
|
435 |
-
|
436 |
-
# Debug out-of-vocab IDs
|
437 |
-
max_id = np.max(input_ids)
|
438 |
-
vocab_size = len(self.tokenizer)
|
439 |
-
if max_id >= vocab_size:
|
440 |
-
logger.error(f"Token ID {max_id} exceeds tokenizer vocab size {vocab_size}.")
|
441 |
-
raise ValueError("Token ID exceeds vocabulary size.")
|
442 |
-
|
443 |
-
# Get embeddings from the model. These are already L2-normalized by the model's final layer.
|
444 |
-
embeddings = self.encoder(input_ids, training=False).numpy()
|
445 |
-
|
446 |
-
return embeddings[0]
|
447 |
|
448 |
def encode_responses(
|
449 |
-
self,
|
450 |
-
responses: List[str],
|
451 |
context: Optional[List[Tuple[str, str]]] = None
|
452 |
) -> np.ndarray:
|
453 |
"""
|
454 |
-
Encode multiple response texts into
|
455 |
-
Args:
|
456 |
-
responses: List of assistant responses.
|
457 |
-
context: Optional conversation context (last N turns).
|
458 |
-
Returns:
|
459 |
-
np.ndarray of shape [num_responses, embedding_dim].
|
460 |
"""
|
461 |
-
|
|
|
|
|
462 |
if context:
|
463 |
relevant_history = context[-self.config.max_context_turns:]
|
464 |
prepared = []
|
465 |
for resp in responses:
|
466 |
context_str_parts = []
|
|
|
467 |
for (u_text, a_text) in relevant_history:
|
468 |
context_str_parts.append(
|
469 |
-
f"{
|
470 |
-
f"{self.tokenizer.additional_special_tokens[self.tokenizer.additional_special_tokens.index('<ASSISTANT>')]} {a_text}"
|
471 |
)
|
472 |
context_str = " ".join(context_str_parts)
|
473 |
-
|
474 |
-
|
475 |
-
full_resp = (
|
476 |
-
f"{context_str} "
|
477 |
-
f"{self.tokenizer.additional_special_tokens[self.tokenizer.additional_special_tokens.index('<ASSISTANT>')]} {resp}"
|
478 |
-
)
|
479 |
prepared.append(full_resp)
|
480 |
else:
|
481 |
# Single response from the assistant
|
482 |
-
prepared = [
|
483 |
-
f"{self.tokenizer.additional_special_tokens[self.tokenizer.additional_special_tokens.index('<ASSISTANT>')]} {r}"
|
484 |
-
for r in responses
|
485 |
-
]
|
486 |
|
487 |
-
#
|
488 |
encodings = self.tokenizer(
|
489 |
prepared,
|
490 |
padding='max_length',
|
@@ -493,19 +438,42 @@ class TFDataPipeline:
|
|
493 |
return_tensors='np'
|
494 |
)
|
495 |
input_ids = encodings['input_ids']
|
496 |
-
|
497 |
# Debug for out-of-vocab
|
498 |
max_id = np.max(input_ids)
|
499 |
vocab_size = len(self.tokenizer)
|
500 |
if max_id >= vocab_size:
|
501 |
-
logger.error(f"Token ID {max_id}
|
502 |
raise ValueError("Token ID exceeds vocabulary size.")
|
503 |
|
504 |
-
# Get embeddings from
|
505 |
-
embeddings = self.encoder(
|
506 |
|
507 |
return embeddings.astype('float32')
|
508 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
509 |
def prepare_and_save_data(self, dialogues: List[dict], tf_record_path: str, batch_size: int = 32):
|
510 |
"""
|
511 |
Batch-Process dialogues and save to TFRecord file.
|
|
|
11 |
from pathlib import Path
|
12 |
from typing import Union, Optional, Dict, List, Tuple, Generator
|
13 |
from transformers import AutoTokenizer
|
14 |
+
from sentence_transformers import SentenceTransformer
|
15 |
+
from chatbot_config import ChatbotConfig
|
16 |
from typing import List, Tuple, Generator
|
17 |
from transformers import AutoTokenizer
|
18 |
import random
|
|
|
23 |
class TFDataPipeline:
|
24 |
def __init__(
|
25 |
self,
|
26 |
+
config: ChatbotConfig,
|
27 |
+
tokenizer: AutoTokenizer,
|
28 |
+
encoder: SentenceTransformer,
|
29 |
response_pool: List[str],
|
30 |
query_embeddings_cache: dict,
|
31 |
+
model_name: str = 'sentence-transformers/all-MiniLM-L6-v2',
|
32 |
max_length: int = 512,
|
33 |
neg_samples: int = 10,
|
34 |
index_type: str = 'IndexFlatIP',
|
35 |
faiss_index_file_path: str = 'models/faiss_indices/faiss_index_production.index',
|
36 |
+
dimension: int = 384,
|
37 |
nlist: int = 100,
|
38 |
max_retries: int = 3
|
39 |
):
|
40 |
self.config = config
|
41 |
self.tokenizer = tokenizer
|
42 |
self.encoder = encoder
|
43 |
+
self.model = SentenceTransformer(model_name)
|
44 |
self.faiss_index_file_path = faiss_index_file_path
|
45 |
self.response_pool = response_pool
|
46 |
self.max_length = max_length
|
47 |
self.neg_samples = neg_samples
|
48 |
self.query_embeddings_cache = query_embeddings_cache # In-memory cache for embeddings
|
49 |
+
self.dimension = config.embedding_dim
|
50 |
self.index_type = index_type
|
51 |
self.nlist = nlist
|
52 |
self.embedding_batch_size = 16 if len(response_pool) < 100 else 64
|
|
|
65 |
self.validate_faiss_index()
|
66 |
logger.info("FAISS index loaded and validated successfully.")
|
67 |
else:
|
68 |
+
self.index = faiss.IndexFlatIP(self.dimension)
|
69 |
+
logger.info(f"Initialized FAISS IndexFlatIP with dimension {self.dimension}.")
|
|
|
70 |
|
71 |
if not self.index.is_trained:
|
72 |
# Train the index if it's not trained. IndexFlatIP doesn't need training, but others do (Future switch to IndexIVFFlat)
|
|
|
103 |
|
104 |
def validate_faiss_index(self):
|
105 |
"""Validates FAISS index dimensionality."""
|
106 |
+
expected_dim = self.dimension
|
107 |
if self.index.d != expected_dim:
|
108 |
logger.error(f"FAISS index dimension {self.index.d} does not match encoder embedding dimension {expected_dim}.")
|
109 |
raise ValueError("FAISS index dimensionality mismatch.")
|
|
|
191 |
pairs.append((query, positive))
|
192 |
|
193 |
return pairs
|
194 |
+
|
195 |
def compute_and_index_response_embeddings(self):
|
196 |
"""
|
197 |
+
Compute embeddings for the response pool using SentenceTransformer
|
198 |
+
and add them to the FAISS index.
|
199 |
"""
|
200 |
+
if not self.response_pool:
|
201 |
+
logger.warning("Response pool is empty. No embeddings to compute.")
|
202 |
+
return
|
203 |
|
204 |
+
logger.info("Computing embeddings for the response pool...")
|
205 |
texts = [resp["text"] for resp in self.response_pool]
|
206 |
logger.debug(f"Total texts to embed: {len(texts)}")
|
207 |
|
|
|
208 |
embeddings = []
|
209 |
+
batch_size = self.embedding_batch_size
|
210 |
|
211 |
+
# Use SentenceTransformer to compute embeddings in batches
|
212 |
with tqdm(total=len(texts), desc="Computing Embeddings", unit="response") as pbar:
|
213 |
for i in range(0, len(texts), batch_size):
|
214 |
+
batch_texts = texts[i:i + batch_size]
|
215 |
+
|
216 |
+
# Compute embeddings
|
217 |
+
batch_embeddings = self.encoder.encode(
|
218 |
batch_texts,
|
219 |
+
batch_size=batch_size,
|
220 |
+
convert_to_numpy=True,
|
221 |
+
normalize_embeddings=True # Normalizes for cosine similarity
|
|
|
222 |
)
|
|
|
223 |
|
224 |
+
embeddings.append(batch_embeddings)
|
225 |
pbar.update(len(batch_texts))
|
226 |
|
227 |
+
# Combine all embeddings
|
228 |
all_embeddings = np.vstack(embeddings).astype(np.float32)
|
229 |
logger.info(f"Adding {len(all_embeddings)} response embeddings to FAISS index...")
|
230 |
+
|
231 |
+
# Add to FAISS index
|
232 |
self.index.add(all_embeddings)
|
233 |
|
234 |
# Store in memory
|
235 |
self.response_embeddings = all_embeddings
|
236 |
+
logger.info(f"FAISS index now contains {self.index.ntotal} vectors.")
|
237 |
|
238 |
def _find_hard_negatives(self, queries: List[str], positives: List[str], batch_size: int = 128) -> List[List[str]]:
|
239 |
"""
|
|
|
395 |
self._text_domain_map[stripped_text] = domain
|
396 |
|
397 |
logger.info(f"Built text -> domain map with {len(self._text_domain_map)} unique text entries.")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
398 |
|
399 |
+
def encode_query(self, query: str) -> np.ndarray:
|
400 |
+
"""Generate embedding for a query string."""
|
401 |
+
return self.encoder.encode(query, convert_to_numpy=True)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
402 |
|
403 |
def encode_responses(
|
404 |
+
self,
|
405 |
+
responses: List[str],
|
406 |
context: Optional[List[Tuple[str, str]]] = None
|
407 |
) -> np.ndarray:
|
408 |
"""
|
409 |
+
Encode multiple response texts into embeddings, injecting <ASSISTANT> literally.
|
|
|
|
|
|
|
|
|
|
|
410 |
"""
|
411 |
+
USER_TOKEN = "<USER>"
|
412 |
+
ASSISTANT_TOKEN = "<ASSISTANT>"
|
413 |
+
|
414 |
if context:
|
415 |
relevant_history = context[-self.config.max_context_turns:]
|
416 |
prepared = []
|
417 |
for resp in responses:
|
418 |
context_str_parts = []
|
419 |
+
# Build all user->assistant text
|
420 |
for (u_text, a_text) in relevant_history:
|
421 |
context_str_parts.append(
|
422 |
+
f"{USER_TOKEN} {u_text} {ASSISTANT_TOKEN} {a_text}"
|
|
|
423 |
)
|
424 |
context_str = " ".join(context_str_parts)
|
425 |
+
# Treat resp as an assistant turn:
|
426 |
+
full_resp = f"{context_str} {ASSISTANT_TOKEN} {resp}"
|
|
|
|
|
|
|
|
|
427 |
prepared.append(full_resp)
|
428 |
else:
|
429 |
# Single response from the assistant
|
430 |
+
prepared = [f"{ASSISTANT_TOKEN} {r}" for r in responses]
|
|
|
|
|
|
|
431 |
|
432 |
+
# Pass the prepared strings to the SentenceTransformer tokenizer:
|
433 |
encodings = self.tokenizer(
|
434 |
prepared,
|
435 |
padding='max_length',
|
|
|
438 |
return_tensors='np'
|
439 |
)
|
440 |
input_ids = encodings['input_ids']
|
441 |
+
|
442 |
# Debug for out-of-vocab
|
443 |
max_id = np.max(input_ids)
|
444 |
vocab_size = len(self.tokenizer)
|
445 |
if max_id >= vocab_size:
|
446 |
+
logger.error(f"Token ID {max_id} >= tokenizer vocab size {vocab_size}")
|
447 |
raise ValueError("Token ID exceeds vocabulary size.")
|
448 |
|
449 |
+
# Get embeddings from SentenceTransformer
|
450 |
+
embeddings = self.encoder.encode(prepared, convert_to_numpy=True)
|
451 |
|
452 |
return embeddings.astype('float32')
|
453 |
|
454 |
+
def retrieve_responses(self, query: str, top_k: int = 10) -> List[Tuple[str, float]]:
|
455 |
+
"""
|
456 |
+
Retrieve top-k responses for a query using FAISS.
|
457 |
+
|
458 |
+
Args:
|
459 |
+
query: User's query text.
|
460 |
+
top_k: Number of responses to return.
|
461 |
+
|
462 |
+
Returns:
|
463 |
+
List of tuples (response text, similarity score).
|
464 |
+
"""
|
465 |
+
query_embedding = self.encode_query(query).reshape(1, -1).astype("float32")
|
466 |
+
distances, indices = self.index.search(query_embedding, top_k)
|
467 |
+
|
468 |
+
results = []
|
469 |
+
for idx, dist in zip(indices[0], distances[0]):
|
470 |
+
if idx < 0:
|
471 |
+
continue
|
472 |
+
response = self.response_pool[idx]
|
473 |
+
results.append((response["text"], dist))
|
474 |
+
|
475 |
+
return results
|
476 |
+
|
477 |
def prepare_and_save_data(self, dialogues: List[dict], tf_record_path: str, batch_size: int = 32):
|
478 |
"""
|
479 |
Batch-Process dialogues and save to TFRecord file.
|