JoeArmani commited on
Commit
cc2577d
·
1 Parent(s): 3ea7670

style updates

Browse files
chatbot_model.py CHANGED
@@ -372,7 +372,7 @@ class RetrievalChatbot(DeviceAwareModel):
372
  reranker: CrossEncoderReranker for refined scoring, if available.
373
  summarizer: Summarizer for long queries, if desired.
374
  summarize_threshold: Summarize if query wordcount > threshold.
375
-
376
  Returns:
377
  List of (response_text, final_score).
378
  """
@@ -383,11 +383,13 @@ class RetrievalChatbot(DeviceAwareModel):
383
  logger.info(f"Summarized Query: {query}")
384
 
385
  detected_domain = self.detect_domain_from_query(query)
386
- #logger.debug(f"Detected domain '{detected_domain}' for query: {query}")
387
-
388
  # Retrieve initial candidates from FAISS
389
  initial_k = min(top_k * 10, len(self.data_pipeline.response_pool))
390
- faiss_candidates = self.retrieve_responses_faiss(query, domain=detected_domain, top_k=initial_k)
 
 
 
391
 
392
  texts = [item[0] for item in faiss_candidates]
393
 
@@ -395,23 +397,18 @@ class RetrievalChatbot(DeviceAwareModel):
395
  if not reranker:
396
  reranker = CrossEncoderReranker(model_name=self.config.cross_encoder_model)
397
 
398
- ce_scores = reranker.rerank(query, texts, max_length=256)
399
-
400
  # Combine cross-encoder score with the base FAISS score (simple multiplicative approach)
401
  final_candidates = []
402
- for (resp_text, faiss_score), ce_score in zip(faiss_candidates, ce_scores):
403
- # TODO: dial this in.
404
- ce_prob = self.sigmoid(ce_score) # ~ relevance in [0..1]
405
- faiss_norm = (faiss_score + 1)/2.0
406
- combined_score = 0.9 * ce_prob + 0.1 * faiss_norm
407
- # alpha = 0.9
408
- # print(f'CE SCORE: {ce_score} FAISS SCORE: {faiss_score}')
409
- # combined_score = alpha * ce_score + (1 - alpha) * faiss_score
410
  length_adjusted_score = self.length_adjust_score(resp_text, combined_score)
411
- #combined_score = ce_score * faiss_score
412
- #final_candidates.append((resp_text, combined_score))
413
  final_candidates.append((resp_text, length_adjusted_score))
414
-
415
  # Sort descending by combined score
416
  final_candidates.sort(key=lambda x: x[1], reverse=True)
417
 
@@ -441,20 +438,18 @@ class RetrievalChatbot(DeviceAwareModel):
441
 
442
  def length_adjust_score(self, text: str, base_score: float) -> float:
443
  """
444
- Penalize very short lines or numeric lines; mildly reward longer lines.
445
- Adjust carefully so you don't overshadow cross-encoder signals.
446
  """
447
  words = text.split()
448
  wcount = len(words)
449
 
450
- # Penalty if under 3 words
451
  if wcount < 4:
452
  return base_score * 0.8
453
 
454
- # Bonus for lines > 12 words
455
- if wcount > 12:
456
- extra = min(wcount - 12, 8)
457
- bonus = 0.0005 * extra
458
  base_score += bonus
459
 
460
  return base_score
@@ -487,7 +482,7 @@ class RetrievalChatbot(DeviceAwareModel):
487
  pattern = r'^[\s]*[\d]+([\s.,\d]+)*[\s]*$'
488
  return bool(re.match(pattern, text.strip()))
489
 
490
- def retrieve_responses_faiss(
491
  self,
492
  query: str,
493
  domain: str = 'other',
@@ -518,9 +513,9 @@ class RetrievalChatbot(DeviceAwareModel):
518
  for rank, idx in enumerate(indices[0]):
519
  if idx < 0:
520
  continue
521
- response = self.data_pipeline.response_pool[idx]
522
- text = response.get('text', '').strip()
523
- cand_domain = response.get('domain', 'other')
524
  score = distances[0][rank]
525
 
526
  # Skip purely numeric or extremely short text (fewer than 3 words):
@@ -554,21 +549,19 @@ class RetrievalChatbot(DeviceAwareModel):
554
  # shares any query keywords, apply a small boost
555
  if query_keywords and any(kw in resp_text.lower() for kw in query_keywords):
556
  new_score *= boost_factor
557
- #logger.debug(f"Boosting response: '{resp_text}' by factor {boost_factor}")
558
-
559
  # Apply length penalty/bonus
560
  new_score = self.length_adjust_score(resp_text, new_score)
561
-
562
  boosted.append((resp_text, new_score))
563
-
564
  # Sort boosted responses
565
  boosted.sort(key=lambda x: x[1], reverse=True)
566
 
567
- # Print top 10
568
- # for resp, score in boosted[:150]:
569
  # logger.debug(f"Candidate: '{resp}' with score {score}")
570
-
571
- # 8) Return top_k
572
  return boosted[:top_k]
573
 
574
  def chat(
@@ -584,10 +577,10 @@ class RetrievalChatbot(DeviceAwareModel):
584
  """
585
  @self.run_on_device
586
  def get_response(self_arg, query_arg):
587
- # 1) Build conversation context string
588
  conversation_str = self_arg._build_conversation_context(query_arg, conversation_history)
589
 
590
- # 2) Retrieve + cross-encoder re-rank
591
  results = self_arg.retrieve_responses_cross_encoder(
592
  query=conversation_str,
593
  top_k=top_k,
@@ -595,26 +588,15 @@ class RetrievalChatbot(DeviceAwareModel):
595
  summarizer=self_arg.summarizer,
596
  summarize_threshold=512
597
  )
598
-
599
- # 3) Handle empty or confidence
600
  if not results:
601
- return (
602
- "I'm sorry, but I couldn't find a relevant response.",
603
- [],
604
- {}
605
- )
606
-
607
- if quality_checker:
608
- metrics = quality_checker.check_response_quality(query_arg, results)
609
- if not metrics.get('is_confident', False):
610
- return (
611
- "I need more information to provide a good answer. Could you please clarify?",
612
- results,
613
- metrics
614
- )
615
- return results[0][0], results, metrics
616
 
617
- return results[0][0], results, {}
 
 
 
618
 
619
  return get_response(self, query)
620
 
 
372
  reranker: CrossEncoderReranker for refined scoring, if available.
373
  summarizer: Summarizer for long queries, if desired.
374
  summarize_threshold: Summarize if query wordcount > threshold.
375
+
376
  Returns:
377
  List of (response_text, final_score).
378
  """
 
383
  logger.info(f"Summarized Query: {query}")
384
 
385
  detected_domain = self.detect_domain_from_query(query)
386
+
 
387
  # Retrieve initial candidates from FAISS
388
  initial_k = min(top_k * 10, len(self.data_pipeline.response_pool))
389
+ faiss_candidates = self.faiss_search(query, domain=detected_domain, top_k=initial_k)
390
+
391
+ if not faiss_candidates:
392
+ return []
393
 
394
  texts = [item[0] for item in faiss_candidates]
395
 
 
397
  if not reranker:
398
  reranker = CrossEncoderReranker(model_name=self.config.cross_encoder_model)
399
 
400
+ ce_logits = reranker.rerank(query, texts, max_length=256)
401
+
402
  # Combine cross-encoder score with the base FAISS score (simple multiplicative approach)
403
  final_candidates = []
404
+ for (resp_text, faiss_score), logit in zip(faiss_candidates, ce_logits):
405
+ ce_prob = self.sigmoid(logit) # [0...1]
406
+ faiss_norm = (faiss_score + 1)/2.0 # [0...1]
407
+ combined_score = 0.85 * ce_prob + 0.15 * faiss_norm
 
 
 
 
408
  length_adjusted_score = self.length_adjust_score(resp_text, combined_score)
409
+
 
410
  final_candidates.append((resp_text, length_adjusted_score))
411
+
412
  # Sort descending by combined score
413
  final_candidates.sort(key=lambda x: x[1], reverse=True)
414
 
 
438
 
439
  def length_adjust_score(self, text: str, base_score: float) -> float:
440
  """
441
+ Penalize very short lines, reward longer lines.
 
442
  """
443
  words = text.split()
444
  wcount = len(words)
445
 
446
+ # Penalty if under 4 words
447
  if wcount < 4:
448
  return base_score * 0.8
449
 
450
+ # Bonus for lines > 15 words
451
+ if wcount > 15:
452
+ bonus = min(0.03, 0.001 * (wcount - 15))
 
453
  base_score += bonus
454
 
455
  return base_score
 
482
  pattern = r'^[\s]*[\d]+([\s.,\d]+)*[\s]*$'
483
  return bool(re.match(pattern, text.strip()))
484
 
485
+ def faiss_search(
486
  self,
487
  query: str,
488
  domain: str = 'other',
 
513
  for rank, idx in enumerate(indices[0]):
514
  if idx < 0:
515
  continue
516
+ text_dict = self.data_pipeline.response_pool[idx]
517
+ text = text_dict.get('text', '').strip()
518
+ cand_domain = text_dict.get('domain', 'other')
519
  score = distances[0][rank]
520
 
521
  # Skip purely numeric or extremely short text (fewer than 3 words):
 
549
  # shares any query keywords, apply a small boost
550
  if query_keywords and any(kw in resp_text.lower() for kw in query_keywords):
551
  new_score *= boost_factor
552
+
 
553
  # Apply length penalty/bonus
554
  new_score = self.length_adjust_score(resp_text, new_score)
555
+
556
  boosted.append((resp_text, new_score))
557
+
558
  # Sort boosted responses
559
  boosted.sort(key=lambda x: x[1], reverse=True)
560
 
561
+ # Debug
562
+ # for resp, score in boosted[:100]:
563
  # logger.debug(f"Candidate: '{resp}' with score {score}")
564
+
 
565
  return boosted[:top_k]
566
 
567
  def chat(
 
577
  """
578
  @self.run_on_device
579
  def get_response(self_arg, query_arg):
580
+ # Build conversation context string
581
  conversation_str = self_arg._build_conversation_context(query_arg, conversation_history)
582
 
583
+ # Retrieve and re-rank
584
  results = self_arg.retrieve_responses_cross_encoder(
585
  query=conversation_str,
586
  top_k=top_k,
 
588
  summarizer=self_arg.summarizer,
589
  summarize_threshold=512
590
  )
591
+
592
+ # Handle low confidence or empty responses
593
  if not results:
594
+ return ("I'm sorry, but I couldn't find a relevant response.", [], {})
 
 
 
 
 
 
 
 
 
 
 
 
 
 
595
 
596
+ metrics = quality_checker.check_response_quality(query_arg, results)
597
+ if not metrics.get('is_confident', False):
598
+ return ("I need more information to provide a good answer. Could you please clarify?", results, metrics)
599
+ return results[0][0], results, metrics
600
 
601
  return get_response(self, query)
602
 
conversation_summarizer.py CHANGED
@@ -13,9 +13,11 @@ class ChatConfig:
13
  chunk_size: int = 512
14
  chunk_overlap: int = 256
15
  min_confidence_score: float = 0.7
16
-
17
  class DeviceAwareModel:
18
- """Mixin to handle device placement and mixed precision training."""
 
 
19
 
20
  def setup_device(self, device: str = None):
21
  if device is None:
@@ -24,31 +26,33 @@ class DeviceAwareModel:
24
  self.device = device.upper()
25
  self.strategy = None
26
 
 
 
27
  if self.device == 'GPU':
28
  # # Enable mixed precision for better performance
29
  # policy = tf.keras.mixed_precision.Policy('mixed_float16')
30
  # tf.keras.mixed_precision.set_global_policy(policy)
31
 
32
- # Setup distribution strategy for multi-GPU if available
33
  gpus = tf.config.list_physical_devices('GPU')
34
  if len(gpus) > 1:
35
  self.strategy = tf.distribute.MirroredStrategy()
36
 
37
  return self.device
38
-
39
  def run_on_device(self, func):
40
  """Decorator to ensure ops run on the correct device."""
41
  def wrapper(*args, **kwargs):
42
  with tf.device(f'/{self.device}:0'):
43
  return func(*args, **kwargs)
44
  return wrapper
45
-
46
  class Summarizer(DeviceAwareModel):
47
  """
48
- Enhanced T5-based summarizer with better chunking and device management.
49
- Handles long conversations by intelligent chunking and progressive summarization.
50
  """
51
-
52
  def __init__(
53
  self,
54
  tokenizer: AutoTokenizer,
@@ -57,10 +61,10 @@ class Summarizer(DeviceAwareModel):
57
  device=None,
58
  max_summary_rounds=2
59
  ):
60
- self.tokenizer = tokenizer # Injected tokenizer
61
  self.setup_device(device)
62
 
63
- # Initialize model within strategy scope if using distribution
64
  if self.strategy:
65
  with self.strategy.scope():
66
  self._setup_model(model_name)
@@ -69,11 +73,11 @@ class Summarizer(DeviceAwareModel):
69
 
70
  self.max_summary_length = max_summary_length
71
  self.max_summary_rounds = max_summary_rounds
72
-
73
  def _setup_model(self, model_name):
74
  self.model = TFAutoModelForSeq2SeqLM.from_pretrained(model_name)
75
 
76
- # Optimize model for inference
77
  self.model.generate = tf.function(
78
  self.model.generate,
79
  input_signature=[
@@ -83,7 +87,7 @@ class Summarizer(DeviceAwareModel):
83
  }
84
  ]
85
  )
86
-
87
  @tf.function
88
  def _generate_summary(self, inputs):
89
  return self.model.generate(
@@ -94,9 +98,9 @@ class Summarizer(DeviceAwareModel):
94
  early_stopping=True,
95
  no_repeat_ngram_size=3
96
  )
97
-
98
  def chunk_text(self, text: str, chunk_size: int = 512, overlap: int = 256) -> List[str]:
99
- """Split text into overlapping chunks for better context preservation."""
100
  tokens = self.tokenizer.encode(text)
101
  chunks = []
102
 
@@ -105,7 +109,7 @@ class Summarizer(DeviceAwareModel):
105
  chunks.append(self.tokenizer.decode(chunk, skip_special_tokens=True))
106
 
107
  return chunks
108
-
109
  def summarize_text(
110
  self,
111
  text: str,
@@ -113,8 +117,7 @@ class Summarizer(DeviceAwareModel):
113
  round_idx: int = 0
114
  ) -> str:
115
  """
116
- Summarize text with optional progressive summarization
117
- and limit the maximum number of re-summarization rounds.
118
  """
119
  @self.run_on_device
120
  def _summarize_chunk(chunk: str) -> str:
@@ -127,28 +130,27 @@ class Summarizer(DeviceAwareModel):
127
  )
128
  summary_ids = self._generate_summary(inputs)
129
  return self.tokenizer.decode(summary_ids[0], skip_special_tokens=True)
130
-
131
- # If we've hit our max allowed summarization rounds, just do a single pass
132
  if round_idx >= self.max_summary_rounds:
133
  return _summarize_chunk(text)
134
-
135
- # If text is longer than threshold and progressive summarization is on
136
  if len(text.split()) > 512 and progressive:
137
  chunks = self.chunk_text(text)
138
  chunk_summaries = [_summarize_chunk(chunk) for chunk in chunks]
139
-
140
  # Combine chunk-level summaries
141
  combined_summary = " ".join(chunk_summaries)
142
-
143
- # If still too long, do another summarization pass but increment round_idx
144
  if len(combined_summary.split()) > 512:
145
  return self.summarize_text(
146
  combined_summary,
147
  progressive=True,
148
  round_idx=round_idx + 1
149
  )
150
-
151
  return combined_summary
152
  else:
153
- # If text is not too long, just summarize once and return
154
- return _summarize_chunk(text)
 
13
  chunk_size: int = 512
14
  chunk_overlap: int = 256
15
  min_confidence_score: float = 0.7
16
+
17
  class DeviceAwareModel:
18
+ """
19
+ Mixin: Handle device placement and mixed precision training.
20
+ """
21
 
22
  def setup_device(self, device: str = None):
23
  if device is None:
 
26
  self.device = device.upper()
27
  self.strategy = None
28
 
29
+ # NOTE: Needs more testing. Training issues may have been from other bugs I found since this was tested.
30
+ # Reminder: Test model saving/loading alongside mixed precision settings
31
  if self.device == 'GPU':
32
  # # Enable mixed precision for better performance
33
  # policy = tf.keras.mixed_precision.Policy('mixed_float16')
34
  # tf.keras.mixed_precision.set_global_policy(policy)
35
 
36
+ # Setup multi-GPU if available
37
  gpus = tf.config.list_physical_devices('GPU')
38
  if len(gpus) > 1:
39
  self.strategy = tf.distribute.MirroredStrategy()
40
 
41
  return self.device
42
+
43
  def run_on_device(self, func):
44
  """Decorator to ensure ops run on the correct device."""
45
  def wrapper(*args, **kwargs):
46
  with tf.device(f'/{self.device}:0'):
47
  return func(*args, **kwargs)
48
  return wrapper
49
+
50
  class Summarizer(DeviceAwareModel):
51
  """
52
+ T5-based summarizer with chunking and device management.
53
+ Chunking and progressive summarization for long conversations.
54
  """
55
+
56
  def __init__(
57
  self,
58
  tokenizer: AutoTokenizer,
 
61
  device=None,
62
  max_summary_rounds=2
63
  ):
64
+ self.tokenizer = tokenizer
65
  self.setup_device(device)
66
 
67
+ # Strategy scope if using distribution
68
  if self.strategy:
69
  with self.strategy.scope():
70
  self._setup_model(model_name)
 
73
 
74
  self.max_summary_length = max_summary_length
75
  self.max_summary_rounds = max_summary_rounds
76
+
77
  def _setup_model(self, model_name):
78
  self.model = TFAutoModelForSeq2SeqLM.from_pretrained(model_name)
79
 
80
+ # Optimize for inference
81
  self.model.generate = tf.function(
82
  self.model.generate,
83
  input_signature=[
 
87
  }
88
  ]
89
  )
90
+
91
  @tf.function
92
  def _generate_summary(self, inputs):
93
  return self.model.generate(
 
98
  early_stopping=True,
99
  no_repeat_ngram_size=3
100
  )
101
+
102
  def chunk_text(self, text: str, chunk_size: int = 512, overlap: int = 256) -> List[str]:
103
+ """Split text into overlapping chunks for context preservation."""
104
  tokens = self.tokenizer.encode(text)
105
  chunks = []
106
 
 
109
  chunks.append(self.tokenizer.decode(chunk, skip_special_tokens=True))
110
 
111
  return chunks
112
+
113
  def summarize_text(
114
  self,
115
  text: str,
 
117
  round_idx: int = 0
118
  ) -> str:
119
  """
120
+ Progressive summarization and limited number of resummarization rounds.
 
121
  """
122
  @self.run_on_device
123
  def _summarize_chunk(chunk: str) -> str:
 
130
  )
131
  summary_ids = self._generate_summary(inputs)
132
  return self.tokenizer.decode(summary_ids[0], skip_special_tokens=True)
133
+
134
+ # Do a single pass at resummarizing if max_summary rounds is hit
135
  if round_idx >= self.max_summary_rounds:
136
  return _summarize_chunk(text)
137
+
138
+ # Chunk and summarize
139
  if len(text.split()) > 512 and progressive:
140
  chunks = self.chunk_text(text)
141
  chunk_summaries = [_summarize_chunk(chunk) for chunk in chunks]
142
+
143
  # Combine chunk-level summaries
144
  combined_summary = " ".join(chunk_summaries)
145
+
 
146
  if len(combined_summary.split()) > 512:
147
  return self.summarize_text(
148
  combined_summary,
149
  progressive=True,
150
  round_idx=round_idx + 1
151
  )
152
+
153
  return combined_summary
154
  else:
155
+ # Summarize once and return
156
+ return _summarize_chunk(text)
cross_encoder_reranker.py CHANGED
@@ -1,23 +1,19 @@
1
  from transformers import AutoTokenizer, TFAutoModelForSequenceClassification
2
  import tensorflow as tf
3
  from typing import List
4
- import numpy as np
5
 
6
  from logger_config import config_logger
7
  logger = config_logger(__name__)
8
 
9
  class CrossEncoderReranker:
10
  """
11
- Cross-Encoder Re-Ranker that takes (query, candidate) pairs,
12
- outputs a single relevance score in [0,1].
13
  """
14
  def __init__(self, model_name: str = "cross-encoder/ms-marco-MiniLM-L-12-v2"):
15
  """
16
- Initialize the cross-encoder with a pretrained model.
17
-
18
  Args:
19
- model_name: Name of a HF cross-encoder model. Must be
20
- compatible with TFAutoModelForSequenceClassification.
21
  """
22
  logger.info(f"Initializing CrossEncoderReranker with {model_name}...")
23
  self.tokenizer = AutoTokenizer.from_pretrained(model_name)
@@ -31,21 +27,16 @@ class CrossEncoderReranker:
31
  max_length: int = 256
32
  ) -> List[float]:
33
  """
34
- Compute relevance scores for each candidate w.r.t. the query.
35
-
36
  Args:
37
  query: User's query text.
38
  candidates: List of candidate response texts.
39
  max_length: Max token length for each (query, candidate) pair.
40
-
41
  Returns:
42
- A list of float scores in [0,1], one per candidate,
43
- indicating model's predicted relevance.
44
  """
45
- # 1) Build (query, candidate) pairs
46
  pair_texts = [(query, candidate) for candidate in candidates]
47
-
48
- # 2) Tokenize the entire batch
49
  encodings = self.tokenizer(
50
  pair_texts,
51
  padding=True,
@@ -54,24 +45,20 @@ class CrossEncoderReranker:
54
  return_tensors="tf"
55
  )
56
 
57
- # 3) Forward pass -> logits shape [batch_size, 1]
 
 
58
  outputs = self.model(
59
  input_ids=encodings["input_ids"],
60
  attention_mask=encodings["attention_mask"],
61
- token_type_ids=encodings.get("token_type_ids") # Some models need token_type_ids
62
  )
63
 
64
  logits = outputs.logits # shape [batch_size, 1]
65
- # 4) Convert logits -> [0,1] range via sigmoid
66
- # If the cross-encoder is a single-logit regression to [0,1],
67
- # this is a typical interpretation.
68
  scores = tf.nn.sigmoid(logits) # shape [batch_size, 1]
69
 
70
- # 5) Flatten to a 1D NumPy array of floats
71
  scores = tf.reshape(scores, [-1])
72
  scores = scores.numpy().astype(float)
73
 
74
- # logger.debug(f"Cross-Encoder raw logits: {logits.numpy().flatten().tolist()}")
75
- # logger.debug(f"Cross-Encoder sigmoid scores: {scores.tolist()}")
76
-
77
  return scores.tolist()
 
1
  from transformers import AutoTokenizer, TFAutoModelForSequenceClassification
2
  import tensorflow as tf
3
  from typing import List
 
4
 
5
  from logger_config import config_logger
6
  logger = config_logger(__name__)
7
 
8
  class CrossEncoderReranker:
9
  """
10
+ Cross-Encoder Re-Ranker. Takes (query, candidate) pairs and outputs a relevance score [0...1].
 
11
  """
12
  def __init__(self, model_name: str = "cross-encoder/ms-marco-MiniLM-L-12-v2"):
13
  """
14
+ Init the cross-encoder with a pretrained model.
 
15
  Args:
16
+ model_name: Name of a HF cross-encoder model. Must be compatible with TFAutoModelForSequenceClassification.
 
17
  """
18
  logger.info(f"Initializing CrossEncoderReranker with {model_name}...")
19
  self.tokenizer = AutoTokenizer.from_pretrained(model_name)
 
27
  max_length: int = 256
28
  ) -> List[float]:
29
  """
30
+ Compute relevance scores for each candidate w.r.t. query.
 
31
  Args:
32
  query: User's query text.
33
  candidates: List of candidate response texts.
34
  max_length: Max token length for each (query, candidate) pair.
 
35
  Returns:
36
+ A list of float scores [0...1]. One per candidate, indicating model's predicted relevance.
 
37
  """
38
+ # Build (query, candidate) pairs, then tokenize
39
  pair_texts = [(query, candidate) for candidate in candidates]
 
 
40
  encodings = self.tokenizer(
41
  pair_texts,
42
  padding=True,
 
45
  return_tensors="tf"
46
  )
47
 
48
+ # Forward pass, logits shape [batch_size, 1]
49
+ # Then convert logits to [0...1] range with sigmoid
50
+ # Note: token_type_ids are optional. .get() avoids KeyError
51
  outputs = self.model(
52
  input_ids=encodings["input_ids"],
53
  attention_mask=encodings["attention_mask"],
54
+ token_type_ids=encodings.get("token_type_ids")
55
  )
56
 
57
  logits = outputs.logits # shape [batch_size, 1]
 
 
 
58
  scores = tf.nn.sigmoid(logits) # shape [batch_size, 1]
59
 
60
+ # Flatten to 1D NumPy array, ensure float type
61
  scores = tf.reshape(scores, [-1])
62
  scores = scores.numpy().astype(float)
63
 
 
 
 
64
  return scores.tolist()
tf_data_pipeline.py CHANGED
@@ -4,6 +4,8 @@ import faiss
4
  import tensorflow as tf
5
  import h5py
6
  import math
 
 
7
  from tqdm import tqdm
8
  import json
9
  from pathlib import Path
@@ -46,47 +48,47 @@ class TFDataPipeline:
46
  self.max_batch_size = 16 if len(response_pool) < 100 else 64
47
  self.max_retries = max_retries
48
 
49
- # Build a quick text->domain map for O(1) domain lookups
50
  self._text_domain_map = {}
51
  self.build_text_to_domain_map()
52
-
 
53
  if os.path.exists(index_file_path):
54
  logger.info(f"Loading existing FAISS index from {index_file_path}...")
55
  self.index = faiss.read_index(index_file_path)
56
  self.validate_faiss_index()
57
  logger.info("FAISS index loaded and validated successfully.")
58
  else:
59
- # Initialize FAISS index
60
  dimension = self.encoder.config.embedding_dim
61
  self.index = faiss.IndexFlatIP(dimension)
62
  logger.info(f"Initialized FAISS IndexFlatIP with dimension {dimension}.")
63
 
64
  if not self.index.is_trained:
65
- # Train the index if it's not trained. # TODO: Replace 'dimension' with embedding size
66
  dimension = self.query_embeddings_cache[next(iter(self.query_embeddings_cache))].shape[0]
67
  self.index.train(np.array(list(self.query_embeddings_cache.values())).astype(np.float32))
68
  self.index.add(np.array(list(self.query_embeddings_cache.values())).astype(np.float32))
69
-
70
  def save_embeddings_cache_hdf5(self, cache_file_path: str):
71
- """Save the embeddings cache to an HDF5 file."""
72
  with h5py.File(cache_file_path, 'w') as hf:
73
  for query, emb in self.query_embeddings_cache.items():
74
  hf.create_dataset(query, data=emb)
75
  logger.info(f"Embeddings cache saved to {cache_file_path}.")
76
-
77
  def load_embeddings_cache_hdf5(self, cache_file_path: str):
78
- """Load the embeddings cache from an HDF5 file."""
79
  with h5py.File(cache_file_path, 'r') as hf:
80
  for query in hf.keys():
81
  self.query_embeddings_cache[query] = hf[query][:]
82
  logger.info(f"Embeddings cache loaded from {cache_file_path}.")
83
-
84
  def save_faiss_index(self, index_file_path: str):
85
  faiss.write_index(self.index, index_file_path)
86
  logger.info(f"FAISS index saved to {index_file_path}")
87
 
88
  def load_faiss_index(self, index_file_path: str):
89
- """Load the FAISS index from the specified file path."""
90
  if os.path.exists(index_file_path):
91
  self.index = faiss.read_index(index_file_path)
92
  logger.info(f"FAISS index loaded from {index_file_path}.")
@@ -95,7 +97,7 @@ class TFDataPipeline:
95
  raise FileNotFoundError(f"FAISS index file not found at {index_file_path}.")
96
 
97
  def validate_faiss_index(self):
98
- """Validates that the FAISS index has the correct dimensionality."""
99
  expected_dim = self.encoder.config.embedding_dim
100
  if self.index.d != expected_dim:
101
  logger.error(f"FAISS index dimension {self.index.d} does not match encoder embedding dimension {expected_dim}.")
@@ -114,7 +116,6 @@ class TFDataPipeline:
114
  def load_json_training_data(data_path: Union[str, Path], debug_samples: Optional[int] = None) -> List[dict]:
115
  """
116
  Load training data from a JSON file.
117
-
118
  Args:
119
  data_path (Union[str, Path]): Path to the JSON file containing dialogues.
120
  debug_samples (Optional[int]): Number of samples to load for debugging.
@@ -137,17 +138,16 @@ class TFDataPipeline:
137
 
138
  logger.info(f"Loaded {len(dialogues)} dialogues.")
139
  return dialogues
140
-
141
  def collect_responses_with_domain(self, dialogues: List[dict]) -> List[Dict[str, str]]:
142
  """
143
- Extract unique assistant responses from dialogues, along with the domain.
144
- Returns a list of dicts: [{'domain': str, 'text': str}, ...]
145
  """
146
- response_set = set() # We'll store (domain, text) tuples to keep them unique
147
  results = []
148
-
149
  for dialogue in tqdm(dialogues, desc="Processing Dialogues", unit="dialogue"):
150
- # domain is stored at the top level in your new JSON format
151
  domain = dialogue.get('domain', 'other')
152
  turns = dialogue.get('turns', [])
153
  for turn in turns:
@@ -155,7 +155,7 @@ class TFDataPipeline:
155
  text = turn.get('text', '').strip()
156
  if speaker == 'assistant' and text:
157
  if len(text) <= self.max_length:
158
- # Use a tuple as a "set" key to ensure uniqueness
159
  key = (domain, text)
160
  if key not in response_set:
161
  response_set.add(key)
@@ -163,23 +163,9 @@ class TFDataPipeline:
163
  "domain": domain,
164
  "text": text
165
  })
166
-
167
  logger.info(f"Collected {len(results)} unique assistant responses from dialogues.")
168
  return results
169
- # def collect_responses(self, dialogues: List[dict]) -> List[str]:
170
- # """Extract unique assistant responses from dialogues."""
171
- # response_set = set()
172
- # for dialogue in tqdm(dialogues, desc="Processing Dialogues", unit="dialogue"):
173
- # turns = dialogue.get('turns', [])
174
- # for turn in turns:
175
- # speaker = turn.get('speaker')
176
- # text = turn.get('text', '').strip()
177
- # if speaker == 'assistant' and text:
178
- # # Ensure we don't exclude valid shorter responses
179
- # if len(text) <= self.max_length:
180
- # response_set.add(text)
181
- # logger.info(f"Collected {len(response_set)} unique assistant responses from dialogues.")
182
- # return list(response_set)
183
 
184
  def _extract_pairs_from_dialogue(self, dialogue: dict) -> List[Tuple[str, str]]:
185
  """Extract query-response pairs from a dialogue."""
@@ -203,18 +189,18 @@ class TFDataPipeline:
203
 
204
  def compute_and_index_response_embeddings(self):
205
  """
206
- Computes embeddings for the response pool and adds them to the FAISS index.
207
- self.response_pool is now List[Dict[str, str]] with keys "domain" and "text".
208
  """
209
  logger.info("Computing embeddings for the response pool...")
210
-
211
- # Extract just the assistant text
212
  texts = [resp["text"] for resp in self.response_pool]
213
  logger.debug(f"Total texts to embed: {len(texts)}")
214
 
215
  batch_size = getattr(self, 'embedding_batch_size', 64)
216
  embeddings = []
217
-
218
  with tqdm(total=len(texts), desc="Computing Embeddings", unit="response") as pbar:
219
  for i in range(0, len(texts), batch_size):
220
  batch_texts = texts[i:i+batch_size]
@@ -226,36 +212,30 @@ class TFDataPipeline:
226
  return_tensors='tf'
227
  )
228
  batch_embeds = self.encoder(encodings['input_ids'], training=False).numpy()
229
-
230
  embeddings.append(batch_embeds)
231
  pbar.update(len(batch_texts))
232
-
233
  # Combine embeddings and add to FAISS
234
  all_embeddings = np.vstack(embeddings).astype(np.float32)
235
  logger.info(f"Adding {len(all_embeddings)} response embeddings to FAISS index...")
236
  self.index.add(all_embeddings)
237
 
238
- # For debugging or repeated usage, you might store them:
239
  self.response_embeddings = all_embeddings
240
  logger.info(f"FAISS index now has {self.index.ntotal} vectors.")
241
 
242
- def _find_hard_negatives_batch(self, queries: List[str], positives: List[str]) -> List[List[str]]:
243
  """
244
  Find hard negatives for a batch of queries using FAISS search.
245
- Falls back to random negatives if we run out of tries or can't find enough.
246
- Uses domain-based fallback if possible.
247
  """
248
- import random
249
- import gc
250
-
251
  retry_count = 0
252
  total_responses = len(self.response_pool)
253
- k = self.neg_samples # Number of negatives to retrieve from FAISS
254
- batch_size = 128
255
-
256
  while retry_count < self.max_retries:
257
  try:
258
- # 1) Build query embeddings from the cache
259
  query_embeddings = []
260
  for i in range(0, len(queries), batch_size):
261
  sub_queries = queries[i : i + batch_size]
@@ -263,23 +243,24 @@ class TFDataPipeline:
263
  sub_embeds = np.vstack(sub_embeds).astype(np.float32)
264
  faiss.normalize_L2(sub_embeds) # If not already normalized
265
  query_embeddings.append(sub_embeds)
266
-
267
  query_embeddings = np.vstack(query_embeddings)
268
  query_embeddings = np.ascontiguousarray(query_embeddings)
269
-
270
- # 2) Perform FAISS search
271
- distances, indices = self.index.search(query_embeddings, k)
272
-
273
  all_negatives = []
274
- # For each query, find domain from the corresponding positive if possible
275
  for query_indices, query_text, pos_text in zip(indices, queries, positives):
276
  negative_list = []
 
 
277
  seen = {pos_text.strip()}
278
-
279
- # Attempt to detect the domain of the positive text
280
  domain_of_positive = self._detect_domain_for_text(pos_text)
281
-
282
- # Collect hard negatives from FAISS
283
  for idx in query_indices:
284
  if 0 <= idx < total_responses:
285
  candidate_dict = self.response_pool[idx] # e.g. {domain, text}
@@ -289,18 +270,18 @@ class TFDataPipeline:
289
  negative_list.append(candidate_text)
290
  if len(negative_list) >= self.neg_samples:
291
  break
292
-
293
- # If not enough negatives, fallback to random domain-based
294
  if len(negative_list) < self.neg_samples:
295
  needed = self.neg_samples - len(negative_list)
296
- # Pass in domain_of_positive to your updated `_get_random_negatives(...)`
297
  random_negatives = self._get_random_negatives(needed, seen, domain=domain_of_positive)
298
  negative_list.extend(random_negatives)
299
-
300
  all_negatives.append(negative_list)
301
-
302
  return all_negatives
303
-
304
  except KeyError as ke:
305
  retry_count += 1
306
  logger.warning(f"Hard negative search attempt {retry_count} failed due to missing embeddings: {ke}")
@@ -310,7 +291,7 @@ class TFDataPipeline:
310
  gc.collect()
311
  if tf.config.list_physical_devices('GPU'):
312
  tf.keras.backend.clear_session()
313
-
314
  except Exception as e:
315
  retry_count += 1
316
  logger.warning(f"Hard negative search attempt {retry_count} failed: {e}")
@@ -320,29 +301,27 @@ class TFDataPipeline:
320
  gc.collect()
321
  if tf.config.list_physical_devices('GPU'):
322
  tf.keras.backend.clear_session()
323
-
324
  def _detect_domain_for_text(self, text: str) -> Optional[str]:
325
  """
326
- O(1) domain detection by looking up text in our dictionary.
327
- Returns the domain if found, else None.
328
  """
329
  stripped_text = text.strip()
330
  return self._text_domain_map.get(stripped_text, None)
331
 
332
  def _get_random_negatives(self, needed: int, seen: set, domain: Optional[str] = None) -> List[str]:
333
  """
334
- Return a list of 'needed' random negative texts from the same domain if possible,
335
- otherwise fallback to all-domain.
336
  """
337
- # 1) Filter response_pool for domain if provided
338
  if domain:
339
  domain_texts = [r["text"] for r in self.response_pool if r["domain"] == domain]
340
  # fallback to entire set if insufficient domain_texts
341
- if len(domain_texts) < needed * 2: # pick some threshold
342
  domain_texts = [r["text"] for r in self.response_pool]
343
  else:
344
  domain_texts = [r["text"] for r in self.response_pool]
345
-
346
  negatives = []
347
  tries = 0
348
  max_tries = needed * 10
@@ -352,8 +331,7 @@ class TFDataPipeline:
352
  if candidate and candidate not in seen:
353
  negatives.append(candidate)
354
  seen.add(candidate)
355
-
356
- # If still not enough, we do the best we can
357
  if len(negatives) < needed:
358
  logger.warning(f"Could not find enough domain-based random negatives; needed {needed}, got {len(negatives)}.")
359
 
@@ -369,47 +347,44 @@ class TFDataPipeline:
369
  all_negatives = []
370
 
371
  for pos_text in positives:
372
- # Build a 'seen' set with the positive
373
  seen = {pos_text.strip()}
374
-
375
- # Attempt to detect the domain of the positive text
376
  domain_of_positive = self._detect_domain_for_text(pos_text)
377
-
378
- # Use domain-based random negatives if available
379
  negs = self._get_random_negatives(self.neg_samples, seen, domain=domain_of_positive)
380
  all_negatives.append(negs)
381
-
382
  return all_negatives
383
 
384
  def build_text_to_domain_map(self):
385
  """
386
- Build an O(1) lookup dict: text -> domain,
387
- so we don't have to scan the entire self.response_pool each time.
388
  """
389
  self._text_domain_map = {}
390
-
391
  for item in self.response_pool:
392
- # e.g., item = {"domain": "restaurant", "text": "some text..."}
393
  stripped_text = item["text"].strip()
394
  domain = item["domain"]
395
-
396
- # If the same text appears multiple times with the same domain, no big deal.
397
- # If it appears with a different domain, you can decide how to handle collisions.
398
  if stripped_text in self._text_domain_map:
399
- existing_domain = self._text_domain_map[stripped_text]
400
- if existing_domain != domain:
401
- # Log a warning or decide on a policy:
402
- logger.warning(
403
- f"Collision detected: text '{stripped_text}' found with domains "
404
- f"'{existing_domain}' and '{domain}'. Keeping the first."
405
- )
 
406
  # By default, keep the first domain or overwrite. We'll skip overwriting:
407
  continue
408
  else:
409
  # Insert into the dict
410
  self._text_domain_map[stripped_text] = domain
411
-
412
- logger.info(f"Built text->domain map with {len(self._text_domain_map)} unique text entries.")
413
 
414
  def encode_query(
415
  self,
@@ -422,11 +397,10 @@ class TFDataPipeline:
422
  Args:
423
  query: The user query.
424
  context: Optional conversation history as a list of (user_text, assistant_text).
425
-
426
  Returns:
427
  np.ndarray of shape [embedding_dim], typically L2-normalized already.
428
  """
429
- # 1) Prepare context (if any) by concatenating user/assistant pairs
430
  if context:
431
  # Take the last N turns
432
  relevant_history = context[-self.config.max_context_turns:]
@@ -438,18 +412,18 @@ class TFDataPipeline:
438
  )
439
  context_str = " ".join(context_str_parts)
440
 
441
- # Append the user's new query
442
  full_query = (
443
  f"{context_str} "
444
  f"{self.tokenizer.additional_special_tokens[self.tokenizer.additional_special_tokens.index('<USER>')]} {query}"
445
  )
446
  else:
447
- # Just a single user turn
448
  full_query = (
449
  f"{self.tokenizer.additional_special_tokens[self.tokenizer.additional_special_tokens.index('<USER>')]} {query}"
450
  )
451
 
452
- # 2) Tokenize
453
  encodings = self.tokenizer(
454
  [full_query],
455
  padding='max_length',
@@ -459,20 +433,18 @@ class TFDataPipeline:
459
  )
460
  input_ids = encodings['input_ids']
461
 
462
- # 3) Check for out-of-vocab IDs
463
  max_id = np.max(input_ids)
464
  vocab_size = len(self.tokenizer)
465
  if max_id >= vocab_size:
466
  logger.error(f"Token ID {max_id} exceeds tokenizer vocab size {vocab_size}.")
467
  raise ValueError("Token ID exceeds vocabulary size.")
468
 
469
- # 4) Get embeddings from the model
470
  embeddings = self.encoder(input_ids, training=False).numpy()
471
- # Typically your custom model already L2-normalizes the final embeddings.
472
-
473
- # 5) Return the single embedding as 1D array
474
  return embeddings[0]
475
-
476
  def encode_responses(
477
  self,
478
  responses: List[str],
@@ -480,16 +452,13 @@ class TFDataPipeline:
480
  ) -> np.ndarray:
481
  """
482
  Encode multiple response texts into embedding vectors.
483
-
484
  Args:
485
- responses: List of raw assistant responses.
486
  context: Optional conversation context (last N turns).
487
-
488
  Returns:
489
  np.ndarray of shape [num_responses, embedding_dim].
490
  """
491
- # 1) If you want to incorporate context into response encoding
492
- # Usually for retrieval we might skip this. But if you want it:
493
  if context:
494
  relevant_history = context[-self.config.max_context_turns:]
495
  prepared = []
@@ -501,21 +470,21 @@ class TFDataPipeline:
501
  f"{self.tokenizer.additional_special_tokens[self.tokenizer.additional_special_tokens.index('<ASSISTANT>')]} {a_text}"
502
  )
503
  context_str = " ".join(context_str_parts)
504
-
505
- # Now treat resp as an assistant turn
506
  full_resp = (
507
  f"{context_str} "
508
  f"{self.tokenizer.additional_special_tokens[self.tokenizer.additional_special_tokens.index('<ASSISTANT>')]} {resp}"
509
  )
510
  prepared.append(full_resp)
511
  else:
512
- # By default, just mark each response as from the assistant
513
  prepared = [
514
  f"{self.tokenizer.additional_special_tokens[self.tokenizer.additional_special_tokens.index('<ASSISTANT>')]} {r}"
515
  for r in responses
516
  ]
517
-
518
- # 2) Tokenize
519
  encodings = self.tokenizer(
520
  prepared,
521
  padding='max_length',
@@ -524,28 +493,22 @@ class TFDataPipeline:
524
  return_tensors='np'
525
  )
526
  input_ids = encodings['input_ids']
527
-
528
- # 3) Check for out-of-vocab
529
  max_id = np.max(input_ids)
530
  vocab_size = len(self.tokenizer)
531
  if max_id >= vocab_size:
532
  logger.error(f"Token ID {max_id} exceeds tokenizer vocab size {vocab_size}.")
533
  raise ValueError("Token ID exceeds vocabulary size.")
534
-
535
- # 4) Model forward
536
  embeddings = self.encoder(input_ids, training=False).numpy()
537
- # Typically already L2-normalized if your final layer is normalized.
538
-
539
  return embeddings.astype('float32')
540
 
541
  def prepare_and_save_data(self, dialogues: List[dict], tf_record_path: str, batch_size: int = 32):
542
  """
543
- Processes dialogues in batches and saves to a TFRecord file using optimized batch tokenization and encoding.
544
-
545
- Args:
546
- dialogues (List[dict]): List of dialogue dictionaries.
547
- tf_record_path (str): Path to save the TFRecord file.
548
- batch_size (int): Number of dialogues to process per batch.
549
  """
550
  logger.info(f"Preparing and saving data to {tf_record_path}...")
551
 
@@ -553,14 +516,13 @@ class TFDataPipeline:
553
  num_batches = math.ceil(num_dialogues / batch_size)
554
 
555
  with tf.io.TFRecordWriter(tf_record_path) as writer:
556
- # Initialize progress bar
557
  with tqdm(total=num_batches, desc="Preparing Data Batches", unit="batch") as pbar:
558
  for i in range(num_batches):
559
  start_idx = i * batch_size
560
  end_idx = min(start_idx + batch_size, num_dialogues)
561
  batch_dialogues = dialogues[start_idx:end_idx]
562
 
563
- # Extract all query-positive pairs in the batch
564
  queries = []
565
  positives = []
566
  for dialogue in batch_dialogues:
@@ -572,7 +534,7 @@ class TFDataPipeline:
572
 
573
  if not queries:
574
  pbar.update(1)
575
- continue # Skip if no valid queries
576
 
577
  # Compute and cache query embeddings
578
  try:
@@ -580,11 +542,11 @@ class TFDataPipeline:
580
  except Exception as e:
581
  logger.error(f"Error computing embeddings: {e}")
582
  pbar.update(1)
583
- continue # Skip to the next batch
584
 
585
- # Find hard negatives for the batch
586
  try:
587
- hard_negatives = self._find_hard_negatives_batch(queries, positives)
588
  except Exception as e:
589
  logger.error(f"Error finding hard negatives: {e}")
590
  pbar.update(1)
@@ -611,8 +573,8 @@ class TFDataPipeline:
611
  pbar.update(1)
612
  continue # Skip to the next batch
613
 
614
- # Flatten hard_negatives while maintaining alignment
615
- # Assuming hard_negatives is a list of lists, where each sublist corresponds to a query
616
  try:
617
  flattened_negatives = [neg for sublist in hard_negatives for neg in sublist]
618
  encoded_negatives = self.tokenizer.batch_encode_plus(
@@ -623,15 +585,15 @@ class TFDataPipeline:
623
  return_tensors='tf'
624
  )
625
 
626
- # Reshape encoded_negatives['input_ids'] to [num_queries, num_negatives, max_length]
627
  num_negatives = self.config.neg_samples
628
  reshaped_negatives = encoded_negatives['input_ids'].numpy().reshape(-1, num_negatives, self.config.max_context_token_limit)
629
  except Exception as e:
630
  logger.error(f"Error during negatives tokenization: {e}")
631
  pbar.update(1)
632
- continue # Skip to the next batch
633
 
634
- # Serialize each example and write to TFRecord
635
  for j in range(len(queries)):
636
  try:
637
  q_id = encoded_queries['input_ids'][j].numpy()
@@ -655,11 +617,14 @@ class TFDataPipeline:
655
  logger.info(f"Data preparation complete. TFRecord saved.")
656
 
657
  def _compute_embeddings(self, queries: List[str]) -> None:
 
 
 
658
  new_queries = [q for q in queries if q not in self.query_embeddings_cache]
659
  if not new_queries:
660
- return # All queries already cached
661
-
662
- # Compute embeddings for new queries
663
  new_embeddings = []
664
  for i in range(0, len(new_queries), self.embedding_batch_size):
665
  batch_queries = new_queries[i:i + self.embedding_batch_size]
@@ -673,49 +638,46 @@ class TFDataPipeline:
673
  batch_embeddings = self.encoder(encoded['input_ids'], training=False).numpy()
674
  faiss.normalize_L2(batch_embeddings)
675
  new_embeddings.extend(batch_embeddings)
676
-
677
  # Update the cache
678
  for query, emb in zip(new_queries, new_embeddings):
679
  self.query_embeddings_cache[query] = emb
680
-
681
  def data_generator(self, dialogues: List[dict]) -> Generator[Tuple[str, str, List[str]], None, None]:
682
  """
683
- Generates training examples: (query, positive, hard_negatives).
684
- Wrapped the outer loop with tqdm for progress tracking.
685
  """
686
  total_dialogues = len(dialogues)
687
  logger.debug(f"Total dialogues to process: {total_dialogues}")
688
-
689
- # Initialize tqdm progress bar
690
  with tqdm(total=total_dialogues, desc="Processing Dialogues", unit="dialogue") as pbar:
691
  for dialogue in dialogues:
692
  pairs = self._extract_pairs_from_dialogue(dialogue)
693
  for query, positive in pairs:
694
  # Ensure embeddings are computed, find hard negatives, etc.
695
  self._compute_embeddings([query])
696
- hard_negatives = self._find_hard_negatives_batch([query], [positive])[0]
697
  yield (query, positive, hard_negatives)
698
  pbar.update(1)
699
 
700
  def get_tf_dataset(self, dialogues: List[dict], batch_size: int) -> tf.data.Dataset:
701
  """
702
- Creates a tf.data.Dataset for streaming training that yields
703
- (input_ids_query, input_ids_positive, input_ids_negatives).
704
  """
705
  # 1) Start with a generator dataset
706
  dataset = tf.data.Dataset.from_generator(
707
  lambda: self.data_generator(dialogues),
708
  output_signature=(
709
- tf.TensorSpec(shape=(), dtype=tf.string), # Query (single string)
710
- tf.TensorSpec(shape=(), dtype=tf.string), # Positive (single string)
711
- tf.TensorSpec(shape=(self.neg_samples,), dtype=tf.string) # Hard Negatives (list of strings)
712
  )
713
  )
714
 
715
- # 2) Batch the raw strings
 
716
  dataset = dataset.batch(batch_size, drop_remainder=True)
717
-
718
- # 3) Map them through a tokenize step using `tf.py_function`
719
  dataset = dataset.map(
720
  lambda q, p, n: self._tokenize_triple(q, p, n),
721
  num_parallel_calls=1 #tf.data.AUTOTUNE
@@ -731,22 +693,19 @@ class TFDataPipeline:
731
  n: tf.Tensor
732
  ) -> Tuple[tf.Tensor, tf.Tensor, tf.Tensor]:
733
  """
734
- Wraps a Python function via tf.py_function to convert tf.Tensors of strings
735
- -> Python lists of strings -> HF tokenizer -> Tensors of IDs.
736
-
737
- q is shape [batch_size], p is shape [batch_size],
738
- n is shape [batch_size, neg_samples] (i.e., each row is a list of negatives).
739
  """
740
- # Use tf.py_function with limited parallelism
741
  q_ids, p_ids, n_ids = tf.py_function(
742
  func=self._tokenize_triple_py,
743
  inp=[q, p, n, tf.constant(self.max_length), tf.constant(self.neg_samples)],
744
  Tout=[tf.int32, tf.int32, tf.int32]
745
  )
746
 
747
- # Manually set shape information
748
- q_ids.set_shape([None, self.max_length]) # [batch_size, max_length]
749
- p_ids.set_shape([None, self.max_length]) # [batch_size, max_length]
750
  n_ids.set_shape([None, self.neg_samples, self.max_length]) # [batch_size, neg_samples, max_length]
751
 
752
  return q_ids, p_ids, n_ids
@@ -760,32 +719,30 @@ class TFDataPipeline:
760
  neg_samples: tf.Tensor
761
  ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
762
  """
763
- Python function that:
764
- - Decodes each tf.string Tensor to a Python list of strings
765
- - Calls the HF tokenizer
766
- - Reshapes negatives
767
- - Returns np.array of int32s for (q_ids, p_ids, n_ids).
768
 
769
  q: shape [batch_size], p: shape [batch_size]
770
  n: shape [batch_size, neg_samples]
771
- max_len: scalar int
772
- neg_samples: scalar int
773
  """
774
- max_len = int(max_len.numpy()) # Convert to Python int
775
  neg_samples = int(neg_samples.numpy())
776
 
777
- # 1) Convert Tensors -> Python lists of strings
778
  q_list = [q_i.decode("utf-8") for q_i in q.numpy()] # shape [batch_size]
779
  p_list = [p_i.decode("utf-8") for p_i in p.numpy()] # shape [batch_size]
780
 
781
- # shape [batch_size, neg_samples], decode each row
782
  n_list = []
783
  for row in n.numpy():
784
  # row is shape [neg_samples], each is a tf.string
785
  decoded = [neg.decode("utf-8") for neg in row]
786
  n_list.append(decoded)
787
 
788
- # 2) Tokenize queries & positives
789
  q_enc = self.tokenizer(
790
  q_list,
791
  padding="max_length",
@@ -801,11 +758,11 @@ class TFDataPipeline:
801
  return_tensors="np"
802
  )
803
 
804
- # 3) Tokenize negatives
805
- # Flatten [batch_size, neg_samples] -> single list
806
  flattened_negatives = [neg for row in n_list for neg in row]
807
  if len(flattened_negatives) == 0:
808
- # No negatives at all: return a zero array
809
  n_ids = np.zeros((len(q_list), neg_samples, max_len), dtype=np.int32)
810
  else:
811
  n_enc = self.tokenizer(
@@ -815,11 +772,10 @@ class TFDataPipeline:
815
  max_length=max_len,
816
  return_tensors="np"
817
  )
818
- # shape [batch_size * neg_samples, max_len]
819
  n_input_ids = n_enc["input_ids"]
820
 
821
- # We want to reshape to [batch_size, neg_samples, max_len]
822
- # Handle cases where there might be fewer negatives
823
  batch_size = len(q_list)
824
  n_ids_list = []
825
  for i in range(batch_size):
@@ -827,7 +783,7 @@ class TFDataPipeline:
827
  end_idx = start_idx + neg_samples
828
  row_negs = n_input_ids[start_idx:end_idx]
829
 
830
- # If fewer negatives, pad with zeros
831
  if row_negs.shape[0] < neg_samples:
832
  deficit = neg_samples - row_negs.shape[0]
833
  pad_arr = np.zeros((deficit, max_len), dtype=np.int32)
@@ -835,10 +791,10 @@ class TFDataPipeline:
835
 
836
  n_ids_list.append(row_negs)
837
 
838
- # stack them -> shape [batch_size, neg_samples, max_len]
839
  n_ids = np.stack(n_ids_list, axis=0)
840
 
841
- # 4) Return as np.int32 arrays
842
  q_ids = q_enc["input_ids"].astype(np.int32) # shape [batch_size, max_len]
843
  p_ids = p_enc["input_ids"].astype(np.int32) # shape [batch_size, max_len]
844
  n_ids = n_ids.astype(np.int32) # shape [batch_size, neg_samples, max_len]
 
4
  import tensorflow as tf
5
  import h5py
6
  import math
7
+ import random
8
+ import gc
9
  from tqdm import tqdm
10
  import json
11
  from pathlib import Path
 
48
  self.max_batch_size = 16 if len(response_pool) < 100 else 64
49
  self.max_retries = max_retries
50
 
51
+ # Build text -> domain map for O(1) domain lookups (hard negative sampling)
52
  self._text_domain_map = {}
53
  self.build_text_to_domain_map()
54
+
55
+ # Initialize FAISS index
56
  if os.path.exists(index_file_path):
57
  logger.info(f"Loading existing FAISS index from {index_file_path}...")
58
  self.index = faiss.read_index(index_file_path)
59
  self.validate_faiss_index()
60
  logger.info("FAISS index loaded and validated successfully.")
61
  else:
 
62
  dimension = self.encoder.config.embedding_dim
63
  self.index = faiss.IndexFlatIP(dimension)
64
  logger.info(f"Initialized FAISS IndexFlatIP with dimension {dimension}.")
65
 
66
  if not self.index.is_trained:
67
+ # Train the index if it's not trained. IndexFlatIP doesn't need training, but others do (Future switch to IndexIVFFlat)
68
  dimension = self.query_embeddings_cache[next(iter(self.query_embeddings_cache))].shape[0]
69
  self.index.train(np.array(list(self.query_embeddings_cache.values())).astype(np.float32))
70
  self.index.add(np.array(list(self.query_embeddings_cache.values())).astype(np.float32))
71
+
72
  def save_embeddings_cache_hdf5(self, cache_file_path: str):
73
+ """Save embeddings cache to HDF5 file."""
74
  with h5py.File(cache_file_path, 'w') as hf:
75
  for query, emb in self.query_embeddings_cache.items():
76
  hf.create_dataset(query, data=emb)
77
  logger.info(f"Embeddings cache saved to {cache_file_path}.")
78
+
79
  def load_embeddings_cache_hdf5(self, cache_file_path: str):
80
+ """Load embeddings cache from HDF5 file."""
81
  with h5py.File(cache_file_path, 'r') as hf:
82
  for query in hf.keys():
83
  self.query_embeddings_cache[query] = hf[query][:]
84
  logger.info(f"Embeddings cache loaded from {cache_file_path}.")
85
+
86
  def save_faiss_index(self, index_file_path: str):
87
  faiss.write_index(self.index, index_file_path)
88
  logger.info(f"FAISS index saved to {index_file_path}")
89
 
90
  def load_faiss_index(self, index_file_path: str):
91
+ """Load FAISS index from specified file path."""
92
  if os.path.exists(index_file_path):
93
  self.index = faiss.read_index(index_file_path)
94
  logger.info(f"FAISS index loaded from {index_file_path}.")
 
97
  raise FileNotFoundError(f"FAISS index file not found at {index_file_path}.")
98
 
99
  def validate_faiss_index(self):
100
+ """Validates FAISS index dimensionality."""
101
  expected_dim = self.encoder.config.embedding_dim
102
  if self.index.d != expected_dim:
103
  logger.error(f"FAISS index dimension {self.index.d} does not match encoder embedding dimension {expected_dim}.")
 
116
  def load_json_training_data(data_path: Union[str, Path], debug_samples: Optional[int] = None) -> List[dict]:
117
  """
118
  Load training data from a JSON file.
 
119
  Args:
120
  data_path (Union[str, Path]): Path to the JSON file containing dialogues.
121
  debug_samples (Optional[int]): Number of samples to load for debugging.
 
138
 
139
  logger.info(f"Loaded {len(dialogues)} dialogues.")
140
  return dialogues
141
+
142
  def collect_responses_with_domain(self, dialogues: List[dict]) -> List[Dict[str, str]]:
143
  """
144
+ Extract unique assistant responses and their domains from dialogues.
145
+ Returns List[Dict[str: "domain", str: text"]]
146
  """
147
+ response_set = set() # Store (domain, text) unique tuples
148
  results = []
149
+
150
  for dialogue in tqdm(dialogues, desc="Processing Dialogues", unit="dialogue"):
 
151
  domain = dialogue.get('domain', 'other')
152
  turns = dialogue.get('turns', [])
153
  for turn in turns:
 
155
  text = turn.get('text', '').strip()
156
  if speaker == 'assistant' and text:
157
  if len(text) <= self.max_length:
158
+ # Use tuple as set key to ensure uniqueness
159
  key = (domain, text)
160
  if key not in response_set:
161
  response_set.add(key)
 
163
  "domain": domain,
164
  "text": text
165
  })
166
+
167
  logger.info(f"Collected {len(results)} unique assistant responses from dialogues.")
168
  return results
 
 
 
 
 
 
 
 
 
 
 
 
 
 
169
 
170
  def _extract_pairs_from_dialogue(self, dialogue: dict) -> List[Tuple[str, str]]:
171
  """Extract query-response pairs from a dialogue."""
 
189
 
190
  def compute_and_index_response_embeddings(self):
191
  """
192
+ Compute embeddings for the response pool and add them to the FAISS index.
193
+ self.response_pool: List[Dict[str, str]] with keys "domain" and "text".
194
  """
195
  logger.info("Computing embeddings for the response pool...")
196
+
197
+ # Extract the assistant text
198
  texts = [resp["text"] for resp in self.response_pool]
199
  logger.debug(f"Total texts to embed: {len(texts)}")
200
 
201
  batch_size = getattr(self, 'embedding_batch_size', 64)
202
  embeddings = []
203
+
204
  with tqdm(total=len(texts), desc="Computing Embeddings", unit="response") as pbar:
205
  for i in range(0, len(texts), batch_size):
206
  batch_texts = texts[i:i+batch_size]
 
212
  return_tensors='tf'
213
  )
214
  batch_embeds = self.encoder(encodings['input_ids'], training=False).numpy()
215
+
216
  embeddings.append(batch_embeds)
217
  pbar.update(len(batch_texts))
218
+
219
  # Combine embeddings and add to FAISS
220
  all_embeddings = np.vstack(embeddings).astype(np.float32)
221
  logger.info(f"Adding {len(all_embeddings)} response embeddings to FAISS index...")
222
  self.index.add(all_embeddings)
223
 
224
+ # Store in memory
225
  self.response_embeddings = all_embeddings
226
  logger.info(f"FAISS index now has {self.index.ntotal} vectors.")
227
 
228
+ def _find_hard_negatives(self, queries: List[str], positives: List[str], batch_size: int = 128) -> List[List[str]]:
229
  """
230
  Find hard negatives for a batch of queries using FAISS search.
231
+ Fallback: in-domain negatives, then random negatives when needed.
 
232
  """
 
 
 
233
  retry_count = 0
234
  total_responses = len(self.response_pool)
235
+
 
 
236
  while retry_count < self.max_retries:
237
  try:
238
+ # Build query embeddings from the cache
239
  query_embeddings = []
240
  for i in range(0, len(queries), batch_size):
241
  sub_queries = queries[i : i + batch_size]
 
243
  sub_embeds = np.vstack(sub_embeds).astype(np.float32)
244
  faiss.normalize_L2(sub_embeds) # If not already normalized
245
  query_embeddings.append(sub_embeds)
246
+
247
  query_embeddings = np.vstack(query_embeddings)
248
  query_embeddings = np.ascontiguousarray(query_embeddings)
249
+
250
+ # FAISS search for nearest neighbors (hard negatives)
251
+ distances, indices = self.index.search(query_embeddings, self.neg_samples)
252
+
253
  all_negatives = []
254
+ # Extract domain from the positive assistant response
255
  for query_indices, query_text, pos_text in zip(indices, queries, positives):
256
  negative_list = []
257
+
258
+ # Build a 'seen' set with the positive
259
  seen = {pos_text.strip()}
260
+
 
261
  domain_of_positive = self._detect_domain_for_text(pos_text)
262
+
263
+ # Collect hard negatives (from config self.neg_samples)
264
  for idx in query_indices:
265
  if 0 <= idx < total_responses:
266
  candidate_dict = self.response_pool[idx] # e.g. {domain, text}
 
270
  negative_list.append(candidate_text)
271
  if len(negative_list) >= self.neg_samples:
272
  break
273
+
274
+ # Fall back to random domain-based
275
  if len(negative_list) < self.neg_samples:
276
  needed = self.neg_samples - len(negative_list)
277
+
278
  random_negatives = self._get_random_negatives(needed, seen, domain=domain_of_positive)
279
  negative_list.extend(random_negatives)
280
+
281
  all_negatives.append(negative_list)
282
+
283
  return all_negatives
284
+
285
  except KeyError as ke:
286
  retry_count += 1
287
  logger.warning(f"Hard negative search attempt {retry_count} failed due to missing embeddings: {ke}")
 
291
  gc.collect()
292
  if tf.config.list_physical_devices('GPU'):
293
  tf.keras.backend.clear_session()
294
+
295
  except Exception as e:
296
  retry_count += 1
297
  logger.warning(f"Hard negative search attempt {retry_count} failed: {e}")
 
301
  gc.collect()
302
  if tf.config.list_physical_devices('GPU'):
303
  tf.keras.backend.clear_session()
304
+
305
  def _detect_domain_for_text(self, text: str) -> Optional[str]:
306
  """
307
+ Domain detection for related negatives.
 
308
  """
309
  stripped_text = text.strip()
310
  return self._text_domain_map.get(stripped_text, None)
311
 
312
  def _get_random_negatives(self, needed: int, seen: set, domain: Optional[str] = None) -> List[str]:
313
  """
314
+ Return a list of negative texts from the same domain. Fall back to any domain.
 
315
  """
316
+ # Filter response_pool for domain
317
  if domain:
318
  domain_texts = [r["text"] for r in self.response_pool if r["domain"] == domain]
319
  # fallback to entire set if insufficient domain_texts
320
+ if len(domain_texts) < needed * 2:
321
  domain_texts = [r["text"] for r in self.response_pool]
322
  else:
323
  domain_texts = [r["text"] for r in self.response_pool]
324
+
325
  negatives = []
326
  tries = 0
327
  max_tries = needed * 10
 
331
  if candidate and candidate not in seen:
332
  negatives.append(candidate)
333
  seen.add(candidate)
334
+
 
335
  if len(negatives) < needed:
336
  logger.warning(f"Could not find enough domain-based random negatives; needed {needed}, got {len(negatives)}.")
337
 
 
347
  all_negatives = []
348
 
349
  for pos_text in positives:
350
+ # Build a 'seen' set with the positive assistant response
351
  seen = {pos_text.strip()}
352
+
353
+ # Detect domain of the positive
354
  domain_of_positive = self._detect_domain_for_text(pos_text)
355
+
356
+ # Use domain-based negatives when available
357
  negs = self._get_random_negatives(self.neg_samples, seen, domain=domain_of_positive)
358
  all_negatives.append(negs)
359
+
360
  return all_negatives
361
 
362
  def build_text_to_domain_map(self):
363
  """
364
+ Build O(1) lookup dict: text -> domain for hard negative sampling.
 
365
  """
366
  self._text_domain_map = {}
367
+
368
  for item in self.response_pool:
 
369
  stripped_text = item["text"].strip()
370
  domain = item["domain"]
371
+
 
 
372
  if stripped_text in self._text_domain_map:
373
+ #existing_domain = self._text_domain_map[stripped_text]
374
+ #if existing_domain != domain:
375
+ # Collision detected. Using first found domain for now.
376
+ # This happens often with low-signal responses. "ok", "yes", etc.
377
+ # logger.warning(
378
+ # f"Collision detected: text '{stripped_text}' found with domains "
379
+ # f"'{existing_domain}' and '{domain}'. Keeping the first."
380
+ # )
381
  # By default, keep the first domain or overwrite. We'll skip overwriting:
382
  continue
383
  else:
384
  # Insert into the dict
385
  self._text_domain_map[stripped_text] = domain
386
+
387
+ logger.info(f"Built text -> domain map with {len(self._text_domain_map)} unique text entries.")
388
 
389
  def encode_query(
390
  self,
 
397
  Args:
398
  query: The user query.
399
  context: Optional conversation history as a list of (user_text, assistant_text).
 
400
  Returns:
401
  np.ndarray of shape [embedding_dim], typically L2-normalized already.
402
  """
403
+ # Prepare context: concat user/assistant pairs
404
  if context:
405
  # Take the last N turns
406
  relevant_history = context[-self.config.max_context_turns:]
 
412
  )
413
  context_str = " ".join(context_str_parts)
414
 
415
+ # Append the new query
416
  full_query = (
417
  f"{context_str} "
418
  f"{self.tokenizer.additional_special_tokens[self.tokenizer.additional_special_tokens.index('<USER>')]} {query}"
419
  )
420
  else:
421
+ # Single user turn
422
  full_query = (
423
  f"{self.tokenizer.additional_special_tokens[self.tokenizer.additional_special_tokens.index('<USER>')]} {query}"
424
  )
425
 
426
+ # Tokenize
427
  encodings = self.tokenizer(
428
  [full_query],
429
  padding='max_length',
 
433
  )
434
  input_ids = encodings['input_ids']
435
 
436
+ # Debug out-of-vocab IDs
437
  max_id = np.max(input_ids)
438
  vocab_size = len(self.tokenizer)
439
  if max_id >= vocab_size:
440
  logger.error(f"Token ID {max_id} exceeds tokenizer vocab size {vocab_size}.")
441
  raise ValueError("Token ID exceeds vocabulary size.")
442
 
443
+ # Get embeddings from the model. These are already L2-normalized by the model's final layer.
444
  embeddings = self.encoder(input_ids, training=False).numpy()
445
+
 
 
446
  return embeddings[0]
447
+
448
  def encode_responses(
449
  self,
450
  responses: List[str],
 
452
  ) -> np.ndarray:
453
  """
454
  Encode multiple response texts into embedding vectors.
 
455
  Args:
456
+ responses: List of assistant responses.
457
  context: Optional conversation context (last N turns).
 
458
  Returns:
459
  np.ndarray of shape [num_responses, embedding_dim].
460
  """
461
+ # Incorporate context into response encoding. Note: Undecided on benefit of this
 
462
  if context:
463
  relevant_history = context[-self.config.max_context_turns:]
464
  prepared = []
 
470
  f"{self.tokenizer.additional_special_tokens[self.tokenizer.additional_special_tokens.index('<ASSISTANT>')]} {a_text}"
471
  )
472
  context_str = " ".join(context_str_parts)
473
+
474
+ # Treat resp as an assistant turn
475
  full_resp = (
476
  f"{context_str} "
477
  f"{self.tokenizer.additional_special_tokens[self.tokenizer.additional_special_tokens.index('<ASSISTANT>')]} {resp}"
478
  )
479
  prepared.append(full_resp)
480
  else:
481
+ # Single response from the assistant
482
  prepared = [
483
  f"{self.tokenizer.additional_special_tokens[self.tokenizer.additional_special_tokens.index('<ASSISTANT>')]} {r}"
484
  for r in responses
485
  ]
486
+
487
+ # Tokenize
488
  encodings = self.tokenizer(
489
  prepared,
490
  padding='max_length',
 
493
  return_tensors='np'
494
  )
495
  input_ids = encodings['input_ids']
496
+
497
+ # Debug for out-of-vocab
498
  max_id = np.max(input_ids)
499
  vocab_size = len(self.tokenizer)
500
  if max_id >= vocab_size:
501
  logger.error(f"Token ID {max_id} exceeds tokenizer vocab size {vocab_size}.")
502
  raise ValueError("Token ID exceeds vocabulary size.")
503
+
504
+ # Get embeddings from the model. These are already L2-normalized by the model's final layer.
505
  embeddings = self.encoder(input_ids, training=False).numpy()
506
+
 
507
  return embeddings.astype('float32')
508
 
509
  def prepare_and_save_data(self, dialogues: List[dict], tf_record_path: str, batch_size: int = 32):
510
  """
511
+ Batch-Process dialogues and save to TFRecord file.
 
 
 
 
 
512
  """
513
  logger.info(f"Preparing and saving data to {tf_record_path}...")
514
 
 
516
  num_batches = math.ceil(num_dialogues / batch_size)
517
 
518
  with tf.io.TFRecordWriter(tf_record_path) as writer:
 
519
  with tqdm(total=num_batches, desc="Preparing Data Batches", unit="batch") as pbar:
520
  for i in range(num_batches):
521
  start_idx = i * batch_size
522
  end_idx = min(start_idx + batch_size, num_dialogues)
523
  batch_dialogues = dialogues[start_idx:end_idx]
524
 
525
+ # Extract query-positive pairs for the batch
526
  queries = []
527
  positives = []
528
  for dialogue in batch_dialogues:
 
534
 
535
  if not queries:
536
  pbar.update(1)
537
+ continue
538
 
539
  # Compute and cache query embeddings
540
  try:
 
542
  except Exception as e:
543
  logger.error(f"Error computing embeddings: {e}")
544
  pbar.update(1)
545
+ continue
546
 
547
+ # Find hard negatives
548
  try:
549
+ hard_negatives = self._find_hard_negatives(queries, positives)
550
  except Exception as e:
551
  logger.error(f"Error finding hard negatives: {e}")
552
  pbar.update(1)
 
573
  pbar.update(1)
574
  continue # Skip to the next batch
575
 
576
+ # Flatten hard_negatives. Maintain alignment.
577
+ # hard_negatives is List of Lists. Each sublist corresponds to a query.
578
  try:
579
  flattened_negatives = [neg for sublist in hard_negatives for neg in sublist]
580
  encoded_negatives = self.tokenizer.batch_encode_plus(
 
585
  return_tensors='tf'
586
  )
587
 
588
+ # Reshape to [num_queries, num_negatives, max_length]
589
  num_negatives = self.config.neg_samples
590
  reshaped_negatives = encoded_negatives['input_ids'].numpy().reshape(-1, num_negatives, self.config.max_context_token_limit)
591
  except Exception as e:
592
  logger.error(f"Error during negatives tokenization: {e}")
593
  pbar.update(1)
594
+ continue
595
 
596
+ # Serialize and write to TFRecord
597
  for j in range(len(queries)):
598
  try:
599
  q_id = encoded_queries['input_ids'][j].numpy()
 
617
  logger.info(f"Data preparation complete. TFRecord saved.")
618
 
619
  def _compute_embeddings(self, queries: List[str]) -> None:
620
+ """
621
+ Compute embeddings for new queries and update the cache.
622
+ """
623
  new_queries = [q for q in queries if q not in self.query_embeddings_cache]
624
  if not new_queries:
625
+ return
626
+
627
+ # Compute embeddings
628
  new_embeddings = []
629
  for i in range(0, len(new_queries), self.embedding_batch_size):
630
  batch_queries = new_queries[i:i + self.embedding_batch_size]
 
638
  batch_embeddings = self.encoder(encoded['input_ids'], training=False).numpy()
639
  faiss.normalize_L2(batch_embeddings)
640
  new_embeddings.extend(batch_embeddings)
641
+
642
  # Update the cache
643
  for query, emb in zip(new_queries, new_embeddings):
644
  self.query_embeddings_cache[query] = emb
645
+
646
  def data_generator(self, dialogues: List[dict]) -> Generator[Tuple[str, str, List[str]], None, None]:
647
  """
648
+ Generate training examples: (query, positive, [hard_negatives]).
 
649
  """
650
  total_dialogues = len(dialogues)
651
  logger.debug(f"Total dialogues to process: {total_dialogues}")
652
+
 
653
  with tqdm(total=total_dialogues, desc="Processing Dialogues", unit="dialogue") as pbar:
654
  for dialogue in dialogues:
655
  pairs = self._extract_pairs_from_dialogue(dialogue)
656
  for query, positive in pairs:
657
  # Ensure embeddings are computed, find hard negatives, etc.
658
  self._compute_embeddings([query])
659
+ hard_negatives = self._find_hard_negatives([query], [positive])[0]
660
  yield (query, positive, hard_negatives)
661
  pbar.update(1)
662
 
663
  def get_tf_dataset(self, dialogues: List[dict], batch_size: int) -> tf.data.Dataset:
664
  """
665
+ Creates a tf.data.Dataset for streaming training.
666
+ yields (input_ids_query, input_ids_positive, input_ids_negatives).
667
  """
668
  # 1) Start with a generator dataset
669
  dataset = tf.data.Dataset.from_generator(
670
  lambda: self.data_generator(dialogues),
671
  output_signature=(
672
+ tf.TensorSpec(shape=(), dtype=tf.string), # Query (single string)
673
+ tf.TensorSpec(shape=(), dtype=tf.string), # Positive (single string)
674
+ tf.TensorSpec(shape=(self.neg_samples,), dtype=tf.string) # Hard Negatives (list of strings)
675
  )
676
  )
677
 
678
+ # Batch the raw strings, then map through a tokenize step
679
+ # Note 'Distilbert Tokenizer threw an error when using tf.data.AUTOTUNE.
680
  dataset = dataset.batch(batch_size, drop_remainder=True)
 
 
681
  dataset = dataset.map(
682
  lambda q, p, n: self._tokenize_triple(q, p, n),
683
  num_parallel_calls=1 #tf.data.AUTOTUNE
 
693
  n: tf.Tensor
694
  ) -> Tuple[tf.Tensor, tf.Tensor, tf.Tensor]:
695
  """
696
+ Wraps a Python function. Convert tf.Tensors of strings -> Python lists of strings -> HF tokenizer -> Tensors of IDs.
697
+ q is shape [batch_size], p is shape [batch_size], n is shape [batch_size, neg_samples] (list of negatives).
 
 
 
698
  """
699
+ # Use tf.py_function, limit parallelism
700
  q_ids, p_ids, n_ids = tf.py_function(
701
  func=self._tokenize_triple_py,
702
  inp=[q, p, n, tf.constant(self.max_length), tf.constant(self.neg_samples)],
703
  Tout=[tf.int32, tf.int32, tf.int32]
704
  )
705
 
706
+ # Set shape info for the output tensors
707
+ q_ids.set_shape([None, self.max_length]) # [batch_size, max_length]
708
+ p_ids.set_shape([None, self.max_length]) # [batch_size, max_length]
709
  n_ids.set_shape([None, self.neg_samples, self.max_length]) # [batch_size, neg_samples, max_length]
710
 
711
  return q_ids, p_ids, n_ids
 
719
  neg_samples: tf.Tensor
720
  ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
721
  """
722
+ Decodes tf.string Tensor to Python List[str], then tokenize.
723
+ Reshapes negatives to [batch_size, neg_samples, max_length].
724
+ Returns np.array(int32) for (q_ids, p_ids, n_ids).
 
 
725
 
726
  q: shape [batch_size], p: shape [batch_size]
727
  n: shape [batch_size, neg_samples]
728
+ max_len: int
729
+ neg_samples: int
730
  """
731
+ max_len = int(max_len.numpy())
732
  neg_samples = int(neg_samples.numpy())
733
 
734
+ # Convert Tensors -> Python List[str]
735
  q_list = [q_i.decode("utf-8") for q_i in q.numpy()] # shape [batch_size]
736
  p_list = [p_i.decode("utf-8") for p_i in p.numpy()] # shape [batch_size]
737
 
738
+ # Shape [batch_size, neg_samples], decode each row
739
  n_list = []
740
  for row in n.numpy():
741
  # row is shape [neg_samples], each is a tf.string
742
  decoded = [neg.decode("utf-8") for neg in row]
743
  n_list.append(decoded)
744
 
745
+ # Tokenize queries & positives
746
  q_enc = self.tokenizer(
747
  q_list,
748
  padding="max_length",
 
758
  return_tensors="np"
759
  )
760
 
761
+ # Tokenize negatives
762
+ # Flatten [batch_size, neg_samples] -> List
763
  flattened_negatives = [neg for row in n_list for neg in row]
764
  if len(flattened_negatives) == 0:
765
+ # No negatives: return a zero array
766
  n_ids = np.zeros((len(q_list), neg_samples, max_len), dtype=np.int32)
767
  else:
768
  n_enc = self.tokenizer(
 
772
  max_length=max_len,
773
  return_tensors="np"
774
  )
775
+ # Shape [batch_size * neg_samples, max_len]
776
  n_input_ids = n_enc["input_ids"]
777
 
778
+ # Reshape to [batch_size, neg_samples, max_len]
 
779
  batch_size = len(q_list)
780
  n_ids_list = []
781
  for i in range(batch_size):
 
783
  end_idx = start_idx + neg_samples
784
  row_negs = n_input_ids[start_idx:end_idx]
785
 
786
+ # Pad with zeros if not enough negatives
787
  if row_negs.shape[0] < neg_samples:
788
  deficit = neg_samples - row_negs.shape[0]
789
  pad_arr = np.zeros((deficit, max_len), dtype=np.int32)
 
791
 
792
  n_ids_list.append(row_negs)
793
 
794
+ # Stack shape [batch_size, neg_samples, max_len]
795
  n_ids = np.stack(n_ids_list, axis=0)
796
 
797
+ # Return np.int32 arrays
798
  q_ids = q_enc["input_ids"].astype(np.int32) # shape [batch_size, max_len]
799
  p_ids = p_enc["input_ids"].astype(np.int32) # shape [batch_size, max_len]
800
  n_ids = n_ids.astype(np.int32) # shape [batch_size, neg_samples, max_len]
train_model.py CHANGED
@@ -14,10 +14,10 @@ def inspect_tfrecord(tfrecord_file_path, num_examples=3):
14
  'negative_ids': tf.io.FixedLenFeature([3 * 512], tf.int64), # Adjust neg_samples if different
15
  }
16
  return tf.io.parse_single_example(example_proto, feature_description)
17
-
18
  dataset = tf.data.TFRecordDataset(tfrecord_file_path)
19
  dataset = dataset.map(parse_example)
20
-
21
  for i, example in enumerate(dataset.take(num_examples)):
22
  print(f"Example {i+1}:")
23
  print(f"Query IDs: {example['query_ids'].numpy()}")
@@ -26,29 +26,27 @@ def inspect_tfrecord(tfrecord_file_path, num_examples=3):
26
  print("-" * 50)
27
 
28
  def main():
 
29
 
30
- # Quick test to inspect TFRecord
31
  # inspect_tfrecord('training_data/training_data.tfrecord', num_examples=3)
32
 
33
- # Initialize environment
34
- tf.keras.backend.clear_session()
35
  env = EnvironmentSetup()
36
  env.initialize()
37
 
38
- # Training configuration
39
  EPOCHS = 20
40
  TF_RECORD_FILE_PATH = 'training_data/training_data.tfrecord'
41
  CHECKPOINT_DIR = 'checkpoints/'
42
- # Optimize batch size for Colab
43
- batch_size = 32 # env.optimize_batch_size(base_batch_size=16)
44
 
45
- # Initialize config
46
- config = ChatbotConfig()
47
 
48
- # Initialize chatbot
 
49
  chatbot = RetrievalChatbot(config, mode='training')
50
 
51
- # Check for existing checkpoint and get initial epoch
52
  latest_checkpoint = tf.train.latest_checkpoint(CHECKPOINT_DIR)
53
  initial_epoch = 0
54
  if latest_checkpoint:
@@ -60,7 +58,7 @@ def main():
60
  logger.error(f"Failed to parse checkpoint number from {latest_checkpoint}")
61
  initial_epoch = 0
62
 
63
- # Train the model
64
  chatbot.train_model(
65
  tfrecord_file_path=TF_RECORD_FILE_PATH,
66
  epochs=EPOCHS,
@@ -71,13 +69,13 @@ def main():
71
  initial_epoch=initial_epoch
72
  )
73
 
74
- # Save final model
75
  model_save_path = env.training_dirs['base'] / 'final_model'
76
  chatbot.save_models(model_save_path)
77
 
78
- # Plot and save training history
79
  plotter = Plotter(save_dir=env.training_dirs['plots'])
80
  plotter.plot_training_history(chatbot.history)
81
-
82
  if __name__ == "__main__":
83
  main()
 
14
  'negative_ids': tf.io.FixedLenFeature([3 * 512], tf.int64), # Adjust neg_samples if different
15
  }
16
  return tf.io.parse_single_example(example_proto, feature_description)
17
+
18
  dataset = tf.data.TFRecordDataset(tfrecord_file_path)
19
  dataset = dataset.map(parse_example)
20
+
21
  for i, example in enumerate(dataset.take(num_examples)):
22
  print(f"Example {i+1}:")
23
  print(f"Query IDs: {example['query_ids'].numpy()}")
 
26
  print("-" * 50)
27
 
28
  def main():
29
+ tf.keras.backend.clear_session()
30
 
31
+ # Validate TFRecord
32
  # inspect_tfrecord('training_data/training_data.tfrecord', num_examples=3)
33
 
34
+ # Init env
 
35
  env = EnvironmentSetup()
36
  env.initialize()
37
 
38
+ # Training config
39
  EPOCHS = 20
40
  TF_RECORD_FILE_PATH = 'training_data/training_data.tfrecord'
41
  CHECKPOINT_DIR = 'checkpoints/'
 
 
42
 
43
+ batch_size = 32
 
44
 
45
+ # Initialize config and chatbot model
46
+ config = ChatbotConfig()
47
  chatbot = RetrievalChatbot(config, mode='training')
48
 
49
+ # Check for existing checkpoint
50
  latest_checkpoint = tf.train.latest_checkpoint(CHECKPOINT_DIR)
51
  initial_epoch = 0
52
  if latest_checkpoint:
 
58
  logger.error(f"Failed to parse checkpoint number from {latest_checkpoint}")
59
  initial_epoch = 0
60
 
61
+ # Train
62
  chatbot.train_model(
63
  tfrecord_file_path=TF_RECORD_FILE_PATH,
64
  epochs=EPOCHS,
 
69
  initial_epoch=initial_epoch
70
  )
71
 
72
+ # Save
73
  model_save_path = env.training_dirs['base'] / 'final_model'
74
  chatbot.save_models(model_save_path)
75
 
76
+ # Plot
77
  plotter = Plotter(save_dir=env.training_dirs['plots'])
78
  plotter.plot_training_history(chatbot.history)
79
+
80
  if __name__ == "__main__":
81
  main()
validate_model.py CHANGED
@@ -1,6 +1,5 @@
1
  import os
2
  import json
3
-
4
  from chatbot_model import ChatbotConfig, RetrievalChatbot
5
  from response_quality_checker import ResponseQualityChecker
6
  from chatbot_validator import ChatbotValidator
@@ -18,20 +17,20 @@ def run_interactive_chat(chatbot, quality_checker):
18
  except (KeyboardInterrupt, EOFError):
19
  print("\nAssistant: Goodbye!")
20
  break
21
-
22
  if user_input.lower() in ["quit", "exit", "bye"]:
23
  print("Assistant: Goodbye!")
24
  break
25
-
26
  response, candidates, metrics = chatbot.chat(
27
  query=user_input,
28
  conversation_history=None,
29
  quality_checker=quality_checker,
30
  top_k=10
31
  )
32
-
33
  print(f"Assistant: {response}")
34
-
35
  # Show alternative responses if confident
36
  if metrics.get("is_confident", False):
37
  print("\nAlternative responses:")
@@ -39,17 +38,17 @@ def run_interactive_chat(chatbot, quality_checker):
39
  print(f"Score: {score:.4f} - {resp}")
40
  else:
41
  print("\n[Low Confidence]: Consider rephrasing your query for better assistance.")
42
-
43
  def validate_chatbot():
44
  # Initialize environment
45
  env = EnvironmentSetup()
46
  env.initialize()
47
-
48
  MODEL_DIR = "new_iteration/data_prep_iterative_models"
49
  FAISS_INDICES_DIR = os.path.join(MODEL_DIR, "faiss_indices")
50
  FAISS_INDEX_PRODUCTION_PATH = os.path.join(FAISS_INDICES_DIR, "faiss_index_production.index")
51
  FAISS_INDEX_TEST_PATH = os.path.join(FAISS_INDICES_DIR, "faiss_index_test.index")
52
-
53
  # Toggle 'production' or 'test' env
54
  ENVIRONMENT = "production"
55
  if ENVIRONMENT == "test":
@@ -58,7 +57,7 @@ def validate_chatbot():
58
  else:
59
  FAISS_INDEX_PATH = FAISS_INDEX_PRODUCTION_PATH
60
  RESPONSE_POOL_PATH = FAISS_INDEX_PRODUCTION_PATH.replace(".index", "_responses.json")
61
-
62
  # Load the config
63
  config_path = os.path.join(MODEL_DIR, "config.json")
64
  if os.path.exists(config_path):
@@ -69,50 +68,47 @@ def validate_chatbot():
69
  else:
70
  config = ChatbotConfig()
71
  logger.warning("No config.json found. Using default ChatbotConfig.")
72
-
73
- # Load RetrievalChatbot in 'inference' mode using the classmethod
74
  try:
75
  chatbot = RetrievalChatbot.load_model(load_dir=MODEL_DIR, mode="inference")
76
  logger.info("RetrievalChatbot loaded in 'inference' mode successfully.")
77
  except Exception as e:
78
  logger.error(f"Failed to load RetrievalChatbot: {e}")
79
  return
80
-
81
  # Confirm FAISS index & response pool exist
82
  if not os.path.exists(FAISS_INDEX_PATH) or not os.path.exists(RESPONSE_POOL_PATH):
83
  logger.error("FAISS index or response pool file is missing.")
84
  return
85
-
86
- # Load specific FAISS index and response pool
87
  try:
88
- # Even though load_model might auto-load an index, we override here with the specific file
89
  chatbot.data_pipeline.load_faiss_index(FAISS_INDEX_PATH)
90
  logger.info(f"FAISS index loaded from {FAISS_INDEX_PATH}.")
91
-
92
- print("FAISS dimensions:", chatbot.data_pipeline.index.d)
93
- print("FAISS index type:", type(chatbot.data_pipeline.index))
94
- print("FAISS index total vectors:", chatbot.data_pipeline.index.ntotal)
95
- print("FAISS is_trained:", chatbot.data_pipeline.index.is_trained)
96
-
97
  with open(RESPONSE_POOL_PATH, "r", encoding="utf-8") as f:
98
  chatbot.data_pipeline.response_pool = json.load(f)
99
- logger.info(f"Response pool loaded from {RESPONSE_POOL_PATH}.")
100
-
101
- print("\nTotal responses in pool:", len(chatbot.data_pipeline.response_pool))
102
-
103
  # Validate dimension consistency
104
  chatbot.data_pipeline.validate_faiss_index()
105
  logger.info("FAISS index and response pool validated successfully.")
106
-
107
  except Exception as e:
108
  logger.error(f"Failed to load or validate FAISS index: {e}")
109
  return
110
-
111
  # Init QualityChecker and Validator
112
  quality_checker = ResponseQualityChecker(data_pipeline=chatbot.data_pipeline)
113
  validator = ChatbotValidator(chatbot=chatbot, quality_checker=quality_checker)
114
  logger.info("ResponseQualityChecker and ChatbotValidator initialized.")
115
-
116
  # Run validation
117
  try:
118
  validation_metrics = validator.run_validation(num_examples=5)
@@ -120,7 +116,7 @@ def validate_chatbot():
120
  except Exception as e:
121
  logger.error(f"Validation process failed: {e}")
122
  return
123
-
124
  # Plot metrics
125
  # try:
126
  # plotter = Plotter(save_dir=env.training_dirs["plots"])
@@ -128,10 +124,10 @@ def validate_chatbot():
128
  # logger.info("Validation metrics plotted successfully.")
129
  # except Exception as e:
130
  # logger.error(f"Failed to plot validation metrics: {e}")
131
-
132
  # Run interactive chat loop
133
- # logger.info("\nStarting interactive chat session...")
134
- # run_interactive_chat(chatbot, quality_checker)
135
 
136
  if __name__ == "__main__":
137
  validate_chatbot()
 
1
  import os
2
  import json
 
3
  from chatbot_model import ChatbotConfig, RetrievalChatbot
4
  from response_quality_checker import ResponseQualityChecker
5
  from chatbot_validator import ChatbotValidator
 
17
  except (KeyboardInterrupt, EOFError):
18
  print("\nAssistant: Goodbye!")
19
  break
20
+
21
  if user_input.lower() in ["quit", "exit", "bye"]:
22
  print("Assistant: Goodbye!")
23
  break
24
+
25
  response, candidates, metrics = chatbot.chat(
26
  query=user_input,
27
  conversation_history=None,
28
  quality_checker=quality_checker,
29
  top_k=10
30
  )
31
+
32
  print(f"Assistant: {response}")
33
+
34
  # Show alternative responses if confident
35
  if metrics.get("is_confident", False):
36
  print("\nAlternative responses:")
 
38
  print(f"Score: {score:.4f} - {resp}")
39
  else:
40
  print("\n[Low Confidence]: Consider rephrasing your query for better assistance.")
41
+
42
  def validate_chatbot():
43
  # Initialize environment
44
  env = EnvironmentSetup()
45
  env.initialize()
46
+
47
  MODEL_DIR = "new_iteration/data_prep_iterative_models"
48
  FAISS_INDICES_DIR = os.path.join(MODEL_DIR, "faiss_indices")
49
  FAISS_INDEX_PRODUCTION_PATH = os.path.join(FAISS_INDICES_DIR, "faiss_index_production.index")
50
  FAISS_INDEX_TEST_PATH = os.path.join(FAISS_INDICES_DIR, "faiss_index_test.index")
51
+
52
  # Toggle 'production' or 'test' env
53
  ENVIRONMENT = "production"
54
  if ENVIRONMENT == "test":
 
57
  else:
58
  FAISS_INDEX_PATH = FAISS_INDEX_PRODUCTION_PATH
59
  RESPONSE_POOL_PATH = FAISS_INDEX_PRODUCTION_PATH.replace(".index", "_responses.json")
60
+
61
  # Load the config
62
  config_path = os.path.join(MODEL_DIR, "config.json")
63
  if os.path.exists(config_path):
 
68
  else:
69
  config = ChatbotConfig()
70
  logger.warning("No config.json found. Using default ChatbotConfig.")
71
+
72
+ # Load RetrievalChatbot in 'inference' mode
73
  try:
74
  chatbot = RetrievalChatbot.load_model(load_dir=MODEL_DIR, mode="inference")
75
  logger.info("RetrievalChatbot loaded in 'inference' mode successfully.")
76
  except Exception as e:
77
  logger.error(f"Failed to load RetrievalChatbot: {e}")
78
  return
79
+
80
  # Confirm FAISS index & response pool exist
81
  if not os.path.exists(FAISS_INDEX_PATH) or not os.path.exists(RESPONSE_POOL_PATH):
82
  logger.error("FAISS index or response pool file is missing.")
83
  return
84
+
85
+ # Load FAISS index and response pool
86
  try:
 
87
  chatbot.data_pipeline.load_faiss_index(FAISS_INDEX_PATH)
88
  logger.info(f"FAISS index loaded from {FAISS_INDEX_PATH}.")
89
+ logger.info("FAISS dimensions:", chatbot.data_pipeline.index.d)
90
+ logger.info("FAISS index type:", type(chatbot.data_pipeline.index))
91
+ logger.info("FAISS index total vectors:", chatbot.data_pipeline.index.ntotal)
92
+ logger.info("FAISS is_trained:", chatbot.data_pipeline.index.is_trained)
93
+
 
94
  with open(RESPONSE_POOL_PATH, "r", encoding="utf-8") as f:
95
  chatbot.data_pipeline.response_pool = json.load(f)
96
+ logger.info(f"Response pool loaded from {RESPONSE_POOL_PATH}.")
97
+ logger.info("\nTotal responses in pool:", len(chatbot.data_pipeline.response_pool))
98
+
 
99
  # Validate dimension consistency
100
  chatbot.data_pipeline.validate_faiss_index()
101
  logger.info("FAISS index and response pool validated successfully.")
102
+
103
  except Exception as e:
104
  logger.error(f"Failed to load or validate FAISS index: {e}")
105
  return
106
+
107
  # Init QualityChecker and Validator
108
  quality_checker = ResponseQualityChecker(data_pipeline=chatbot.data_pipeline)
109
  validator = ChatbotValidator(chatbot=chatbot, quality_checker=quality_checker)
110
  logger.info("ResponseQualityChecker and ChatbotValidator initialized.")
111
+
112
  # Run validation
113
  try:
114
  validation_metrics = validator.run_validation(num_examples=5)
 
116
  except Exception as e:
117
  logger.error(f"Validation process failed: {e}")
118
  return
119
+
120
  # Plot metrics
121
  # try:
122
  # plotter = Plotter(save_dir=env.training_dirs["plots"])
 
124
  # logger.info("Validation metrics plotted successfully.")
125
  # except Exception as e:
126
  # logger.error(f"Failed to plot validation metrics: {e}")
127
+
128
  # Run interactive chat loop
129
+ logger.info("\nStarting interactive chat session...")
130
+ run_interactive_chat(chatbot, quality_checker)
131
 
132
  if __name__ == "__main__":
133
  validate_chatbot()