JoeArmani
commited on
Commit
·
cc2577d
1
Parent(s):
3ea7670
style updates
Browse files- chatbot_model.py +38 -56
- conversation_summarizer.py +30 -28
- cross_encoder_reranker.py +11 -24
- tf_data_pipeline.py +148 -192
- train_model.py +14 -16
- validate_model.py +28 -32
chatbot_model.py
CHANGED
@@ -372,7 +372,7 @@ class RetrievalChatbot(DeviceAwareModel):
|
|
372 |
reranker: CrossEncoderReranker for refined scoring, if available.
|
373 |
summarizer: Summarizer for long queries, if desired.
|
374 |
summarize_threshold: Summarize if query wordcount > threshold.
|
375 |
-
|
376 |
Returns:
|
377 |
List of (response_text, final_score).
|
378 |
"""
|
@@ -383,11 +383,13 @@ class RetrievalChatbot(DeviceAwareModel):
|
|
383 |
logger.info(f"Summarized Query: {query}")
|
384 |
|
385 |
detected_domain = self.detect_domain_from_query(query)
|
386 |
-
|
387 |
-
|
388 |
# Retrieve initial candidates from FAISS
|
389 |
initial_k = min(top_k * 10, len(self.data_pipeline.response_pool))
|
390 |
-
faiss_candidates = self.
|
|
|
|
|
|
|
391 |
|
392 |
texts = [item[0] for item in faiss_candidates]
|
393 |
|
@@ -395,23 +397,18 @@ class RetrievalChatbot(DeviceAwareModel):
|
|
395 |
if not reranker:
|
396 |
reranker = CrossEncoderReranker(model_name=self.config.cross_encoder_model)
|
397 |
|
398 |
-
|
399 |
-
|
400 |
# Combine cross-encoder score with the base FAISS score (simple multiplicative approach)
|
401 |
final_candidates = []
|
402 |
-
for (resp_text, faiss_score),
|
403 |
-
|
404 |
-
|
405 |
-
|
406 |
-
combined_score = 0.9 * ce_prob + 0.1 * faiss_norm
|
407 |
-
# alpha = 0.9
|
408 |
-
# print(f'CE SCORE: {ce_score} FAISS SCORE: {faiss_score}')
|
409 |
-
# combined_score = alpha * ce_score + (1 - alpha) * faiss_score
|
410 |
length_adjusted_score = self.length_adjust_score(resp_text, combined_score)
|
411 |
-
|
412 |
-
#final_candidates.append((resp_text, combined_score))
|
413 |
final_candidates.append((resp_text, length_adjusted_score))
|
414 |
-
|
415 |
# Sort descending by combined score
|
416 |
final_candidates.sort(key=lambda x: x[1], reverse=True)
|
417 |
|
@@ -441,20 +438,18 @@ class RetrievalChatbot(DeviceAwareModel):
|
|
441 |
|
442 |
def length_adjust_score(self, text: str, base_score: float) -> float:
|
443 |
"""
|
444 |
-
Penalize very short lines
|
445 |
-
Adjust carefully so you don't overshadow cross-encoder signals.
|
446 |
"""
|
447 |
words = text.split()
|
448 |
wcount = len(words)
|
449 |
|
450 |
-
# Penalty if under
|
451 |
if wcount < 4:
|
452 |
return base_score * 0.8
|
453 |
|
454 |
-
# Bonus for lines >
|
455 |
-
if wcount >
|
456 |
-
|
457 |
-
bonus = 0.0005 * extra
|
458 |
base_score += bonus
|
459 |
|
460 |
return base_score
|
@@ -487,7 +482,7 @@ class RetrievalChatbot(DeviceAwareModel):
|
|
487 |
pattern = r'^[\s]*[\d]+([\s.,\d]+)*[\s]*$'
|
488 |
return bool(re.match(pattern, text.strip()))
|
489 |
|
490 |
-
def
|
491 |
self,
|
492 |
query: str,
|
493 |
domain: str = 'other',
|
@@ -518,9 +513,9 @@ class RetrievalChatbot(DeviceAwareModel):
|
|
518 |
for rank, idx in enumerate(indices[0]):
|
519 |
if idx < 0:
|
520 |
continue
|
521 |
-
|
522 |
-
text =
|
523 |
-
cand_domain =
|
524 |
score = distances[0][rank]
|
525 |
|
526 |
# Skip purely numeric or extremely short text (fewer than 3 words):
|
@@ -554,21 +549,19 @@ class RetrievalChatbot(DeviceAwareModel):
|
|
554 |
# shares any query keywords, apply a small boost
|
555 |
if query_keywords and any(kw in resp_text.lower() for kw in query_keywords):
|
556 |
new_score *= boost_factor
|
557 |
-
|
558 |
-
|
559 |
# Apply length penalty/bonus
|
560 |
new_score = self.length_adjust_score(resp_text, new_score)
|
561 |
-
|
562 |
boosted.append((resp_text, new_score))
|
563 |
-
|
564 |
# Sort boosted responses
|
565 |
boosted.sort(key=lambda x: x[1], reverse=True)
|
566 |
|
567 |
-
#
|
568 |
-
# for resp, score in boosted[:
|
569 |
# logger.debug(f"Candidate: '{resp}' with score {score}")
|
570 |
-
|
571 |
-
# 8) Return top_k
|
572 |
return boosted[:top_k]
|
573 |
|
574 |
def chat(
|
@@ -584,10 +577,10 @@ class RetrievalChatbot(DeviceAwareModel):
|
|
584 |
"""
|
585 |
@self.run_on_device
|
586 |
def get_response(self_arg, query_arg):
|
587 |
-
#
|
588 |
conversation_str = self_arg._build_conversation_context(query_arg, conversation_history)
|
589 |
|
590 |
-
#
|
591 |
results = self_arg.retrieve_responses_cross_encoder(
|
592 |
query=conversation_str,
|
593 |
top_k=top_k,
|
@@ -595,26 +588,15 @@ class RetrievalChatbot(DeviceAwareModel):
|
|
595 |
summarizer=self_arg.summarizer,
|
596 |
summarize_threshold=512
|
597 |
)
|
598 |
-
|
599 |
-
#
|
600 |
if not results:
|
601 |
-
return (
|
602 |
-
"I'm sorry, but I couldn't find a relevant response.",
|
603 |
-
[],
|
604 |
-
{}
|
605 |
-
)
|
606 |
-
|
607 |
-
if quality_checker:
|
608 |
-
metrics = quality_checker.check_response_quality(query_arg, results)
|
609 |
-
if not metrics.get('is_confident', False):
|
610 |
-
return (
|
611 |
-
"I need more information to provide a good answer. Could you please clarify?",
|
612 |
-
results,
|
613 |
-
metrics
|
614 |
-
)
|
615 |
-
return results[0][0], results, metrics
|
616 |
|
617 |
-
|
|
|
|
|
|
|
618 |
|
619 |
return get_response(self, query)
|
620 |
|
|
|
372 |
reranker: CrossEncoderReranker for refined scoring, if available.
|
373 |
summarizer: Summarizer for long queries, if desired.
|
374 |
summarize_threshold: Summarize if query wordcount > threshold.
|
375 |
+
|
376 |
Returns:
|
377 |
List of (response_text, final_score).
|
378 |
"""
|
|
|
383 |
logger.info(f"Summarized Query: {query}")
|
384 |
|
385 |
detected_domain = self.detect_domain_from_query(query)
|
386 |
+
|
|
|
387 |
# Retrieve initial candidates from FAISS
|
388 |
initial_k = min(top_k * 10, len(self.data_pipeline.response_pool))
|
389 |
+
faiss_candidates = self.faiss_search(query, domain=detected_domain, top_k=initial_k)
|
390 |
+
|
391 |
+
if not faiss_candidates:
|
392 |
+
return []
|
393 |
|
394 |
texts = [item[0] for item in faiss_candidates]
|
395 |
|
|
|
397 |
if not reranker:
|
398 |
reranker = CrossEncoderReranker(model_name=self.config.cross_encoder_model)
|
399 |
|
400 |
+
ce_logits = reranker.rerank(query, texts, max_length=256)
|
401 |
+
|
402 |
# Combine cross-encoder score with the base FAISS score (simple multiplicative approach)
|
403 |
final_candidates = []
|
404 |
+
for (resp_text, faiss_score), logit in zip(faiss_candidates, ce_logits):
|
405 |
+
ce_prob = self.sigmoid(logit) # [0...1]
|
406 |
+
faiss_norm = (faiss_score + 1)/2.0 # [0...1]
|
407 |
+
combined_score = 0.85 * ce_prob + 0.15 * faiss_norm
|
|
|
|
|
|
|
|
|
408 |
length_adjusted_score = self.length_adjust_score(resp_text, combined_score)
|
409 |
+
|
|
|
410 |
final_candidates.append((resp_text, length_adjusted_score))
|
411 |
+
|
412 |
# Sort descending by combined score
|
413 |
final_candidates.sort(key=lambda x: x[1], reverse=True)
|
414 |
|
|
|
438 |
|
439 |
def length_adjust_score(self, text: str, base_score: float) -> float:
|
440 |
"""
|
441 |
+
Penalize very short lines, reward longer lines.
|
|
|
442 |
"""
|
443 |
words = text.split()
|
444 |
wcount = len(words)
|
445 |
|
446 |
+
# Penalty if under 4 words
|
447 |
if wcount < 4:
|
448 |
return base_score * 0.8
|
449 |
|
450 |
+
# Bonus for lines > 15 words
|
451 |
+
if wcount > 15:
|
452 |
+
bonus = min(0.03, 0.001 * (wcount - 15))
|
|
|
453 |
base_score += bonus
|
454 |
|
455 |
return base_score
|
|
|
482 |
pattern = r'^[\s]*[\d]+([\s.,\d]+)*[\s]*$'
|
483 |
return bool(re.match(pattern, text.strip()))
|
484 |
|
485 |
+
def faiss_search(
|
486 |
self,
|
487 |
query: str,
|
488 |
domain: str = 'other',
|
|
|
513 |
for rank, idx in enumerate(indices[0]):
|
514 |
if idx < 0:
|
515 |
continue
|
516 |
+
text_dict = self.data_pipeline.response_pool[idx]
|
517 |
+
text = text_dict.get('text', '').strip()
|
518 |
+
cand_domain = text_dict.get('domain', 'other')
|
519 |
score = distances[0][rank]
|
520 |
|
521 |
# Skip purely numeric or extremely short text (fewer than 3 words):
|
|
|
549 |
# shares any query keywords, apply a small boost
|
550 |
if query_keywords and any(kw in resp_text.lower() for kw in query_keywords):
|
551 |
new_score *= boost_factor
|
552 |
+
|
|
|
553 |
# Apply length penalty/bonus
|
554 |
new_score = self.length_adjust_score(resp_text, new_score)
|
555 |
+
|
556 |
boosted.append((resp_text, new_score))
|
557 |
+
|
558 |
# Sort boosted responses
|
559 |
boosted.sort(key=lambda x: x[1], reverse=True)
|
560 |
|
561 |
+
# Debug
|
562 |
+
# for resp, score in boosted[:100]:
|
563 |
# logger.debug(f"Candidate: '{resp}' with score {score}")
|
564 |
+
|
|
|
565 |
return boosted[:top_k]
|
566 |
|
567 |
def chat(
|
|
|
577 |
"""
|
578 |
@self.run_on_device
|
579 |
def get_response(self_arg, query_arg):
|
580 |
+
# Build conversation context string
|
581 |
conversation_str = self_arg._build_conversation_context(query_arg, conversation_history)
|
582 |
|
583 |
+
# Retrieve and re-rank
|
584 |
results = self_arg.retrieve_responses_cross_encoder(
|
585 |
query=conversation_str,
|
586 |
top_k=top_k,
|
|
|
588 |
summarizer=self_arg.summarizer,
|
589 |
summarize_threshold=512
|
590 |
)
|
591 |
+
|
592 |
+
# Handle low confidence or empty responses
|
593 |
if not results:
|
594 |
+
return ("I'm sorry, but I couldn't find a relevant response.", [], {})
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
595 |
|
596 |
+
metrics = quality_checker.check_response_quality(query_arg, results)
|
597 |
+
if not metrics.get('is_confident', False):
|
598 |
+
return ("I need more information to provide a good answer. Could you please clarify?", results, metrics)
|
599 |
+
return results[0][0], results, metrics
|
600 |
|
601 |
return get_response(self, query)
|
602 |
|
conversation_summarizer.py
CHANGED
@@ -13,9 +13,11 @@ class ChatConfig:
|
|
13 |
chunk_size: int = 512
|
14 |
chunk_overlap: int = 256
|
15 |
min_confidence_score: float = 0.7
|
16 |
-
|
17 |
class DeviceAwareModel:
|
18 |
-
"""
|
|
|
|
|
19 |
|
20 |
def setup_device(self, device: str = None):
|
21 |
if device is None:
|
@@ -24,31 +26,33 @@ class DeviceAwareModel:
|
|
24 |
self.device = device.upper()
|
25 |
self.strategy = None
|
26 |
|
|
|
|
|
27 |
if self.device == 'GPU':
|
28 |
# # Enable mixed precision for better performance
|
29 |
# policy = tf.keras.mixed_precision.Policy('mixed_float16')
|
30 |
# tf.keras.mixed_precision.set_global_policy(policy)
|
31 |
|
32 |
-
# Setup
|
33 |
gpus = tf.config.list_physical_devices('GPU')
|
34 |
if len(gpus) > 1:
|
35 |
self.strategy = tf.distribute.MirroredStrategy()
|
36 |
|
37 |
return self.device
|
38 |
-
|
39 |
def run_on_device(self, func):
|
40 |
"""Decorator to ensure ops run on the correct device."""
|
41 |
def wrapper(*args, **kwargs):
|
42 |
with tf.device(f'/{self.device}:0'):
|
43 |
return func(*args, **kwargs)
|
44 |
return wrapper
|
45 |
-
|
46 |
class Summarizer(DeviceAwareModel):
|
47 |
"""
|
48 |
-
|
49 |
-
|
50 |
"""
|
51 |
-
|
52 |
def __init__(
|
53 |
self,
|
54 |
tokenizer: AutoTokenizer,
|
@@ -57,10 +61,10 @@ class Summarizer(DeviceAwareModel):
|
|
57 |
device=None,
|
58 |
max_summary_rounds=2
|
59 |
):
|
60 |
-
self.tokenizer = tokenizer
|
61 |
self.setup_device(device)
|
62 |
|
63 |
-
#
|
64 |
if self.strategy:
|
65 |
with self.strategy.scope():
|
66 |
self._setup_model(model_name)
|
@@ -69,11 +73,11 @@ class Summarizer(DeviceAwareModel):
|
|
69 |
|
70 |
self.max_summary_length = max_summary_length
|
71 |
self.max_summary_rounds = max_summary_rounds
|
72 |
-
|
73 |
def _setup_model(self, model_name):
|
74 |
self.model = TFAutoModelForSeq2SeqLM.from_pretrained(model_name)
|
75 |
|
76 |
-
# Optimize
|
77 |
self.model.generate = tf.function(
|
78 |
self.model.generate,
|
79 |
input_signature=[
|
@@ -83,7 +87,7 @@ class Summarizer(DeviceAwareModel):
|
|
83 |
}
|
84 |
]
|
85 |
)
|
86 |
-
|
87 |
@tf.function
|
88 |
def _generate_summary(self, inputs):
|
89 |
return self.model.generate(
|
@@ -94,9 +98,9 @@ class Summarizer(DeviceAwareModel):
|
|
94 |
early_stopping=True,
|
95 |
no_repeat_ngram_size=3
|
96 |
)
|
97 |
-
|
98 |
def chunk_text(self, text: str, chunk_size: int = 512, overlap: int = 256) -> List[str]:
|
99 |
-
"""Split text into overlapping chunks for
|
100 |
tokens = self.tokenizer.encode(text)
|
101 |
chunks = []
|
102 |
|
@@ -105,7 +109,7 @@ class Summarizer(DeviceAwareModel):
|
|
105 |
chunks.append(self.tokenizer.decode(chunk, skip_special_tokens=True))
|
106 |
|
107 |
return chunks
|
108 |
-
|
109 |
def summarize_text(
|
110 |
self,
|
111 |
text: str,
|
@@ -113,8 +117,7 @@ class Summarizer(DeviceAwareModel):
|
|
113 |
round_idx: int = 0
|
114 |
) -> str:
|
115 |
"""
|
116 |
-
|
117 |
-
and limit the maximum number of re-summarization rounds.
|
118 |
"""
|
119 |
@self.run_on_device
|
120 |
def _summarize_chunk(chunk: str) -> str:
|
@@ -127,28 +130,27 @@ class Summarizer(DeviceAwareModel):
|
|
127 |
)
|
128 |
summary_ids = self._generate_summary(inputs)
|
129 |
return self.tokenizer.decode(summary_ids[0], skip_special_tokens=True)
|
130 |
-
|
131 |
-
#
|
132 |
if round_idx >= self.max_summary_rounds:
|
133 |
return _summarize_chunk(text)
|
134 |
-
|
135 |
-
#
|
136 |
if len(text.split()) > 512 and progressive:
|
137 |
chunks = self.chunk_text(text)
|
138 |
chunk_summaries = [_summarize_chunk(chunk) for chunk in chunks]
|
139 |
-
|
140 |
# Combine chunk-level summaries
|
141 |
combined_summary = " ".join(chunk_summaries)
|
142 |
-
|
143 |
-
# If still too long, do another summarization pass but increment round_idx
|
144 |
if len(combined_summary.split()) > 512:
|
145 |
return self.summarize_text(
|
146 |
combined_summary,
|
147 |
progressive=True,
|
148 |
round_idx=round_idx + 1
|
149 |
)
|
150 |
-
|
151 |
return combined_summary
|
152 |
else:
|
153 |
-
#
|
154 |
-
return _summarize_chunk(text)
|
|
|
13 |
chunk_size: int = 512
|
14 |
chunk_overlap: int = 256
|
15 |
min_confidence_score: float = 0.7
|
16 |
+
|
17 |
class DeviceAwareModel:
|
18 |
+
"""
|
19 |
+
Mixin: Handle device placement and mixed precision training.
|
20 |
+
"""
|
21 |
|
22 |
def setup_device(self, device: str = None):
|
23 |
if device is None:
|
|
|
26 |
self.device = device.upper()
|
27 |
self.strategy = None
|
28 |
|
29 |
+
# NOTE: Needs more testing. Training issues may have been from other bugs I found since this was tested.
|
30 |
+
# Reminder: Test model saving/loading alongside mixed precision settings
|
31 |
if self.device == 'GPU':
|
32 |
# # Enable mixed precision for better performance
|
33 |
# policy = tf.keras.mixed_precision.Policy('mixed_float16')
|
34 |
# tf.keras.mixed_precision.set_global_policy(policy)
|
35 |
|
36 |
+
# Setup multi-GPU if available
|
37 |
gpus = tf.config.list_physical_devices('GPU')
|
38 |
if len(gpus) > 1:
|
39 |
self.strategy = tf.distribute.MirroredStrategy()
|
40 |
|
41 |
return self.device
|
42 |
+
|
43 |
def run_on_device(self, func):
|
44 |
"""Decorator to ensure ops run on the correct device."""
|
45 |
def wrapper(*args, **kwargs):
|
46 |
with tf.device(f'/{self.device}:0'):
|
47 |
return func(*args, **kwargs)
|
48 |
return wrapper
|
49 |
+
|
50 |
class Summarizer(DeviceAwareModel):
|
51 |
"""
|
52 |
+
T5-based summarizer with chunking and device management.
|
53 |
+
Chunking and progressive summarization for long conversations.
|
54 |
"""
|
55 |
+
|
56 |
def __init__(
|
57 |
self,
|
58 |
tokenizer: AutoTokenizer,
|
|
|
61 |
device=None,
|
62 |
max_summary_rounds=2
|
63 |
):
|
64 |
+
self.tokenizer = tokenizer
|
65 |
self.setup_device(device)
|
66 |
|
67 |
+
# Strategy scope if using distribution
|
68 |
if self.strategy:
|
69 |
with self.strategy.scope():
|
70 |
self._setup_model(model_name)
|
|
|
73 |
|
74 |
self.max_summary_length = max_summary_length
|
75 |
self.max_summary_rounds = max_summary_rounds
|
76 |
+
|
77 |
def _setup_model(self, model_name):
|
78 |
self.model = TFAutoModelForSeq2SeqLM.from_pretrained(model_name)
|
79 |
|
80 |
+
# Optimize for inference
|
81 |
self.model.generate = tf.function(
|
82 |
self.model.generate,
|
83 |
input_signature=[
|
|
|
87 |
}
|
88 |
]
|
89 |
)
|
90 |
+
|
91 |
@tf.function
|
92 |
def _generate_summary(self, inputs):
|
93 |
return self.model.generate(
|
|
|
98 |
early_stopping=True,
|
99 |
no_repeat_ngram_size=3
|
100 |
)
|
101 |
+
|
102 |
def chunk_text(self, text: str, chunk_size: int = 512, overlap: int = 256) -> List[str]:
|
103 |
+
"""Split text into overlapping chunks for context preservation."""
|
104 |
tokens = self.tokenizer.encode(text)
|
105 |
chunks = []
|
106 |
|
|
|
109 |
chunks.append(self.tokenizer.decode(chunk, skip_special_tokens=True))
|
110 |
|
111 |
return chunks
|
112 |
+
|
113 |
def summarize_text(
|
114 |
self,
|
115 |
text: str,
|
|
|
117 |
round_idx: int = 0
|
118 |
) -> str:
|
119 |
"""
|
120 |
+
Progressive summarization and limited number of resummarization rounds.
|
|
|
121 |
"""
|
122 |
@self.run_on_device
|
123 |
def _summarize_chunk(chunk: str) -> str:
|
|
|
130 |
)
|
131 |
summary_ids = self._generate_summary(inputs)
|
132 |
return self.tokenizer.decode(summary_ids[0], skip_special_tokens=True)
|
133 |
+
|
134 |
+
# Do a single pass at resummarizing if max_summary rounds is hit
|
135 |
if round_idx >= self.max_summary_rounds:
|
136 |
return _summarize_chunk(text)
|
137 |
+
|
138 |
+
# Chunk and summarize
|
139 |
if len(text.split()) > 512 and progressive:
|
140 |
chunks = self.chunk_text(text)
|
141 |
chunk_summaries = [_summarize_chunk(chunk) for chunk in chunks]
|
142 |
+
|
143 |
# Combine chunk-level summaries
|
144 |
combined_summary = " ".join(chunk_summaries)
|
145 |
+
|
|
|
146 |
if len(combined_summary.split()) > 512:
|
147 |
return self.summarize_text(
|
148 |
combined_summary,
|
149 |
progressive=True,
|
150 |
round_idx=round_idx + 1
|
151 |
)
|
152 |
+
|
153 |
return combined_summary
|
154 |
else:
|
155 |
+
# Summarize once and return
|
156 |
+
return _summarize_chunk(text)
|
cross_encoder_reranker.py
CHANGED
@@ -1,23 +1,19 @@
|
|
1 |
from transformers import AutoTokenizer, TFAutoModelForSequenceClassification
|
2 |
import tensorflow as tf
|
3 |
from typing import List
|
4 |
-
import numpy as np
|
5 |
|
6 |
from logger_config import config_logger
|
7 |
logger = config_logger(__name__)
|
8 |
|
9 |
class CrossEncoderReranker:
|
10 |
"""
|
11 |
-
Cross-Encoder Re-Ranker
|
12 |
-
outputs a single relevance score in [0,1].
|
13 |
"""
|
14 |
def __init__(self, model_name: str = "cross-encoder/ms-marco-MiniLM-L-12-v2"):
|
15 |
"""
|
16 |
-
|
17 |
-
|
18 |
Args:
|
19 |
-
model_name: Name of a HF cross-encoder model. Must be
|
20 |
-
compatible with TFAutoModelForSequenceClassification.
|
21 |
"""
|
22 |
logger.info(f"Initializing CrossEncoderReranker with {model_name}...")
|
23 |
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
|
@@ -31,21 +27,16 @@ class CrossEncoderReranker:
|
|
31 |
max_length: int = 256
|
32 |
) -> List[float]:
|
33 |
"""
|
34 |
-
Compute relevance scores for each candidate w.r.t.
|
35 |
-
|
36 |
Args:
|
37 |
query: User's query text.
|
38 |
candidates: List of candidate response texts.
|
39 |
max_length: Max token length for each (query, candidate) pair.
|
40 |
-
|
41 |
Returns:
|
42 |
-
A list of float scores
|
43 |
-
indicating model's predicted relevance.
|
44 |
"""
|
45 |
-
#
|
46 |
pair_texts = [(query, candidate) for candidate in candidates]
|
47 |
-
|
48 |
-
# 2) Tokenize the entire batch
|
49 |
encodings = self.tokenizer(
|
50 |
pair_texts,
|
51 |
padding=True,
|
@@ -54,24 +45,20 @@ class CrossEncoderReranker:
|
|
54 |
return_tensors="tf"
|
55 |
)
|
56 |
|
57 |
-
#
|
|
|
|
|
58 |
outputs = self.model(
|
59 |
input_ids=encodings["input_ids"],
|
60 |
attention_mask=encodings["attention_mask"],
|
61 |
-
token_type_ids=encodings.get("token_type_ids")
|
62 |
)
|
63 |
|
64 |
logits = outputs.logits # shape [batch_size, 1]
|
65 |
-
# 4) Convert logits -> [0,1] range via sigmoid
|
66 |
-
# If the cross-encoder is a single-logit regression to [0,1],
|
67 |
-
# this is a typical interpretation.
|
68 |
scores = tf.nn.sigmoid(logits) # shape [batch_size, 1]
|
69 |
|
70 |
-
#
|
71 |
scores = tf.reshape(scores, [-1])
|
72 |
scores = scores.numpy().astype(float)
|
73 |
|
74 |
-
# logger.debug(f"Cross-Encoder raw logits: {logits.numpy().flatten().tolist()}")
|
75 |
-
# logger.debug(f"Cross-Encoder sigmoid scores: {scores.tolist()}")
|
76 |
-
|
77 |
return scores.tolist()
|
|
|
1 |
from transformers import AutoTokenizer, TFAutoModelForSequenceClassification
|
2 |
import tensorflow as tf
|
3 |
from typing import List
|
|
|
4 |
|
5 |
from logger_config import config_logger
|
6 |
logger = config_logger(__name__)
|
7 |
|
8 |
class CrossEncoderReranker:
|
9 |
"""
|
10 |
+
Cross-Encoder Re-Ranker. Takes (query, candidate) pairs and outputs a relevance score [0...1].
|
|
|
11 |
"""
|
12 |
def __init__(self, model_name: str = "cross-encoder/ms-marco-MiniLM-L-12-v2"):
|
13 |
"""
|
14 |
+
Init the cross-encoder with a pretrained model.
|
|
|
15 |
Args:
|
16 |
+
model_name: Name of a HF cross-encoder model. Must be compatible with TFAutoModelForSequenceClassification.
|
|
|
17 |
"""
|
18 |
logger.info(f"Initializing CrossEncoderReranker with {model_name}...")
|
19 |
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
|
|
|
27 |
max_length: int = 256
|
28 |
) -> List[float]:
|
29 |
"""
|
30 |
+
Compute relevance scores for each candidate w.r.t. query.
|
|
|
31 |
Args:
|
32 |
query: User's query text.
|
33 |
candidates: List of candidate response texts.
|
34 |
max_length: Max token length for each (query, candidate) pair.
|
|
|
35 |
Returns:
|
36 |
+
A list of float scores [0...1]. One per candidate, indicating model's predicted relevance.
|
|
|
37 |
"""
|
38 |
+
# Build (query, candidate) pairs, then tokenize
|
39 |
pair_texts = [(query, candidate) for candidate in candidates]
|
|
|
|
|
40 |
encodings = self.tokenizer(
|
41 |
pair_texts,
|
42 |
padding=True,
|
|
|
45 |
return_tensors="tf"
|
46 |
)
|
47 |
|
48 |
+
# Forward pass, logits shape [batch_size, 1]
|
49 |
+
# Then convert logits to [0...1] range with sigmoid
|
50 |
+
# Note: token_type_ids are optional. .get() avoids KeyError
|
51 |
outputs = self.model(
|
52 |
input_ids=encodings["input_ids"],
|
53 |
attention_mask=encodings["attention_mask"],
|
54 |
+
token_type_ids=encodings.get("token_type_ids")
|
55 |
)
|
56 |
|
57 |
logits = outputs.logits # shape [batch_size, 1]
|
|
|
|
|
|
|
58 |
scores = tf.nn.sigmoid(logits) # shape [batch_size, 1]
|
59 |
|
60 |
+
# Flatten to 1D NumPy array, ensure float type
|
61 |
scores = tf.reshape(scores, [-1])
|
62 |
scores = scores.numpy().astype(float)
|
63 |
|
|
|
|
|
|
|
64 |
return scores.tolist()
|
tf_data_pipeline.py
CHANGED
@@ -4,6 +4,8 @@ import faiss
|
|
4 |
import tensorflow as tf
|
5 |
import h5py
|
6 |
import math
|
|
|
|
|
7 |
from tqdm import tqdm
|
8 |
import json
|
9 |
from pathlib import Path
|
@@ -46,47 +48,47 @@ class TFDataPipeline:
|
|
46 |
self.max_batch_size = 16 if len(response_pool) < 100 else 64
|
47 |
self.max_retries = max_retries
|
48 |
|
49 |
-
# Build
|
50 |
self._text_domain_map = {}
|
51 |
self.build_text_to_domain_map()
|
52 |
-
|
|
|
53 |
if os.path.exists(index_file_path):
|
54 |
logger.info(f"Loading existing FAISS index from {index_file_path}...")
|
55 |
self.index = faiss.read_index(index_file_path)
|
56 |
self.validate_faiss_index()
|
57 |
logger.info("FAISS index loaded and validated successfully.")
|
58 |
else:
|
59 |
-
# Initialize FAISS index
|
60 |
dimension = self.encoder.config.embedding_dim
|
61 |
self.index = faiss.IndexFlatIP(dimension)
|
62 |
logger.info(f"Initialized FAISS IndexFlatIP with dimension {dimension}.")
|
63 |
|
64 |
if not self.index.is_trained:
|
65 |
-
# Train the index if it's not trained.
|
66 |
dimension = self.query_embeddings_cache[next(iter(self.query_embeddings_cache))].shape[0]
|
67 |
self.index.train(np.array(list(self.query_embeddings_cache.values())).astype(np.float32))
|
68 |
self.index.add(np.array(list(self.query_embeddings_cache.values())).astype(np.float32))
|
69 |
-
|
70 |
def save_embeddings_cache_hdf5(self, cache_file_path: str):
|
71 |
-
"""Save
|
72 |
with h5py.File(cache_file_path, 'w') as hf:
|
73 |
for query, emb in self.query_embeddings_cache.items():
|
74 |
hf.create_dataset(query, data=emb)
|
75 |
logger.info(f"Embeddings cache saved to {cache_file_path}.")
|
76 |
-
|
77 |
def load_embeddings_cache_hdf5(self, cache_file_path: str):
|
78 |
-
"""Load
|
79 |
with h5py.File(cache_file_path, 'r') as hf:
|
80 |
for query in hf.keys():
|
81 |
self.query_embeddings_cache[query] = hf[query][:]
|
82 |
logger.info(f"Embeddings cache loaded from {cache_file_path}.")
|
83 |
-
|
84 |
def save_faiss_index(self, index_file_path: str):
|
85 |
faiss.write_index(self.index, index_file_path)
|
86 |
logger.info(f"FAISS index saved to {index_file_path}")
|
87 |
|
88 |
def load_faiss_index(self, index_file_path: str):
|
89 |
-
"""Load
|
90 |
if os.path.exists(index_file_path):
|
91 |
self.index = faiss.read_index(index_file_path)
|
92 |
logger.info(f"FAISS index loaded from {index_file_path}.")
|
@@ -95,7 +97,7 @@ class TFDataPipeline:
|
|
95 |
raise FileNotFoundError(f"FAISS index file not found at {index_file_path}.")
|
96 |
|
97 |
def validate_faiss_index(self):
|
98 |
-
"""Validates
|
99 |
expected_dim = self.encoder.config.embedding_dim
|
100 |
if self.index.d != expected_dim:
|
101 |
logger.error(f"FAISS index dimension {self.index.d} does not match encoder embedding dimension {expected_dim}.")
|
@@ -114,7 +116,6 @@ class TFDataPipeline:
|
|
114 |
def load_json_training_data(data_path: Union[str, Path], debug_samples: Optional[int] = None) -> List[dict]:
|
115 |
"""
|
116 |
Load training data from a JSON file.
|
117 |
-
|
118 |
Args:
|
119 |
data_path (Union[str, Path]): Path to the JSON file containing dialogues.
|
120 |
debug_samples (Optional[int]): Number of samples to load for debugging.
|
@@ -137,17 +138,16 @@ class TFDataPipeline:
|
|
137 |
|
138 |
logger.info(f"Loaded {len(dialogues)} dialogues.")
|
139 |
return dialogues
|
140 |
-
|
141 |
def collect_responses_with_domain(self, dialogues: List[dict]) -> List[Dict[str, str]]:
|
142 |
"""
|
143 |
-
Extract unique assistant responses
|
144 |
-
Returns
|
145 |
"""
|
146 |
-
response_set = set() #
|
147 |
results = []
|
148 |
-
|
149 |
for dialogue in tqdm(dialogues, desc="Processing Dialogues", unit="dialogue"):
|
150 |
-
# domain is stored at the top level in your new JSON format
|
151 |
domain = dialogue.get('domain', 'other')
|
152 |
turns = dialogue.get('turns', [])
|
153 |
for turn in turns:
|
@@ -155,7 +155,7 @@ class TFDataPipeline:
|
|
155 |
text = turn.get('text', '').strip()
|
156 |
if speaker == 'assistant' and text:
|
157 |
if len(text) <= self.max_length:
|
158 |
-
# Use
|
159 |
key = (domain, text)
|
160 |
if key not in response_set:
|
161 |
response_set.add(key)
|
@@ -163,23 +163,9 @@ class TFDataPipeline:
|
|
163 |
"domain": domain,
|
164 |
"text": text
|
165 |
})
|
166 |
-
|
167 |
logger.info(f"Collected {len(results)} unique assistant responses from dialogues.")
|
168 |
return results
|
169 |
-
# def collect_responses(self, dialogues: List[dict]) -> List[str]:
|
170 |
-
# """Extract unique assistant responses from dialogues."""
|
171 |
-
# response_set = set()
|
172 |
-
# for dialogue in tqdm(dialogues, desc="Processing Dialogues", unit="dialogue"):
|
173 |
-
# turns = dialogue.get('turns', [])
|
174 |
-
# for turn in turns:
|
175 |
-
# speaker = turn.get('speaker')
|
176 |
-
# text = turn.get('text', '').strip()
|
177 |
-
# if speaker == 'assistant' and text:
|
178 |
-
# # Ensure we don't exclude valid shorter responses
|
179 |
-
# if len(text) <= self.max_length:
|
180 |
-
# response_set.add(text)
|
181 |
-
# logger.info(f"Collected {len(response_set)} unique assistant responses from dialogues.")
|
182 |
-
# return list(response_set)
|
183 |
|
184 |
def _extract_pairs_from_dialogue(self, dialogue: dict) -> List[Tuple[str, str]]:
|
185 |
"""Extract query-response pairs from a dialogue."""
|
@@ -203,18 +189,18 @@ class TFDataPipeline:
|
|
203 |
|
204 |
def compute_and_index_response_embeddings(self):
|
205 |
"""
|
206 |
-
|
207 |
-
self.response_pool
|
208 |
"""
|
209 |
logger.info("Computing embeddings for the response pool...")
|
210 |
-
|
211 |
-
# Extract
|
212 |
texts = [resp["text"] for resp in self.response_pool]
|
213 |
logger.debug(f"Total texts to embed: {len(texts)}")
|
214 |
|
215 |
batch_size = getattr(self, 'embedding_batch_size', 64)
|
216 |
embeddings = []
|
217 |
-
|
218 |
with tqdm(total=len(texts), desc="Computing Embeddings", unit="response") as pbar:
|
219 |
for i in range(0, len(texts), batch_size):
|
220 |
batch_texts = texts[i:i+batch_size]
|
@@ -226,36 +212,30 @@ class TFDataPipeline:
|
|
226 |
return_tensors='tf'
|
227 |
)
|
228 |
batch_embeds = self.encoder(encodings['input_ids'], training=False).numpy()
|
229 |
-
|
230 |
embeddings.append(batch_embeds)
|
231 |
pbar.update(len(batch_texts))
|
232 |
-
|
233 |
# Combine embeddings and add to FAISS
|
234 |
all_embeddings = np.vstack(embeddings).astype(np.float32)
|
235 |
logger.info(f"Adding {len(all_embeddings)} response embeddings to FAISS index...")
|
236 |
self.index.add(all_embeddings)
|
237 |
|
238 |
-
#
|
239 |
self.response_embeddings = all_embeddings
|
240 |
logger.info(f"FAISS index now has {self.index.ntotal} vectors.")
|
241 |
|
242 |
-
def
|
243 |
"""
|
244 |
Find hard negatives for a batch of queries using FAISS search.
|
245 |
-
|
246 |
-
Uses domain-based fallback if possible.
|
247 |
"""
|
248 |
-
import random
|
249 |
-
import gc
|
250 |
-
|
251 |
retry_count = 0
|
252 |
total_responses = len(self.response_pool)
|
253 |
-
|
254 |
-
batch_size = 128
|
255 |
-
|
256 |
while retry_count < self.max_retries:
|
257 |
try:
|
258 |
-
#
|
259 |
query_embeddings = []
|
260 |
for i in range(0, len(queries), batch_size):
|
261 |
sub_queries = queries[i : i + batch_size]
|
@@ -263,23 +243,24 @@ class TFDataPipeline:
|
|
263 |
sub_embeds = np.vstack(sub_embeds).astype(np.float32)
|
264 |
faiss.normalize_L2(sub_embeds) # If not already normalized
|
265 |
query_embeddings.append(sub_embeds)
|
266 |
-
|
267 |
query_embeddings = np.vstack(query_embeddings)
|
268 |
query_embeddings = np.ascontiguousarray(query_embeddings)
|
269 |
-
|
270 |
-
#
|
271 |
-
distances, indices = self.index.search(query_embeddings,
|
272 |
-
|
273 |
all_negatives = []
|
274 |
-
#
|
275 |
for query_indices, query_text, pos_text in zip(indices, queries, positives):
|
276 |
negative_list = []
|
|
|
|
|
277 |
seen = {pos_text.strip()}
|
278 |
-
|
279 |
-
# Attempt to detect the domain of the positive text
|
280 |
domain_of_positive = self._detect_domain_for_text(pos_text)
|
281 |
-
|
282 |
-
# Collect hard negatives from
|
283 |
for idx in query_indices:
|
284 |
if 0 <= idx < total_responses:
|
285 |
candidate_dict = self.response_pool[idx] # e.g. {domain, text}
|
@@ -289,18 +270,18 @@ class TFDataPipeline:
|
|
289 |
negative_list.append(candidate_text)
|
290 |
if len(negative_list) >= self.neg_samples:
|
291 |
break
|
292 |
-
|
293 |
-
#
|
294 |
if len(negative_list) < self.neg_samples:
|
295 |
needed = self.neg_samples - len(negative_list)
|
296 |
-
|
297 |
random_negatives = self._get_random_negatives(needed, seen, domain=domain_of_positive)
|
298 |
negative_list.extend(random_negatives)
|
299 |
-
|
300 |
all_negatives.append(negative_list)
|
301 |
-
|
302 |
return all_negatives
|
303 |
-
|
304 |
except KeyError as ke:
|
305 |
retry_count += 1
|
306 |
logger.warning(f"Hard negative search attempt {retry_count} failed due to missing embeddings: {ke}")
|
@@ -310,7 +291,7 @@ class TFDataPipeline:
|
|
310 |
gc.collect()
|
311 |
if tf.config.list_physical_devices('GPU'):
|
312 |
tf.keras.backend.clear_session()
|
313 |
-
|
314 |
except Exception as e:
|
315 |
retry_count += 1
|
316 |
logger.warning(f"Hard negative search attempt {retry_count} failed: {e}")
|
@@ -320,29 +301,27 @@ class TFDataPipeline:
|
|
320 |
gc.collect()
|
321 |
if tf.config.list_physical_devices('GPU'):
|
322 |
tf.keras.backend.clear_session()
|
323 |
-
|
324 |
def _detect_domain_for_text(self, text: str) -> Optional[str]:
|
325 |
"""
|
326 |
-
|
327 |
-
Returns the domain if found, else None.
|
328 |
"""
|
329 |
stripped_text = text.strip()
|
330 |
return self._text_domain_map.get(stripped_text, None)
|
331 |
|
332 |
def _get_random_negatives(self, needed: int, seen: set, domain: Optional[str] = None) -> List[str]:
|
333 |
"""
|
334 |
-
Return a list of
|
335 |
-
otherwise fallback to all-domain.
|
336 |
"""
|
337 |
-
#
|
338 |
if domain:
|
339 |
domain_texts = [r["text"] for r in self.response_pool if r["domain"] == domain]
|
340 |
# fallback to entire set if insufficient domain_texts
|
341 |
-
if len(domain_texts) < needed * 2:
|
342 |
domain_texts = [r["text"] for r in self.response_pool]
|
343 |
else:
|
344 |
domain_texts = [r["text"] for r in self.response_pool]
|
345 |
-
|
346 |
negatives = []
|
347 |
tries = 0
|
348 |
max_tries = needed * 10
|
@@ -352,8 +331,7 @@ class TFDataPipeline:
|
|
352 |
if candidate and candidate not in seen:
|
353 |
negatives.append(candidate)
|
354 |
seen.add(candidate)
|
355 |
-
|
356 |
-
# If still not enough, we do the best we can
|
357 |
if len(negatives) < needed:
|
358 |
logger.warning(f"Could not find enough domain-based random negatives; needed {needed}, got {len(negatives)}.")
|
359 |
|
@@ -369,47 +347,44 @@ class TFDataPipeline:
|
|
369 |
all_negatives = []
|
370 |
|
371 |
for pos_text in positives:
|
372 |
-
# Build a 'seen' set with the positive
|
373 |
seen = {pos_text.strip()}
|
374 |
-
|
375 |
-
#
|
376 |
domain_of_positive = self._detect_domain_for_text(pos_text)
|
377 |
-
|
378 |
-
# Use domain-based
|
379 |
negs = self._get_random_negatives(self.neg_samples, seen, domain=domain_of_positive)
|
380 |
all_negatives.append(negs)
|
381 |
-
|
382 |
return all_negatives
|
383 |
|
384 |
def build_text_to_domain_map(self):
|
385 |
"""
|
386 |
-
Build
|
387 |
-
so we don't have to scan the entire self.response_pool each time.
|
388 |
"""
|
389 |
self._text_domain_map = {}
|
390 |
-
|
391 |
for item in self.response_pool:
|
392 |
-
# e.g., item = {"domain": "restaurant", "text": "some text..."}
|
393 |
stripped_text = item["text"].strip()
|
394 |
domain = item["domain"]
|
395 |
-
|
396 |
-
# If the same text appears multiple times with the same domain, no big deal.
|
397 |
-
# If it appears with a different domain, you can decide how to handle collisions.
|
398 |
if stripped_text in self._text_domain_map:
|
399 |
-
existing_domain = self._text_domain_map[stripped_text]
|
400 |
-
if existing_domain != domain:
|
401 |
-
#
|
402 |
-
|
403 |
-
|
404 |
-
|
405 |
-
|
|
|
406 |
# By default, keep the first domain or overwrite. We'll skip overwriting:
|
407 |
continue
|
408 |
else:
|
409 |
# Insert into the dict
|
410 |
self._text_domain_map[stripped_text] = domain
|
411 |
-
|
412 |
-
logger.info(f"Built text->domain map with {len(self._text_domain_map)} unique text entries.")
|
413 |
|
414 |
def encode_query(
|
415 |
self,
|
@@ -422,11 +397,10 @@ class TFDataPipeline:
|
|
422 |
Args:
|
423 |
query: The user query.
|
424 |
context: Optional conversation history as a list of (user_text, assistant_text).
|
425 |
-
|
426 |
Returns:
|
427 |
np.ndarray of shape [embedding_dim], typically L2-normalized already.
|
428 |
"""
|
429 |
-
#
|
430 |
if context:
|
431 |
# Take the last N turns
|
432 |
relevant_history = context[-self.config.max_context_turns:]
|
@@ -438,18 +412,18 @@ class TFDataPipeline:
|
|
438 |
)
|
439 |
context_str = " ".join(context_str_parts)
|
440 |
|
441 |
-
# Append the
|
442 |
full_query = (
|
443 |
f"{context_str} "
|
444 |
f"{self.tokenizer.additional_special_tokens[self.tokenizer.additional_special_tokens.index('<USER>')]} {query}"
|
445 |
)
|
446 |
else:
|
447 |
-
#
|
448 |
full_query = (
|
449 |
f"{self.tokenizer.additional_special_tokens[self.tokenizer.additional_special_tokens.index('<USER>')]} {query}"
|
450 |
)
|
451 |
|
452 |
-
#
|
453 |
encodings = self.tokenizer(
|
454 |
[full_query],
|
455 |
padding='max_length',
|
@@ -459,20 +433,18 @@ class TFDataPipeline:
|
|
459 |
)
|
460 |
input_ids = encodings['input_ids']
|
461 |
|
462 |
-
#
|
463 |
max_id = np.max(input_ids)
|
464 |
vocab_size = len(self.tokenizer)
|
465 |
if max_id >= vocab_size:
|
466 |
logger.error(f"Token ID {max_id} exceeds tokenizer vocab size {vocab_size}.")
|
467 |
raise ValueError("Token ID exceeds vocabulary size.")
|
468 |
|
469 |
-
#
|
470 |
embeddings = self.encoder(input_ids, training=False).numpy()
|
471 |
-
|
472 |
-
|
473 |
-
# 5) Return the single embedding as 1D array
|
474 |
return embeddings[0]
|
475 |
-
|
476 |
def encode_responses(
|
477 |
self,
|
478 |
responses: List[str],
|
@@ -480,16 +452,13 @@ class TFDataPipeline:
|
|
480 |
) -> np.ndarray:
|
481 |
"""
|
482 |
Encode multiple response texts into embedding vectors.
|
483 |
-
|
484 |
Args:
|
485 |
-
responses: List of
|
486 |
context: Optional conversation context (last N turns).
|
487 |
-
|
488 |
Returns:
|
489 |
np.ndarray of shape [num_responses, embedding_dim].
|
490 |
"""
|
491 |
-
#
|
492 |
-
# Usually for retrieval we might skip this. But if you want it:
|
493 |
if context:
|
494 |
relevant_history = context[-self.config.max_context_turns:]
|
495 |
prepared = []
|
@@ -501,21 +470,21 @@ class TFDataPipeline:
|
|
501 |
f"{self.tokenizer.additional_special_tokens[self.tokenizer.additional_special_tokens.index('<ASSISTANT>')]} {a_text}"
|
502 |
)
|
503 |
context_str = " ".join(context_str_parts)
|
504 |
-
|
505 |
-
#
|
506 |
full_resp = (
|
507 |
f"{context_str} "
|
508 |
f"{self.tokenizer.additional_special_tokens[self.tokenizer.additional_special_tokens.index('<ASSISTANT>')]} {resp}"
|
509 |
)
|
510 |
prepared.append(full_resp)
|
511 |
else:
|
512 |
-
#
|
513 |
prepared = [
|
514 |
f"{self.tokenizer.additional_special_tokens[self.tokenizer.additional_special_tokens.index('<ASSISTANT>')]} {r}"
|
515 |
for r in responses
|
516 |
]
|
517 |
-
|
518 |
-
#
|
519 |
encodings = self.tokenizer(
|
520 |
prepared,
|
521 |
padding='max_length',
|
@@ -524,28 +493,22 @@ class TFDataPipeline:
|
|
524 |
return_tensors='np'
|
525 |
)
|
526 |
input_ids = encodings['input_ids']
|
527 |
-
|
528 |
-
#
|
529 |
max_id = np.max(input_ids)
|
530 |
vocab_size = len(self.tokenizer)
|
531 |
if max_id >= vocab_size:
|
532 |
logger.error(f"Token ID {max_id} exceeds tokenizer vocab size {vocab_size}.")
|
533 |
raise ValueError("Token ID exceeds vocabulary size.")
|
534 |
-
|
535 |
-
#
|
536 |
embeddings = self.encoder(input_ids, training=False).numpy()
|
537 |
-
|
538 |
-
|
539 |
return embeddings.astype('float32')
|
540 |
|
541 |
def prepare_and_save_data(self, dialogues: List[dict], tf_record_path: str, batch_size: int = 32):
|
542 |
"""
|
543 |
-
|
544 |
-
|
545 |
-
Args:
|
546 |
-
dialogues (List[dict]): List of dialogue dictionaries.
|
547 |
-
tf_record_path (str): Path to save the TFRecord file.
|
548 |
-
batch_size (int): Number of dialogues to process per batch.
|
549 |
"""
|
550 |
logger.info(f"Preparing and saving data to {tf_record_path}...")
|
551 |
|
@@ -553,14 +516,13 @@ class TFDataPipeline:
|
|
553 |
num_batches = math.ceil(num_dialogues / batch_size)
|
554 |
|
555 |
with tf.io.TFRecordWriter(tf_record_path) as writer:
|
556 |
-
# Initialize progress bar
|
557 |
with tqdm(total=num_batches, desc="Preparing Data Batches", unit="batch") as pbar:
|
558 |
for i in range(num_batches):
|
559 |
start_idx = i * batch_size
|
560 |
end_idx = min(start_idx + batch_size, num_dialogues)
|
561 |
batch_dialogues = dialogues[start_idx:end_idx]
|
562 |
|
563 |
-
# Extract
|
564 |
queries = []
|
565 |
positives = []
|
566 |
for dialogue in batch_dialogues:
|
@@ -572,7 +534,7 @@ class TFDataPipeline:
|
|
572 |
|
573 |
if not queries:
|
574 |
pbar.update(1)
|
575 |
-
continue
|
576 |
|
577 |
# Compute and cache query embeddings
|
578 |
try:
|
@@ -580,11 +542,11 @@ class TFDataPipeline:
|
|
580 |
except Exception as e:
|
581 |
logger.error(f"Error computing embeddings: {e}")
|
582 |
pbar.update(1)
|
583 |
-
continue
|
584 |
|
585 |
-
# Find hard negatives
|
586 |
try:
|
587 |
-
hard_negatives = self.
|
588 |
except Exception as e:
|
589 |
logger.error(f"Error finding hard negatives: {e}")
|
590 |
pbar.update(1)
|
@@ -611,8 +573,8 @@ class TFDataPipeline:
|
|
611 |
pbar.update(1)
|
612 |
continue # Skip to the next batch
|
613 |
|
614 |
-
# Flatten hard_negatives
|
615 |
-
#
|
616 |
try:
|
617 |
flattened_negatives = [neg for sublist in hard_negatives for neg in sublist]
|
618 |
encoded_negatives = self.tokenizer.batch_encode_plus(
|
@@ -623,15 +585,15 @@ class TFDataPipeline:
|
|
623 |
return_tensors='tf'
|
624 |
)
|
625 |
|
626 |
-
# Reshape
|
627 |
num_negatives = self.config.neg_samples
|
628 |
reshaped_negatives = encoded_negatives['input_ids'].numpy().reshape(-1, num_negatives, self.config.max_context_token_limit)
|
629 |
except Exception as e:
|
630 |
logger.error(f"Error during negatives tokenization: {e}")
|
631 |
pbar.update(1)
|
632 |
-
continue
|
633 |
|
634 |
-
# Serialize
|
635 |
for j in range(len(queries)):
|
636 |
try:
|
637 |
q_id = encoded_queries['input_ids'][j].numpy()
|
@@ -655,11 +617,14 @@ class TFDataPipeline:
|
|
655 |
logger.info(f"Data preparation complete. TFRecord saved.")
|
656 |
|
657 |
def _compute_embeddings(self, queries: List[str]) -> None:
|
|
|
|
|
|
|
658 |
new_queries = [q for q in queries if q not in self.query_embeddings_cache]
|
659 |
if not new_queries:
|
660 |
-
return
|
661 |
-
|
662 |
-
# Compute embeddings
|
663 |
new_embeddings = []
|
664 |
for i in range(0, len(new_queries), self.embedding_batch_size):
|
665 |
batch_queries = new_queries[i:i + self.embedding_batch_size]
|
@@ -673,49 +638,46 @@ class TFDataPipeline:
|
|
673 |
batch_embeddings = self.encoder(encoded['input_ids'], training=False).numpy()
|
674 |
faiss.normalize_L2(batch_embeddings)
|
675 |
new_embeddings.extend(batch_embeddings)
|
676 |
-
|
677 |
# Update the cache
|
678 |
for query, emb in zip(new_queries, new_embeddings):
|
679 |
self.query_embeddings_cache[query] = emb
|
680 |
-
|
681 |
def data_generator(self, dialogues: List[dict]) -> Generator[Tuple[str, str, List[str]], None, None]:
|
682 |
"""
|
683 |
-
|
684 |
-
Wrapped the outer loop with tqdm for progress tracking.
|
685 |
"""
|
686 |
total_dialogues = len(dialogues)
|
687 |
logger.debug(f"Total dialogues to process: {total_dialogues}")
|
688 |
-
|
689 |
-
# Initialize tqdm progress bar
|
690 |
with tqdm(total=total_dialogues, desc="Processing Dialogues", unit="dialogue") as pbar:
|
691 |
for dialogue in dialogues:
|
692 |
pairs = self._extract_pairs_from_dialogue(dialogue)
|
693 |
for query, positive in pairs:
|
694 |
# Ensure embeddings are computed, find hard negatives, etc.
|
695 |
self._compute_embeddings([query])
|
696 |
-
hard_negatives = self.
|
697 |
yield (query, positive, hard_negatives)
|
698 |
pbar.update(1)
|
699 |
|
700 |
def get_tf_dataset(self, dialogues: List[dict], batch_size: int) -> tf.data.Dataset:
|
701 |
"""
|
702 |
-
Creates a tf.data.Dataset for streaming training
|
703 |
-
(input_ids_query, input_ids_positive, input_ids_negatives).
|
704 |
"""
|
705 |
# 1) Start with a generator dataset
|
706 |
dataset = tf.data.Dataset.from_generator(
|
707 |
lambda: self.data_generator(dialogues),
|
708 |
output_signature=(
|
709 |
-
tf.TensorSpec(shape=(), dtype=tf.string),
|
710 |
-
tf.TensorSpec(shape=(), dtype=tf.string),
|
711 |
-
tf.TensorSpec(shape=(self.neg_samples,), dtype=tf.string)
|
712 |
)
|
713 |
)
|
714 |
|
715 |
-
#
|
|
|
716 |
dataset = dataset.batch(batch_size, drop_remainder=True)
|
717 |
-
|
718 |
-
# 3) Map them through a tokenize step using `tf.py_function`
|
719 |
dataset = dataset.map(
|
720 |
lambda q, p, n: self._tokenize_triple(q, p, n),
|
721 |
num_parallel_calls=1 #tf.data.AUTOTUNE
|
@@ -731,22 +693,19 @@ class TFDataPipeline:
|
|
731 |
n: tf.Tensor
|
732 |
) -> Tuple[tf.Tensor, tf.Tensor, tf.Tensor]:
|
733 |
"""
|
734 |
-
Wraps a Python function
|
735 |
-
|
736 |
-
|
737 |
-
q is shape [batch_size], p is shape [batch_size],
|
738 |
-
n is shape [batch_size, neg_samples] (i.e., each row is a list of negatives).
|
739 |
"""
|
740 |
-
# Use tf.py_function
|
741 |
q_ids, p_ids, n_ids = tf.py_function(
|
742 |
func=self._tokenize_triple_py,
|
743 |
inp=[q, p, n, tf.constant(self.max_length), tf.constant(self.neg_samples)],
|
744 |
Tout=[tf.int32, tf.int32, tf.int32]
|
745 |
)
|
746 |
|
747 |
-
#
|
748 |
-
q_ids.set_shape([None, self.max_length])
|
749 |
-
p_ids.set_shape([None, self.max_length])
|
750 |
n_ids.set_shape([None, self.neg_samples, self.max_length]) # [batch_size, neg_samples, max_length]
|
751 |
|
752 |
return q_ids, p_ids, n_ids
|
@@ -760,32 +719,30 @@ class TFDataPipeline:
|
|
760 |
neg_samples: tf.Tensor
|
761 |
) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
|
762 |
"""
|
763 |
-
Python
|
764 |
-
|
765 |
-
|
766 |
-
- Reshapes negatives
|
767 |
-
- Returns np.array of int32s for (q_ids, p_ids, n_ids).
|
768 |
|
769 |
q: shape [batch_size], p: shape [batch_size]
|
770 |
n: shape [batch_size, neg_samples]
|
771 |
-
max_len:
|
772 |
-
neg_samples:
|
773 |
"""
|
774 |
-
max_len = int(max_len.numpy())
|
775 |
neg_samples = int(neg_samples.numpy())
|
776 |
|
777 |
-
#
|
778 |
q_list = [q_i.decode("utf-8") for q_i in q.numpy()] # shape [batch_size]
|
779 |
p_list = [p_i.decode("utf-8") for p_i in p.numpy()] # shape [batch_size]
|
780 |
|
781 |
-
#
|
782 |
n_list = []
|
783 |
for row in n.numpy():
|
784 |
# row is shape [neg_samples], each is a tf.string
|
785 |
decoded = [neg.decode("utf-8") for neg in row]
|
786 |
n_list.append(decoded)
|
787 |
|
788 |
-
#
|
789 |
q_enc = self.tokenizer(
|
790 |
q_list,
|
791 |
padding="max_length",
|
@@ -801,11 +758,11 @@ class TFDataPipeline:
|
|
801 |
return_tensors="np"
|
802 |
)
|
803 |
|
804 |
-
#
|
805 |
-
# Flatten [batch_size, neg_samples] ->
|
806 |
flattened_negatives = [neg for row in n_list for neg in row]
|
807 |
if len(flattened_negatives) == 0:
|
808 |
-
# No negatives
|
809 |
n_ids = np.zeros((len(q_list), neg_samples, max_len), dtype=np.int32)
|
810 |
else:
|
811 |
n_enc = self.tokenizer(
|
@@ -815,11 +772,10 @@ class TFDataPipeline:
|
|
815 |
max_length=max_len,
|
816 |
return_tensors="np"
|
817 |
)
|
818 |
-
#
|
819 |
n_input_ids = n_enc["input_ids"]
|
820 |
|
821 |
-
#
|
822 |
-
# Handle cases where there might be fewer negatives
|
823 |
batch_size = len(q_list)
|
824 |
n_ids_list = []
|
825 |
for i in range(batch_size):
|
@@ -827,7 +783,7 @@ class TFDataPipeline:
|
|
827 |
end_idx = start_idx + neg_samples
|
828 |
row_negs = n_input_ids[start_idx:end_idx]
|
829 |
|
830 |
-
#
|
831 |
if row_negs.shape[0] < neg_samples:
|
832 |
deficit = neg_samples - row_negs.shape[0]
|
833 |
pad_arr = np.zeros((deficit, max_len), dtype=np.int32)
|
@@ -835,10 +791,10 @@ class TFDataPipeline:
|
|
835 |
|
836 |
n_ids_list.append(row_negs)
|
837 |
|
838 |
-
#
|
839 |
n_ids = np.stack(n_ids_list, axis=0)
|
840 |
|
841 |
-
#
|
842 |
q_ids = q_enc["input_ids"].astype(np.int32) # shape [batch_size, max_len]
|
843 |
p_ids = p_enc["input_ids"].astype(np.int32) # shape [batch_size, max_len]
|
844 |
n_ids = n_ids.astype(np.int32) # shape [batch_size, neg_samples, max_len]
|
|
|
4 |
import tensorflow as tf
|
5 |
import h5py
|
6 |
import math
|
7 |
+
import random
|
8 |
+
import gc
|
9 |
from tqdm import tqdm
|
10 |
import json
|
11 |
from pathlib import Path
|
|
|
48 |
self.max_batch_size = 16 if len(response_pool) < 100 else 64
|
49 |
self.max_retries = max_retries
|
50 |
|
51 |
+
# Build text -> domain map for O(1) domain lookups (hard negative sampling)
|
52 |
self._text_domain_map = {}
|
53 |
self.build_text_to_domain_map()
|
54 |
+
|
55 |
+
# Initialize FAISS index
|
56 |
if os.path.exists(index_file_path):
|
57 |
logger.info(f"Loading existing FAISS index from {index_file_path}...")
|
58 |
self.index = faiss.read_index(index_file_path)
|
59 |
self.validate_faiss_index()
|
60 |
logger.info("FAISS index loaded and validated successfully.")
|
61 |
else:
|
|
|
62 |
dimension = self.encoder.config.embedding_dim
|
63 |
self.index = faiss.IndexFlatIP(dimension)
|
64 |
logger.info(f"Initialized FAISS IndexFlatIP with dimension {dimension}.")
|
65 |
|
66 |
if not self.index.is_trained:
|
67 |
+
# Train the index if it's not trained. IndexFlatIP doesn't need training, but others do (Future switch to IndexIVFFlat)
|
68 |
dimension = self.query_embeddings_cache[next(iter(self.query_embeddings_cache))].shape[0]
|
69 |
self.index.train(np.array(list(self.query_embeddings_cache.values())).astype(np.float32))
|
70 |
self.index.add(np.array(list(self.query_embeddings_cache.values())).astype(np.float32))
|
71 |
+
|
72 |
def save_embeddings_cache_hdf5(self, cache_file_path: str):
|
73 |
+
"""Save embeddings cache to HDF5 file."""
|
74 |
with h5py.File(cache_file_path, 'w') as hf:
|
75 |
for query, emb in self.query_embeddings_cache.items():
|
76 |
hf.create_dataset(query, data=emb)
|
77 |
logger.info(f"Embeddings cache saved to {cache_file_path}.")
|
78 |
+
|
79 |
def load_embeddings_cache_hdf5(self, cache_file_path: str):
|
80 |
+
"""Load embeddings cache from HDF5 file."""
|
81 |
with h5py.File(cache_file_path, 'r') as hf:
|
82 |
for query in hf.keys():
|
83 |
self.query_embeddings_cache[query] = hf[query][:]
|
84 |
logger.info(f"Embeddings cache loaded from {cache_file_path}.")
|
85 |
+
|
86 |
def save_faiss_index(self, index_file_path: str):
|
87 |
faiss.write_index(self.index, index_file_path)
|
88 |
logger.info(f"FAISS index saved to {index_file_path}")
|
89 |
|
90 |
def load_faiss_index(self, index_file_path: str):
|
91 |
+
"""Load FAISS index from specified file path."""
|
92 |
if os.path.exists(index_file_path):
|
93 |
self.index = faiss.read_index(index_file_path)
|
94 |
logger.info(f"FAISS index loaded from {index_file_path}.")
|
|
|
97 |
raise FileNotFoundError(f"FAISS index file not found at {index_file_path}.")
|
98 |
|
99 |
def validate_faiss_index(self):
|
100 |
+
"""Validates FAISS index dimensionality."""
|
101 |
expected_dim = self.encoder.config.embedding_dim
|
102 |
if self.index.d != expected_dim:
|
103 |
logger.error(f"FAISS index dimension {self.index.d} does not match encoder embedding dimension {expected_dim}.")
|
|
|
116 |
def load_json_training_data(data_path: Union[str, Path], debug_samples: Optional[int] = None) -> List[dict]:
|
117 |
"""
|
118 |
Load training data from a JSON file.
|
|
|
119 |
Args:
|
120 |
data_path (Union[str, Path]): Path to the JSON file containing dialogues.
|
121 |
debug_samples (Optional[int]): Number of samples to load for debugging.
|
|
|
138 |
|
139 |
logger.info(f"Loaded {len(dialogues)} dialogues.")
|
140 |
return dialogues
|
141 |
+
|
142 |
def collect_responses_with_domain(self, dialogues: List[dict]) -> List[Dict[str, str]]:
|
143 |
"""
|
144 |
+
Extract unique assistant responses and their domains from dialogues.
|
145 |
+
Returns List[Dict[str: "domain", str: text"]]
|
146 |
"""
|
147 |
+
response_set = set() # Store (domain, text) unique tuples
|
148 |
results = []
|
149 |
+
|
150 |
for dialogue in tqdm(dialogues, desc="Processing Dialogues", unit="dialogue"):
|
|
|
151 |
domain = dialogue.get('domain', 'other')
|
152 |
turns = dialogue.get('turns', [])
|
153 |
for turn in turns:
|
|
|
155 |
text = turn.get('text', '').strip()
|
156 |
if speaker == 'assistant' and text:
|
157 |
if len(text) <= self.max_length:
|
158 |
+
# Use tuple as set key to ensure uniqueness
|
159 |
key = (domain, text)
|
160 |
if key not in response_set:
|
161 |
response_set.add(key)
|
|
|
163 |
"domain": domain,
|
164 |
"text": text
|
165 |
})
|
166 |
+
|
167 |
logger.info(f"Collected {len(results)} unique assistant responses from dialogues.")
|
168 |
return results
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
169 |
|
170 |
def _extract_pairs_from_dialogue(self, dialogue: dict) -> List[Tuple[str, str]]:
|
171 |
"""Extract query-response pairs from a dialogue."""
|
|
|
189 |
|
190 |
def compute_and_index_response_embeddings(self):
|
191 |
"""
|
192 |
+
Compute embeddings for the response pool and add them to the FAISS index.
|
193 |
+
self.response_pool: List[Dict[str, str]] with keys "domain" and "text".
|
194 |
"""
|
195 |
logger.info("Computing embeddings for the response pool...")
|
196 |
+
|
197 |
+
# Extract the assistant text
|
198 |
texts = [resp["text"] for resp in self.response_pool]
|
199 |
logger.debug(f"Total texts to embed: {len(texts)}")
|
200 |
|
201 |
batch_size = getattr(self, 'embedding_batch_size', 64)
|
202 |
embeddings = []
|
203 |
+
|
204 |
with tqdm(total=len(texts), desc="Computing Embeddings", unit="response") as pbar:
|
205 |
for i in range(0, len(texts), batch_size):
|
206 |
batch_texts = texts[i:i+batch_size]
|
|
|
212 |
return_tensors='tf'
|
213 |
)
|
214 |
batch_embeds = self.encoder(encodings['input_ids'], training=False).numpy()
|
215 |
+
|
216 |
embeddings.append(batch_embeds)
|
217 |
pbar.update(len(batch_texts))
|
218 |
+
|
219 |
# Combine embeddings and add to FAISS
|
220 |
all_embeddings = np.vstack(embeddings).astype(np.float32)
|
221 |
logger.info(f"Adding {len(all_embeddings)} response embeddings to FAISS index...")
|
222 |
self.index.add(all_embeddings)
|
223 |
|
224 |
+
# Store in memory
|
225 |
self.response_embeddings = all_embeddings
|
226 |
logger.info(f"FAISS index now has {self.index.ntotal} vectors.")
|
227 |
|
228 |
+
def _find_hard_negatives(self, queries: List[str], positives: List[str], batch_size: int = 128) -> List[List[str]]:
|
229 |
"""
|
230 |
Find hard negatives for a batch of queries using FAISS search.
|
231 |
+
Fallback: in-domain negatives, then random negatives when needed.
|
|
|
232 |
"""
|
|
|
|
|
|
|
233 |
retry_count = 0
|
234 |
total_responses = len(self.response_pool)
|
235 |
+
|
|
|
|
|
236 |
while retry_count < self.max_retries:
|
237 |
try:
|
238 |
+
# Build query embeddings from the cache
|
239 |
query_embeddings = []
|
240 |
for i in range(0, len(queries), batch_size):
|
241 |
sub_queries = queries[i : i + batch_size]
|
|
|
243 |
sub_embeds = np.vstack(sub_embeds).astype(np.float32)
|
244 |
faiss.normalize_L2(sub_embeds) # If not already normalized
|
245 |
query_embeddings.append(sub_embeds)
|
246 |
+
|
247 |
query_embeddings = np.vstack(query_embeddings)
|
248 |
query_embeddings = np.ascontiguousarray(query_embeddings)
|
249 |
+
|
250 |
+
# FAISS search for nearest neighbors (hard negatives)
|
251 |
+
distances, indices = self.index.search(query_embeddings, self.neg_samples)
|
252 |
+
|
253 |
all_negatives = []
|
254 |
+
# Extract domain from the positive assistant response
|
255 |
for query_indices, query_text, pos_text in zip(indices, queries, positives):
|
256 |
negative_list = []
|
257 |
+
|
258 |
+
# Build a 'seen' set with the positive
|
259 |
seen = {pos_text.strip()}
|
260 |
+
|
|
|
261 |
domain_of_positive = self._detect_domain_for_text(pos_text)
|
262 |
+
|
263 |
+
# Collect hard negatives (from config self.neg_samples)
|
264 |
for idx in query_indices:
|
265 |
if 0 <= idx < total_responses:
|
266 |
candidate_dict = self.response_pool[idx] # e.g. {domain, text}
|
|
|
270 |
negative_list.append(candidate_text)
|
271 |
if len(negative_list) >= self.neg_samples:
|
272 |
break
|
273 |
+
|
274 |
+
# Fall back to random domain-based
|
275 |
if len(negative_list) < self.neg_samples:
|
276 |
needed = self.neg_samples - len(negative_list)
|
277 |
+
|
278 |
random_negatives = self._get_random_negatives(needed, seen, domain=domain_of_positive)
|
279 |
negative_list.extend(random_negatives)
|
280 |
+
|
281 |
all_negatives.append(negative_list)
|
282 |
+
|
283 |
return all_negatives
|
284 |
+
|
285 |
except KeyError as ke:
|
286 |
retry_count += 1
|
287 |
logger.warning(f"Hard negative search attempt {retry_count} failed due to missing embeddings: {ke}")
|
|
|
291 |
gc.collect()
|
292 |
if tf.config.list_physical_devices('GPU'):
|
293 |
tf.keras.backend.clear_session()
|
294 |
+
|
295 |
except Exception as e:
|
296 |
retry_count += 1
|
297 |
logger.warning(f"Hard negative search attempt {retry_count} failed: {e}")
|
|
|
301 |
gc.collect()
|
302 |
if tf.config.list_physical_devices('GPU'):
|
303 |
tf.keras.backend.clear_session()
|
304 |
+
|
305 |
def _detect_domain_for_text(self, text: str) -> Optional[str]:
|
306 |
"""
|
307 |
+
Domain detection for related negatives.
|
|
|
308 |
"""
|
309 |
stripped_text = text.strip()
|
310 |
return self._text_domain_map.get(stripped_text, None)
|
311 |
|
312 |
def _get_random_negatives(self, needed: int, seen: set, domain: Optional[str] = None) -> List[str]:
|
313 |
"""
|
314 |
+
Return a list of negative texts from the same domain. Fall back to any domain.
|
|
|
315 |
"""
|
316 |
+
# Filter response_pool for domain
|
317 |
if domain:
|
318 |
domain_texts = [r["text"] for r in self.response_pool if r["domain"] == domain]
|
319 |
# fallback to entire set if insufficient domain_texts
|
320 |
+
if len(domain_texts) < needed * 2:
|
321 |
domain_texts = [r["text"] for r in self.response_pool]
|
322 |
else:
|
323 |
domain_texts = [r["text"] for r in self.response_pool]
|
324 |
+
|
325 |
negatives = []
|
326 |
tries = 0
|
327 |
max_tries = needed * 10
|
|
|
331 |
if candidate and candidate not in seen:
|
332 |
negatives.append(candidate)
|
333 |
seen.add(candidate)
|
334 |
+
|
|
|
335 |
if len(negatives) < needed:
|
336 |
logger.warning(f"Could not find enough domain-based random negatives; needed {needed}, got {len(negatives)}.")
|
337 |
|
|
|
347 |
all_negatives = []
|
348 |
|
349 |
for pos_text in positives:
|
350 |
+
# Build a 'seen' set with the positive assistant response
|
351 |
seen = {pos_text.strip()}
|
352 |
+
|
353 |
+
# Detect domain of the positive
|
354 |
domain_of_positive = self._detect_domain_for_text(pos_text)
|
355 |
+
|
356 |
+
# Use domain-based negatives when available
|
357 |
negs = self._get_random_negatives(self.neg_samples, seen, domain=domain_of_positive)
|
358 |
all_negatives.append(negs)
|
359 |
+
|
360 |
return all_negatives
|
361 |
|
362 |
def build_text_to_domain_map(self):
|
363 |
"""
|
364 |
+
Build O(1) lookup dict: text -> domain for hard negative sampling.
|
|
|
365 |
"""
|
366 |
self._text_domain_map = {}
|
367 |
+
|
368 |
for item in self.response_pool:
|
|
|
369 |
stripped_text = item["text"].strip()
|
370 |
domain = item["domain"]
|
371 |
+
|
|
|
|
|
372 |
if stripped_text in self._text_domain_map:
|
373 |
+
#existing_domain = self._text_domain_map[stripped_text]
|
374 |
+
#if existing_domain != domain:
|
375 |
+
# Collision detected. Using first found domain for now.
|
376 |
+
# This happens often with low-signal responses. "ok", "yes", etc.
|
377 |
+
# logger.warning(
|
378 |
+
# f"Collision detected: text '{stripped_text}' found with domains "
|
379 |
+
# f"'{existing_domain}' and '{domain}'. Keeping the first."
|
380 |
+
# )
|
381 |
# By default, keep the first domain or overwrite. We'll skip overwriting:
|
382 |
continue
|
383 |
else:
|
384 |
# Insert into the dict
|
385 |
self._text_domain_map[stripped_text] = domain
|
386 |
+
|
387 |
+
logger.info(f"Built text -> domain map with {len(self._text_domain_map)} unique text entries.")
|
388 |
|
389 |
def encode_query(
|
390 |
self,
|
|
|
397 |
Args:
|
398 |
query: The user query.
|
399 |
context: Optional conversation history as a list of (user_text, assistant_text).
|
|
|
400 |
Returns:
|
401 |
np.ndarray of shape [embedding_dim], typically L2-normalized already.
|
402 |
"""
|
403 |
+
# Prepare context: concat user/assistant pairs
|
404 |
if context:
|
405 |
# Take the last N turns
|
406 |
relevant_history = context[-self.config.max_context_turns:]
|
|
|
412 |
)
|
413 |
context_str = " ".join(context_str_parts)
|
414 |
|
415 |
+
# Append the new query
|
416 |
full_query = (
|
417 |
f"{context_str} "
|
418 |
f"{self.tokenizer.additional_special_tokens[self.tokenizer.additional_special_tokens.index('<USER>')]} {query}"
|
419 |
)
|
420 |
else:
|
421 |
+
# Single user turn
|
422 |
full_query = (
|
423 |
f"{self.tokenizer.additional_special_tokens[self.tokenizer.additional_special_tokens.index('<USER>')]} {query}"
|
424 |
)
|
425 |
|
426 |
+
# Tokenize
|
427 |
encodings = self.tokenizer(
|
428 |
[full_query],
|
429 |
padding='max_length',
|
|
|
433 |
)
|
434 |
input_ids = encodings['input_ids']
|
435 |
|
436 |
+
# Debug out-of-vocab IDs
|
437 |
max_id = np.max(input_ids)
|
438 |
vocab_size = len(self.tokenizer)
|
439 |
if max_id >= vocab_size:
|
440 |
logger.error(f"Token ID {max_id} exceeds tokenizer vocab size {vocab_size}.")
|
441 |
raise ValueError("Token ID exceeds vocabulary size.")
|
442 |
|
443 |
+
# Get embeddings from the model. These are already L2-normalized by the model's final layer.
|
444 |
embeddings = self.encoder(input_ids, training=False).numpy()
|
445 |
+
|
|
|
|
|
446 |
return embeddings[0]
|
447 |
+
|
448 |
def encode_responses(
|
449 |
self,
|
450 |
responses: List[str],
|
|
|
452 |
) -> np.ndarray:
|
453 |
"""
|
454 |
Encode multiple response texts into embedding vectors.
|
|
|
455 |
Args:
|
456 |
+
responses: List of assistant responses.
|
457 |
context: Optional conversation context (last N turns).
|
|
|
458 |
Returns:
|
459 |
np.ndarray of shape [num_responses, embedding_dim].
|
460 |
"""
|
461 |
+
# Incorporate context into response encoding. Note: Undecided on benefit of this
|
|
|
462 |
if context:
|
463 |
relevant_history = context[-self.config.max_context_turns:]
|
464 |
prepared = []
|
|
|
470 |
f"{self.tokenizer.additional_special_tokens[self.tokenizer.additional_special_tokens.index('<ASSISTANT>')]} {a_text}"
|
471 |
)
|
472 |
context_str = " ".join(context_str_parts)
|
473 |
+
|
474 |
+
# Treat resp as an assistant turn
|
475 |
full_resp = (
|
476 |
f"{context_str} "
|
477 |
f"{self.tokenizer.additional_special_tokens[self.tokenizer.additional_special_tokens.index('<ASSISTANT>')]} {resp}"
|
478 |
)
|
479 |
prepared.append(full_resp)
|
480 |
else:
|
481 |
+
# Single response from the assistant
|
482 |
prepared = [
|
483 |
f"{self.tokenizer.additional_special_tokens[self.tokenizer.additional_special_tokens.index('<ASSISTANT>')]} {r}"
|
484 |
for r in responses
|
485 |
]
|
486 |
+
|
487 |
+
# Tokenize
|
488 |
encodings = self.tokenizer(
|
489 |
prepared,
|
490 |
padding='max_length',
|
|
|
493 |
return_tensors='np'
|
494 |
)
|
495 |
input_ids = encodings['input_ids']
|
496 |
+
|
497 |
+
# Debug for out-of-vocab
|
498 |
max_id = np.max(input_ids)
|
499 |
vocab_size = len(self.tokenizer)
|
500 |
if max_id >= vocab_size:
|
501 |
logger.error(f"Token ID {max_id} exceeds tokenizer vocab size {vocab_size}.")
|
502 |
raise ValueError("Token ID exceeds vocabulary size.")
|
503 |
+
|
504 |
+
# Get embeddings from the model. These are already L2-normalized by the model's final layer.
|
505 |
embeddings = self.encoder(input_ids, training=False).numpy()
|
506 |
+
|
|
|
507 |
return embeddings.astype('float32')
|
508 |
|
509 |
def prepare_and_save_data(self, dialogues: List[dict], tf_record_path: str, batch_size: int = 32):
|
510 |
"""
|
511 |
+
Batch-Process dialogues and save to TFRecord file.
|
|
|
|
|
|
|
|
|
|
|
512 |
"""
|
513 |
logger.info(f"Preparing and saving data to {tf_record_path}...")
|
514 |
|
|
|
516 |
num_batches = math.ceil(num_dialogues / batch_size)
|
517 |
|
518 |
with tf.io.TFRecordWriter(tf_record_path) as writer:
|
|
|
519 |
with tqdm(total=num_batches, desc="Preparing Data Batches", unit="batch") as pbar:
|
520 |
for i in range(num_batches):
|
521 |
start_idx = i * batch_size
|
522 |
end_idx = min(start_idx + batch_size, num_dialogues)
|
523 |
batch_dialogues = dialogues[start_idx:end_idx]
|
524 |
|
525 |
+
# Extract query-positive pairs for the batch
|
526 |
queries = []
|
527 |
positives = []
|
528 |
for dialogue in batch_dialogues:
|
|
|
534 |
|
535 |
if not queries:
|
536 |
pbar.update(1)
|
537 |
+
continue
|
538 |
|
539 |
# Compute and cache query embeddings
|
540 |
try:
|
|
|
542 |
except Exception as e:
|
543 |
logger.error(f"Error computing embeddings: {e}")
|
544 |
pbar.update(1)
|
545 |
+
continue
|
546 |
|
547 |
+
# Find hard negatives
|
548 |
try:
|
549 |
+
hard_negatives = self._find_hard_negatives(queries, positives)
|
550 |
except Exception as e:
|
551 |
logger.error(f"Error finding hard negatives: {e}")
|
552 |
pbar.update(1)
|
|
|
573 |
pbar.update(1)
|
574 |
continue # Skip to the next batch
|
575 |
|
576 |
+
# Flatten hard_negatives. Maintain alignment.
|
577 |
+
# hard_negatives is List of Lists. Each sublist corresponds to a query.
|
578 |
try:
|
579 |
flattened_negatives = [neg for sublist in hard_negatives for neg in sublist]
|
580 |
encoded_negatives = self.tokenizer.batch_encode_plus(
|
|
|
585 |
return_tensors='tf'
|
586 |
)
|
587 |
|
588 |
+
# Reshape to [num_queries, num_negatives, max_length]
|
589 |
num_negatives = self.config.neg_samples
|
590 |
reshaped_negatives = encoded_negatives['input_ids'].numpy().reshape(-1, num_negatives, self.config.max_context_token_limit)
|
591 |
except Exception as e:
|
592 |
logger.error(f"Error during negatives tokenization: {e}")
|
593 |
pbar.update(1)
|
594 |
+
continue
|
595 |
|
596 |
+
# Serialize and write to TFRecord
|
597 |
for j in range(len(queries)):
|
598 |
try:
|
599 |
q_id = encoded_queries['input_ids'][j].numpy()
|
|
|
617 |
logger.info(f"Data preparation complete. TFRecord saved.")
|
618 |
|
619 |
def _compute_embeddings(self, queries: List[str]) -> None:
|
620 |
+
"""
|
621 |
+
Compute embeddings for new queries and update the cache.
|
622 |
+
"""
|
623 |
new_queries = [q for q in queries if q not in self.query_embeddings_cache]
|
624 |
if not new_queries:
|
625 |
+
return
|
626 |
+
|
627 |
+
# Compute embeddings
|
628 |
new_embeddings = []
|
629 |
for i in range(0, len(new_queries), self.embedding_batch_size):
|
630 |
batch_queries = new_queries[i:i + self.embedding_batch_size]
|
|
|
638 |
batch_embeddings = self.encoder(encoded['input_ids'], training=False).numpy()
|
639 |
faiss.normalize_L2(batch_embeddings)
|
640 |
new_embeddings.extend(batch_embeddings)
|
641 |
+
|
642 |
# Update the cache
|
643 |
for query, emb in zip(new_queries, new_embeddings):
|
644 |
self.query_embeddings_cache[query] = emb
|
645 |
+
|
646 |
def data_generator(self, dialogues: List[dict]) -> Generator[Tuple[str, str, List[str]], None, None]:
|
647 |
"""
|
648 |
+
Generate training examples: (query, positive, [hard_negatives]).
|
|
|
649 |
"""
|
650 |
total_dialogues = len(dialogues)
|
651 |
logger.debug(f"Total dialogues to process: {total_dialogues}")
|
652 |
+
|
|
|
653 |
with tqdm(total=total_dialogues, desc="Processing Dialogues", unit="dialogue") as pbar:
|
654 |
for dialogue in dialogues:
|
655 |
pairs = self._extract_pairs_from_dialogue(dialogue)
|
656 |
for query, positive in pairs:
|
657 |
# Ensure embeddings are computed, find hard negatives, etc.
|
658 |
self._compute_embeddings([query])
|
659 |
+
hard_negatives = self._find_hard_negatives([query], [positive])[0]
|
660 |
yield (query, positive, hard_negatives)
|
661 |
pbar.update(1)
|
662 |
|
663 |
def get_tf_dataset(self, dialogues: List[dict], batch_size: int) -> tf.data.Dataset:
|
664 |
"""
|
665 |
+
Creates a tf.data.Dataset for streaming training.
|
666 |
+
yields (input_ids_query, input_ids_positive, input_ids_negatives).
|
667 |
"""
|
668 |
# 1) Start with a generator dataset
|
669 |
dataset = tf.data.Dataset.from_generator(
|
670 |
lambda: self.data_generator(dialogues),
|
671 |
output_signature=(
|
672 |
+
tf.TensorSpec(shape=(), dtype=tf.string), # Query (single string)
|
673 |
+
tf.TensorSpec(shape=(), dtype=tf.string), # Positive (single string)
|
674 |
+
tf.TensorSpec(shape=(self.neg_samples,), dtype=tf.string) # Hard Negatives (list of strings)
|
675 |
)
|
676 |
)
|
677 |
|
678 |
+
# Batch the raw strings, then map through a tokenize step
|
679 |
+
# Note 'Distilbert Tokenizer threw an error when using tf.data.AUTOTUNE.
|
680 |
dataset = dataset.batch(batch_size, drop_remainder=True)
|
|
|
|
|
681 |
dataset = dataset.map(
|
682 |
lambda q, p, n: self._tokenize_triple(q, p, n),
|
683 |
num_parallel_calls=1 #tf.data.AUTOTUNE
|
|
|
693 |
n: tf.Tensor
|
694 |
) -> Tuple[tf.Tensor, tf.Tensor, tf.Tensor]:
|
695 |
"""
|
696 |
+
Wraps a Python function. Convert tf.Tensors of strings -> Python lists of strings -> HF tokenizer -> Tensors of IDs.
|
697 |
+
q is shape [batch_size], p is shape [batch_size], n is shape [batch_size, neg_samples] (list of negatives).
|
|
|
|
|
|
|
698 |
"""
|
699 |
+
# Use tf.py_function, limit parallelism
|
700 |
q_ids, p_ids, n_ids = tf.py_function(
|
701 |
func=self._tokenize_triple_py,
|
702 |
inp=[q, p, n, tf.constant(self.max_length), tf.constant(self.neg_samples)],
|
703 |
Tout=[tf.int32, tf.int32, tf.int32]
|
704 |
)
|
705 |
|
706 |
+
# Set shape info for the output tensors
|
707 |
+
q_ids.set_shape([None, self.max_length]) # [batch_size, max_length]
|
708 |
+
p_ids.set_shape([None, self.max_length]) # [batch_size, max_length]
|
709 |
n_ids.set_shape([None, self.neg_samples, self.max_length]) # [batch_size, neg_samples, max_length]
|
710 |
|
711 |
return q_ids, p_ids, n_ids
|
|
|
719 |
neg_samples: tf.Tensor
|
720 |
) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
|
721 |
"""
|
722 |
+
Decodes tf.string Tensor to Python List[str], then tokenize.
|
723 |
+
Reshapes negatives to [batch_size, neg_samples, max_length].
|
724 |
+
Returns np.array(int32) for (q_ids, p_ids, n_ids).
|
|
|
|
|
725 |
|
726 |
q: shape [batch_size], p: shape [batch_size]
|
727 |
n: shape [batch_size, neg_samples]
|
728 |
+
max_len: int
|
729 |
+
neg_samples: int
|
730 |
"""
|
731 |
+
max_len = int(max_len.numpy())
|
732 |
neg_samples = int(neg_samples.numpy())
|
733 |
|
734 |
+
# Convert Tensors -> Python List[str]
|
735 |
q_list = [q_i.decode("utf-8") for q_i in q.numpy()] # shape [batch_size]
|
736 |
p_list = [p_i.decode("utf-8") for p_i in p.numpy()] # shape [batch_size]
|
737 |
|
738 |
+
# Shape [batch_size, neg_samples], decode each row
|
739 |
n_list = []
|
740 |
for row in n.numpy():
|
741 |
# row is shape [neg_samples], each is a tf.string
|
742 |
decoded = [neg.decode("utf-8") for neg in row]
|
743 |
n_list.append(decoded)
|
744 |
|
745 |
+
# Tokenize queries & positives
|
746 |
q_enc = self.tokenizer(
|
747 |
q_list,
|
748 |
padding="max_length",
|
|
|
758 |
return_tensors="np"
|
759 |
)
|
760 |
|
761 |
+
# Tokenize negatives
|
762 |
+
# Flatten [batch_size, neg_samples] -> List
|
763 |
flattened_negatives = [neg for row in n_list for neg in row]
|
764 |
if len(flattened_negatives) == 0:
|
765 |
+
# No negatives: return a zero array
|
766 |
n_ids = np.zeros((len(q_list), neg_samples, max_len), dtype=np.int32)
|
767 |
else:
|
768 |
n_enc = self.tokenizer(
|
|
|
772 |
max_length=max_len,
|
773 |
return_tensors="np"
|
774 |
)
|
775 |
+
# Shape [batch_size * neg_samples, max_len]
|
776 |
n_input_ids = n_enc["input_ids"]
|
777 |
|
778 |
+
# Reshape to [batch_size, neg_samples, max_len]
|
|
|
779 |
batch_size = len(q_list)
|
780 |
n_ids_list = []
|
781 |
for i in range(batch_size):
|
|
|
783 |
end_idx = start_idx + neg_samples
|
784 |
row_negs = n_input_ids[start_idx:end_idx]
|
785 |
|
786 |
+
# Pad with zeros if not enough negatives
|
787 |
if row_negs.shape[0] < neg_samples:
|
788 |
deficit = neg_samples - row_negs.shape[0]
|
789 |
pad_arr = np.zeros((deficit, max_len), dtype=np.int32)
|
|
|
791 |
|
792 |
n_ids_list.append(row_negs)
|
793 |
|
794 |
+
# Stack shape [batch_size, neg_samples, max_len]
|
795 |
n_ids = np.stack(n_ids_list, axis=0)
|
796 |
|
797 |
+
# Return np.int32 arrays
|
798 |
q_ids = q_enc["input_ids"].astype(np.int32) # shape [batch_size, max_len]
|
799 |
p_ids = p_enc["input_ids"].astype(np.int32) # shape [batch_size, max_len]
|
800 |
n_ids = n_ids.astype(np.int32) # shape [batch_size, neg_samples, max_len]
|
train_model.py
CHANGED
@@ -14,10 +14,10 @@ def inspect_tfrecord(tfrecord_file_path, num_examples=3):
|
|
14 |
'negative_ids': tf.io.FixedLenFeature([3 * 512], tf.int64), # Adjust neg_samples if different
|
15 |
}
|
16 |
return tf.io.parse_single_example(example_proto, feature_description)
|
17 |
-
|
18 |
dataset = tf.data.TFRecordDataset(tfrecord_file_path)
|
19 |
dataset = dataset.map(parse_example)
|
20 |
-
|
21 |
for i, example in enumerate(dataset.take(num_examples)):
|
22 |
print(f"Example {i+1}:")
|
23 |
print(f"Query IDs: {example['query_ids'].numpy()}")
|
@@ -26,29 +26,27 @@ def inspect_tfrecord(tfrecord_file_path, num_examples=3):
|
|
26 |
print("-" * 50)
|
27 |
|
28 |
def main():
|
|
|
29 |
|
30 |
-
#
|
31 |
# inspect_tfrecord('training_data/training_data.tfrecord', num_examples=3)
|
32 |
|
33 |
-
#
|
34 |
-
tf.keras.backend.clear_session()
|
35 |
env = EnvironmentSetup()
|
36 |
env.initialize()
|
37 |
|
38 |
-
# Training
|
39 |
EPOCHS = 20
|
40 |
TF_RECORD_FILE_PATH = 'training_data/training_data.tfrecord'
|
41 |
CHECKPOINT_DIR = 'checkpoints/'
|
42 |
-
# Optimize batch size for Colab
|
43 |
-
batch_size = 32 # env.optimize_batch_size(base_batch_size=16)
|
44 |
|
45 |
-
|
46 |
-
config = ChatbotConfig()
|
47 |
|
48 |
-
# Initialize chatbot
|
|
|
49 |
chatbot = RetrievalChatbot(config, mode='training')
|
50 |
|
51 |
-
# Check for existing checkpoint
|
52 |
latest_checkpoint = tf.train.latest_checkpoint(CHECKPOINT_DIR)
|
53 |
initial_epoch = 0
|
54 |
if latest_checkpoint:
|
@@ -60,7 +58,7 @@ def main():
|
|
60 |
logger.error(f"Failed to parse checkpoint number from {latest_checkpoint}")
|
61 |
initial_epoch = 0
|
62 |
|
63 |
-
# Train
|
64 |
chatbot.train_model(
|
65 |
tfrecord_file_path=TF_RECORD_FILE_PATH,
|
66 |
epochs=EPOCHS,
|
@@ -71,13 +69,13 @@ def main():
|
|
71 |
initial_epoch=initial_epoch
|
72 |
)
|
73 |
|
74 |
-
# Save
|
75 |
model_save_path = env.training_dirs['base'] / 'final_model'
|
76 |
chatbot.save_models(model_save_path)
|
77 |
|
78 |
-
# Plot
|
79 |
plotter = Plotter(save_dir=env.training_dirs['plots'])
|
80 |
plotter.plot_training_history(chatbot.history)
|
81 |
-
|
82 |
if __name__ == "__main__":
|
83 |
main()
|
|
|
14 |
'negative_ids': tf.io.FixedLenFeature([3 * 512], tf.int64), # Adjust neg_samples if different
|
15 |
}
|
16 |
return tf.io.parse_single_example(example_proto, feature_description)
|
17 |
+
|
18 |
dataset = tf.data.TFRecordDataset(tfrecord_file_path)
|
19 |
dataset = dataset.map(parse_example)
|
20 |
+
|
21 |
for i, example in enumerate(dataset.take(num_examples)):
|
22 |
print(f"Example {i+1}:")
|
23 |
print(f"Query IDs: {example['query_ids'].numpy()}")
|
|
|
26 |
print("-" * 50)
|
27 |
|
28 |
def main():
|
29 |
+
tf.keras.backend.clear_session()
|
30 |
|
31 |
+
# Validate TFRecord
|
32 |
# inspect_tfrecord('training_data/training_data.tfrecord', num_examples=3)
|
33 |
|
34 |
+
# Init env
|
|
|
35 |
env = EnvironmentSetup()
|
36 |
env.initialize()
|
37 |
|
38 |
+
# Training config
|
39 |
EPOCHS = 20
|
40 |
TF_RECORD_FILE_PATH = 'training_data/training_data.tfrecord'
|
41 |
CHECKPOINT_DIR = 'checkpoints/'
|
|
|
|
|
42 |
|
43 |
+
batch_size = 32
|
|
|
44 |
|
45 |
+
# Initialize config and chatbot model
|
46 |
+
config = ChatbotConfig()
|
47 |
chatbot = RetrievalChatbot(config, mode='training')
|
48 |
|
49 |
+
# Check for existing checkpoint
|
50 |
latest_checkpoint = tf.train.latest_checkpoint(CHECKPOINT_DIR)
|
51 |
initial_epoch = 0
|
52 |
if latest_checkpoint:
|
|
|
58 |
logger.error(f"Failed to parse checkpoint number from {latest_checkpoint}")
|
59 |
initial_epoch = 0
|
60 |
|
61 |
+
# Train
|
62 |
chatbot.train_model(
|
63 |
tfrecord_file_path=TF_RECORD_FILE_PATH,
|
64 |
epochs=EPOCHS,
|
|
|
69 |
initial_epoch=initial_epoch
|
70 |
)
|
71 |
|
72 |
+
# Save
|
73 |
model_save_path = env.training_dirs['base'] / 'final_model'
|
74 |
chatbot.save_models(model_save_path)
|
75 |
|
76 |
+
# Plot
|
77 |
plotter = Plotter(save_dir=env.training_dirs['plots'])
|
78 |
plotter.plot_training_history(chatbot.history)
|
79 |
+
|
80 |
if __name__ == "__main__":
|
81 |
main()
|
validate_model.py
CHANGED
@@ -1,6 +1,5 @@
|
|
1 |
import os
|
2 |
import json
|
3 |
-
|
4 |
from chatbot_model import ChatbotConfig, RetrievalChatbot
|
5 |
from response_quality_checker import ResponseQualityChecker
|
6 |
from chatbot_validator import ChatbotValidator
|
@@ -18,20 +17,20 @@ def run_interactive_chat(chatbot, quality_checker):
|
|
18 |
except (KeyboardInterrupt, EOFError):
|
19 |
print("\nAssistant: Goodbye!")
|
20 |
break
|
21 |
-
|
22 |
if user_input.lower() in ["quit", "exit", "bye"]:
|
23 |
print("Assistant: Goodbye!")
|
24 |
break
|
25 |
-
|
26 |
response, candidates, metrics = chatbot.chat(
|
27 |
query=user_input,
|
28 |
conversation_history=None,
|
29 |
quality_checker=quality_checker,
|
30 |
top_k=10
|
31 |
)
|
32 |
-
|
33 |
print(f"Assistant: {response}")
|
34 |
-
|
35 |
# Show alternative responses if confident
|
36 |
if metrics.get("is_confident", False):
|
37 |
print("\nAlternative responses:")
|
@@ -39,17 +38,17 @@ def run_interactive_chat(chatbot, quality_checker):
|
|
39 |
print(f"Score: {score:.4f} - {resp}")
|
40 |
else:
|
41 |
print("\n[Low Confidence]: Consider rephrasing your query for better assistance.")
|
42 |
-
|
43 |
def validate_chatbot():
|
44 |
# Initialize environment
|
45 |
env = EnvironmentSetup()
|
46 |
env.initialize()
|
47 |
-
|
48 |
MODEL_DIR = "new_iteration/data_prep_iterative_models"
|
49 |
FAISS_INDICES_DIR = os.path.join(MODEL_DIR, "faiss_indices")
|
50 |
FAISS_INDEX_PRODUCTION_PATH = os.path.join(FAISS_INDICES_DIR, "faiss_index_production.index")
|
51 |
FAISS_INDEX_TEST_PATH = os.path.join(FAISS_INDICES_DIR, "faiss_index_test.index")
|
52 |
-
|
53 |
# Toggle 'production' or 'test' env
|
54 |
ENVIRONMENT = "production"
|
55 |
if ENVIRONMENT == "test":
|
@@ -58,7 +57,7 @@ def validate_chatbot():
|
|
58 |
else:
|
59 |
FAISS_INDEX_PATH = FAISS_INDEX_PRODUCTION_PATH
|
60 |
RESPONSE_POOL_PATH = FAISS_INDEX_PRODUCTION_PATH.replace(".index", "_responses.json")
|
61 |
-
|
62 |
# Load the config
|
63 |
config_path = os.path.join(MODEL_DIR, "config.json")
|
64 |
if os.path.exists(config_path):
|
@@ -69,50 +68,47 @@ def validate_chatbot():
|
|
69 |
else:
|
70 |
config = ChatbotConfig()
|
71 |
logger.warning("No config.json found. Using default ChatbotConfig.")
|
72 |
-
|
73 |
-
# Load RetrievalChatbot in 'inference' mode
|
74 |
try:
|
75 |
chatbot = RetrievalChatbot.load_model(load_dir=MODEL_DIR, mode="inference")
|
76 |
logger.info("RetrievalChatbot loaded in 'inference' mode successfully.")
|
77 |
except Exception as e:
|
78 |
logger.error(f"Failed to load RetrievalChatbot: {e}")
|
79 |
return
|
80 |
-
|
81 |
# Confirm FAISS index & response pool exist
|
82 |
if not os.path.exists(FAISS_INDEX_PATH) or not os.path.exists(RESPONSE_POOL_PATH):
|
83 |
logger.error("FAISS index or response pool file is missing.")
|
84 |
return
|
85 |
-
|
86 |
-
# Load
|
87 |
try:
|
88 |
-
# Even though load_model might auto-load an index, we override here with the specific file
|
89 |
chatbot.data_pipeline.load_faiss_index(FAISS_INDEX_PATH)
|
90 |
logger.info(f"FAISS index loaded from {FAISS_INDEX_PATH}.")
|
91 |
-
|
92 |
-
|
93 |
-
|
94 |
-
|
95 |
-
|
96 |
-
|
97 |
with open(RESPONSE_POOL_PATH, "r", encoding="utf-8") as f:
|
98 |
chatbot.data_pipeline.response_pool = json.load(f)
|
99 |
-
|
100 |
-
|
101 |
-
|
102 |
-
|
103 |
# Validate dimension consistency
|
104 |
chatbot.data_pipeline.validate_faiss_index()
|
105 |
logger.info("FAISS index and response pool validated successfully.")
|
106 |
-
|
107 |
except Exception as e:
|
108 |
logger.error(f"Failed to load or validate FAISS index: {e}")
|
109 |
return
|
110 |
-
|
111 |
# Init QualityChecker and Validator
|
112 |
quality_checker = ResponseQualityChecker(data_pipeline=chatbot.data_pipeline)
|
113 |
validator = ChatbotValidator(chatbot=chatbot, quality_checker=quality_checker)
|
114 |
logger.info("ResponseQualityChecker and ChatbotValidator initialized.")
|
115 |
-
|
116 |
# Run validation
|
117 |
try:
|
118 |
validation_metrics = validator.run_validation(num_examples=5)
|
@@ -120,7 +116,7 @@ def validate_chatbot():
|
|
120 |
except Exception as e:
|
121 |
logger.error(f"Validation process failed: {e}")
|
122 |
return
|
123 |
-
|
124 |
# Plot metrics
|
125 |
# try:
|
126 |
# plotter = Plotter(save_dir=env.training_dirs["plots"])
|
@@ -128,10 +124,10 @@ def validate_chatbot():
|
|
128 |
# logger.info("Validation metrics plotted successfully.")
|
129 |
# except Exception as e:
|
130 |
# logger.error(f"Failed to plot validation metrics: {e}")
|
131 |
-
|
132 |
# Run interactive chat loop
|
133 |
-
|
134 |
-
|
135 |
|
136 |
if __name__ == "__main__":
|
137 |
validate_chatbot()
|
|
|
1 |
import os
|
2 |
import json
|
|
|
3 |
from chatbot_model import ChatbotConfig, RetrievalChatbot
|
4 |
from response_quality_checker import ResponseQualityChecker
|
5 |
from chatbot_validator import ChatbotValidator
|
|
|
17 |
except (KeyboardInterrupt, EOFError):
|
18 |
print("\nAssistant: Goodbye!")
|
19 |
break
|
20 |
+
|
21 |
if user_input.lower() in ["quit", "exit", "bye"]:
|
22 |
print("Assistant: Goodbye!")
|
23 |
break
|
24 |
+
|
25 |
response, candidates, metrics = chatbot.chat(
|
26 |
query=user_input,
|
27 |
conversation_history=None,
|
28 |
quality_checker=quality_checker,
|
29 |
top_k=10
|
30 |
)
|
31 |
+
|
32 |
print(f"Assistant: {response}")
|
33 |
+
|
34 |
# Show alternative responses if confident
|
35 |
if metrics.get("is_confident", False):
|
36 |
print("\nAlternative responses:")
|
|
|
38 |
print(f"Score: {score:.4f} - {resp}")
|
39 |
else:
|
40 |
print("\n[Low Confidence]: Consider rephrasing your query for better assistance.")
|
41 |
+
|
42 |
def validate_chatbot():
|
43 |
# Initialize environment
|
44 |
env = EnvironmentSetup()
|
45 |
env.initialize()
|
46 |
+
|
47 |
MODEL_DIR = "new_iteration/data_prep_iterative_models"
|
48 |
FAISS_INDICES_DIR = os.path.join(MODEL_DIR, "faiss_indices")
|
49 |
FAISS_INDEX_PRODUCTION_PATH = os.path.join(FAISS_INDICES_DIR, "faiss_index_production.index")
|
50 |
FAISS_INDEX_TEST_PATH = os.path.join(FAISS_INDICES_DIR, "faiss_index_test.index")
|
51 |
+
|
52 |
# Toggle 'production' or 'test' env
|
53 |
ENVIRONMENT = "production"
|
54 |
if ENVIRONMENT == "test":
|
|
|
57 |
else:
|
58 |
FAISS_INDEX_PATH = FAISS_INDEX_PRODUCTION_PATH
|
59 |
RESPONSE_POOL_PATH = FAISS_INDEX_PRODUCTION_PATH.replace(".index", "_responses.json")
|
60 |
+
|
61 |
# Load the config
|
62 |
config_path = os.path.join(MODEL_DIR, "config.json")
|
63 |
if os.path.exists(config_path):
|
|
|
68 |
else:
|
69 |
config = ChatbotConfig()
|
70 |
logger.warning("No config.json found. Using default ChatbotConfig.")
|
71 |
+
|
72 |
+
# Load RetrievalChatbot in 'inference' mode
|
73 |
try:
|
74 |
chatbot = RetrievalChatbot.load_model(load_dir=MODEL_DIR, mode="inference")
|
75 |
logger.info("RetrievalChatbot loaded in 'inference' mode successfully.")
|
76 |
except Exception as e:
|
77 |
logger.error(f"Failed to load RetrievalChatbot: {e}")
|
78 |
return
|
79 |
+
|
80 |
# Confirm FAISS index & response pool exist
|
81 |
if not os.path.exists(FAISS_INDEX_PATH) or not os.path.exists(RESPONSE_POOL_PATH):
|
82 |
logger.error("FAISS index or response pool file is missing.")
|
83 |
return
|
84 |
+
|
85 |
+
# Load FAISS index and response pool
|
86 |
try:
|
|
|
87 |
chatbot.data_pipeline.load_faiss_index(FAISS_INDEX_PATH)
|
88 |
logger.info(f"FAISS index loaded from {FAISS_INDEX_PATH}.")
|
89 |
+
logger.info("FAISS dimensions:", chatbot.data_pipeline.index.d)
|
90 |
+
logger.info("FAISS index type:", type(chatbot.data_pipeline.index))
|
91 |
+
logger.info("FAISS index total vectors:", chatbot.data_pipeline.index.ntotal)
|
92 |
+
logger.info("FAISS is_trained:", chatbot.data_pipeline.index.is_trained)
|
93 |
+
|
|
|
94 |
with open(RESPONSE_POOL_PATH, "r", encoding="utf-8") as f:
|
95 |
chatbot.data_pipeline.response_pool = json.load(f)
|
96 |
+
logger.info(f"Response pool loaded from {RESPONSE_POOL_PATH}.")
|
97 |
+
logger.info("\nTotal responses in pool:", len(chatbot.data_pipeline.response_pool))
|
98 |
+
|
|
|
99 |
# Validate dimension consistency
|
100 |
chatbot.data_pipeline.validate_faiss_index()
|
101 |
logger.info("FAISS index and response pool validated successfully.")
|
102 |
+
|
103 |
except Exception as e:
|
104 |
logger.error(f"Failed to load or validate FAISS index: {e}")
|
105 |
return
|
106 |
+
|
107 |
# Init QualityChecker and Validator
|
108 |
quality_checker = ResponseQualityChecker(data_pipeline=chatbot.data_pipeline)
|
109 |
validator = ChatbotValidator(chatbot=chatbot, quality_checker=quality_checker)
|
110 |
logger.info("ResponseQualityChecker and ChatbotValidator initialized.")
|
111 |
+
|
112 |
# Run validation
|
113 |
try:
|
114 |
validation_metrics = validator.run_validation(num_examples=5)
|
|
|
116 |
except Exception as e:
|
117 |
logger.error(f"Validation process failed: {e}")
|
118 |
return
|
119 |
+
|
120 |
# Plot metrics
|
121 |
# try:
|
122 |
# plotter = Plotter(save_dir=env.training_dirs["plots"])
|
|
|
124 |
# logger.info("Validation metrics plotted successfully.")
|
125 |
# except Exception as e:
|
126 |
# logger.error(f"Failed to plot validation metrics: {e}")
|
127 |
+
|
128 |
# Run interactive chat loop
|
129 |
+
logger.info("\nStarting interactive chat session...")
|
130 |
+
run_interactive_chat(chatbot, quality_checker)
|
131 |
|
132 |
if __name__ == "__main__":
|
133 |
validate_chatbot()
|