JoeArmani
commited on
Commit
Β·
71ca212
1
Parent(s):
e5be70f
restructuring
Browse files- build_faiss_index.py +0 -161
- chatbot_model.py +1 -3
- chatbot_validator.py +42 -74
- processing_pipeline.py β data_augmentation/augmentation_processing_pipeline.py +3 -3
- back_translator.py β data_augmentation/back_translator.py +0 -0
- dialogue_augmenter.py β data_augmentation/dialogue_augmenter.py +3 -3
- main.py β data_augmentation/main.py +5 -5
- paraphraser.py β data_augmentation/paraphraser.py +0 -0
- pipeline_config.py β data_augmentation/pipeline_config.py +0 -0
- quality_metrics.py β data_augmentation/quality_metrics.py +1 -1
- schema_guided_dialogue_processor.py β data_augmentation/schema_guided_dialogue_processor.py +1 -1
- taskmaster_processor.py β data_augmentation/taskmaster_processor.py +1 -1
- deduplicate_augmented_dialogues.py +7 -6
- environment_setup.py +5 -8
- new_iteration/run_taskmaster_processor.py +2 -2
- plotter.py +14 -19
- prepare_data.py +22 -28
- tf_data_pipeline.py +0 -3
- unused/build_faiss_index.py +160 -0
- gpu_monitor.py β unused/gpu_monitor.py +4 -13
- validate_model.py +0 -5
build_faiss_index.py
DELETED
@@ -1,161 +0,0 @@
|
|
1 |
-
import os
|
2 |
-
import json
|
3 |
-
from pathlib import Path
|
4 |
-
|
5 |
-
import faiss
|
6 |
-
import numpy as np
|
7 |
-
import tensorflow as tf
|
8 |
-
from transformers import AutoTokenizer, TFAutoModel
|
9 |
-
from tqdm.auto import tqdm
|
10 |
-
|
11 |
-
from chatbot_model import ChatbotConfig, EncoderModel
|
12 |
-
from tf_data_pipeline import TFDataPipeline
|
13 |
-
from logger_config import config_logger
|
14 |
-
|
15 |
-
logger = config_logger(__name__)
|
16 |
-
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
17 |
-
|
18 |
-
def sanity_check(encoder: EncoderModel, tokenizer: AutoTokenizer, config: ChatbotConfig):
|
19 |
-
"""
|
20 |
-
Perform a quick sanity check to ensure the model is loaded correctly.
|
21 |
-
"""
|
22 |
-
sample_response = "This is a test response."
|
23 |
-
encoded_sample = tokenizer(
|
24 |
-
[sample_response],
|
25 |
-
padding=True,
|
26 |
-
truncation=True,
|
27 |
-
max_length=config.max_context_token_limit,
|
28 |
-
return_tensors='tf'
|
29 |
-
)
|
30 |
-
|
31 |
-
# Get embedding
|
32 |
-
sample_embedding = encoder(encoded_sample['input_ids'], training=False).numpy()
|
33 |
-
|
34 |
-
# Check shape
|
35 |
-
if sample_embedding.shape[1] != config.embedding_dim:
|
36 |
-
logger.error(
|
37 |
-
f"Embedding dimension mismatch: Expected {config.embedding_dim}, "
|
38 |
-
f"got {sample_embedding.shape[1]}"
|
39 |
-
)
|
40 |
-
raise ValueError("Embedding dimension mismatch.")
|
41 |
-
else:
|
42 |
-
logger.info("Embedding dimension matches the configuration.")
|
43 |
-
|
44 |
-
# Check normalization
|
45 |
-
embedding_norm = np.linalg.norm(sample_embedding, axis=1)
|
46 |
-
if not np.allclose(embedding_norm, 1.0, atol=1e-5):
|
47 |
-
logger.error("Embeddings are not properly normalized.")
|
48 |
-
raise ValueError("Embeddings are not normalized.")
|
49 |
-
else:
|
50 |
-
logger.info("Embeddings are properly normalized.")
|
51 |
-
|
52 |
-
logger.info("Sanity check passed: Model loaded correctly and outputs are as expected.")
|
53 |
-
|
54 |
-
def build_faiss_index():
|
55 |
-
"""
|
56 |
-
Rebuild the FAISS index by:
|
57 |
-
1) Loading your config.json
|
58 |
-
2) Initializing encoder + loading submodule & custom weights
|
59 |
-
3) Loading tokenizer from disk
|
60 |
-
4) Creating a TFDataPipeline
|
61 |
-
5) Setting the pipeline's response_pool from a JSON file
|
62 |
-
6) Using pipeline.compute_and_index_response_embeddings()
|
63 |
-
7) Saving the FAISS index
|
64 |
-
"""
|
65 |
-
# Directories
|
66 |
-
MODELS_DIR = Path("models")
|
67 |
-
FAISS_DIR = MODELS_DIR / "faiss_indices"
|
68 |
-
FAISS_INDEX_PATH = FAISS_DIR / "faiss_index_production.index"
|
69 |
-
RESPONSES_PATH = FAISS_DIR / "faiss_index_production_responses.json"
|
70 |
-
TOKENIZER_DIR = MODELS_DIR / "tokenizer"
|
71 |
-
SHARED_ENCODER_DIR = MODELS_DIR / "shared_encoder"
|
72 |
-
CUSTOM_WEIGHTS_PATH = MODELS_DIR / "encoder_custom_weights.weights.h5"
|
73 |
-
|
74 |
-
# 1) Load ChatbotConfig
|
75 |
-
config_path = MODELS_DIR / "config.json"
|
76 |
-
if config_path.exists():
|
77 |
-
with open(config_path, "r", encoding="utf-8") as f:
|
78 |
-
config_dict = json.load(f)
|
79 |
-
config = ChatbotConfig.from_dict(config_dict)
|
80 |
-
logger.info(f"Loaded ChatbotConfig from {config_path}")
|
81 |
-
else:
|
82 |
-
config = ChatbotConfig()
|
83 |
-
logger.warning(f"No config.json found at {config_path}. Using default ChatbotConfig.")
|
84 |
-
|
85 |
-
# 2) Initialize the EncoderModel
|
86 |
-
encoder = EncoderModel(config=config)
|
87 |
-
logger.info("EncoderModel instantiated (empty).")
|
88 |
-
|
89 |
-
# Overwrite the submodule from 'shared_encoder' directory
|
90 |
-
if SHARED_ENCODER_DIR.exists():
|
91 |
-
logger.info(f"Loading DistilBERT submodule from {SHARED_ENCODER_DIR}...")
|
92 |
-
encoder.pretrained = TFAutoModel.from_pretrained(str(SHARED_ENCODER_DIR))
|
93 |
-
logger.info("Loaded HF submodule into encoder.pretrained.")
|
94 |
-
else:
|
95 |
-
logger.warning(f"No shared_encoder directory at {SHARED_ENCODER_DIR}. Using default pretrained model.")
|
96 |
-
|
97 |
-
# Build model once, then load custom weights (projection, etc.)
|
98 |
-
dummy_input = tf.zeros((1, config.max_context_token_limit), dtype=tf.int32)
|
99 |
-
_ = encoder(dummy_input, training=False) # builds the layers
|
100 |
-
|
101 |
-
if CUSTOM_WEIGHTS_PATH.exists():
|
102 |
-
logger.info(f"Loading custom top-level weights from {CUSTOM_WEIGHTS_PATH}")
|
103 |
-
encoder.load_weights(str(CUSTOM_WEIGHTS_PATH))
|
104 |
-
logger.info("Custom top-level weights loaded successfully.")
|
105 |
-
else:
|
106 |
-
logger.warning(f"Custom weights file not found at {CUSTOM_WEIGHTS_PATH}.")
|
107 |
-
|
108 |
-
# 3) Load tokenizer
|
109 |
-
if TOKENIZER_DIR.exists():
|
110 |
-
logger.info(f"Loading tokenizer from {TOKENIZER_DIR}")
|
111 |
-
tokenizer = AutoTokenizer.from_pretrained(str(TOKENIZER_DIR))
|
112 |
-
else:
|
113 |
-
logger.warning(f"No tokenizer dir at {TOKENIZER_DIR}, falling back to default HF tokenizer.")
|
114 |
-
tokenizer = AutoTokenizer.from_pretrained(config.pretrained_model)
|
115 |
-
#tokenizer.add_special_tokens({'additional_special_tokens': ['<EMPTY_NEGATIVE>']})
|
116 |
-
|
117 |
-
# 4) Quick sanity check
|
118 |
-
sanity_check(encoder, tokenizer, config)
|
119 |
-
|
120 |
-
# 5) Prepare a TFDataPipeline
|
121 |
-
pipeline = TFDataPipeline(
|
122 |
-
config=config,
|
123 |
-
tokenizer=tokenizer,
|
124 |
-
encoder=encoder,
|
125 |
-
index_file_path=str(FAISS_INDEX_PATH),
|
126 |
-
response_pool=[],
|
127 |
-
max_length=config.max_context_token_limit,
|
128 |
-
query_embeddings_cache={},
|
129 |
-
neg_samples=config.neg_samples,
|
130 |
-
index_type='IndexFlatIP',
|
131 |
-
nlist=100,
|
132 |
-
max_retries=config.max_retries
|
133 |
-
)
|
134 |
-
|
135 |
-
# 6) Load the existing response pool
|
136 |
-
if not RESPONSES_PATH.exists():
|
137 |
-
logger.error(f"Response pool JSON file not found at {RESPONSES_PATH}")
|
138 |
-
raise FileNotFoundError(f"No response pool JSON at {RESPONSES_PATH}")
|
139 |
-
|
140 |
-
with open(RESPONSES_PATH, "r", encoding="utf-8") as f:
|
141 |
-
response_pool = json.load(f)
|
142 |
-
logger.info(f"Loaded {len(response_pool)} responses from {RESPONSES_PATH}")
|
143 |
-
|
144 |
-
pipeline.response_pool = response_pool # assign to pipeline
|
145 |
-
|
146 |
-
# 7) Build (or rebuild) the FAISS index from pipeline method
|
147 |
-
# This does all the compute-embeddings + index.add in one place
|
148 |
-
logger.info("Starting to compute and index response embeddings via TFDataPipeline...")
|
149 |
-
pipeline.compute_and_index_response_embeddings()
|
150 |
-
|
151 |
-
# 8) Save the rebuilt FAISS index
|
152 |
-
pipeline.save_faiss_index(str(FAISS_INDEX_PATH))
|
153 |
-
|
154 |
-
# Verify
|
155 |
-
loaded_index = faiss.read_index(str(FAISS_INDEX_PATH))
|
156 |
-
logger.info(f"Verified the rebuilt FAISS index has {loaded_index.ntotal} vectors.")
|
157 |
-
|
158 |
-
return loaded_index, pipeline.response_pool
|
159 |
-
|
160 |
-
if __name__ == "__main__":
|
161 |
-
build_faiss_index()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
chatbot_model.py
CHANGED
@@ -15,7 +15,6 @@ from tf_data_pipeline import TFDataPipeline
|
|
15 |
from response_quality_checker import ResponseQualityChecker
|
16 |
from cross_encoder_reranker import CrossEncoderReranker
|
17 |
from conversation_summarizer import DeviceAwareModel, Summarizer
|
18 |
-
from gpu_monitor import GPUMemoryMonitor
|
19 |
import absl.logging
|
20 |
from logger_config import config_logger
|
21 |
from tqdm.auto import tqdm
|
@@ -147,7 +146,6 @@ class RetrievalChatbot(DeviceAwareModel):
|
|
147 |
self.tokenizer = self._initialize_tokenizer()
|
148 |
self.encoder = self._initialize_encoder()
|
149 |
self.summarizer = summarizer or self._initialize_summarizer()
|
150 |
-
self.memory_monitor = GPUMemoryMonitor()
|
151 |
|
152 |
# Initialize data pipeline
|
153 |
logger.info("Initializing TFDataPipeline.")
|
@@ -566,7 +564,7 @@ class RetrievalChatbot(DeviceAwareModel):
|
|
566 |
boosted.sort(key=lambda x: x[1], reverse=True)
|
567 |
|
568 |
# Print top 10
|
569 |
-
for resp, score in boosted[:
|
570 |
logger.debug(f"Candidate: '{resp}' with score {score}")
|
571 |
|
572 |
# 8) Return top_k
|
|
|
15 |
from response_quality_checker import ResponseQualityChecker
|
16 |
from cross_encoder_reranker import CrossEncoderReranker
|
17 |
from conversation_summarizer import DeviceAwareModel, Summarizer
|
|
|
18 |
import absl.logging
|
19 |
from logger_config import config_logger
|
20 |
from tqdm.auto import tqdm
|
|
|
146 |
self.tokenizer = self._initialize_tokenizer()
|
147 |
self.encoder = self._initialize_encoder()
|
148 |
self.summarizer = summarizer or self._initialize_summarizer()
|
|
|
149 |
|
150 |
# Initialize data pipeline
|
151 |
logger.info("Initializing TFDataPipeline.")
|
|
|
564 |
boosted.sort(key=lambda x: x[1], reverse=True)
|
565 |
|
566 |
# Print top 10
|
567 |
+
for resp, score in boosted[:150]:
|
568 |
logger.debug(f"Candidate: '{resp}' with score {score}")
|
569 |
|
570 |
# 8) Return top_k
|
chatbot_validator.py
CHANGED
@@ -10,17 +10,12 @@ logger = config_logger(__name__)
|
|
10 |
class ChatbotValidator:
|
11 |
"""
|
12 |
Handles automated validation and performance analysis for the chatbot.
|
13 |
-
|
14 |
-
This validator executes domain-specific test queries, obtains candidate
|
15 |
-
responses via the chatbot, then evaluates them with a quality checker.
|
16 |
-
It aggregates metrics across queries and domains, logs intermediate
|
17 |
-
results, and returns a comprehensive summary.
|
18 |
"""
|
19 |
|
20 |
def __init__(self, chatbot, quality_checker):
|
21 |
"""
|
22 |
Initialize the validator.
|
23 |
-
|
24 |
Args:
|
25 |
chatbot: RetrievalChatbot instance for inference
|
26 |
quality_checker: ResponseQualityChecker instance
|
@@ -28,75 +23,60 @@ class ChatbotValidator:
|
|
28 |
self.chatbot = chatbot
|
29 |
self.quality_checker = quality_checker
|
30 |
|
31 |
-
#
|
32 |
-
# Taskmaster-1 and Schema-Guided style
|
33 |
self.domain_queries = {
|
34 |
-
|
35 |
-
|
36 |
-
|
37 |
-
|
38 |
-
|
39 |
-
# "What's the average cost per plate at your restaurant?"
|
40 |
-
# # ],
|
41 |
'movie': [
|
42 |
"How much are movie tickets for two people?",
|
43 |
"I'm looking for showings after 6pm?",
|
44 |
"Is this at the new theater with reclining seats?",
|
45 |
-
"Hi, I'm thinking about reserving tickets for the new movie.",
|
46 |
-
"What is the price for your largest popcorn?"
|
47 |
],
|
48 |
-
|
49 |
-
|
50 |
-
|
51 |
-
|
52 |
-
|
53 |
-
|
54 |
-
|
55 |
-
|
56 |
-
|
57 |
-
|
58 |
-
|
59 |
-
|
60 |
-
|
61 |
-
|
62 |
-
|
63 |
-
|
64 |
-
|
65 |
-
|
66 |
-
|
67 |
-
|
68 |
-
# ],
|
69 |
-
# 'auto': [
|
70 |
-
# "I need to schedule an oil change for my car.",
|
71 |
-
# "When can I bring my car in for maintenance?",
|
72 |
-
# "Do you have any openings for auto repair today?",
|
73 |
-
# "How long will the service take?",
|
74 |
-
# "Can I get an estimate for brake repair?"
|
75 |
-
#],
|
76 |
}
|
77 |
|
78 |
def run_validation(
|
79 |
self,
|
80 |
-
num_examples: int =
|
81 |
top_k: int = 10,
|
82 |
domains: Optional[List[str]] = None,
|
83 |
randomize: bool = False,
|
84 |
seed: int = 42
|
85 |
) -> Dict[str, Any]:
|
86 |
"""
|
87 |
-
Run
|
88 |
-
|
89 |
Args:
|
90 |
num_examples: Number of test queries per domain
|
91 |
top_k: Number of responses to retrieve for each query
|
92 |
domains: Optional list of domain keys to test. If None, test all.
|
93 |
randomize: If True, randomly select queries from the domain lists
|
94 |
seed: Random seed for consistent sampling if randomize=True
|
95 |
-
|
96 |
Returns:
|
97 |
-
Dict
|
98 |
"""
|
99 |
-
logger.info("\n=== Running
|
100 |
|
101 |
# Select which domains to test
|
102 |
test_domains = domains if domains else list(self.domain_queries.keys())
|
@@ -105,6 +85,7 @@ class ChatbotValidator:
|
|
105 |
metrics_history = []
|
106 |
domain_metrics = {}
|
107 |
|
|
|
108 |
reranker = CrossEncoderReranker(model_name="cross-encoder/ms-marco-MiniLM-L-12-v2")
|
109 |
|
110 |
# Prepare random selection if needed
|
@@ -131,26 +112,21 @@ class ChatbotValidator:
|
|
131 |
for i, query in enumerate(queries, 1):
|
132 |
logger.info(f"\nTest Case {i}: {query}")
|
133 |
|
134 |
-
# Retrieve top_k responses
|
135 |
responses = self.chatbot.retrieve_responses_cross_encoder(query, top_k=top_k, reranker=reranker)
|
136 |
-
|
137 |
-
# Evaluate with quality checker
|
138 |
quality_metrics = self.quality_checker.check_response_quality(query, responses)
|
139 |
|
140 |
-
#
|
141 |
quality_metrics['domain'] = domain
|
142 |
metrics_history.append(quality_metrics)
|
143 |
domain_metrics[domain].append(quality_metrics)
|
144 |
-
|
145 |
-
# Detailed logging
|
146 |
-
self._log_validation_results(query, responses, quality_metrics, i)
|
147 |
|
148 |
# Final aggregation
|
149 |
aggregate_metrics = self._calculate_aggregate_metrics(metrics_history)
|
150 |
domain_analysis = self._analyze_domain_performance(domain_metrics)
|
151 |
confidence_analysis = self._analyze_confidence_distribution(metrics_history)
|
152 |
|
153 |
-
# Combine into one dictionary
|
154 |
aggregate_metrics.update({
|
155 |
'domain_performance': domain_analysis,
|
156 |
'confidence_analysis': confidence_analysis
|
@@ -161,7 +137,7 @@ class ChatbotValidator:
|
|
161 |
|
162 |
def _calculate_aggregate_metrics(self, metrics_history: List[Dict]) -> Dict[str, float]:
|
163 |
"""
|
164 |
-
Calculate
|
165 |
"""
|
166 |
if not metrics_history:
|
167 |
logger.warning("No metrics to aggregate. Returning empty summary.")
|
@@ -169,7 +145,6 @@ class ChatbotValidator:
|
|
169 |
|
170 |
top_scores = [m.get('top_score', 0.0) for m in metrics_history]
|
171 |
|
172 |
-
# The length-based metrics are robust to missing or zero-length data
|
173 |
metrics = {
|
174 |
'num_queries_tested': len(metrics_history),
|
175 |
'avg_top_response_score': np.mean(top_scores),
|
@@ -177,10 +152,7 @@ class ChatbotValidator:
|
|
177 |
'avg_relevance': np.mean([m.get('query_response_relevance', 0.0) for m in metrics_history]),
|
178 |
'avg_length_score': np.mean([m.get('response_length_score', 0.0) for m in metrics_history]),
|
179 |
'avg_score_gap': np.mean([m.get('top_3_score_gap', 0.0) for m in metrics_history]),
|
180 |
-
'confidence_rate': np.mean([1.0 if m.get('is_confident', False) else 0.0
|
181 |
-
for m in metrics_history]),
|
182 |
-
|
183 |
-
# Additional statistical metrics
|
184 |
'median_top_score': np.median(top_scores),
|
185 |
'score_std': np.std(top_scores),
|
186 |
'min_score': np.min(top_scores),
|
@@ -202,12 +174,9 @@ class ChatbotValidator:
|
|
202 |
top_scores = [m.get('top_score', 0.0) for m in metrics_list]
|
203 |
|
204 |
analysis[domain] = {
|
205 |
-
'confidence_rate': np.mean([1.0 if m.get('is_confident', False) else 0.0
|
206 |
-
|
207 |
-
'
|
208 |
-
for m in metrics_list]),
|
209 |
-
'avg_diversity': np.mean([m.get('response_diversity', 0.0)
|
210 |
-
for m in metrics_list]),
|
211 |
'avg_top_score': np.mean(top_scores),
|
212 |
'num_samples': len(metrics_list)
|
213 |
}
|
@@ -235,7 +204,6 @@ class ChatbotValidator:
|
|
235 |
query: str,
|
236 |
responses: List[Tuple[str, float]],
|
237 |
metrics: Dict[str, Any],
|
238 |
-
case_num: int
|
239 |
):
|
240 |
"""
|
241 |
Log detailed validation results for each test case.
|
@@ -249,8 +217,8 @@ class ChatbotValidator:
|
|
249 |
# if isinstance(v, (int, float)):
|
250 |
# logger.info(f" {k}: {v:.4f}")
|
251 |
|
252 |
-
logger.info("Top
|
253 |
-
for i, (resp_text, score) in enumerate(responses[:
|
254 |
logger.info(f"{i}) Score: {score:.4f} | {resp_text}")
|
255 |
if i == 1 and not is_confident:
|
256 |
logger.info(" [Low Confidence on Top Response]")
|
|
|
10 |
class ChatbotValidator:
|
11 |
"""
|
12 |
Handles automated validation and performance analysis for the chatbot.
|
13 |
+
This testing module executes domain-specific queries, obtains chatbot responses, and evaluates them with a quality checker.
|
|
|
|
|
|
|
|
|
14 |
"""
|
15 |
|
16 |
def __init__(self, chatbot, quality_checker):
|
17 |
"""
|
18 |
Initialize the validator.
|
|
|
19 |
Args:
|
20 |
chatbot: RetrievalChatbot instance for inference
|
21 |
quality_checker: ResponseQualityChecker instance
|
|
|
23 |
self.chatbot = chatbot
|
24 |
self.quality_checker = quality_checker
|
25 |
|
26 |
+
# Domain-specific test queries (aligns with Taskmaster-1 dataset)
|
|
|
27 |
self.domain_queries = {
|
28 |
+
'restaurant': [
|
29 |
+
"Hi, I have a question about your restaurant. Do they take reservations?",
|
30 |
+
"I'd like to make a reservation for dinner tonight after 6pm. Is that time available?",
|
31 |
+
"Can you recommend an Italian restaurant with wood-fired pizza?",
|
32 |
+
],
|
|
|
|
|
33 |
'movie': [
|
34 |
"How much are movie tickets for two people?",
|
35 |
"I'm looking for showings after 6pm?",
|
36 |
"Is this at the new theater with reclining seats?",
|
|
|
|
|
37 |
],
|
38 |
+
'ride_share': [
|
39 |
+
"I need a ride from the airport to downtown.",
|
40 |
+
"What is the cost for Lyft? How about Uber XL?",
|
41 |
+
"Can you book a car for tomorrow morning?",
|
42 |
+
],
|
43 |
+
'coffee': [
|
44 |
+
"Can I customize my coffee?",
|
45 |
+
"Can I order a mocha from you?",
|
46 |
+
"Can I get my usual venti vanilla latte?",
|
47 |
+
],
|
48 |
+
'pizza': [
|
49 |
+
"Do you have any pizza specials or deals available?",
|
50 |
+
"How long is the wait until the pizza is ready and delivered to me?",
|
51 |
+
"Please repeat my pizza order for two medium pizzas with thick crust.",
|
52 |
+
],
|
53 |
+
'auto': [
|
54 |
+
"The car is making a funny noise when I turn, and I'm due for an oil change.",
|
55 |
+
"Is my buddy John available to work on my car?",
|
56 |
+
"My Jeep needs a repair. Can you help me with that?",
|
57 |
+
],
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
58 |
}
|
59 |
|
60 |
def run_validation(
|
61 |
self,
|
62 |
+
num_examples: int = 3,
|
63 |
top_k: int = 10,
|
64 |
domains: Optional[List[str]] = None,
|
65 |
randomize: bool = False,
|
66 |
seed: int = 42
|
67 |
) -> Dict[str, Any]:
|
68 |
"""
|
69 |
+
Run validation across testable domains.
|
|
|
70 |
Args:
|
71 |
num_examples: Number of test queries per domain
|
72 |
top_k: Number of responses to retrieve for each query
|
73 |
domains: Optional list of domain keys to test. If None, test all.
|
74 |
randomize: If True, randomly select queries from the domain lists
|
75 |
seed: Random seed for consistent sampling if randomize=True
|
|
|
76 |
Returns:
|
77 |
+
Dict with validation metrics
|
78 |
"""
|
79 |
+
logger.info("\n=== Running Automatic Validation ===")
|
80 |
|
81 |
# Select which domains to test
|
82 |
test_domains = domains if domains else list(self.domain_queries.keys())
|
|
|
85 |
metrics_history = []
|
86 |
domain_metrics = {}
|
87 |
|
88 |
+
# Init the cross-encoder reranker to pass to the chatbot
|
89 |
reranker = CrossEncoderReranker(model_name="cross-encoder/ms-marco-MiniLM-L-12-v2")
|
90 |
|
91 |
# Prepare random selection if needed
|
|
|
112 |
for i, query in enumerate(queries, 1):
|
113 |
logger.info(f"\nTest Case {i}: {query}")
|
114 |
|
115 |
+
# Retrieve top_k responses, then evaluate with quality checker
|
116 |
responses = self.chatbot.retrieve_responses_cross_encoder(query, top_k=top_k, reranker=reranker)
|
|
|
|
|
117 |
quality_metrics = self.quality_checker.check_response_quality(query, responses)
|
118 |
|
119 |
+
# Aggregate metrics and log
|
120 |
quality_metrics['domain'] = domain
|
121 |
metrics_history.append(quality_metrics)
|
122 |
domain_metrics[domain].append(quality_metrics)
|
123 |
+
self._log_validation_results(query, responses, quality_metrics)
|
|
|
|
|
124 |
|
125 |
# Final aggregation
|
126 |
aggregate_metrics = self._calculate_aggregate_metrics(metrics_history)
|
127 |
domain_analysis = self._analyze_domain_performance(domain_metrics)
|
128 |
confidence_analysis = self._analyze_confidence_distribution(metrics_history)
|
129 |
|
|
|
130 |
aggregate_metrics.update({
|
131 |
'domain_performance': domain_analysis,
|
132 |
'confidence_analysis': confidence_analysis
|
|
|
137 |
|
138 |
def _calculate_aggregate_metrics(self, metrics_history: List[Dict]) -> Dict[str, float]:
|
139 |
"""
|
140 |
+
Calculate aggregate metrics over tested queries.
|
141 |
"""
|
142 |
if not metrics_history:
|
143 |
logger.warning("No metrics to aggregate. Returning empty summary.")
|
|
|
145 |
|
146 |
top_scores = [m.get('top_score', 0.0) for m in metrics_history]
|
147 |
|
|
|
148 |
metrics = {
|
149 |
'num_queries_tested': len(metrics_history),
|
150 |
'avg_top_response_score': np.mean(top_scores),
|
|
|
152 |
'avg_relevance': np.mean([m.get('query_response_relevance', 0.0) for m in metrics_history]),
|
153 |
'avg_length_score': np.mean([m.get('response_length_score', 0.0) for m in metrics_history]),
|
154 |
'avg_score_gap': np.mean([m.get('top_3_score_gap', 0.0) for m in metrics_history]),
|
155 |
+
'confidence_rate': np.mean([1.0 if m.get('is_confident', False) else 0.0 for m in metrics_history]),
|
|
|
|
|
|
|
156 |
'median_top_score': np.median(top_scores),
|
157 |
'score_std': np.std(top_scores),
|
158 |
'min_score': np.min(top_scores),
|
|
|
174 |
top_scores = [m.get('top_score', 0.0) for m in metrics_list]
|
175 |
|
176 |
analysis[domain] = {
|
177 |
+
'confidence_rate': np.mean([1.0 if m.get('is_confident', False) else 0.0 for m in metrics_list]),
|
178 |
+
'avg_relevance': np.mean([m.get('query_response_relevance', 0.0) for m in metrics_list]),
|
179 |
+
'avg_diversity': np.mean([m.get('response_diversity', 0.0) for m in metrics_list]),
|
|
|
|
|
|
|
180 |
'avg_top_score': np.mean(top_scores),
|
181 |
'num_samples': len(metrics_list)
|
182 |
}
|
|
|
204 |
query: str,
|
205 |
responses: List[Tuple[str, float]],
|
206 |
metrics: Dict[str, Any],
|
|
|
207 |
):
|
208 |
"""
|
209 |
Log detailed validation results for each test case.
|
|
|
217 |
# if isinstance(v, (int, float)):
|
218 |
# logger.info(f" {k}: {v:.4f}")
|
219 |
|
220 |
+
logger.info("Top 3 Responses:")
|
221 |
+
for i, (resp_text, score) in enumerate(responses[:3], 1):
|
222 |
logger.info(f"{i}) Score: {score:.4f} | {resp_text}")
|
223 |
if i == 1 and not is_confident:
|
224 |
logger.info(" [Low Confidence on Top Response]")
|
processing_pipeline.py β data_augmentation/augmentation_processing_pipeline.py
RENAMED
@@ -7,13 +7,13 @@ import hashlib
|
|
7 |
import spacy
|
8 |
import torch
|
9 |
from tqdm import tqdm
|
10 |
-
from pipeline_config import PipelineConfig
|
11 |
-
from dialogue_augmenter import DialogueAugmenter
|
12 |
from sklearn.feature_extraction.text import TfidfVectorizer
|
13 |
from sklearn.metrics.pairwise import cosine_similarity
|
14 |
from typing import Set
|
15 |
|
16 |
-
class
|
17 |
"""
|
18 |
Complete pipeline combining validation, optimization, and augmentation.
|
19 |
"""
|
|
|
7 |
import spacy
|
8 |
import torch
|
9 |
from tqdm import tqdm
|
10 |
+
from data_augmentation.pipeline_config import PipelineConfig
|
11 |
+
from data_augmentation.dialogue_augmenter import DialogueAugmenter
|
12 |
from sklearn.feature_extraction.text import TfidfVectorizer
|
13 |
from sklearn.metrics.pairwise import cosine_similarity
|
14 |
from typing import Set
|
15 |
|
16 |
+
class AugmentationProcessingPipeline:
|
17 |
"""
|
18 |
Complete pipeline combining validation, optimization, and augmentation.
|
19 |
"""
|
back_translator.py β data_augmentation/back_translator.py
RENAMED
File without changes
|
dialogue_augmenter.py β data_augmentation/dialogue_augmenter.py
RENAMED
@@ -3,9 +3,9 @@ import numpy as np
|
|
3 |
import torch
|
4 |
import tensorflow as tf
|
5 |
import tensorflow_hub as hub
|
6 |
-
from pipeline_config import PipelineConfig
|
7 |
-
from quality_metrics import QualityMetrics
|
8 |
-
from paraphraser import Paraphraser
|
9 |
import nlpaug.augmenter.word as naw
|
10 |
from functools import lru_cache
|
11 |
from sklearn.metrics.pairwise import cosine_similarity
|
|
|
3 |
import torch
|
4 |
import tensorflow as tf
|
5 |
import tensorflow_hub as hub
|
6 |
+
from data_augmentation.pipeline_config import PipelineConfig
|
7 |
+
from data_augmentation.quality_metrics import QualityMetrics
|
8 |
+
from data_augmentation.paraphraser import Paraphraser
|
9 |
import nlpaug.augmenter.word as naw
|
10 |
from functools import lru_cache
|
11 |
from sklearn.metrics.pairwise import cosine_similarity
|
main.py β data_augmentation/main.py
RENAMED
@@ -5,10 +5,10 @@ Description and References in the README.md file.
|
|
5 |
import json
|
6 |
import tensorflow as tf
|
7 |
from typing import List, Dict
|
8 |
-
from pipeline_config import PipelineConfig
|
9 |
-
from
|
10 |
-
from taskmaster_processor import TaskmasterProcessor
|
11 |
-
from schema_guided_dialogue_processor import SchemaGuidedProcessor
|
12 |
|
13 |
def combine_datasets(taskmaster_dialogues: List[Dict],
|
14 |
schema_guided_dialogues: List[Dict]) -> List[Dict]:
|
@@ -99,7 +99,7 @@ def main():
|
|
99 |
|
100 |
# Process through augmentation pipeline
|
101 |
print("Processing combined dataset")
|
102 |
-
pipeline =
|
103 |
output_path = pipeline.process_dataset(combined_dialogues)
|
104 |
print(f"Processing complete. Results saved to {output_path}")
|
105 |
pipeline.cleanup()
|
|
|
5 |
import json
|
6 |
import tensorflow as tf
|
7 |
from typing import List, Dict
|
8 |
+
from data_augmentation.pipeline_config import PipelineConfig
|
9 |
+
from data_augmentation.augmentation_processing_pipeline import AugmentationProcessingPipeline
|
10 |
+
from data_augmentation.taskmaster_processor import TaskmasterProcessor
|
11 |
+
from data_augmentation.schema_guided_dialogue_processor import SchemaGuidedProcessor
|
12 |
|
13 |
def combine_datasets(taskmaster_dialogues: List[Dict],
|
14 |
schema_guided_dialogues: List[Dict]) -> List[Dict]:
|
|
|
99 |
|
100 |
# Process through augmentation pipeline
|
101 |
print("Processing combined dataset")
|
102 |
+
pipeline = AugmentationProcessingPipeline(config)
|
103 |
output_path = pipeline.process_dataset(combined_dialogues)
|
104 |
print(f"Processing complete. Results saved to {output_path}")
|
105 |
pipeline.cleanup()
|
paraphraser.py β data_augmentation/paraphraser.py
RENAMED
File without changes
|
pipeline_config.py β data_augmentation/pipeline_config.py
RENAMED
File without changes
|
quality_metrics.py β data_augmentation/quality_metrics.py
RENAMED
@@ -2,7 +2,7 @@ import tensorflow_hub as hub
|
|
2 |
import spacy
|
3 |
from sklearn.metrics.pairwise import cosine_similarity
|
4 |
from typing import Dict
|
5 |
-
from pipeline_config import PipelineConfig
|
6 |
|
7 |
class QualityMetrics:
|
8 |
"""
|
|
|
2 |
import spacy
|
3 |
from sklearn.metrics.pairwise import cosine_similarity
|
4 |
from typing import Dict
|
5 |
+
from data_augmentation.pipeline_config import PipelineConfig
|
6 |
|
7 |
class QualityMetrics:
|
8 |
"""
|
schema_guided_dialogue_processor.py β data_augmentation/schema_guided_dialogue_processor.py
RENAMED
@@ -3,7 +3,7 @@ from typing import List, Dict, Optional, Any
|
|
3 |
import json
|
4 |
import glob
|
5 |
from pathlib import Path
|
6 |
-
from pipeline_config import PipelineConfig
|
7 |
|
8 |
@dataclass
|
9 |
class SchemaGuidedDialogue:
|
|
|
3 |
import json
|
4 |
import glob
|
5 |
from pathlib import Path
|
6 |
+
from data_augmentation.pipeline_config import PipelineConfig
|
7 |
|
8 |
@dataclass
|
9 |
class SchemaGuidedDialogue:
|
taskmaster_processor.py β data_augmentation/taskmaster_processor.py
RENAMED
@@ -3,7 +3,7 @@ from typing import List, Dict, Optional, Any
|
|
3 |
import json
|
4 |
import re
|
5 |
from pathlib import Path
|
6 |
-
from pipeline_config import PipelineConfig
|
7 |
|
8 |
@dataclass
|
9 |
class TaskmasterDialogue:
|
|
|
3 |
import json
|
4 |
import re
|
5 |
from pathlib import Path
|
6 |
+
from data_augmentation.pipeline_config import PipelineConfig
|
7 |
|
8 |
@dataclass
|
9 |
class TaskmasterDialogue:
|
deduplicate_augmented_dialogues.py
CHANGED
@@ -2,13 +2,16 @@ import json
|
|
2 |
from pathlib import Path
|
3 |
import logging
|
4 |
from typing import List, Dict
|
5 |
-
from collections import defaultdict
|
6 |
|
7 |
logging.basicConfig(level=logging.INFO)
|
8 |
logger = logging.getLogger(__name__)
|
9 |
|
|
|
|
|
|
|
|
|
10 |
def load_json_file(file_path: str) -> List[Dict]:
|
11 |
-
"""Load and parse
|
12 |
try:
|
13 |
with open(file_path, 'r', encoding='utf-8') as f:
|
14 |
return json.load(f)
|
@@ -21,13 +24,12 @@ def load_json_file(file_path: str) -> List[Dict]:
|
|
21 |
|
22 |
def combine_json_files(input_directory: str, output_file: str):
|
23 |
"""
|
24 |
-
Combine multiple JSON files
|
25 |
-
|
26 |
Args:
|
27 |
input_directory: Directory containing JSON files to process
|
28 |
output_file: Path to save the combined output
|
29 |
"""
|
30 |
-
# Track unique dialogues
|
31 |
dialogue_map = {}
|
32 |
duplicate_count = 0
|
33 |
|
@@ -66,7 +68,6 @@ def combine_json_files(input_directory: str, output_file: str):
|
|
66 |
except Exception as e:
|
67 |
logger.error(f"Error writing output file: {e}")
|
68 |
|
69 |
-
# Usage example
|
70 |
if __name__ == "__main__":
|
71 |
combine_json_files(
|
72 |
input_directory="/Users/joe/Desktop/Grad School/CSC525/CSC525_mod8_option2_joseph_armani/processed_outputs",
|
|
|
2 |
from pathlib import Path
|
3 |
import logging
|
4 |
from typing import List, Dict
|
|
|
5 |
|
6 |
logging.basicConfig(level=logging.INFO)
|
7 |
logger = logging.getLogger(__name__)
|
8 |
|
9 |
+
"""
|
10 |
+
Standalone script to deduplicate dialogues from multiple JSON files.
|
11 |
+
"""
|
12 |
+
|
13 |
def load_json_file(file_path: str) -> List[Dict]:
|
14 |
+
"""Load and parse JSON file."""
|
15 |
try:
|
16 |
with open(file_path, 'r', encoding='utf-8') as f:
|
17 |
return json.load(f)
|
|
|
24 |
|
25 |
def combine_json_files(input_directory: str, output_file: str):
|
26 |
"""
|
27 |
+
Combine multiple JSON files and removing duplicate dialogues based on dialogue_id.
|
|
|
28 |
Args:
|
29 |
input_directory: Directory containing JSON files to process
|
30 |
output_file: Path to save the combined output
|
31 |
"""
|
32 |
+
# Track unique dialogues
|
33 |
dialogue_map = {}
|
34 |
duplicate_count = 0
|
35 |
|
|
|
68 |
except Exception as e:
|
69 |
logger.error(f"Error writing output file: {e}")
|
70 |
|
|
|
71 |
if __name__ == "__main__":
|
72 |
combine_json_files(
|
73 |
input_directory="/Users/joe/Desktop/Grad School/CSC525/CSC525_mod8_option2_joseph_armani/processed_outputs",
|
environment_setup.py
CHANGED
@@ -90,7 +90,7 @@ class EnvironmentSetup:
|
|
90 |
return None
|
91 |
|
92 |
def setup_devices(self) -> Tuple[str, tf.distribute.Strategy]:
|
93 |
-
"""Configure available compute devices with Colab
|
94 |
logger.info("Checking available compute devices...")
|
95 |
|
96 |
# Colab-specific setup
|
@@ -128,7 +128,7 @@ class EnvironmentSetup:
|
|
128 |
except Exception as e:
|
129 |
logger.error(f"Error configuring Colab GPU: {str(e)}")
|
130 |
|
131 |
-
# Non-Colab setup
|
132 |
else:
|
133 |
# Check for TPU
|
134 |
try:
|
@@ -166,11 +166,11 @@ class EnvironmentSetup:
|
|
166 |
return "CPU", strategy
|
167 |
|
168 |
def optimize_batch_size(self, base_batch_size: int = 16) -> int:
|
169 |
-
"""
|
170 |
if not self.is_colab():
|
171 |
return base_batch_size
|
172 |
|
173 |
-
# Colab
|
174 |
if self.device_type == "GPU":
|
175 |
try:
|
176 |
gpu_name = subprocess.check_output(
|
@@ -179,15 +179,12 @@ class EnvironmentSetup:
|
|
179 |
).decode('utf-8').strip()
|
180 |
|
181 |
if "A100" in gpu_name:
|
182 |
-
# A100 optimizations - has 40GB or 80GB variants
|
183 |
logger.info("Optimizing for Colab A100 GPU")
|
184 |
-
base_batch_size = min(base_batch_size * 8,
|
185 |
elif "T4" in gpu_name:
|
186 |
-
# T4 optimizations
|
187 |
logger.info("Optimizing for Colab T4 GPU")
|
188 |
base_batch_size = min(base_batch_size * 2, 32)
|
189 |
elif "V100" in gpu_name:
|
190 |
-
# V100 optimizations
|
191 |
logger.info("Optimizing for Colab V100 GPU")
|
192 |
base_batch_size = min(base_batch_size * 3, 48)
|
193 |
except (subprocess.SubprocessError, FileNotFoundError):
|
|
|
90 |
return None
|
91 |
|
92 |
def setup_devices(self) -> Tuple[str, tf.distribute.Strategy]:
|
93 |
+
"""Configure available compute devices with Colab optimizations."""
|
94 |
logger.info("Checking available compute devices...")
|
95 |
|
96 |
# Colab-specific setup
|
|
|
128 |
except Exception as e:
|
129 |
logger.error(f"Error configuring Colab GPU: {str(e)}")
|
130 |
|
131 |
+
# Non-Colab setup
|
132 |
else:
|
133 |
# Check for TPU
|
134 |
try:
|
|
|
166 |
return "CPU", strategy
|
167 |
|
168 |
def optimize_batch_size(self, base_batch_size: int = 16) -> int:
|
169 |
+
"""Colab-specific optimizations for training."""
|
170 |
if not self.is_colab():
|
171 |
return base_batch_size
|
172 |
|
173 |
+
# Colab batch size optimization
|
174 |
if self.device_type == "GPU":
|
175 |
try:
|
176 |
gpu_name = subprocess.check_output(
|
|
|
179 |
).decode('utf-8').strip()
|
180 |
|
181 |
if "A100" in gpu_name:
|
|
|
182 |
logger.info("Optimizing for Colab A100 GPU")
|
183 |
+
base_batch_size = min(base_batch_size * 8, 64)
|
184 |
elif "T4" in gpu_name:
|
|
|
185 |
logger.info("Optimizing for Colab T4 GPU")
|
186 |
base_batch_size = min(base_batch_size * 2, 32)
|
187 |
elif "V100" in gpu_name:
|
|
|
188 |
logger.info("Optimizing for Colab V100 GPU")
|
189 |
base_batch_size = min(base_batch_size * 3, 48)
|
190 |
except (subprocess.SubprocessError, FileNotFoundError):
|
new_iteration/run_taskmaster_processor.py
CHANGED
@@ -2,8 +2,8 @@ import json
|
|
2 |
from datetime import datetime
|
3 |
from pathlib import Path
|
4 |
|
5 |
-
from pipeline_config import PipelineConfig
|
6 |
-
from taskmaster_processor import TaskmasterProcessor
|
7 |
|
8 |
def main():
|
9 |
# 1) Setup config
|
|
|
2 |
from datetime import datetime
|
3 |
from pathlib import Path
|
4 |
|
5 |
+
from data_augmentation.pipeline_config import PipelineConfig
|
6 |
+
from data_augmentation.taskmaster_processor import TaskmasterProcessor
|
7 |
|
8 |
def main():
|
9 |
# 1) Setup config
|
plotter.py
CHANGED
@@ -10,11 +10,10 @@ class Plotter:
|
|
10 |
self.save_dir.mkdir(parents=True, exist_ok=True)
|
11 |
|
12 |
def plot_training_history(self, history: Dict[str, List[float]], title: str = "Training History"):
|
13 |
-
"""Plot and
|
14 |
-
|
15 |
Args:
|
16 |
-
history:
|
17 |
-
title:
|
18 |
"""
|
19 |
# Create figure with subplots
|
20 |
fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(10, 12))
|
@@ -28,7 +27,7 @@ class Plotter:
|
|
28 |
ax1.legend()
|
29 |
ax1.grid(True)
|
30 |
|
31 |
-
# Plot learning rate
|
32 |
if 'learning_rate' in history:
|
33 |
ax2.plot(history['learning_rate'], label='Learning Rate')
|
34 |
ax2.set_xlabel('Step')
|
@@ -40,7 +39,7 @@ class Plotter:
|
|
40 |
plt.suptitle(title)
|
41 |
plt.tight_layout()
|
42 |
|
43 |
-
# Save
|
44 |
if self.save_dir:
|
45 |
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
46 |
save_path = self.save_dir / f'training_history_{timestamp}.png'
|
@@ -49,25 +48,23 @@ class Plotter:
|
|
49 |
plt.show()
|
50 |
|
51 |
def plot_validation_metrics(self, metrics: Dict[str, float]):
|
52 |
-
"""Plot validation metrics as a bar chart
|
53 |
-
|
54 |
Args:
|
55 |
metrics: Dictionary of validation metrics. Can handle nested dictionaries.
|
56 |
"""
|
57 |
|
58 |
-
# Flatten nested metrics
|
59 |
flat_metrics = {}
|
60 |
for key, value in metrics.items():
|
61 |
-
# Skip num_queries_tested
|
62 |
if key == 'num_queries_tested':
|
63 |
continue
|
64 |
|
|
|
65 |
if isinstance(value, dict):
|
66 |
-
# If value is a dictionary, flatten it with key prefix
|
67 |
for subkey, subvalue in value.items():
|
68 |
-
if isinstance(subvalue, (int, float)):
|
69 |
flat_metrics[f"{key}_{subkey}"] = subvalue
|
70 |
-
elif isinstance(value, (int, float)):
|
71 |
flat_metrics[key] = value
|
72 |
|
73 |
if not flat_metrics:
|
@@ -87,20 +84,18 @@ class Plotter:
|
|
87 |
plt.xticks(range(len(metric_names)), metric_names, rotation=45, ha='right')
|
88 |
plt.ylabel('Value')
|
89 |
|
90 |
-
# Add value labels on
|
91 |
for bar in bars:
|
92 |
height = bar.get_height()
|
93 |
plt.text(bar.get_x() + bar.get_width()/2., height,
|
94 |
f'{height:.3f}',
|
95 |
ha='center', va='bottom')
|
96 |
|
97 |
-
# Set y-axis limits
|
98 |
-
plt.ylim(0, 1.1)
|
99 |
-
|
100 |
-
# Adjust layout to prevent label cutoff
|
101 |
plt.tight_layout()
|
102 |
|
103 |
-
# Save
|
104 |
if self.save_dir:
|
105 |
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
106 |
save_path = self.save_dir / f'validation_metrics_{timestamp}.png'
|
|
|
10 |
self.save_dir.mkdir(parents=True, exist_ok=True)
|
11 |
|
12 |
def plot_training_history(self, history: Dict[str, List[float]], title: str = "Training History"):
|
13 |
+
"""Plot and save training metrics history
|
|
|
14 |
Args:
|
15 |
+
history: Dict with training metrics
|
16 |
+
title: Plot title
|
17 |
"""
|
18 |
# Create figure with subplots
|
19 |
fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(10, 12))
|
|
|
27 |
ax1.legend()
|
28 |
ax1.grid(True)
|
29 |
|
30 |
+
# Plot learning rate
|
31 |
if 'learning_rate' in history:
|
32 |
ax2.plot(history['learning_rate'], label='Learning Rate')
|
33 |
ax2.set_xlabel('Step')
|
|
|
39 |
plt.suptitle(title)
|
40 |
plt.tight_layout()
|
41 |
|
42 |
+
# Save
|
43 |
if self.save_dir:
|
44 |
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
45 |
save_path = self.save_dir / f'training_history_{timestamp}.png'
|
|
|
48 |
plt.show()
|
49 |
|
50 |
def plot_validation_metrics(self, metrics: Dict[str, float]):
|
51 |
+
"""Plot validation metrics as a bar chart
|
|
|
52 |
Args:
|
53 |
metrics: Dictionary of validation metrics. Can handle nested dictionaries.
|
54 |
"""
|
55 |
|
56 |
+
# Flatten nested metrics dict
|
57 |
flat_metrics = {}
|
58 |
for key, value in metrics.items():
|
|
|
59 |
if key == 'num_queries_tested':
|
60 |
continue
|
61 |
|
62 |
+
# Flatten dict values, use numerical values only
|
63 |
if isinstance(value, dict):
|
|
|
64 |
for subkey, subvalue in value.items():
|
65 |
+
if isinstance(subvalue, (int, float)):
|
66 |
flat_metrics[f"{key}_{subkey}"] = subvalue
|
67 |
+
elif isinstance(value, (int, float)):
|
68 |
flat_metrics[key] = value
|
69 |
|
70 |
if not flat_metrics:
|
|
|
84 |
plt.xticks(range(len(metric_names)), metric_names, rotation=45, ha='right')
|
85 |
plt.ylabel('Value')
|
86 |
|
87 |
+
# Add value labels on bars
|
88 |
for bar in bars:
|
89 |
height = bar.get_height()
|
90 |
plt.text(bar.get_x() + bar.get_width()/2., height,
|
91 |
f'{height:.3f}',
|
92 |
ha='center', va='bottom')
|
93 |
|
94 |
+
# Set y-axis limits and adjust layout
|
95 |
+
plt.ylim(0, 1.1)
|
|
|
|
|
96 |
plt.tight_layout()
|
97 |
|
98 |
+
# Save
|
99 |
if self.save_dir:
|
100 |
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
101 |
save_path = self.save_dir / f'validation_metrics_{timestamp}.png'
|
prepare_data.py
CHANGED
@@ -7,8 +7,6 @@ import tensorflow as tf
|
|
7 |
from transformers import AutoTokenizer, TFAutoModel
|
8 |
from tqdm.auto import tqdm
|
9 |
from pathlib import Path
|
10 |
-
|
11 |
-
# Your existing modules
|
12 |
from chatbot_model import ChatbotConfig, EncoderModel
|
13 |
from tf_data_pipeline import TFDataPipeline
|
14 |
from logger_config import config_logger
|
@@ -18,7 +16,6 @@ logger = config_logger(__name__)
|
|
18 |
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
19 |
|
20 |
def main():
|
21 |
-
# Constants
|
22 |
MODELS_DIR = 'new_iteration/data_prep_iterative_models'
|
23 |
PROCESSED_DATA_DIR = 'new_iteration/processed_outputs'
|
24 |
CACHE_DIR = 'new_iteration/cache'
|
@@ -30,9 +27,9 @@ def main():
|
|
30 |
CACHE_FILE = os.path.join(CACHE_DIR, 'query_embeddings_cache.pkl')
|
31 |
TF_RECORD_PATH = os.path.join(TF_RECORD_DIR, 'training_data_3.tfrecord')
|
32 |
|
33 |
-
# Decide whether to load the **custom**
|
34 |
# True for custom, False for base DistilBERT.
|
35 |
-
LOAD_CUSTOM_MODEL = True
|
36 |
NUM_NEG_SAMPLES = 10
|
37 |
|
38 |
# Ensure output directories exist
|
@@ -43,7 +40,7 @@ def main():
|
|
43 |
os.makedirs(FAISS_INDICES_DIR, exist_ok=True)
|
44 |
os.makedirs(TF_RECORD_DIR, exist_ok=True)
|
45 |
|
46 |
-
#
|
47 |
config_json = Path(MODELS_DIR) / "config.json"
|
48 |
if config_json.exists():
|
49 |
with open(config_json, "r", encoding="utf-8") as f:
|
@@ -54,20 +51,18 @@ def main():
|
|
54 |
config = ChatbotConfig()
|
55 |
logger.warning("No config.json found. Using default ChatbotConfig.")
|
56 |
|
|
|
57 |
config.neg_samples = NUM_NEG_SAMPLES
|
58 |
|
59 |
-
# Load or
|
60 |
try:
|
61 |
-
# If the directory has a valid tokenizer
|
62 |
if Path(TOKENIZER_DIR).exists() and list(Path(TOKENIZER_DIR).iterdir()):
|
63 |
logger.info(f"Loading tokenizer from {TOKENIZER_DIR}")
|
64 |
tokenizer = AutoTokenizer.from_pretrained(TOKENIZER_DIR)
|
65 |
else:
|
66 |
-
# Initialize from base DistilBERT
|
67 |
logger.info(f"Loading base tokenizer for {config.pretrained_model}")
|
68 |
tokenizer = AutoTokenizer.from_pretrained(config.pretrained_model)
|
69 |
|
70 |
-
# Save to disk
|
71 |
Path(TOKENIZER_DIR).mkdir(parents=True, exist_ok=True)
|
72 |
tokenizer.save_pretrained(TOKENIZER_DIR)
|
73 |
logger.info(f"New tokenizer saved to {TOKENIZER_DIR}")
|
@@ -75,7 +70,7 @@ def main():
|
|
75 |
logger.error(f"Failed to load or create tokenizer: {e}")
|
76 |
sys.exit(1)
|
77 |
|
78 |
-
#
|
79 |
try:
|
80 |
encoder = EncoderModel(config=config)
|
81 |
logger.info("EncoderModel initialized successfully.")
|
@@ -89,22 +84,24 @@ def main():
|
|
89 |
else:
|
90 |
logger.warning(f"No shared_encoder found at {shared_encoder_path}, using base DistilBERT instead.")
|
91 |
|
92 |
-
# Load
|
93 |
custom_weights_path = Path(MODELS_DIR) / "encoder_custom_weights.weights.h5"
|
94 |
if custom_weights_path.exists():
|
95 |
logger.info(f"Loading custom top-level weights from {custom_weights_path}")
|
96 |
-
|
|
|
97 |
dummy_input = tf.zeros((1, config.max_context_token_limit), dtype=tf.int32)
|
98 |
_ = encoder(dummy_input, training=False)
|
|
|
99 |
encoder.load_weights(str(custom_weights_path))
|
100 |
logger.info("Custom encoder weights loaded successfully.")
|
101 |
else:
|
102 |
logger.warning(f"Custom weights file not found at {custom_weights_path}. Using only submodule weights.")
|
103 |
else:
|
104 |
-
#
|
105 |
logger.info("Using the base DistilBERT without loading custom weights.")
|
106 |
|
107 |
-
# Resize token embeddings in case we added special tokens
|
108 |
encoder.pretrained.resize_token_embeddings(len(tokenizer))
|
109 |
logger.info(f"Token embeddings resized to: {len(tokenizer)}")
|
110 |
|
@@ -124,7 +121,7 @@ def main():
|
|
124 |
logger.error(f"Failed to load dialogues: {e}")
|
125 |
sys.exit(1)
|
126 |
|
127 |
-
# Load or
|
128 |
query_embeddings_cache = {}
|
129 |
if os.path.exists(CACHE_FILE):
|
130 |
try:
|
@@ -138,20 +135,18 @@ def main():
|
|
138 |
|
139 |
# Initialize TFDataPipeline
|
140 |
try:
|
141 |
-
#
|
142 |
if Path(FAISS_INDEX_PRODUCTION_PATH).exists():
|
143 |
-
# Load existing index
|
144 |
logger.info(f"Loading existing FAISS index from {FAISS_INDEX_PRODUCTION_PATH}...")
|
145 |
faiss_index = faiss.read_index(FAISS_INDEX_PRODUCTION_PATH)
|
146 |
logger.info("FAISS index loaded successfully.")
|
147 |
else:
|
148 |
-
# Initialize a new FAISS index
|
149 |
logger.info("No existing FAISS index found. Initializing a new index.")
|
150 |
dimension = config.embedding_dim # Ensure this matches your encoder's output
|
151 |
faiss_index = faiss.IndexFlatIP(dimension) # Using Inner Product for cosine similarity
|
152 |
logger.info(f"Initialized new FAISS index with dimension {dimension}.")
|
153 |
|
154 |
-
#
|
155 |
data_pipeline = TFDataPipeline(
|
156 |
config=config,
|
157 |
tokenizer=tokenizer,
|
@@ -162,7 +157,7 @@ def main():
|
|
162 |
neg_samples=config.neg_samples,
|
163 |
query_embeddings_cache=query_embeddings_cache,
|
164 |
index_type='IndexFlatIP',
|
165 |
-
nlist=100,
|
166 |
max_retries=config.max_retries
|
167 |
)
|
168 |
logger.info("TFDataPipeline initialized successfully.")
|
@@ -170,7 +165,7 @@ def main():
|
|
170 |
logger.error(f"Failed to initialize TFDataPipeline: {e}")
|
171 |
sys.exit(1)
|
172 |
|
173 |
-
#
|
174 |
try:
|
175 |
if dialogues:
|
176 |
response_pool = data_pipeline.collect_responses_with_domain(dialogues)
|
@@ -182,8 +177,7 @@ def main():
|
|
182 |
logger.error(f"Failed to collect responses: {e}")
|
183 |
sys.exit(1)
|
184 |
|
185 |
-
#
|
186 |
-
# Instead of manually computing embeddings, we use the pipeline method
|
187 |
try:
|
188 |
if data_pipeline.response_pool:
|
189 |
data_pipeline.build_text_to_domain_map()
|
@@ -191,10 +185,10 @@ def main():
|
|
191 |
data_pipeline.compute_and_index_response_embeddings()
|
192 |
logger.info("Response embeddings computed and added to FAISS index.")
|
193 |
|
194 |
-
# Save the
|
195 |
data_pipeline.save_faiss_index(FAISS_INDEX_PRODUCTION_PATH)
|
196 |
|
197 |
-
# Also save
|
198 |
response_pool_path = FAISS_INDEX_PRODUCTION_PATH.replace('.index', '_responses.json')
|
199 |
with open(response_pool_path, 'w', encoding='utf-8') as f:
|
200 |
json.dump(data_pipeline.response_pool, f, indent=2)
|
@@ -206,7 +200,7 @@ def main():
|
|
206 |
logger.error(f"Failed to compute or add response embeddings: {e}")
|
207 |
sys.exit(1)
|
208 |
|
209 |
-
#
|
210 |
try:
|
211 |
if dialogues:
|
212 |
logger.info("Starting data preparation and saving as TFRecord...")
|
@@ -218,7 +212,7 @@ def main():
|
|
218 |
logger.error(f"Failed during data preparation and saving: {e}")
|
219 |
sys.exit(1)
|
220 |
|
221 |
-
#
|
222 |
try:
|
223 |
with open(CACHE_FILE, 'wb') as f:
|
224 |
pickle.dump(data_pipeline.query_embeddings_cache, f)
|
|
|
7 |
from transformers import AutoTokenizer, TFAutoModel
|
8 |
from tqdm.auto import tqdm
|
9 |
from pathlib import Path
|
|
|
|
|
10 |
from chatbot_model import ChatbotConfig, EncoderModel
|
11 |
from tf_data_pipeline import TFDataPipeline
|
12 |
from logger_config import config_logger
|
|
|
16 |
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
17 |
|
18 |
def main():
|
|
|
19 |
MODELS_DIR = 'new_iteration/data_prep_iterative_models'
|
20 |
PROCESSED_DATA_DIR = 'new_iteration/processed_outputs'
|
21 |
CACHE_DIR = 'new_iteration/cache'
|
|
|
27 |
CACHE_FILE = os.path.join(CACHE_DIR, 'query_embeddings_cache.pkl')
|
28 |
TF_RECORD_PATH = os.path.join(TF_RECORD_DIR, 'training_data_3.tfrecord')
|
29 |
|
30 |
+
# Decide whether to load the **custom** model or base DistilBERT (Base used for first iteration).
|
31 |
# True for custom, False for base DistilBERT.
|
32 |
+
LOAD_CUSTOM_MODEL = True
|
33 |
NUM_NEG_SAMPLES = 10
|
34 |
|
35 |
# Ensure output directories exist
|
|
|
40 |
os.makedirs(FAISS_INDICES_DIR, exist_ok=True)
|
41 |
os.makedirs(TF_RECORD_DIR, exist_ok=True)
|
42 |
|
43 |
+
# Init config
|
44 |
config_json = Path(MODELS_DIR) / "config.json"
|
45 |
if config_json.exists():
|
46 |
with open(config_json, "r", encoding="utf-8") as f:
|
|
|
51 |
config = ChatbotConfig()
|
52 |
logger.warning("No config.json found. Using default ChatbotConfig.")
|
53 |
|
54 |
+
# Ensure negative samples are set
|
55 |
config.neg_samples = NUM_NEG_SAMPLES
|
56 |
|
57 |
+
# Load or init tokenizer
|
58 |
try:
|
|
|
59 |
if Path(TOKENIZER_DIR).exists() and list(Path(TOKENIZER_DIR).iterdir()):
|
60 |
logger.info(f"Loading tokenizer from {TOKENIZER_DIR}")
|
61 |
tokenizer = AutoTokenizer.from_pretrained(TOKENIZER_DIR)
|
62 |
else:
|
|
|
63 |
logger.info(f"Loading base tokenizer for {config.pretrained_model}")
|
64 |
tokenizer = AutoTokenizer.from_pretrained(config.pretrained_model)
|
65 |
|
|
|
66 |
Path(TOKENIZER_DIR).mkdir(parents=True, exist_ok=True)
|
67 |
tokenizer.save_pretrained(TOKENIZER_DIR)
|
68 |
logger.info(f"New tokenizer saved to {TOKENIZER_DIR}")
|
|
|
70 |
logger.error(f"Failed to load or create tokenizer: {e}")
|
71 |
sys.exit(1)
|
72 |
|
73 |
+
# Init the encoder
|
74 |
try:
|
75 |
encoder = EncoderModel(config=config)
|
76 |
logger.info("EncoderModel initialized successfully.")
|
|
|
84 |
else:
|
85 |
logger.warning(f"No shared_encoder found at {shared_encoder_path}, using base DistilBERT instead.")
|
86 |
|
87 |
+
# Load custom .weights.h5 (projection, dropout, etc.)
|
88 |
custom_weights_path = Path(MODELS_DIR) / "encoder_custom_weights.weights.h5"
|
89 |
if custom_weights_path.exists():
|
90 |
logger.info(f"Loading custom top-level weights from {custom_weights_path}")
|
91 |
+
|
92 |
+
# Dummy forward pass forces model build to ensure all layers are built
|
93 |
dummy_input = tf.zeros((1, config.max_context_token_limit), dtype=tf.int32)
|
94 |
_ = encoder(dummy_input, training=False)
|
95 |
+
|
96 |
encoder.load_weights(str(custom_weights_path))
|
97 |
logger.info("Custom encoder weights loaded successfully.")
|
98 |
else:
|
99 |
logger.warning(f"Custom weights file not found at {custom_weights_path}. Using only submodule weights.")
|
100 |
else:
|
101 |
+
# Base DistilBERT with special tokens
|
102 |
logger.info("Using the base DistilBERT without loading custom weights.")
|
103 |
|
104 |
+
# Resize token embeddings in case we added special tokens (EncoderModel class)
|
105 |
encoder.pretrained.resize_token_embeddings(len(tokenizer))
|
106 |
logger.info(f"Token embeddings resized to: {len(tokenizer)}")
|
107 |
|
|
|
121 |
logger.error(f"Failed to load dialogues: {e}")
|
122 |
sys.exit(1)
|
123 |
|
124 |
+
# Load or init query_embeddings_cache. NOTE: recompute after each training. This was a bug source.
|
125 |
query_embeddings_cache = {}
|
126 |
if os.path.exists(CACHE_FILE):
|
127 |
try:
|
|
|
135 |
|
136 |
# Initialize TFDataPipeline
|
137 |
try:
|
138 |
+
# Load or init FAISS index
|
139 |
if Path(FAISS_INDEX_PRODUCTION_PATH).exists():
|
|
|
140 |
logger.info(f"Loading existing FAISS index from {FAISS_INDEX_PRODUCTION_PATH}...")
|
141 |
faiss_index = faiss.read_index(FAISS_INDEX_PRODUCTION_PATH)
|
142 |
logger.info("FAISS index loaded successfully.")
|
143 |
else:
|
|
|
144 |
logger.info("No existing FAISS index found. Initializing a new index.")
|
145 |
dimension = config.embedding_dim # Ensure this matches your encoder's output
|
146 |
faiss_index = faiss.IndexFlatIP(dimension) # Using Inner Product for cosine similarity
|
147 |
logger.info(f"Initialized new FAISS index with dimension {dimension}.")
|
148 |
|
149 |
+
# Init TFDataPipeline with the FAISS index
|
150 |
data_pipeline = TFDataPipeline(
|
151 |
config=config,
|
152 |
tokenizer=tokenizer,
|
|
|
157 |
neg_samples=config.neg_samples,
|
158 |
query_embeddings_cache=query_embeddings_cache,
|
159 |
index_type='IndexFlatIP',
|
160 |
+
nlist=100, # Not used for IndexFlatIP. Retained for future use of IndexIVFFlat
|
161 |
max_retries=config.max_retries
|
162 |
)
|
163 |
logger.info("TFDataPipeline initialized successfully.")
|
|
|
165 |
logger.error(f"Failed to initialize TFDataPipeline: {e}")
|
166 |
sys.exit(1)
|
167 |
|
168 |
+
# Collect response pool from dialogues
|
169 |
try:
|
170 |
if dialogues:
|
171 |
response_pool = data_pipeline.collect_responses_with_domain(dialogues)
|
|
|
177 |
logger.error(f"Failed to collect responses: {e}")
|
178 |
sys.exit(1)
|
179 |
|
180 |
+
# Build FAISS index with response embeddings
|
|
|
181 |
try:
|
182 |
if data_pipeline.response_pool:
|
183 |
data_pipeline.build_text_to_domain_map()
|
|
|
185 |
data_pipeline.compute_and_index_response_embeddings()
|
186 |
logger.info("Response embeddings computed and added to FAISS index.")
|
187 |
|
188 |
+
# Save the FAISS index
|
189 |
data_pipeline.save_faiss_index(FAISS_INDEX_PRODUCTION_PATH)
|
190 |
|
191 |
+
# Also save response pool JSON
|
192 |
response_pool_path = FAISS_INDEX_PRODUCTION_PATH.replace('.index', '_responses.json')
|
193 |
with open(response_pool_path, 'w', encoding='utf-8') as f:
|
194 |
json.dump(data_pipeline.response_pool, f, indent=2)
|
|
|
200 |
logger.error(f"Failed to compute or add response embeddings: {e}")
|
201 |
sys.exit(1)
|
202 |
|
203 |
+
# Prepare training data as TFRecords (TensforFlow Record format)
|
204 |
try:
|
205 |
if dialogues:
|
206 |
logger.info("Starting data preparation and saving as TFRecord...")
|
|
|
212 |
logger.error(f"Failed during data preparation and saving: {e}")
|
213 |
sys.exit(1)
|
214 |
|
215 |
+
# Save query embeddings cache
|
216 |
try:
|
217 |
with open(CACHE_FILE, 'wb') as f:
|
218 |
pickle.dump(data_pipeline.query_embeddings_cache, f)
|
tf_data_pipeline.py
CHANGED
@@ -1,5 +1,4 @@
|
|
1 |
import os
|
2 |
-
import gc
|
3 |
import numpy as np
|
4 |
import faiss
|
5 |
import tensorflow as tf
|
@@ -12,7 +11,6 @@ from typing import Union, Optional, Dict, List, Tuple, Generator
|
|
12 |
from transformers import AutoTokenizer
|
13 |
from typing import List, Tuple, Generator
|
14 |
from transformers import AutoTokenizer
|
15 |
-
from gpu_monitor import GPUMemoryMonitor
|
16 |
import random
|
17 |
|
18 |
from logger_config import config_logger
|
@@ -46,7 +44,6 @@ class TFDataPipeline:
|
|
46 |
self.embedding_batch_size = 16 if len(response_pool) < 100 else 64
|
47 |
self.search_batch_size = 16 if len(response_pool) < 100 else 64
|
48 |
self.max_batch_size = 16 if len(response_pool) < 100 else 64
|
49 |
-
self.memory_monitor = GPUMemoryMonitor()
|
50 |
self.max_retries = max_retries
|
51 |
|
52 |
# Build a quick text->domain map for O(1) domain lookups
|
|
|
1 |
import os
|
|
|
2 |
import numpy as np
|
3 |
import faiss
|
4 |
import tensorflow as tf
|
|
|
11 |
from transformers import AutoTokenizer
|
12 |
from typing import List, Tuple, Generator
|
13 |
from transformers import AutoTokenizer
|
|
|
14 |
import random
|
15 |
|
16 |
from logger_config import config_logger
|
|
|
44 |
self.embedding_batch_size = 16 if len(response_pool) < 100 else 64
|
45 |
self.search_batch_size = 16 if len(response_pool) < 100 else 64
|
46 |
self.max_batch_size = 16 if len(response_pool) < 100 else 64
|
|
|
47 |
self.max_retries = max_retries
|
48 |
|
49 |
# Build a quick text->domain map for O(1) domain lookups
|
unused/build_faiss_index.py
ADDED
@@ -0,0 +1,160 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# import os
|
2 |
+
# import json
|
3 |
+
# from pathlib import Path
|
4 |
+
|
5 |
+
# import faiss
|
6 |
+
# import numpy as np
|
7 |
+
# import tensorflow as tf
|
8 |
+
# from transformers import AutoTokenizer, TFAutoModel
|
9 |
+
# from tqdm.auto import tqdm
|
10 |
+
|
11 |
+
# from chatbot_model import ChatbotConfig, EncoderModel
|
12 |
+
# from tf_data_pipeline import TFDataPipeline
|
13 |
+
# from logger_config import config_logger
|
14 |
+
|
15 |
+
# logger = config_logger(__name__)
|
16 |
+
# os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
17 |
+
|
18 |
+
# def sanity_check(encoder: EncoderModel, tokenizer: AutoTokenizer, config: ChatbotConfig):
|
19 |
+
# """
|
20 |
+
# Perform a quick sanity check to ensure the model is loaded correctly.
|
21 |
+
# """
|
22 |
+
# sample_response = "This is a test response."
|
23 |
+
# encoded_sample = tokenizer(
|
24 |
+
# [sample_response],
|
25 |
+
# padding=True,
|
26 |
+
# truncation=True,
|
27 |
+
# max_length=config.max_context_token_limit,
|
28 |
+
# return_tensors='tf'
|
29 |
+
# )
|
30 |
+
|
31 |
+
# # Get embedding
|
32 |
+
# sample_embedding = encoder(encoded_sample['input_ids'], training=False).numpy()
|
33 |
+
|
34 |
+
# # Check shape
|
35 |
+
# if sample_embedding.shape[1] != config.embedding_dim:
|
36 |
+
# logger.error(
|
37 |
+
# f"Embedding dimension mismatch: Expected {config.embedding_dim}, "
|
38 |
+
# f"got {sample_embedding.shape[1]}"
|
39 |
+
# )
|
40 |
+
# raise ValueError("Embedding dimension mismatch.")
|
41 |
+
# else:
|
42 |
+
# logger.info("Embedding dimension matches the configuration.")
|
43 |
+
|
44 |
+
# # Check normalization
|
45 |
+
# embedding_norm = np.linalg.norm(sample_embedding, axis=1)
|
46 |
+
# if not np.allclose(embedding_norm, 1.0, atol=1e-5):
|
47 |
+
# logger.error("Embeddings are not properly normalized.")
|
48 |
+
# raise ValueError("Embeddings are not normalized.")
|
49 |
+
# else:
|
50 |
+
# logger.info("Embeddings are properly normalized.")
|
51 |
+
|
52 |
+
# logger.info("Sanity check passed: Model loaded correctly and outputs are as expected.")
|
53 |
+
|
54 |
+
# def build_faiss_index():
|
55 |
+
# """
|
56 |
+
# Rebuild the FAISS index by:
|
57 |
+
# 1) Loading your config.json
|
58 |
+
# 2) Initializing encoder + loading submodule & custom weights
|
59 |
+
# 3) Loading tokenizer from disk
|
60 |
+
# 4) Creating a TFDataPipeline
|
61 |
+
# 5) Setting the pipeline's response_pool from a JSON file
|
62 |
+
# 6) Using pipeline.compute_and_index_response_embeddings()
|
63 |
+
# 7) Saving the FAISS index
|
64 |
+
# """
|
65 |
+
# # Directories
|
66 |
+
# MODELS_DIR = Path("models")
|
67 |
+
# FAISS_DIR = MODELS_DIR / "faiss_indices"
|
68 |
+
# FAISS_INDEX_PATH = FAISS_DIR / "faiss_index_production.index"
|
69 |
+
# RESPONSES_PATH = FAISS_DIR / "faiss_index_production_responses.json"
|
70 |
+
# TOKENIZER_DIR = MODELS_DIR / "tokenizer"
|
71 |
+
# SHARED_ENCODER_DIR = MODELS_DIR / "shared_encoder"
|
72 |
+
# CUSTOM_WEIGHTS_PATH = MODELS_DIR / "encoder_custom_weights.weights.h5"
|
73 |
+
|
74 |
+
# # 1) Load ChatbotConfig
|
75 |
+
# config_path = MODELS_DIR / "config.json"
|
76 |
+
# if config_path.exists():
|
77 |
+
# with open(config_path, "r", encoding="utf-8") as f:
|
78 |
+
# config_dict = json.load(f)
|
79 |
+
# config = ChatbotConfig.from_dict(config_dict)
|
80 |
+
# logger.info(f"Loaded ChatbotConfig from {config_path}")
|
81 |
+
# else:
|
82 |
+
# config = ChatbotConfig()
|
83 |
+
# logger.warning(f"No config.json found at {config_path}. Using default ChatbotConfig.")
|
84 |
+
|
85 |
+
# # 2) Initialize the EncoderModel
|
86 |
+
# encoder = EncoderModel(config=config)
|
87 |
+
# logger.info("EncoderModel instantiated (empty).")
|
88 |
+
|
89 |
+
# # Overwrite the submodule from 'shared_encoder' directory
|
90 |
+
# if SHARED_ENCODER_DIR.exists():
|
91 |
+
# logger.info(f"Loading DistilBERT submodule from {SHARED_ENCODER_DIR}...")
|
92 |
+
# encoder.pretrained = TFAutoModel.from_pretrained(str(SHARED_ENCODER_DIR))
|
93 |
+
# logger.info("Loaded HF submodule into encoder.pretrained.")
|
94 |
+
# else:
|
95 |
+
# logger.warning(f"No shared_encoder directory at {SHARED_ENCODER_DIR}. Using default pretrained model.")
|
96 |
+
|
97 |
+
# # Build model once, then load custom weights (projection, etc.)
|
98 |
+
# dummy_input = tf.zeros((1, config.max_context_token_limit), dtype=tf.int32)
|
99 |
+
# _ = encoder(dummy_input, training=False) # builds the layers
|
100 |
+
|
101 |
+
# if CUSTOM_WEIGHTS_PATH.exists():
|
102 |
+
# logger.info(f"Loading custom top-level weights from {CUSTOM_WEIGHTS_PATH}")
|
103 |
+
# encoder.load_weights(str(CUSTOM_WEIGHTS_PATH))
|
104 |
+
# logger.info("Custom top-level weights loaded successfully.")
|
105 |
+
# else:
|
106 |
+
# logger.warning(f"Custom weights file not found at {CUSTOM_WEIGHTS_PATH}.")
|
107 |
+
|
108 |
+
# # 3) Load tokenizer
|
109 |
+
# if TOKENIZER_DIR.exists():
|
110 |
+
# logger.info(f"Loading tokenizer from {TOKENIZER_DIR}")
|
111 |
+
# tokenizer = AutoTokenizer.from_pretrained(str(TOKENIZER_DIR))
|
112 |
+
# else:
|
113 |
+
# logger.warning(f"No tokenizer dir at {TOKENIZER_DIR}, falling back to default HF tokenizer.")
|
114 |
+
# tokenizer = AutoTokenizer.from_pretrained(config.pretrained_model)
|
115 |
+
|
116 |
+
# # 4) Quick sanity check
|
117 |
+
# sanity_check(encoder, tokenizer, config)
|
118 |
+
|
119 |
+
# # 5) Prepare a TFDataPipeline
|
120 |
+
# pipeline = TFDataPipeline(
|
121 |
+
# config=config,
|
122 |
+
# tokenizer=tokenizer,
|
123 |
+
# encoder=encoder,
|
124 |
+
# index_file_path=str(FAISS_INDEX_PATH),
|
125 |
+
# response_pool=[],
|
126 |
+
# max_length=config.max_context_token_limit,
|
127 |
+
# query_embeddings_cache={},
|
128 |
+
# neg_samples=config.neg_samples,
|
129 |
+
# index_type='IndexFlatIP',
|
130 |
+
# nlist=100,
|
131 |
+
# max_retries=config.max_retries
|
132 |
+
# )
|
133 |
+
|
134 |
+
# # 6) Load the existing response pool
|
135 |
+
# if not RESPONSES_PATH.exists():
|
136 |
+
# logger.error(f"Response pool JSON file not found at {RESPONSES_PATH}")
|
137 |
+
# raise FileNotFoundError(f"No response pool JSON at {RESPONSES_PATH}")
|
138 |
+
|
139 |
+
# with open(RESPONSES_PATH, "r", encoding="utf-8") as f:
|
140 |
+
# response_pool = json.load(f)
|
141 |
+
# logger.info(f"Loaded {len(response_pool)} responses from {RESPONSES_PATH}")
|
142 |
+
|
143 |
+
# pipeline.response_pool = response_pool # assign to pipeline
|
144 |
+
|
145 |
+
# # 7) Build (or rebuild) the FAISS index from pipeline method
|
146 |
+
# # This does all the compute-embeddings + index.add in one place
|
147 |
+
# logger.info("Starting to compute and index response embeddings via TFDataPipeline...")
|
148 |
+
# pipeline.compute_and_index_response_embeddings()
|
149 |
+
|
150 |
+
# # 8) Save the rebuilt FAISS index
|
151 |
+
# pipeline.save_faiss_index(str(FAISS_INDEX_PATH))
|
152 |
+
|
153 |
+
# # Verify
|
154 |
+
# loaded_index = faiss.read_index(str(FAISS_INDEX_PATH))
|
155 |
+
# logger.info(f"Verified the rebuilt FAISS index has {loaded_index.ntotal} vectors.")
|
156 |
+
|
157 |
+
# return loaded_index, pipeline.response_pool
|
158 |
+
|
159 |
+
# if __name__ == "__main__":
|
160 |
+
# build_faiss_index()
|
gpu_monitor.py β unused/gpu_monitor.py
RENAMED
@@ -1,17 +1,8 @@
|
|
1 |
-
import numpy as np
|
2 |
import tensorflow as tf
|
3 |
-
import
|
4 |
-
import json
|
5 |
-
from pathlib import Path
|
6 |
-
from typing import List, Dict, Tuple, Optional, Generator
|
7 |
from dataclasses import dataclass
|
8 |
-
|
9 |
-
from
|
10 |
-
import gc
|
11 |
-
try:
|
12 |
-
from tqdm.notebook import tqdm
|
13 |
-
except ImportError:
|
14 |
-
from tqdm import tqdm
|
15 |
|
16 |
@dataclass
|
17 |
class GPUMemoryStats:
|
@@ -63,6 +54,6 @@ class GPUMemoryMonitor:
|
|
63 |
def can_increase_batch_size(self) -> bool:
|
64 |
"""Check if batch size can be increased based on memory usage."""
|
65 |
if not self.has_gpu:
|
66 |
-
return True
|
67 |
usage = self.get_memory_usage()
|
68 |
return usage < 0.70
|
|
|
|
|
1 |
import tensorflow as tf
|
2 |
+
from typing import List, Dict, Optional
|
|
|
|
|
|
|
3 |
from dataclasses import dataclass
|
4 |
+
|
5 |
+
from tqdm.auto import tqdm
|
|
|
|
|
|
|
|
|
|
|
6 |
|
7 |
@dataclass
|
8 |
class GPUMemoryStats:
|
|
|
54 |
def can_increase_batch_size(self) -> bool:
|
55 |
"""Check if batch size can be increased based on memory usage."""
|
56 |
if not self.has_gpu:
|
57 |
+
return True
|
58 |
usage = self.get_memory_usage()
|
59 |
return usage < 0.70
|
validate_model.py
CHANGED
@@ -71,11 +71,6 @@ def validate_chatbot():
|
|
71 |
logger.warning("No config.json found. Using default ChatbotConfig.")
|
72 |
|
73 |
# Load RetrievalChatbot in 'inference' mode using the classmethod
|
74 |
-
# This:
|
75 |
-
# - Loads shared_encoder submodule
|
76 |
-
# - Loads encoder_custom_weights.weights.h5
|
77 |
-
# - Loads tokenizer
|
78 |
-
# - Prepares the model for inference
|
79 |
try:
|
80 |
chatbot = RetrievalChatbot.load_model(load_dir=MODEL_DIR, mode="inference")
|
81 |
logger.info("RetrievalChatbot loaded in 'inference' mode successfully.")
|
|
|
71 |
logger.warning("No config.json found. Using default ChatbotConfig.")
|
72 |
|
73 |
# Load RetrievalChatbot in 'inference' mode using the classmethod
|
|
|
|
|
|
|
|
|
|
|
74 |
try:
|
75 |
chatbot = RetrievalChatbot.load_model(load_dir=MODEL_DIR, mode="inference")
|
76 |
logger.info("RetrievalChatbot loaded in 'inference' mode successfully.")
|