|
import os |
|
from transformers import TFAutoModel, AutoTokenizer |
|
import tensorflow as tf |
|
from typing import List, Tuple, Dict, Optional, Union, Any |
|
import math |
|
from dataclasses import dataclass |
|
import json |
|
from pathlib import Path |
|
import datetime |
|
import faiss |
|
import gc |
|
from tf_data_pipeline import TFDataPipeline |
|
from response_quality_checker import ResponseQualityChecker |
|
from cross_encoder_reranker import CrossEncoderReranker |
|
from conversation_summarizer import DeviceAwareModel, Summarizer |
|
from gpu_monitor import GPUMemoryMonitor |
|
import absl.logging |
|
from logger_config import config_logger |
|
from tqdm.auto import tqdm |
|
|
|
absl.logging.set_verbosity(absl.logging.WARNING) |
|
logger = config_logger(__name__) |
|
|
|
@dataclass |
|
class ChatbotConfig: |
|
"""Configuration for the RetrievalChatbot.""" |
|
max_context_token_limit: int = 512 |
|
embedding_dim: int = 768 |
|
encoder_units: int = 256 |
|
num_attention_heads: int = 8 |
|
dropout_rate: float = 0.2 |
|
l2_reg_weight: float = 0.001 |
|
learning_rate: float = 0.001 |
|
min_text_length: int = 3 |
|
max_context_turns: int = 5 |
|
warmup_steps: int = 200 |
|
pretrained_model: str = 'distilbert-base-uncased' |
|
dtype: str = 'float32' |
|
freeze_embeddings: bool = False |
|
embedding_batch_size: int = 64 |
|
search_batch_size: int = 64 |
|
max_batch_size: int = 64 |
|
neg_samples: int = 3 |
|
max_retries: int = 3 |
|
|
|
def to_dict(self) -> Dict: |
|
"""Convert config to dictionary.""" |
|
return {k: (str(v) if isinstance(v, Path) else v) |
|
for k, v in self.__dict__.items()} |
|
|
|
@classmethod |
|
def from_dict(cls, config_dict: Dict) -> 'ChatbotConfig': |
|
"""Create config from dictionary.""" |
|
return cls(**{k: v for k, v in config_dict.items() |
|
if k in cls.__dataclass_fields__}) |
|
|
|
class EncoderModel(tf.keras.Model): |
|
"""Dual encoder model with pretrained embeddings.""" |
|
def __init__( |
|
self, |
|
config: ChatbotConfig, |
|
name: str = "encoder", |
|
**kwargs |
|
): |
|
super().__init__(name=name, **kwargs) |
|
self.config = config |
|
|
|
|
|
self.pretrained = TFAutoModel.from_pretrained(config.pretrained_model) |
|
self._freeze_layers() |
|
|
|
|
|
self.pooler = tf.keras.layers.GlobalAveragePooling1D() |
|
self.projection = tf.keras.layers.Dense( |
|
config.embedding_dim, |
|
activation='tanh', |
|
name="projection" |
|
) |
|
self.dropout = tf.keras.layers.Dropout(config.dropout_rate) |
|
self.normalize = tf.keras.layers.Lambda( |
|
lambda x: tf.nn.l2_normalize(x, axis=1), |
|
name="l2_normalize" |
|
) |
|
|
|
def _freeze_layers(self): |
|
"""Freeze layers of the pretrained model based on configuration.""" |
|
if self.config.freeze_embeddings: |
|
self.pretrained.trainable = False |
|
logger.info("All pretrained layers frozen.") |
|
else: |
|
|
|
for i, layer in enumerate(self.pretrained.layers): |
|
if isinstance(layer, tf.keras.layers.Layer): |
|
if hasattr(layer, 'trainable'): |
|
|
|
if i < 1: |
|
layer.trainable = False |
|
logger.info(f"Layer {i} frozen.") |
|
else: |
|
layer.trainable = True |
|
|
|
def call(self, inputs: tf.Tensor, training: bool = False) -> tf.Tensor: |
|
"""Forward pass.""" |
|
|
|
pretrained_outputs = self.pretrained(inputs, training=training) |
|
x = pretrained_outputs.last_hidden_state |
|
|
|
|
|
x = self.pooler(x) |
|
x = self.projection(x) |
|
x = self.dropout(x, training=training) |
|
x = self.normalize(x) |
|
|
|
return x |
|
|
|
def get_config(self) -> dict: |
|
"""Return the config of the model.""" |
|
config = super().get_config() |
|
config.update({ |
|
"config": self.config.to_dict(), |
|
"name": self.name |
|
}) |
|
return config |
|
|
|
class RetrievalChatbot(DeviceAwareModel): |
|
"""Retrieval-based chatbot using pretrained embeddings and FAISS for similarity search.""" |
|
def __init__( |
|
self, |
|
config: ChatbotConfig, |
|
device: str = None, |
|
strategy=None, |
|
reranker: Optional[CrossEncoderReranker] = None, |
|
summarizer: Optional[Summarizer] = None, |
|
mode: str = 'training' |
|
): |
|
super().__init__() |
|
self.config = config |
|
self.strategy = strategy |
|
self.device = device or self._setup_default_device() |
|
self.mode = mode.lower() |
|
|
|
|
|
self.reranker = reranker or self._initialize_reranker() |
|
self.tokenizer = self._initialize_tokenizer() |
|
self.encoder = self._initialize_encoder() |
|
self.summarizer = summarizer or self._initialize_summarizer() |
|
self.memory_monitor = GPUMemoryMonitor() |
|
|
|
|
|
logger.info("Initializing TFDataPipeline.") |
|
self.data_pipeline = TFDataPipeline( |
|
config=self.config, |
|
tokenizer=self.tokenizer, |
|
encoder=self.encoder, |
|
index_file_path='path/to/index', |
|
response_pool=[], |
|
max_length=self.config.max_context_token_limit, |
|
query_embeddings_cache={}, |
|
neg_samples=self.config.neg_samples, |
|
index_type='IndexFlatIP', |
|
nlist=100, |
|
max_retries=self.config.max_retries |
|
) |
|
|
|
|
|
if self.mode == 'inference': |
|
logger.info("Mode set to 'inference'. Loading FAISS index and response pool.") |
|
self._load_faiss_index_and_responses() |
|
elif self.mode != 'training': |
|
logger.error(f"Unsupported mode in RetrievalChatbot init: {self.mode}") |
|
raise ValueError(f"Unsupported mode in RetrievalChatbot init: {self.mode}") |
|
|
|
|
|
self.history = { |
|
"train_loss": [], |
|
"val_loss": [], |
|
"train_metrics": {}, |
|
"val_metrics": {} |
|
} |
|
|
|
|
|
def _setup_default_device(self) -> str: |
|
"""Set up default device if none is provided.""" |
|
if tf.config.list_physical_devices('GPU'): |
|
return 'GPU' |
|
else: |
|
return 'CPU' |
|
|
|
def _initialize_reranker(self) -> CrossEncoderReranker: |
|
"""Initialize the CrossEncoderReranker.""" |
|
logger.info("Initializing default CrossEncoderReranker...") |
|
return CrossEncoderReranker(model_name="cross-encoder/ms-marco-MiniLM-L-12-v2") |
|
|
|
def _initialize_summarizer(self) -> Summarizer: |
|
"""Initialize the Summarizer.""" |
|
return Summarizer( |
|
tokenizer=self.tokenizer, |
|
model_name="t5-small", |
|
max_summary_length=self.config.max_context_token_limit // 4, |
|
device=self.device, |
|
max_summary_rounds=2 |
|
) |
|
|
|
def _initialize_tokenizer(self) -> AutoTokenizer: |
|
"""Initialize the tokenizer and add special tokens.""" |
|
logger.info("Initializing tokenizer and adding special tokens...") |
|
tokenizer = AutoTokenizer.from_pretrained(self.config.pretrained_model) |
|
special_tokens = { |
|
"user": "<USER>", |
|
"assistant": "<ASSISTANT>", |
|
"context": "<CONTEXT>", |
|
"sep": "<SEP>" |
|
} |
|
tokenizer.add_special_tokens( |
|
{'additional_special_tokens': list(special_tokens.values())} |
|
) |
|
return tokenizer |
|
|
|
def _initialize_encoder(self) -> EncoderModel: |
|
"""Initialize the EncoderModel and resize token embeddings.""" |
|
logger.info("Initializing encoder model...") |
|
encoder = EncoderModel( |
|
self.config, |
|
name="shared_encoder", |
|
) |
|
|
|
new_vocab_size = len(self.tokenizer) |
|
encoder.pretrained.resize_token_embeddings(new_vocab_size) |
|
logger.info(f"Token embeddings resized to: {new_vocab_size}") |
|
return encoder |
|
|
|
def _load_faiss_index_and_responses(self) -> None: |
|
"""Load FAISS index and response pool for inference.""" |
|
try: |
|
logger.info(f"Loading FAISS index from {self.data_pipeline.index_file_path}...") |
|
self.data_pipeline.load_faiss_index(self.data_pipeline.index_file_path) |
|
logger.info("FAISS index loaded successfully.") |
|
|
|
|
|
response_pool_path = self.data_pipeline.index_file_path.replace('.index', '_responses.json') |
|
if os.path.exists(response_pool_path): |
|
with open(response_pool_path, 'r', encoding='utf-8') as f: |
|
self.data_pipeline.response_pool = json.load(f) |
|
logger.info(f"Loaded {len(self.data_pipeline.response_pool)} responses from {response_pool_path}.") |
|
else: |
|
logger.error(f"Response pool file not found at {response_pool_path}.") |
|
raise FileNotFoundError(f"Response pool file not found at {response_pool_path}.") |
|
|
|
|
|
self.data_pipeline.validate_faiss_index() |
|
logger.info("FAISS index and response pool validated successfully.") |
|
|
|
except Exception as e: |
|
logger.error(f"Failed to load FAISS index and response pool: {e}") |
|
raise |
|
|
|
@classmethod |
|
def load_model(cls, load_dir: Union[str, Path], mode: str = 'training') -> 'RetrievalChatbot': |
|
""" |
|
Load saved models and configuration. |
|
|
|
Args: |
|
load_dir (Union[str, Path]): Directory containing saved model files |
|
mode (str): Either 'training' or 'inference'. In inference mode, |
|
also loads FAISS index and response pool. |
|
""" |
|
load_dir = Path(load_dir) |
|
|
|
|
|
with open(load_dir / "config.json", "r") as f: |
|
config = ChatbotConfig.from_dict(json.load(f)) |
|
|
|
|
|
chatbot = cls(config, mode=mode) |
|
|
|
|
|
chatbot.encoder.pretrained = TFAutoModel.from_pretrained( |
|
load_dir / "shared_encoder", |
|
config=config |
|
) |
|
|
|
|
|
chatbot.tokenizer = AutoTokenizer.from_pretrained(load_dir / "tokenizer") |
|
logger.info(f"Models and tokenizer loaded from {load_dir}") |
|
|
|
|
|
if mode == 'inference': |
|
cls._prepare_model_for_inference(chatbot, load_dir) |
|
|
|
return chatbot |
|
|
|
@classmethod |
|
def _prepare_model_for_inference(cls, chatbot: 'RetrievalChatbot', load_dir: Path) -> None: |
|
"""Internal method to load inference components.""" |
|
try: |
|
|
|
faiss_path = load_dir / 'faiss_index.bin' |
|
if faiss_path.exists(): |
|
chatbot.index = faiss.read_index(str(faiss_path)) |
|
logger.info("FAISS index loaded successfully") |
|
else: |
|
raise FileNotFoundError(f"FAISS index not found at {faiss_path}") |
|
|
|
|
|
response_pool_path = load_dir / 'response_pool.json' |
|
if response_pool_path.exists(): |
|
with open(response_pool_path, 'r') as f: |
|
chatbot.response_pool = json.load(f) |
|
logger.info(f"Loaded {len(chatbot.response_pool)} responses") |
|
else: |
|
raise FileNotFoundError(f"Response pool not found at {response_pool_path}") |
|
|
|
|
|
if chatbot.index.d != chatbot.config.embedding_dim: |
|
raise ValueError( |
|
f"FAISS index dimension {chatbot.index.d} doesn't match " |
|
f"model dimension {chatbot.config.embedding_dim}" |
|
) |
|
|
|
except Exception as e: |
|
logger.error(f"Error loading inference components: {e}") |
|
raise |
|
|
|
def save_models(self, save_dir: Union[str, Path]): |
|
"""Save models and configuration.""" |
|
save_dir = Path(save_dir) |
|
save_dir.mkdir(parents=True, exist_ok=True) |
|
|
|
|
|
with open(save_dir / "config.json", "w") as f: |
|
json.dump(self.config.to_dict(), f, indent=2) |
|
|
|
|
|
self.encoder.pretrained.save_pretrained(save_dir / "shared_encoder") |
|
|
|
|
|
self.tokenizer.save_pretrained(save_dir / "tokenizer") |
|
|
|
logger.info(f"Models and tokenizer saved to {save_dir}.") |
|
|
|
def retrieve_responses_cross_encoder( |
|
self, |
|
query: str, |
|
top_k: int, |
|
reranker: Optional[CrossEncoderReranker] = None, |
|
summarizer: Optional[Summarizer] = None, |
|
summarize_threshold: int = 512 |
|
) -> List[Tuple[str, float]]: |
|
""" |
|
Retrieve top-k from FAISS, then re-rank them with a cross-encoder. |
|
Optionally summarize the user query if it's too long. |
|
""" |
|
if reranker is None: |
|
reranker = self.reranker |
|
if summarizer is None: |
|
summarizer = self.summarizer |
|
|
|
|
|
if summarizer and len(query.split()) > summarize_threshold: |
|
logger.info(f"Query is long. Summarizing before cross-encoder. Original length: {len(query.split())}") |
|
query = summarizer.summarize_text(query) |
|
logger.info(f"Summarized query: {query}") |
|
|
|
|
|
dense_topk = self.retrieve_responses_faiss(query, top_k=top_k) |
|
|
|
if not dense_topk: |
|
return [] |
|
|
|
|
|
candidate_texts = [pair[0] for pair in dense_topk] |
|
cross_scores = reranker.rerank(query, candidate_texts, max_length=256) |
|
|
|
|
|
combined = [(text, score) for (text, _), score in zip(dense_topk, cross_scores)] |
|
|
|
combined.sort(key=lambda x: x[1], reverse=True) |
|
|
|
return combined |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def retrieve_responses_faiss(self, query: str, top_k: int = 5) -> List[Tuple[str, float]]: |
|
"""Retrieve top-k responses using FAISS.""" |
|
if not hasattr(self.data_pipeline, 'index') or self.data_pipeline.index is None: |
|
logger.warning("FAISS index not initialized. Cannot retrieve responses.") |
|
return [] |
|
|
|
|
|
q_emb = self.data_pipeline.encode_query(query) |
|
q_emb_np = q_emb.numpy().astype('float32') |
|
|
|
|
|
faiss.normalize_L2(q_emb_np) |
|
|
|
|
|
distances, indices = self.data_pipeline.index.search(q_emb_np, top_k) |
|
|
|
|
|
top_responses = [] |
|
for i, idx in enumerate(indices[0]): |
|
if idx < len(self.data_pipeline.response_pool): |
|
top_responses.append((self.data_pipeline.response_pool[idx], float(distances[0][i]))) |
|
else: |
|
logger.warning(f"FAISS returned invalid index {idx}. Skipping.") |
|
|
|
return top_responses |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def chat( |
|
self, |
|
query: str, |
|
conversation_history: Optional[List[Tuple[str, str]]] = None, |
|
quality_checker: Optional['ResponseQualityChecker'] = None, |
|
top_k: int = 5, |
|
) -> Tuple[str, List[Tuple[str, float]], Dict[str, Any]]: |
|
""" |
|
Example chat method that always uses cross-encoder re-ranking |
|
if self.reranker is available. |
|
""" |
|
@self.run_on_device |
|
def get_response(self_arg, query_arg): |
|
|
|
conversation_str = self_arg._build_conversation_context(query_arg, conversation_history) |
|
|
|
|
|
results = self_arg.retrieve_responses_cross_encoder( |
|
query=conversation_str, |
|
top_k=top_k, |
|
reranker=self_arg.reranker, |
|
summarizer=self_arg.summarizer, |
|
summarize_threshold=512 |
|
) |
|
|
|
|
|
if not results: |
|
return ( |
|
"I'm sorry, but I couldn't find a relevant response.", |
|
[], |
|
{} |
|
) |
|
|
|
if quality_checker: |
|
metrics = quality_checker.check_response_quality(query_arg, results) |
|
if not metrics.get('is_confident', False): |
|
return ( |
|
"I need more information to provide a good answer. Could you please clarify?", |
|
results, |
|
metrics |
|
) |
|
return results[0][0], results, metrics |
|
|
|
return results[0][0], results, {} |
|
|
|
return get_response(self, query) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _build_conversation_context( |
|
self, |
|
query: str, |
|
conversation_history: Optional[List[Tuple[str, str]]] |
|
) -> str: |
|
"""Build conversation context with better memory management.""" |
|
if not conversation_history: |
|
return f"{self.tokenizer.additional_special_tokens[self.tokenizer.additional_special_tokens.index('<USER>')]} {query}" |
|
|
|
conversation_parts = [] |
|
for user_txt, assistant_txt in conversation_history: |
|
conversation_parts.extend([ |
|
f"{self.tokenizer.additional_special_tokens[self.tokenizer.additional_special_tokens.index('<USER>')]} {user_txt}", |
|
f"{self.tokenizer.additional_special_tokens[self.tokenizer.additional_special_tokens.index('<ASSISTANT>')]} {assistant_txt}" |
|
]) |
|
|
|
conversation_parts.append(f"{self.tokenizer.additional_special_tokens[self.tokenizer.additional_special_tokens.index('<USER>')]} {query}") |
|
return "\n".join(conversation_parts) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def train_model( |
|
self, |
|
tfrecord_file_path: str, |
|
epochs: int = 20, |
|
batch_size: int = 16, |
|
validation_split: float = 0.2, |
|
checkpoint_dir: str = "checkpoints/", |
|
use_lr_schedule: bool = True, |
|
peak_lr: float = 1e-5, |
|
warmup_steps_ratio: float = 0.1, |
|
early_stopping_patience: int = 3, |
|
min_delta: float = 1e-4, |
|
test_mode: bool = False, |
|
initial_epoch: int = 0 |
|
) -> None: |
|
"""Training using a pre-prepared TFRecord dataset.""" |
|
logger.info("Starting training with pre-prepared TFRecord dataset...") |
|
|
|
def parse_tfrecord_fn(example_proto, max_length, neg_samples): |
|
""" |
|
Parses a single TFRecord example. |
|
|
|
Args: |
|
example_proto: A serialized TFRecord example. |
|
max_length: The maximum sequence length for tokenization. |
|
neg_samples: The number of hard negatives per query. |
|
|
|
Returns: |
|
A tuple of (query_ids, positive_ids, negative_ids). |
|
""" |
|
feature_description = { |
|
'query_ids': tf.io.FixedLenFeature([max_length], tf.int64), |
|
'positive_ids': tf.io.FixedLenFeature([max_length], tf.int64), |
|
'negative_ids': tf.io.FixedLenFeature([neg_samples * max_length], tf.int64), |
|
} |
|
parsed_features = tf.io.parse_single_example(example_proto, feature_description) |
|
|
|
query_ids = tf.cast(parsed_features['query_ids'], tf.int32) |
|
positive_ids = tf.cast(parsed_features['positive_ids'], tf.int32) |
|
negative_ids = tf.cast(parsed_features['negative_ids'], tf.int32) |
|
negative_ids = tf.reshape(negative_ids, [neg_samples, max_length]) |
|
|
|
return query_ids, positive_ids, negative_ids |
|
|
|
|
|
raw_dataset = tf.data.TFRecordDataset(tfrecord_file_path) |
|
total_pairs = sum(1 for _ in raw_dataset) |
|
logger.info(f"Total pairs in TFRecord: {total_pairs}") |
|
|
|
train_size = int(total_pairs * (1 - validation_split)) |
|
val_size = total_pairs - train_size |
|
steps_per_epoch = math.ceil(train_size / batch_size) |
|
val_steps = math.ceil(val_size / batch_size) |
|
total_steps = steps_per_epoch * epochs |
|
buffer_size = total_pairs // 10 |
|
|
|
logger.info(f"Training pairs: {train_size}") |
|
logger.info(f"Validation pairs: {val_size}") |
|
logger.info(f"Steps per epoch: {steps_per_epoch}") |
|
logger.info(f"Validation steps: {val_steps}") |
|
logger.info(f"Total steps: {total_steps}") |
|
|
|
|
|
if use_lr_schedule: |
|
warmup_steps = int(total_steps * warmup_steps_ratio) |
|
lr_schedule = self._get_lr_schedule( |
|
total_steps=total_steps, |
|
peak_lr=peak_lr, |
|
warmup_steps=warmup_steps |
|
) |
|
self.optimizer = tf.keras.optimizers.Adam(learning_rate=lr_schedule) |
|
logger.info("Using custom learning rate schedule.") |
|
else: |
|
self.optimizer = tf.keras.optimizers.Adam(learning_rate=peak_lr) |
|
logger.info("Using fixed learning rate.") |
|
|
|
|
|
checkpoint = tf.train.Checkpoint( |
|
epoch=tf.Variable(0), |
|
optimizer=self.optimizer, |
|
model=self.encoder, |
|
variables=self.encoder.variables |
|
) |
|
manager = tf.train.CheckpointManager(checkpoint, checkpoint_dir, max_to_keep=3, checkpoint_name='ckpt') |
|
|
|
|
|
latest_checkpoint = manager.latest_checkpoint |
|
if latest_checkpoint: |
|
history_path = Path(checkpoint_dir) / 'training_history.json' |
|
if history_path.exists(): |
|
try: |
|
with open(history_path, 'r') as f: |
|
self.history = json.load(f) |
|
logger.info(f"Loaded previous training history from {history_path}") |
|
except Exception as e: |
|
logger.warning(f"Could not load history, starting fresh: {e}") |
|
self.history = {'train_loss': [], 'val_loss': [], 'learning_rate': []} |
|
else: |
|
self.history = {'train_loss': [], 'val_loss': [], 'learning_rate': []} |
|
|
|
status = checkpoint.restore(latest_checkpoint) |
|
status.expect_partial() |
|
|
|
logger.info(f"Restored from checkpoint: {latest_checkpoint}") |
|
|
|
ckpt_number = int(latest_checkpoint.split('ckpt-')[-1]) |
|
if initial_epoch == 0: |
|
initial_epoch = ckpt_number |
|
logger.info(f"Resuming from epoch {initial_epoch}") |
|
else: |
|
logger.info("Starting training from scratch") |
|
initial_epoch = 0 |
|
|
|
|
|
log_dir = Path(checkpoint_dir) / "tensorboard_logs" |
|
log_dir.mkdir(parents=True, exist_ok=True) |
|
current_time = datetime.datetime.now().strftime("%Y%m%d-%H%M%S") |
|
train_log_dir = str(log_dir / f"train_{current_time}") |
|
val_log_dir = str(log_dir / f"val_{current_time}") |
|
train_summary_writer = tf.summary.create_file_writer(train_log_dir) |
|
val_summary_writer = tf.summary.create_file_writer(val_log_dir) |
|
logger.info(f"TensorBoard logs will be saved in {log_dir}") |
|
|
|
|
|
parse_fn = lambda x: parse_tfrecord_fn(x, self.config.max_context_token_limit, self.config.neg_samples) |
|
|
|
|
|
dataset = tf.data.TFRecordDataset(tfrecord_file_path) |
|
|
|
|
|
if test_mode: |
|
subset_size = 200 |
|
dataset = dataset.take(subset_size) |
|
logger.info(f"TEST MODE: Using only {subset_size} examples") |
|
|
|
total_pairs = subset_size |
|
train_size = int(total_pairs * (1 - validation_split)) |
|
val_size = total_pairs - train_size |
|
steps_per_epoch = math.ceil(train_size / batch_size) |
|
val_steps = math.ceil(val_size / batch_size) |
|
total_steps = steps_per_epoch * epochs |
|
buffer_size = total_pairs // 10 |
|
epochs = min(epochs, 5) |
|
early_stopping_patience = 2 |
|
logger.info(f"New training pairs: {train_size}") |
|
logger.info(f"New validation pairs: {val_size}") |
|
|
|
dataset = dataset.map(parse_fn, num_parallel_calls=tf.data.AUTOTUNE) |
|
|
|
|
|
train_dataset = dataset.take(train_size) |
|
val_dataset = dataset.skip(train_size).take(val_size) |
|
|
|
|
|
train_dataset = train_dataset.shuffle(buffer_size=buffer_size) |
|
|
|
|
|
train_dataset = train_dataset.batch(batch_size, drop_remainder=True) |
|
train_dataset = train_dataset.prefetch(tf.data.AUTOTUNE) |
|
|
|
val_dataset = val_dataset.batch(batch_size, drop_remainder=True) |
|
val_dataset = val_dataset.prefetch(tf.data.AUTOTUNE) |
|
val_dataset = val_dataset.cache() |
|
|
|
|
|
best_val_loss = float("inf") |
|
epochs_no_improve = 0 |
|
|
|
for epoch in range(initial_epoch + 1, epochs + 1): |
|
|
|
epoch_loss_avg = tf.keras.metrics.Mean() |
|
batches_processed = 0 |
|
|
|
try: |
|
train_pbar = tqdm(total=steps_per_epoch, desc=f"Training Epoch {epoch}", unit="batch") |
|
is_tqdm_train = True |
|
except ImportError: |
|
train_pbar = None |
|
is_tqdm_train = False |
|
logger.info("Training progress bar disabled") |
|
|
|
for q_batch, p_batch, n_batch in train_dataset: |
|
loss, grad_norm, post_clip_norm = self.train_step(q_batch, p_batch, n_batch) |
|
|
|
|
|
grad_norm_value = float(grad_norm.numpy()) |
|
post_clip_value = float(post_clip_norm.numpy()) |
|
if grad_norm_value < 1e-7: |
|
logger.warning(f"Potential vanishing gradient detected: norm = {grad_norm_value:.2e}") |
|
elif grad_norm_value > 100: |
|
logger.warning(f"Potential exploding gradient detected: norm = {grad_norm_value:.2e}") |
|
|
|
if grad_norm_value != post_clip_value: |
|
logger.info(f"Gradient clipped: {grad_norm_value:.2e} -> {post_clip_value:.2e}") |
|
|
|
epoch_loss_avg(loss) |
|
batches_processed += 1 |
|
|
|
|
|
with train_summary_writer.as_default(): |
|
step = (epoch - 1) * steps_per_epoch + batches_processed |
|
tf.summary.scalar("loss", loss, step=step) |
|
tf.summary.scalar("gradient_norm_pre_clip", grad_norm, step=step) |
|
tf.summary.scalar("gradient_norm_post_clip", post_clip_norm, step=step) |
|
|
|
|
|
if use_lr_schedule: |
|
current_lr = float(lr_schedule(self.optimizer.iterations)) |
|
else: |
|
current_lr = float(self.optimizer.learning_rate.numpy()) |
|
|
|
if is_tqdm_train: |
|
train_pbar.update(1) |
|
train_pbar.set_postfix({ |
|
"loss": f"{loss.numpy():.4f}", |
|
"pre_clip": f"{grad_norm_value:.2e}", |
|
"post_clip": f"{post_clip_value:.2e}", |
|
"lr": f"{current_lr:.2e}", |
|
"batches": f"{batches_processed}/{steps_per_epoch}" |
|
}) |
|
|
|
|
|
gc.collect() |
|
|
|
if batches_processed >= steps_per_epoch: |
|
break |
|
|
|
if is_tqdm_train and train_pbar: |
|
train_pbar.close() |
|
|
|
|
|
val_loss_avg = tf.keras.metrics.Mean() |
|
val_batches_processed = 0 |
|
|
|
try: |
|
val_pbar = tqdm(total=val_steps, desc="Validation", unit="batch") |
|
is_tqdm_val = True |
|
except ImportError: |
|
val_pbar = None |
|
is_tqdm_val = False |
|
logger.info("Validation progress bar disabled") |
|
|
|
for q_batch, p_batch, n_batch in val_dataset: |
|
val_loss = self.validation_step(q_batch, p_batch, n_batch) |
|
val_loss_avg(val_loss) |
|
val_batches_processed += 1 |
|
|
|
if is_tqdm_val: |
|
val_pbar.update(1) |
|
val_pbar.set_postfix({ |
|
"val_loss": f"{val_loss.numpy():.4f}", |
|
"batches": f"{val_batches_processed}/{val_steps}" |
|
}) |
|
|
|
|
|
gc.collect() |
|
|
|
if val_batches_processed >= val_steps: |
|
break |
|
|
|
if is_tqdm_val and val_pbar: |
|
val_pbar.close() |
|
|
|
|
|
train_loss = epoch_loss_avg.result().numpy() |
|
val_loss = val_loss_avg.result().numpy() |
|
logger.info(f"Epoch {epoch} Complete: Train Loss={train_loss:.4f}, Val Loss={val_loss:.4f}") |
|
|
|
|
|
with train_summary_writer.as_default(): |
|
tf.summary.scalar("epoch_loss", train_loss, step=epoch) |
|
with val_summary_writer.as_default(): |
|
tf.summary.scalar("val_loss", val_loss, step=epoch) |
|
|
|
|
|
manager.save() |
|
|
|
|
|
model_save_path = Path(checkpoint_dir) / f"model_epoch_{epoch}" |
|
self.save_models(model_save_path) |
|
logger.info(f"Saved model for epoch {epoch} at {model_save_path}") |
|
|
|
|
|
self.history['train_loss'].append(train_loss) |
|
self.history['val_loss'].append(val_loss) |
|
|
|
if use_lr_schedule: |
|
current_lr = float(lr_schedule(self.optimizer.iterations)) |
|
else: |
|
current_lr = float(self.optimizer.learning_rate.numpy()) |
|
|
|
|
|
self.history.setdefault('learning_rate', []).append(current_lr) |
|
|
|
|
|
with open(history_path, 'w') as f: |
|
json.dump(self.history, f) |
|
logger.info(f"Saved training history to {history_path}") |
|
|
|
|
|
if val_loss < best_val_loss - min_delta: |
|
best_val_loss = val_loss |
|
epochs_no_improve = 0 |
|
logger.info(f"Validation loss improved to {val_loss:.4f}. Reset patience.") |
|
else: |
|
epochs_no_improve += 1 |
|
logger.info(f"No improvement this epoch. Patience: {epochs_no_improve}/{early_stopping_patience}") |
|
if epochs_no_improve >= early_stopping_patience: |
|
logger.info("Early stopping triggered.") |
|
break |
|
|
|
logger.info("Training completed!") |
|
|
|
@tf.function |
|
def train_step( |
|
self, |
|
q_batch: tf.Tensor, |
|
p_batch: tf.Tensor, |
|
n_batch: tf.Tensor |
|
) -> tf.Tensor: |
|
""" |
|
Single training step using queries, positives, and hard negatives. |
|
""" |
|
with tf.GradientTape() as tape: |
|
|
|
q_enc = self.encoder(q_batch, training=True) |
|
|
|
|
|
p_enc = self.encoder(p_batch, training=True) |
|
|
|
|
|
|
|
shape = tf.shape(n_batch) |
|
bs = shape[0] |
|
neg_samples = shape[1] |
|
|
|
|
|
|
|
n_batch_flat = tf.reshape(n_batch, [bs * neg_samples, shape[2]]) |
|
n_enc_flat = self.encoder(n_batch_flat, training=True) |
|
|
|
|
|
n_enc = tf.reshape(n_enc_flat, [bs, neg_samples, -1]) |
|
|
|
|
|
|
|
|
|
combined_p_n = tf.concat( |
|
[tf.expand_dims(p_enc, axis=1), n_enc], |
|
axis=1 |
|
) |
|
|
|
|
|
|
|
|
|
dot_products = tf.einsum('bd,bkd->bk', q_enc, combined_p_n) |
|
|
|
|
|
labels = tf.zeros([bs], dtype=tf.int32) |
|
|
|
|
|
loss = tf.nn.sparse_softmax_cross_entropy_with_logits( |
|
labels=labels, |
|
logits=dot_products |
|
) |
|
loss = tf.reduce_mean(loss) |
|
|
|
|
|
gradients = tape.gradient(loss, self.encoder.trainable_variables) |
|
gradients_norm = tf.linalg.global_norm(gradients) |
|
|
|
|
|
max_grad_norm = 1.0 |
|
gradients, _ = tf.clip_by_global_norm(gradients, max_grad_norm, gradients_norm) |
|
post_clip_norm = tf.linalg.global_norm(gradients) |
|
|
|
|
|
self.optimizer.apply_gradients(zip(gradients, self.encoder.trainable_variables)) |
|
|
|
return loss, gradients_norm, post_clip_norm |
|
|
|
@tf.function |
|
def validation_step( |
|
self, |
|
q_batch: tf.Tensor, |
|
p_batch: tf.Tensor, |
|
n_batch: tf.Tensor |
|
) -> tf.Tensor: |
|
""" |
|
Single validation step using queries, positives, and hard negatives. |
|
""" |
|
q_enc = self.encoder(q_batch, training=False) |
|
p_enc = self.encoder(p_batch, training=False) |
|
|
|
shape = tf.shape(n_batch) |
|
bs = shape[0] |
|
neg_samples = shape[1] |
|
|
|
n_batch_flat = tf.reshape(n_batch, [bs * neg_samples, shape[2]]) |
|
n_enc_flat = self.encoder(n_batch_flat, training=False) |
|
n_enc = tf.reshape(n_enc_flat, [bs, neg_samples, -1]) |
|
|
|
combined_p_n = tf.concat( |
|
[tf.expand_dims(p_enc, axis=1), n_enc], |
|
axis=1 |
|
) |
|
|
|
dot_products = tf.einsum('bd,bkd->bk', q_enc, combined_p_n) |
|
labels = tf.zeros([bs], dtype=tf.int32) |
|
|
|
loss = tf.nn.sparse_softmax_cross_entropy_with_logits( |
|
labels=labels, |
|
logits=dot_products |
|
) |
|
loss = tf.reduce_mean(loss) |
|
|
|
return loss |
|
|
|
def _get_lr_schedule( |
|
self, |
|
total_steps: int, |
|
peak_lr: float, |
|
warmup_steps: int |
|
) -> tf.keras.optimizers.schedules.LearningRateSchedule: |
|
"""Create a custom learning rate schedule with warmup and cosine decay.""" |
|
class CustomSchedule(tf.keras.optimizers.schedules.LearningRateSchedule): |
|
def __init__( |
|
self, |
|
total_steps: int, |
|
peak_lr: float, |
|
warmup_steps: int |
|
): |
|
super().__init__() |
|
self.total_steps = tf.cast(total_steps, tf.float32) |
|
self.peak_lr = tf.cast(peak_lr, tf.float32) |
|
|
|
|
|
adjusted_warmup_steps = min(warmup_steps, max(1, total_steps // 10)) |
|
self.warmup_steps = tf.cast(adjusted_warmup_steps, tf.float32) |
|
|
|
|
|
self.initial_lr = self.peak_lr * 0.1 |
|
self.min_lr = self.peak_lr * 0.01 |
|
|
|
logger.info(f"Learning rate schedule initialized:") |
|
logger.info(f" Initial LR: {float(self.initial_lr):.6f}") |
|
logger.info(f" Peak LR: {float(self.peak_lr):.6f}") |
|
logger.info(f" Min LR: {float(self.min_lr):.6f}") |
|
logger.info(f" Warmup steps: {int(self.warmup_steps)}") |
|
logger.info(f" Total steps: {int(self.total_steps)}") |
|
|
|
def __call__(self, step): |
|
step = tf.cast(step, tf.float32) |
|
|
|
|
|
warmup_factor = tf.minimum(1.0, step / self.warmup_steps) |
|
warmup_lr = self.initial_lr + (self.peak_lr - self.initial_lr) * warmup_factor |
|
|
|
|
|
decay_steps = tf.maximum(1.0, self.total_steps - self.warmup_steps) |
|
decay_factor = (step - self.warmup_steps) / decay_steps |
|
decay_factor = tf.minimum(tf.maximum(0.0, decay_factor), 1.0) |
|
|
|
cosine_decay = 0.5 * (1.0 + tf.cos(tf.constant(math.pi) * decay_factor)) |
|
decay_lr = self.min_lr + (self.peak_lr - self.min_lr) * cosine_decay |
|
|
|
|
|
final_lr = tf.where(step < self.warmup_steps, warmup_lr, decay_lr) |
|
|
|
|
|
final_lr = tf.maximum(self.min_lr, final_lr) |
|
final_lr = tf.where(tf.math.is_finite(final_lr), final_lr, self.min_lr) |
|
|
|
return final_lr |
|
|
|
def get_config(self): |
|
return { |
|
"total_steps": self.total_steps, |
|
"peak_lr": self.peak_lr, |
|
"warmup_steps": self.warmup_steps, |
|
} |
|
|
|
return CustomSchedule(total_steps, peak_lr, warmup_steps) |
|
|