|
import time |
|
from transformers import TFAutoModel, AutoTokenizer |
|
import tensorflow as tf |
|
import numpy as np |
|
from typing import Generator, List, Tuple, Dict, Optional, Union, Any |
|
import math |
|
from dataclasses import dataclass |
|
import json |
|
from pathlib import Path |
|
import datetime |
|
import faiss |
|
import gc |
|
from response_quality_checker import ResponseQualityChecker |
|
from cross_encoder_reranker import CrossEncoderReranker |
|
from conversation_summarizer import DeviceAwareModel, Summarizer |
|
from gpu_monitor import GPUMemoryMonitor |
|
import absl.logging |
|
from logger_config import config_logger |
|
from tqdm.auto import tqdm |
|
|
|
absl.logging.set_verbosity(absl.logging.WARNING) |
|
logger = config_logger(__name__) |
|
|
|
@dataclass |
|
class ChatbotConfig: |
|
"""Configuration for the RetrievalChatbot.""" |
|
vocab_size: int = 30526 |
|
max_context_token_limit: int = 512 |
|
embedding_dim: int = 768 |
|
encoder_units: int = 256 |
|
num_attention_heads: int = 8 |
|
dropout_rate: float = 0.2 |
|
l2_reg_weight: float = 0.001 |
|
margin: float = 0.3 |
|
learning_rate: float = 0.001 |
|
min_text_length: int = 3 |
|
max_context_turns: int = 5 |
|
warmup_steps: int = 200 |
|
pretrained_model: str = 'distilbert-base-uncased' |
|
dtype: str = 'float32' |
|
freeze_embeddings: bool = False |
|
embedding_batch_size: int = 128 |
|
|
|
|
|
def to_dict(self) -> dict: |
|
"""Convert config to dictionary.""" |
|
return {k: str(v) if isinstance(v, Path) else v |
|
for k, v in self.__dict__.items()} |
|
|
|
@classmethod |
|
def from_dict(cls, config_dict: dict) -> 'ChatbotConfig': |
|
"""Create config from dictionary.""" |
|
return cls(**{k: v for k, v in config_dict.items() |
|
if k in cls.__dataclass_fields__}) |
|
|
|
class EncoderModel(tf.keras.Model): |
|
"""Dual encoder model with pretrained embeddings.""" |
|
def __init__( |
|
self, |
|
config: ChatbotConfig, |
|
name: str = "encoder", |
|
shared_weights: bool = False, |
|
**kwargs |
|
): |
|
super().__init__(name=name, **kwargs) |
|
self.config = config |
|
self.shared_weights = shared_weights |
|
|
|
|
|
self.pretrained = TFAutoModel.from_pretrained(config.pretrained_model) |
|
|
|
|
|
self.pretrained.distilbert.embeddings.trainable = False |
|
for i, layer_module in enumerate(self.pretrained.distilbert.transformer.layer): |
|
if i < 1: |
|
layer_module.trainable = False |
|
else: |
|
layer_module.trainable = True |
|
|
|
|
|
self.pooler = tf.keras.layers.GlobalAveragePooling1D() |
|
|
|
|
|
self.projection = tf.keras.layers.Dense( |
|
config.embedding_dim, |
|
activation='tanh', |
|
name="projection" |
|
) |
|
|
|
|
|
self.dropout = tf.keras.layers.Dropout(config.dropout_rate) |
|
self.normalize = tf.keras.layers.Lambda( |
|
lambda x: tf.nn.l2_normalize(x, axis=1) |
|
) |
|
|
|
def call(self, inputs: tf.Tensor, training: bool = False) -> tf.Tensor: |
|
"""Forward pass.""" |
|
|
|
pretrained_outputs = self.pretrained(inputs, training=training) |
|
x = pretrained_outputs.last_hidden_state |
|
|
|
|
|
x = self.pooler(x) |
|
x = self.projection(x) |
|
x = self.dropout(x, training=training) |
|
x = self.normalize(x) |
|
|
|
return x |
|
|
|
def get_config(self) -> dict: |
|
"""Return the config of the model.""" |
|
config = super().get_config() |
|
config.update({ |
|
"config": self.config.to_dict(), |
|
"shared_weights": self.shared_weights, |
|
"name": self.name |
|
}) |
|
return config |
|
|
|
class RetrievalChatbot(DeviceAwareModel): |
|
"""Retrieval-based chatbot using pretrained embeddings and FAISS for similarity search.""" |
|
def __init__(self, config: ChatbotConfig, dialogues: List[dict] = [], device: str = None, |
|
strategy=None, reranker: Optional[CrossEncoderReranker] = None, |
|
summarizer: Optional[Summarizer] = None |
|
): |
|
self.config = config |
|
self.strategy = strategy |
|
self.setup_device(device) |
|
|
|
if reranker is None: |
|
logger.info("Creating default CrossEncoderReranker...") |
|
reranker = CrossEncoderReranker(model_name="cross-encoder/ms-marco-MiniLM-L-12-v2") |
|
self.reranker = reranker |
|
|
|
if summarizer is None: |
|
logger.info("Creating default Summarizer...") |
|
summarizer = Summarizer(device=self.device) |
|
self.summarizer = summarizer |
|
|
|
|
|
self.special_tokens = { |
|
"user": "<USER>", |
|
"assistant": "<ASSISTANT>", |
|
"context": "<CONTEXT>", |
|
"sep": "<SEP>" |
|
} |
|
|
|
|
|
self.tokenizer = AutoTokenizer.from_pretrained(config.pretrained_model) |
|
self.tokenizer.add_special_tokens( |
|
{'additional_special_tokens': list(self.special_tokens.values())} |
|
) |
|
|
|
self.memory_monitor = GPUMemoryMonitor() |
|
self.min_batch_size = 8 |
|
self.max_batch_size = 128 |
|
self.current_batch_size = 32 |
|
|
|
|
|
self.response_pool, self.unique_responses = self._collect_responses(dialogues) |
|
|
|
|
|
self.history = { |
|
"train_loss": [], |
|
"val_loss": [], |
|
"train_metrics": {}, |
|
"val_metrics": {} |
|
} |
|
|
|
def build_models(self): |
|
"""Initialize the shared encoder.""" |
|
logger.info("Building encoder model...") |
|
tf.keras.backend.clear_session() |
|
|
|
|
|
self.encoder = EncoderModel( |
|
self.config, |
|
name="shared_encoder", |
|
) |
|
|
|
|
|
new_vocab_size = len(self.tokenizer) |
|
self.encoder.pretrained.resize_token_embeddings(new_vocab_size) |
|
logger.info(f"Token embeddings resized to: {new_vocab_size}") |
|
|
|
|
|
self._initialize_faiss() |
|
|
|
self._compute_and_index_embeddings() |
|
|
|
|
|
try: |
|
|
|
embedding_dim = self.encoder.pretrained.config.dim |
|
logger.info("Got embedding dim from config") |
|
except AttributeError: |
|
try: |
|
|
|
embedding_dim = self.encoder.pretrained.distilbert.embeddings.word_embeddings.embedding_dim |
|
logger.info("Got embedding dim from word embeddings") |
|
except AttributeError: |
|
try: |
|
|
|
embedding_dim = self.encoder.pretrained.distilbert.embeddings.embedding_dim |
|
logger.info("Got embedding dim from embeddings module") |
|
except AttributeError: |
|
|
|
embedding_dim = self.config.embedding_dim |
|
logger.info("Using config embedding dim") |
|
|
|
vocab_size = len(self.tokenizer) |
|
|
|
logger.info(f"Encoder Embedding Dimension: {embedding_dim}") |
|
logger.info(f"Encoder Embedding Vocabulary Size: {vocab_size}") |
|
if vocab_size >= embedding_dim: |
|
logger.info("Encoder model built and embeddings resized successfully.") |
|
else: |
|
logger.error("Vocabulary size is less than embedding dimension.") |
|
raise ValueError("Vocabulary size is less than embedding dimension.") |
|
|
|
def _collect_responses(self, dialogues: List[dict]) -> Tuple[List[str], List[str]]: |
|
"""Collect all unique responses from dialogues.""" |
|
logger.info("Collecting responses from dialogues...") |
|
|
|
responses = [] |
|
try: |
|
progress_bar = tqdm(dialogues, desc="Collecting assistant responses") |
|
except ImportError: |
|
progress_bar = dialogues |
|
logger.info("Progress bar disabled - continuing without visual progress") |
|
|
|
for dialogue in progress_bar: |
|
turns = dialogue.get('turns', []) |
|
for turn in turns: |
|
if turn.get('speaker') == 'assistant' and 'text' in turn: |
|
responses.append(turn['text'].strip()) |
|
|
|
|
|
unique_responses = list(set(responses)) |
|
logger.info(f"Found {len(unique_responses)} unique responses.") |
|
|
|
return responses, unique_responses |
|
|
|
def _adjust_batch_size(self) -> None: |
|
"""Dynamically adjust batch size based on GPU memory usage.""" |
|
if self.memory_monitor.should_reduce_batch_size(): |
|
new_size = max(self.min_batch_size, self.current_batch_size // 2) |
|
if new_size != self.current_batch_size: |
|
logger.info(f"Reducing batch size to {new_size} due to high memory usage") |
|
self.current_batch_size = new_size |
|
gc.collect() |
|
if tf.config.list_physical_devices('GPU'): |
|
tf.keras.backend.clear_session() |
|
elif self.memory_monitor.can_increase_batch_size(): |
|
new_size = min(self.max_batch_size, self.current_batch_size * 2) |
|
if new_size != self.current_batch_size: |
|
logger.info(f"Increasing batch size to {new_size}") |
|
self.current_batch_size = new_size |
|
|
|
def _initialize_faiss(self): |
|
"""Initialize FAISS with safer GPU handling and memory monitoring.""" |
|
logger.info("Initializing FAISS index...") |
|
|
|
|
|
self.faiss_gpu = False |
|
self.gpu_resources = [] |
|
|
|
try: |
|
if hasattr(faiss, 'get_num_gpus'): |
|
ngpus = faiss.get_num_gpus() |
|
if ngpus > 0: |
|
|
|
for i in range(ngpus): |
|
res = faiss.StandardGpuResources() |
|
|
|
if self.memory_monitor.has_gpu: |
|
stats = self.memory_monitor.get_memory_stats() |
|
if stats: |
|
temp_memory = int(stats.total * 0.25) |
|
res.setTempMemory(temp_memory) |
|
self.gpu_resources.append(res) |
|
self.faiss_gpu = True |
|
logger.info(f"FAISS GPU resources initialized on {ngpus} GPUs") |
|
else: |
|
logger.info("Using CPU-only FAISS build") |
|
|
|
except Exception as e: |
|
logger.warning(f"Using CPU due to GPU initialization error: {e}") |
|
|
|
|
|
try: |
|
|
|
if len(self.unique_responses) < 1000: |
|
logger.info("Small dataset detected, using simple FlatIP index") |
|
self.index = faiss.IndexFlatIP(self.config.embedding_dim) |
|
else: |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
self.index = faiss.IndexFlatIP(self.config.embedding_dim) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
except Exception as e: |
|
logger.error(f"Error initializing FAISS: {e}") |
|
raise |
|
|
|
def encode_responses( |
|
self, |
|
responses: List[str], |
|
batch_size: int = 64 |
|
) -> tf.Tensor: |
|
""" |
|
Encodes responses with more conservative memory management. |
|
""" |
|
all_embeddings = [] |
|
self.current_batch_size = batch_size |
|
|
|
if self.memory_monitor.has_gpu: |
|
batch_size = 128 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
total_processed = 0 |
|
|
|
with tqdm(total=len(responses), desc="Encoding responses") as pbar: |
|
while total_processed < len(responses): |
|
|
|
if self.memory_monitor.has_gpu: |
|
gpu_usage = self.memory_monitor.get_memory_usage() |
|
if gpu_usage > 0.8: |
|
self.current_batch_size = max(128, self.current_batch_size // 2) |
|
logger.info(f"High GPU memory usage ({gpu_usage:.1%}), reducing batch size to {self.current_batch_size}") |
|
gc.collect() |
|
tf.keras.backend.clear_session() |
|
|
|
|
|
end_idx = min(total_processed + self.current_batch_size, len(responses)) |
|
batch_texts = responses[total_processed:end_idx] |
|
|
|
try: |
|
|
|
encodings = self.tokenizer( |
|
batch_texts, |
|
padding='max_length', |
|
truncation=True, |
|
max_length=self.config.max_context_token_limit, |
|
return_tensors='tf' |
|
) |
|
|
|
|
|
embeddings_batch = self.encoder(encodings['input_ids'], training=False) |
|
|
|
|
|
if embeddings_batch.dtype != tf.float32: |
|
embeddings_batch = tf.cast(embeddings_batch, tf.float32) |
|
|
|
|
|
all_embeddings.append(embeddings_batch) |
|
|
|
|
|
batch_processed = len(batch_texts) |
|
total_processed += batch_processed |
|
|
|
|
|
if self.memory_monitor.has_gpu: |
|
gpu_usage = self.memory_monitor.get_memory_usage() |
|
pbar.set_postfix({ |
|
'GPU mem': f'{gpu_usage:.1%}', |
|
'batch_size': self.current_batch_size |
|
}) |
|
pbar.update(batch_processed) |
|
|
|
|
|
if total_processed % 1000 == 0: |
|
gc.collect() |
|
if tf.config.list_physical_devices('GPU'): |
|
tf.keras.backend.clear_session() |
|
|
|
except tf.errors.ResourceExhaustedError: |
|
logger.warning("GPU memory exhausted during encoding, reducing batch size") |
|
self.current_batch_size = max(8, self.current_batch_size // 2) |
|
continue |
|
|
|
except Exception as e: |
|
logger.error(f"Error during encoding: {str(e)}") |
|
raise |
|
|
|
|
|
|
|
if len(all_embeddings) == 1: |
|
final_embeddings = all_embeddings[0] |
|
else: |
|
final_embeddings = tf.concat(all_embeddings, axis=0) |
|
|
|
return final_embeddings |
|
|
|
def _train_faiss_index(self, response_embeddings: np.ndarray) -> None: |
|
"""Train FAISS index with better memory management and robust fallback mechanisms.""" |
|
if self.index.is_trained: |
|
logger.info("Index already trained, skipping training phase") |
|
return |
|
|
|
logger.info("Starting FAISS index training...") |
|
|
|
try: |
|
|
|
subset_size = min(5000, len(response_embeddings)) |
|
logger.info(f"Using {subset_size} samples for initial training attempt") |
|
subset_idx = np.random.choice(len(response_embeddings), subset_size, replace=False) |
|
training_embeddings = response_embeddings[subset_idx].copy() |
|
|
|
|
|
training_embeddings = np.ascontiguousarray(training_embeddings) |
|
|
|
|
|
gc.collect() |
|
if tf.config.list_physical_devices('GPU'): |
|
tf.keras.backend.clear_session() |
|
|
|
|
|
logger.info(f"FAISS training data shape: {training_embeddings.shape}") |
|
logger.info(f"FAISS training data dtype: {training_embeddings.dtype}") |
|
|
|
logger.info("Starting initial training attempt...") |
|
self.index.train(training_embeddings) |
|
logger.info("Training completed successfully") |
|
|
|
except (RuntimeError, Exception) as e: |
|
logger.warning(f"Initial training attempt failed: {str(e)}") |
|
logger.info("Attempting fallback strategy...") |
|
|
|
try: |
|
|
|
if self.faiss_gpu: |
|
logger.info("Moving index to CPU for fallback training") |
|
cpu_index = faiss.index_gpu_to_cpu(self.index) |
|
else: |
|
cpu_index = self.index |
|
|
|
|
|
if isinstance(cpu_index, faiss.IndexIVFFlat): |
|
logger.info("Creating simpler FlatL2 index for fallback") |
|
cpu_index = faiss.IndexFlatL2(self.config.embedding_dim) |
|
|
|
|
|
subset_size = min(2000, len(response_embeddings)) |
|
subset_idx = np.random.choice(len(response_embeddings), subset_size, replace=False) |
|
fallback_embeddings = response_embeddings[subset_idx].copy() |
|
|
|
|
|
if not fallback_embeddings.flags['C_CONTIGUOUS']: |
|
fallback_embeddings = np.ascontiguousarray(fallback_embeddings) |
|
if fallback_embeddings.dtype != np.float32: |
|
fallback_embeddings = fallback_embeddings.astype(np.float32) |
|
|
|
|
|
logger.info("Training fallback index on CPU...") |
|
cpu_index.train(fallback_embeddings) |
|
|
|
|
|
if self.faiss_gpu: |
|
logger.info("Moving trained index back to GPU...") |
|
if len(self.gpu_resources) > 1: |
|
self.index = faiss.index_cpu_to_gpus_list(cpu_index, self.gpu_resources) |
|
else: |
|
self.index = faiss.index_cpu_to_gpu(self.gpu_resources[0], 0, cpu_index) |
|
else: |
|
self.index = cpu_index |
|
|
|
logger.info("Fallback training completed successfully") |
|
|
|
except Exception as e2: |
|
logger.error(f"Fallback training also failed: {str(e2)}") |
|
logger.warning("Creating basic brute-force index as last resort") |
|
|
|
try: |
|
|
|
dim = response_embeddings.shape[1] |
|
basic_index = faiss.IndexFlatL2(dim) |
|
|
|
if self.faiss_gpu: |
|
if len(self.gpu_resources) > 1: |
|
self.index = faiss.index_cpu_to_gpus_list(basic_index, self.gpu_resources) |
|
else: |
|
self.index = faiss.index_cpu_to_gpu(self.gpu_resources[0], 0, basic_index) |
|
else: |
|
self.index = basic_index |
|
|
|
logger.info("Basic index created as fallback") |
|
|
|
except Exception as e3: |
|
logger.error(f"All training attempts failed: {str(e3)}") |
|
raise RuntimeError("Unable to create working FAISS index") |
|
|
|
def _add_vectors_to_index(self, response_embeddings: np.ndarray) -> None: |
|
"""Add vectors to FAISS index with enhanced memory management.""" |
|
logger.info("Starting vector addition process...") |
|
|
|
|
|
initial_batch_size = 128 |
|
min_batch_size = 32 |
|
max_batch_size = 1024 |
|
|
|
total_added = 0 |
|
retry_count = 0 |
|
max_retries = 5 |
|
|
|
while total_added < len(response_embeddings): |
|
try: |
|
|
|
if self.memory_monitor.has_gpu: |
|
gpu_usage = self.memory_monitor.get_memory_usage() |
|
|
|
|
|
|
|
if gpu_usage > 0.7: |
|
logger.info("High memory usage detected, forcing cleanup") |
|
gc.collect() |
|
tf.keras.backend.clear_session() |
|
|
|
|
|
end_idx = min(total_added + initial_batch_size, len(response_embeddings)) |
|
batch = response_embeddings[total_added:end_idx] |
|
|
|
|
|
self.index.add(batch) |
|
|
|
|
|
batch_size = len(batch) |
|
total_added += batch_size |
|
|
|
|
|
if total_added % (initial_batch_size * 5) == 0: |
|
gc.collect() |
|
if tf.config.list_physical_devices('GPU'): |
|
tf.keras.backend.clear_session() |
|
|
|
|
|
if initial_batch_size < max_batch_size: |
|
initial_batch_size = min(initial_batch_size + 25, max_batch_size) |
|
|
|
except Exception as e: |
|
logger.warning(f"Error adding batch: {str(e)}") |
|
retry_count += 1 |
|
|
|
if retry_count > max_retries: |
|
logger.error("Max retries exceeded.") |
|
raise |
|
|
|
|
|
initial_batch_size = max(min_batch_size, initial_batch_size // 2) |
|
logger.info(f"Reducing batch size to {initial_batch_size} and retrying...") |
|
|
|
|
|
gc.collect() |
|
if tf.config.list_physical_devices('GPU'): |
|
tf.keras.backend.clear_session() |
|
|
|
time.sleep(1) |
|
|
|
logger.info(f"Successfully added all {total_added} vectors to index") |
|
|
|
def _add_vectors_cpu_fallback(self, remaining_embeddings: np.ndarray, already_added: int = 0) -> None: |
|
"""CPU fallback with extra safeguards and progress tracking.""" |
|
logger.info(f"CPU Fallback: Adding {len(remaining_embeddings)} remaining vectors...") |
|
|
|
try: |
|
|
|
if self.faiss_gpu: |
|
logger.info("Moving index to CPU...") |
|
cpu_index = faiss.index_gpu_to_cpu(self.index) |
|
else: |
|
cpu_index = self.index |
|
|
|
|
|
batch_size = 128 |
|
total_added = already_added |
|
|
|
for i in range(0, len(remaining_embeddings), batch_size): |
|
end_idx = min(i + batch_size, len(remaining_embeddings)) |
|
batch = remaining_embeddings[i:end_idx] |
|
|
|
|
|
cpu_index.add(batch) |
|
|
|
|
|
total_added += len(batch) |
|
if i % (batch_size * 10) == 0: |
|
logger.info(f"Added {total_added} vectors total " |
|
f"({i}/{len(remaining_embeddings)} in current phase)") |
|
|
|
|
|
if i % (batch_size * 20) == 0: |
|
gc.collect() |
|
|
|
|
|
if self.faiss_gpu: |
|
logger.info("Moving index back to GPU...") |
|
if len(self.gpu_resources) > 1: |
|
self.index = faiss.index_cpu_to_gpus_list(cpu_index, self.gpu_resources) |
|
else: |
|
self.index = faiss.index_cpu_to_gpu(self.gpu_resources[0], 0, cpu_index) |
|
else: |
|
self.index = cpu_index |
|
|
|
logger.info("CPU fallback completed successfully") |
|
|
|
except Exception as e: |
|
logger.error(f"Error during CPU fallback: {str(e)}") |
|
raise |
|
|
|
def _compute_and_index_embeddings(self): |
|
"""Compute embeddings and build FAISS index with simpler handling.""" |
|
logger.info("Computing embeddings and indexing with FAISS...") |
|
|
|
try: |
|
|
|
logger.info("Encoding unique responses") |
|
response_embeddings = self.encode_responses(self.unique_responses) |
|
response_embeddings = response_embeddings.numpy() |
|
|
|
|
|
gc.collect() |
|
if tf.config.list_physical_devices('GPU'): |
|
tf.keras.backend.clear_session() |
|
|
|
|
|
response_embeddings = response_embeddings.astype('float32') |
|
response_embeddings = np.ascontiguousarray(response_embeddings) |
|
|
|
|
|
if self.memory_monitor.has_gpu: |
|
stats = self.memory_monitor.get_memory_stats() |
|
if stats: |
|
logger.info(f"GPU memory before normalization: {stats.used/1e9:.2f}GB used") |
|
|
|
|
|
logger.info("Normalizing embeddings with FAISS") |
|
faiss.normalize_L2(response_embeddings) |
|
|
|
|
|
dim = response_embeddings.shape[1] |
|
if self.faiss_gpu: |
|
cpu_index = faiss.IndexFlatIP(dim) |
|
if len(self.gpu_resources) > 1: |
|
self.index = faiss.index_cpu_to_gpus_list(cpu_index, self.gpu_resources) |
|
else: |
|
self.index = faiss.index_cpu_to_gpu(self.gpu_resources[0], 0, cpu_index) |
|
else: |
|
self.index = faiss.IndexFlatIP(dim) |
|
|
|
|
|
self._add_vectors_to_index(response_embeddings) |
|
|
|
|
|
self.response_pool = self.unique_responses |
|
self.response_embeddings = response_embeddings |
|
|
|
|
|
gc.collect() |
|
if tf.config.list_physical_devices('GPU'): |
|
tf.keras.backend.clear_session() |
|
|
|
|
|
logger.info(f"Successfully indexed {self.index.ntotal} responses") |
|
if self.memory_monitor.has_gpu: |
|
stats = self.memory_monitor.get_memory_stats() |
|
if stats: |
|
logger.info(f"Final GPU memory usage: {stats.used/1e9:.2f}GB used") |
|
|
|
logger.info("Indexing completed successfully") |
|
|
|
except Exception as e: |
|
logger.error(f"Error during indexing: {e}") |
|
|
|
gc.collect() |
|
if tf.config.list_physical_devices('GPU'): |
|
tf.keras.backend.clear_session() |
|
raise |
|
|
|
def verify_faiss_index(self): |
|
"""Verify that FAISS index matches the response pool.""" |
|
indexed_size = self.index.ntotal |
|
pool_size = len(self.response_pool) |
|
logger.info(f"FAISS index size: {indexed_size}") |
|
logger.info(f"Response pool size: {pool_size}") |
|
if indexed_size != pool_size: |
|
logger.warning("Mismatch between FAISS index size and response pool size.") |
|
else: |
|
logger.info("FAISS index correctly matches the response pool.") |
|
|
|
def encode_query(self, query: str, context: Optional[List[Tuple[str, str]]] = None) -> tf.Tensor: |
|
"""Encode a query with optional conversation context.""" |
|
|
|
if context: |
|
context_str = ' '.join([ |
|
f"{self.special_tokens['user']} {q} " |
|
f"{self.special_tokens['assistant']} {r}" |
|
for q, r in context[-self.config.max_context_turns:] |
|
]) |
|
query = f"{context_str} {self.special_tokens['user']} {query}" |
|
else: |
|
query = f"{self.special_tokens['user']} {query}" |
|
|
|
|
|
encodings = self.tokenizer( |
|
[query], |
|
padding='max_length', |
|
truncation=True, |
|
max_length=self.config.max_context_token_limit, |
|
return_tensors='tf' |
|
) |
|
input_ids = encodings['input_ids'] |
|
|
|
|
|
max_id = tf.reduce_max(input_ids).numpy() |
|
new_vocab_size = len(self.tokenizer) |
|
|
|
if max_id >= new_vocab_size: |
|
logger.error(f"Token ID {max_id} exceeds the vocabulary size {new_vocab_size}.") |
|
raise ValueError("Token ID exceeds vocabulary size.") |
|
|
|
|
|
return self.encoder(input_ids, training=False) |
|
|
|
def retrieve_responses_cross_encoder( |
|
self, |
|
query: str, |
|
top_k: int, |
|
reranker: Optional[CrossEncoderReranker] = None, |
|
summarizer: Optional[Summarizer] = None, |
|
summarize_threshold: int = 512 |
|
) -> List[Tuple[str, float]]: |
|
""" |
|
Retrieve top-k from FAISS, then re-rank them with a cross-encoder. |
|
Optionally summarize the user query if it's too long. |
|
""" |
|
if reranker is None: |
|
reranker = self.reranker |
|
if summarizer is None: |
|
summarizer = self.summarizer |
|
|
|
|
|
if summarizer and len(query.split()) > summarize_threshold: |
|
logger.info(f"Query is long. Summarizing before cross-encoder. Original length: {len(query.split())}") |
|
query = summarizer.summarize_text(query) |
|
logger.info(f"Summarized query: {query}") |
|
|
|
|
|
dense_topk = self.retrieve_responses_faiss(query, top_k=top_k) |
|
|
|
if not dense_topk: |
|
return [] |
|
|
|
|
|
candidate_texts = [pair[0] for pair in dense_topk] |
|
cross_scores = reranker.rerank(query, candidate_texts, max_length=256) |
|
|
|
|
|
combined = [(text, score) for (text, _), score in zip(dense_topk, cross_scores)] |
|
|
|
combined.sort(key=lambda x: x[1], reverse=True) |
|
|
|
return combined |
|
|
|
def retrieve_responses_faiss(self, query: str, top_k: int = 5) -> List[Tuple[str, float]]: |
|
"""Retrieve top-k responses using FAISS.""" |
|
|
|
q_emb = self.encode_query(query) |
|
q_emb_np = q_emb.numpy().astype('float32') |
|
|
|
|
|
faiss.normalize_L2(q_emb_np) |
|
|
|
|
|
distances, indices = self.index.search(q_emb_np, top_k) |
|
|
|
|
|
top_responses = [] |
|
for i, idx in enumerate(indices[0]): |
|
if idx < len(self.response_pool): |
|
top_responses.append((self.response_pool[idx], float(distances[0][i]))) |
|
else: |
|
logger.warning(f"FAISS returned invalid index {idx}. Skipping.") |
|
|
|
return top_responses |
|
|
|
def save_models(self, save_dir: Union[str, Path]): |
|
"""Save models and configuration.""" |
|
save_dir = Path(save_dir) |
|
save_dir.mkdir(parents=True, exist_ok=True) |
|
|
|
|
|
with open(save_dir / "config.json", "w") as f: |
|
json.dump(self.config.to_dict(), f, indent=2) |
|
|
|
|
|
self.encoder.pretrained.save_pretrained(save_dir / "shared_encoder") |
|
|
|
|
|
self.tokenizer.save_pretrained(save_dir / "tokenizer") |
|
|
|
logger.info(f"Models and tokenizer saved to {save_dir}.") |
|
|
|
@classmethod |
|
def load_models(cls, load_dir: Union[str, Path]) -> 'RetrievalChatbot': |
|
"""Load saved models and configuration.""" |
|
load_dir = Path(load_dir) |
|
|
|
|
|
with open(load_dir / "config.json", "r") as f: |
|
config = ChatbotConfig.from_dict(json.load(f)) |
|
|
|
|
|
chatbot = cls(config) |
|
|
|
|
|
chatbot.encoder.pretrained = TFAutoModel.from_pretrained( |
|
load_dir / "shared_encoder", |
|
config=config |
|
) |
|
|
|
|
|
chatbot.tokenizer = AutoTokenizer.from_pretrained(load_dir / "tokenizer") |
|
|
|
logger.info(f"Models and tokenizer loaded from {load_dir}.") |
|
return chatbot |
|
|
|
@staticmethod |
|
def load_training_data(data_path: Union[str, Path], debug_samples: Optional[int] = None) -> List[dict]: |
|
""" |
|
Load training data from a JSON file. |
|
|
|
Args: |
|
data_path (Union[str, Path]): Path to the JSON file containing dialogues. |
|
debug_samples (Optional[int]): Number of samples to load for debugging. |
|
|
|
Returns: |
|
List[dict]: List of dialogue dictionaries. |
|
""" |
|
logger.info(f"Loading training data from {data_path}...") |
|
data_path = Path(data_path) |
|
if not data_path.exists(): |
|
logger.error(f"Data file {data_path} does not exist.") |
|
return [] |
|
|
|
with open(data_path, 'r', encoding='utf-8') as f: |
|
dialogues = json.load(f) |
|
|
|
if debug_samples is not None: |
|
dialogues = dialogues[:debug_samples] |
|
logger.info(f"Debug mode: Limited to {debug_samples} dialogues") |
|
|
|
logger.info(f"Loaded {len(dialogues)} dialogues.") |
|
return dialogues |
|
|
|
def train_streaming( |
|
self, |
|
dialogues: List[dict], |
|
epochs: int = 20, |
|
batch_size: int = 16, |
|
validation_split: float = 0.2, |
|
checkpoint_dir: str = "checkpoints/", |
|
use_lr_schedule: bool = True, |
|
peak_lr: float = 2e-5, |
|
warmup_steps_ratio: float = 0.1, |
|
early_stopping_patience: int = 3, |
|
min_delta: float = 1e-4, |
|
neg_samples: int = 1 |
|
) -> None: |
|
"""Streaming training with tf.data pipeline.""" |
|
logger.info("Starting streaming training pipeline with tf.data...") |
|
|
|
|
|
dataset_preparer = TFDataPipeline( |
|
embedding_batch_size=self.config.embedding_batch_size, |
|
tokenizer=self.tokenizer, |
|
encoder=self.encoder, |
|
index=self.index, |
|
response_pool=self.response_pool, |
|
max_length=self.config.max_context_token_limit, |
|
neg_samples=neg_samples |
|
) |
|
|
|
|
|
total_pairs = dataset_preparer.estimate_total_pairs(dialogues) |
|
train_size = int(total_pairs * (1 - validation_split)) |
|
val_size = int(total_pairs * validation_split) |
|
steps_per_epoch = int(math.ceil(train_size / batch_size)) |
|
val_steps = int(math.ceil(val_size / batch_size)) |
|
total_steps = steps_per_epoch * epochs |
|
|
|
logger.info(f"Total pairs: {total_pairs}") |
|
logger.info(f"Training pairs: {train_size}") |
|
logger.info(f"Validation pairs: {val_size}") |
|
logger.info(f"Steps per epoch: {steps_per_epoch}") |
|
logger.info(f"Validation steps: {val_steps}") |
|
logger.info(f"Total steps: {total_steps}") |
|
|
|
|
|
if use_lr_schedule: |
|
warmup_steps = int(total_steps * warmup_steps_ratio) |
|
lr_schedule = self._get_lr_schedule( |
|
total_steps=total_steps, |
|
peak_lr=peak_lr, |
|
warmup_steps=warmup_steps |
|
) |
|
self.optimizer = tf.keras.optimizers.Adam(learning_rate=lr_schedule) |
|
logger.info("Using custom learning rate schedule.") |
|
else: |
|
self.optimizer = tf.keras.optimizers.Adam(learning_rate=peak_lr) |
|
logger.info("Using fixed learning rate.") |
|
|
|
|
|
checkpoint = tf.train.Checkpoint(optimizer=self.optimizer, model=self.encoder) |
|
manager = tf.train.CheckpointManager(checkpoint, checkpoint_dir, max_to_keep=3) |
|
|
|
|
|
log_dir = Path(checkpoint_dir) / "tensorboard_logs" |
|
log_dir.mkdir(parents=True, exist_ok=True) |
|
current_time = datetime.datetime.now().strftime("%Y%m%d-%H%M%S") |
|
train_log_dir = str(log_dir / f"train_{current_time}") |
|
val_log_dir = str(log_dir / f"val_{current_time}") |
|
train_summary_writer = tf.summary.create_file_writer(train_log_dir) |
|
val_summary_writer = tf.summary.create_file_writer(val_log_dir) |
|
logger.info(f"TensorBoard logs will be saved in {log_dir}") |
|
|
|
|
|
train_dataset = dataset_preparer.get_tf_dataset(dialogues, batch_size).take(train_size) |
|
val_dataset = dataset_preparer.get_tf_dataset(dialogues, batch_size).skip(train_size).take(val_size) |
|
|
|
|
|
best_val_loss = float("inf") |
|
epochs_no_improve = 0 |
|
|
|
for epoch in range(1, epochs + 1): |
|
|
|
epoch_loss_avg = tf.keras.metrics.Mean() |
|
batches_processed = 0 |
|
|
|
try: |
|
train_pbar = tqdm(total=steps_per_epoch, desc=f"Training Epoch {epoch}", unit="batch") |
|
is_tqdm_train = True |
|
except ImportError: |
|
train_pbar = None |
|
is_tqdm_train = False |
|
logger.info("Training progress bar disabled") |
|
|
|
for q_batch, p_batch, n_batch in train_dataset: |
|
|
|
loss = self.train_step(q_batch, p_batch, n_batch) |
|
epoch_loss_avg(loss) |
|
batches_processed += 1 |
|
|
|
|
|
with train_summary_writer.as_default(): |
|
tf.summary.scalar("loss", loss, step=(epoch - 1) * steps_per_epoch + batches_processed) |
|
|
|
|
|
if use_lr_schedule: |
|
current_lr = float(lr_schedule(self.optimizer.iterations)) |
|
else: |
|
current_lr = float(self.optimizer.learning_rate.numpy()) |
|
|
|
if is_tqdm_train: |
|
train_pbar.update(1) |
|
train_pbar.set_postfix({ |
|
"loss": f"{loss.numpy():.4f}", |
|
"lr": f"{current_lr:.2e}", |
|
"batches": f"{batches_processed}/{steps_per_epoch}" |
|
}) |
|
|
|
|
|
gc.collect() |
|
|
|
if batches_processed >= steps_per_epoch: |
|
break |
|
|
|
if is_tqdm_train and train_pbar: |
|
train_pbar.close() |
|
|
|
|
|
val_loss_avg = tf.keras.metrics.Mean() |
|
val_batches_processed = 0 |
|
|
|
try: |
|
val_pbar = tqdm(total=val_steps, desc="Validation", unit="batch") |
|
is_tqdm_val = True |
|
except ImportError: |
|
val_pbar = None |
|
is_tqdm_val = False |
|
logger.info("Validation progress bar disabled") |
|
|
|
for q_batch, p_batch, n_batch in val_dataset: |
|
|
|
val_loss = self.validation_step(q_batch, p_batch, n_batch) |
|
val_loss_avg(val_loss) |
|
val_batches_processed += 1 |
|
|
|
if is_tqdm_val: |
|
val_pbar.update(1) |
|
val_pbar.set_postfix({ |
|
"val_loss": f"{val_loss.numpy():.4f}", |
|
"batches": f"{val_batches_processed}/{val_steps}" |
|
}) |
|
|
|
|
|
gc.collect() |
|
|
|
|
|
if val_batches_processed >= val_steps: |
|
break |
|
|
|
if is_tqdm_val and val_pbar: |
|
val_pbar.close() |
|
|
|
|
|
train_loss = epoch_loss_avg.result().numpy() |
|
val_loss = val_loss_avg.result().numpy() |
|
logger.info(f"Epoch {epoch} Complete: Train Loss={train_loss:.4f}, Val Loss={val_loss:.4f}") |
|
|
|
|
|
with train_summary_writer.as_default(): |
|
tf.summary.scalar("epoch_loss", train_loss, step=epoch) |
|
with val_summary_writer.as_default(): |
|
tf.summary.scalar("val_loss", val_loss, step=epoch) |
|
|
|
|
|
manager.save() |
|
|
|
|
|
self.history['train_loss'].append(train_loss) |
|
self.history['val_loss'].append(val_loss) |
|
|
|
if use_lr_schedule: |
|
current_lr = float(lr_schedule(self.optimizer.iterations)) |
|
else: |
|
current_lr = float(self.optimizer.learning_rate.numpy()) |
|
|
|
self.history.setdefault('learning_rate', []).append(current_lr) |
|
|
|
|
|
if val_loss < best_val_loss - min_delta: |
|
best_val_loss = val_loss |
|
epochs_no_improve = 0 |
|
logger.info(f"Validation loss improved to {val_loss:.4f}. Reset patience.") |
|
else: |
|
epochs_no_improve += 1 |
|
logger.info(f"No improvement this epoch. Patience: {epochs_no_improve}/{early_stopping_patience}") |
|
if epochs_no_improve >= early_stopping_patience: |
|
logger.info("Early stopping triggered.") |
|
break |
|
|
|
logger.info("Streaming training completed!") |
|
|
|
|
|
@tf.function |
|
def train_step( |
|
self, |
|
q_batch: tf.Tensor, |
|
p_batch: tf.Tensor, |
|
n_batch: tf.Tensor, |
|
attention_mask: Optional[tf.Tensor] = None |
|
) -> tf.Tensor: |
|
""" |
|
Single training step that uses queries, positives, and negatives in a |
|
contrastive/InfoNCE style. The label is always 0 (the positive) vs. |
|
the negative alternatives. |
|
""" |
|
with tf.GradientTape() as tape: |
|
|
|
q_enc = self.encoder(q_batch, training=True) |
|
|
|
|
|
p_enc = self.encoder(p_batch, training=True) |
|
|
|
|
|
|
|
shape = tf.shape(n_batch) |
|
bs = shape[0] |
|
neg_samples = shape[1] |
|
|
|
|
|
|
|
n_batch_flat = tf.reshape(n_batch, [bs * neg_samples, shape[2]]) |
|
n_enc_flat = self.encoder(n_batch_flat, training=True) |
|
|
|
|
|
n_enc = tf.reshape(n_enc_flat, [bs, neg_samples, -1]) |
|
|
|
|
|
|
|
|
|
combined_p_n = tf.concat( |
|
[tf.expand_dims(p_enc, axis=1), n_enc], |
|
axis=1 |
|
) |
|
|
|
|
|
|
|
|
|
dot_products = tf.einsum('bd,bkd->bk', q_enc, combined_p_n) |
|
|
|
|
|
labels = tf.zeros([bs], dtype=tf.int32) |
|
|
|
|
|
loss = tf.nn.sparse_softmax_cross_entropy_with_logits( |
|
labels=labels, |
|
logits=dot_products |
|
) |
|
loss = tf.reduce_mean(loss) |
|
|
|
|
|
|
|
if attention_mask is not None: |
|
loss = loss * attention_mask |
|
loss = tf.reduce_sum(loss) / tf.reduce_sum(attention_mask) |
|
|
|
|
|
gradients = tape.gradient(loss, self.encoder.trainable_variables) |
|
self.optimizer.apply_gradients(zip(gradients, self.encoder.trainable_variables)) |
|
return loss |
|
|
|
@tf.function |
|
def validation_step( |
|
self, |
|
q_batch: tf.Tensor, |
|
p_batch: tf.Tensor, |
|
n_batch: tf.Tensor, |
|
attention_mask: Optional[tf.Tensor] = None |
|
) -> tf.Tensor: |
|
""" |
|
Single validation step with queries, positives, and negatives. |
|
Uses the same loss calculation as train_step, but `training=False`. |
|
""" |
|
q_enc = self.encoder(q_batch, training=False) |
|
p_enc = self.encoder(p_batch, training=False) |
|
|
|
shape = tf.shape(n_batch) |
|
bs = shape[0] |
|
neg_samples = shape[1] |
|
|
|
n_batch_flat = tf.reshape(n_batch, [bs * neg_samples, shape[2]]) |
|
n_enc_flat = self.encoder(n_batch_flat, training=False) |
|
n_enc = tf.reshape(n_enc_flat, [bs, neg_samples, -1]) |
|
|
|
combined_p_n = tf.concat( |
|
[tf.expand_dims(p_enc, axis=1), n_enc], |
|
axis=1 |
|
) |
|
|
|
dot_products = tf.einsum('bd,bkd->bk', q_enc, combined_p_n) |
|
labels = tf.zeros([bs], dtype=tf.int32) |
|
|
|
loss = tf.nn.sparse_softmax_cross_entropy_with_logits( |
|
labels=labels, |
|
logits=dot_products |
|
) |
|
loss = tf.reduce_mean(loss) |
|
|
|
if attention_mask is not None: |
|
loss = loss * attention_mask |
|
loss = tf.reduce_sum(loss) / tf.reduce_sum(attention_mask) |
|
|
|
return loss |
|
|
|
def _get_lr_schedule( |
|
self, |
|
total_steps: int, |
|
peak_lr: float, |
|
warmup_steps: int |
|
) -> tf.keras.optimizers.schedules.LearningRateSchedule: |
|
"""Create a custom learning rate schedule with warmup and cosine decay.""" |
|
class CustomSchedule(tf.keras.optimizers.schedules.LearningRateSchedule): |
|
def __init__( |
|
self, |
|
total_steps: int, |
|
peak_lr: float, |
|
warmup_steps: int |
|
): |
|
super().__init__() |
|
self.total_steps = tf.cast(total_steps, tf.float32) |
|
self.peak_lr = tf.cast(peak_lr, tf.float32) |
|
|
|
|
|
adjusted_warmup_steps = min(warmup_steps, max(1, total_steps // 10)) |
|
self.warmup_steps = tf.cast(adjusted_warmup_steps, tf.float32) |
|
|
|
|
|
self.initial_lr = self.peak_lr * 0.1 |
|
self.min_lr = self.peak_lr * 0.01 |
|
|
|
logger.info(f"Learning rate schedule initialized:") |
|
logger.info(f" Initial LR: {float(self.initial_lr):.6f}") |
|
logger.info(f" Peak LR: {float(self.peak_lr):.6f}") |
|
logger.info(f" Min LR: {float(self.min_lr):.6f}") |
|
logger.info(f" Warmup steps: {int(self.warmup_steps)}") |
|
logger.info(f" Total steps: {int(self.total_steps)}") |
|
|
|
def __call__(self, step): |
|
step = tf.cast(step, tf.float32) |
|
|
|
|
|
warmup_factor = tf.minimum(1.0, step / self.warmup_steps) |
|
warmup_lr = self.initial_lr + (self.peak_lr - self.initial_lr) * warmup_factor |
|
|
|
|
|
decay_steps = tf.maximum(1.0, self.total_steps - self.warmup_steps) |
|
decay_factor = (step - self.warmup_steps) / decay_steps |
|
decay_factor = tf.minimum(tf.maximum(0.0, decay_factor), 1.0) |
|
|
|
cosine_decay = 0.5 * (1.0 + tf.cos(tf.constant(math.pi) * decay_factor)) |
|
decay_lr = self.min_lr + (self.peak_lr - self.min_lr) * cosine_decay |
|
|
|
|
|
final_lr = tf.where(step < self.warmup_steps, warmup_lr, decay_lr) |
|
|
|
|
|
final_lr = tf.maximum(self.min_lr, final_lr) |
|
final_lr = tf.where(tf.math.is_finite(final_lr), final_lr, self.min_lr) |
|
|
|
return final_lr |
|
|
|
def get_config(self): |
|
return { |
|
"total_steps": self.total_steps, |
|
"peak_lr": self.peak_lr, |
|
"warmup_steps": self.warmup_steps, |
|
} |
|
|
|
return CustomSchedule(total_steps, peak_lr, warmup_steps) |
|
|
|
def _cosine_similarity(self, emb1: np.ndarray, emb2: np.ndarray) -> np.ndarray: |
|
"""Compute cosine similarity between two numpy arrays.""" |
|
normalized_emb1 = emb1 / np.linalg.norm(emb1, axis=1, keepdims=True) |
|
normalized_emb2 = emb2 / np.linalg.norm(emb2, axis=1, keepdims=True) |
|
return np.dot(normalized_emb1, normalized_emb2.T) |
|
|
|
def chat( |
|
self, |
|
query: str, |
|
conversation_history: Optional[List[Tuple[str, str]]] = None, |
|
quality_checker: Optional['ResponseQualityChecker'] = None, |
|
top_k: int = 5, |
|
) -> Tuple[str, List[Tuple[str, float]], Dict[str, Any]]: |
|
""" |
|
Example chat method that always uses cross-encoder re-ranking |
|
if self.reranker is available. |
|
""" |
|
@self.run_on_device |
|
def get_response(self_arg, query_arg): |
|
|
|
conversation_str = self_arg._build_conversation_context(query_arg, conversation_history) |
|
|
|
|
|
results = self_arg.retrieve_responses_cross_encoder( |
|
query=conversation_str, |
|
top_k=top_k, |
|
reranker=self_arg.reranker, |
|
summarizer=self_arg.summarizer, |
|
summarize_threshold=512 |
|
) |
|
|
|
|
|
if not results: |
|
return ( |
|
"I'm sorry, but I couldn't find a relevant response.", |
|
[], |
|
{} |
|
) |
|
|
|
if quality_checker: |
|
metrics = quality_checker.check_response_quality(query_arg, results) |
|
if not metrics.get('is_confident', False): |
|
return ( |
|
"I need more information to provide a good answer. Could you please clarify?", |
|
results, |
|
metrics |
|
) |
|
return results[0][0], results, metrics |
|
|
|
return results[0][0], results, {} |
|
|
|
return get_response(self, query) |
|
|
|
def _build_conversation_context( |
|
self, |
|
query: str, |
|
conversation_history: Optional[List[Tuple[str, str]]] |
|
) -> str: |
|
"""Build conversation context with better memory management.""" |
|
if not conversation_history: |
|
return f"{self.special_tokens['user']} {query}" |
|
|
|
conversation_parts = [] |
|
for user_txt, assistant_txt in conversation_history: |
|
conversation_parts.extend([ |
|
f"{self.special_tokens['user']} {user_txt}", |
|
f"{self.special_tokens['assistant']} {assistant_txt}" |
|
]) |
|
|
|
conversation_parts.append(f"{self.special_tokens['user']} {query}") |
|
return "\n".join(conversation_parts) |
|
|
|
class TFDataPipeline: |
|
def __init__( |
|
self, |
|
embedding_batch_size, |
|
tokenizer, |
|
encoder, |
|
index, |
|
response_pool, |
|
max_length: int, |
|
neg_samples: int, |
|
): |
|
self.embedding_batch_size = embedding_batch_size |
|
self.tokenizer = tokenizer |
|
self.encoder = encoder |
|
self.index = index |
|
self.response_pool = response_pool |
|
self.max_length = max_length |
|
self.neg_samples = neg_samples |
|
self.embedding_batch_size = 16 if len(response_pool) < 100 else 64 |
|
self.search_batch_size = 8 if len(response_pool) < 100 else 32 |
|
self.max_batch_size = 32 if len(response_pool) < 100 else 256 |
|
self.memory_monitor = GPUMemoryMonitor() |
|
self.max_retries = 3 |
|
|
|
|
|
self.query_embeddings_cache = {} |
|
|
|
def _extract_pairs_from_dialogue(self, dialogue: dict) -> List[Tuple[str, str]]: |
|
"""Extract query-response pairs from a dialogue.""" |
|
pairs = [] |
|
turns = dialogue.get('turns', []) |
|
|
|
for i in range(len(turns) - 1): |
|
current_turn = turns[i] |
|
next_turn = turns[i+1] |
|
|
|
if (current_turn.get('speaker') == 'user' and |
|
next_turn.get('speaker') == 'assistant' and |
|
'text' in current_turn and |
|
'text' in next_turn): |
|
|
|
query = current_turn['text'].strip() |
|
positive = next_turn['text'].strip() |
|
pairs.append((query, positive)) |
|
|
|
return pairs |
|
|
|
def estimate_total_pairs(self, dialogues: List[dict]) -> int: |
|
"""Estimate total number of training pairs including hard negatives.""" |
|
base_pairs = sum( |
|
len([ |
|
1 for i in range(len(d.get('turns', [])) - 1) |
|
if (d['turns'][i].get('speaker') == 'user' and |
|
d['turns'][i+1].get('speaker') == 'assistant') |
|
]) |
|
for d in dialogues |
|
) |
|
|
|
return base_pairs * (1 + self.neg_samples) |
|
|
|
def _find_hard_negatives_batch(self, queries: List[str], positives: List[str]) -> List[List[str]]: |
|
"""Find hard negatives for a batch of queries with error handling and retries.""" |
|
retry_count = 0 |
|
total_responses = len(self.response_pool) |
|
|
|
while retry_count < self.max_retries: |
|
try: |
|
query_embeddings = np.vstack([ |
|
self.query_embeddings_cache[q] for q in queries |
|
]).astype(np.float32) |
|
|
|
query_embeddings = np.ascontiguousarray(query_embeddings) |
|
faiss.normalize_L2(query_embeddings) |
|
|
|
k = 1 |
|
|
|
|
|
distances, indices = self.index.search(query_embeddings, k) |
|
|
|
all_negatives = [] |
|
for query_indices, query, positive in zip(indices, queries, positives): |
|
negatives = [] |
|
positive_strip = positive.strip() |
|
seen = {positive_strip} |
|
|
|
for idx in query_indices: |
|
if idx >= 0 and idx < total_responses: |
|
candidate = self.response_pool[idx].strip() |
|
if candidate and candidate not in seen: |
|
seen.add(candidate) |
|
negatives.append(candidate) |
|
if len(negatives) >= self.neg_samples: |
|
break |
|
|
|
|
|
while len(negatives) < self.neg_samples: |
|
negatives.append("<EMPTY_NEGATIVE>") |
|
|
|
all_negatives.append(negatives) |
|
|
|
return all_negatives |
|
|
|
except Exception as e: |
|
retry_count += 1 |
|
logger.warning(f"Hard negative search attempt {retry_count} failed: {e}") |
|
if retry_count == self.max_retries: |
|
logger.error("Max retries reached for hard negative search") |
|
return [["<EMPTY_NEGATIVE>"] * self.neg_samples for _ in queries] |
|
gc.collect() |
|
if tf.config.list_physical_devices('GPU'): |
|
tf.keras.backend.clear_session() |
|
|
|
def _tokenize_negatives_tf(self, negatives): |
|
"""Tokenizes negatives using tf.py_function.""" |
|
|
|
if tf.size(negatives) == 0: |
|
return tf.zeros([0, self.neg_samples, self.max_length], dtype=tf.int32) |
|
|
|
|
|
negatives_list = [] |
|
for neg_list in negatives.numpy(): |
|
decoded_negs = [neg.decode("utf-8") for neg in neg_list if neg] |
|
negatives_list.append(decoded_negs) |
|
|
|
|
|
flattened_negatives = [neg for sublist in negatives_list for neg in sublist] |
|
|
|
|
|
if flattened_negatives: |
|
n_tokens = self.tokenizer( |
|
flattened_negatives, |
|
padding='max_length', |
|
truncation=True, |
|
max_length=self.max_length, |
|
return_tensors='tf' |
|
) |
|
|
|
n_tokens_reshaped = tf.reshape(n_tokens['input_ids'], [-1, self.neg_samples, self.max_length]) |
|
return n_tokens_reshaped |
|
else: |
|
return tf.zeros([0, self.neg_samples, self.max_length], dtype=tf.int32) |
|
|
|
def _compute_embeddings(self, queries: List[str]) -> None: |
|
"""Computes and caches embeddings for new queries.""" |
|
new_queries = [q for q in queries if q not in self.query_embeddings_cache] |
|
if not new_queries: |
|
return |
|
|
|
new_embeddings = [] |
|
for i in range(0, len(new_queries), self.embedding_batch_size): |
|
batch_queries = new_queries[i:i + self.embedding_batch_size] |
|
|
|
encoded = self.tokenizer( |
|
batch_queries, |
|
padding=True, |
|
truncation=True, |
|
max_length=self.max_length, |
|
return_tensors='tf' |
|
) |
|
|
|
|
|
with tf.device('/CPU:0'): |
|
batch_embeddings = self.encoder(encoded['input_ids'], training=False).numpy() |
|
|
|
new_embeddings.extend(batch_embeddings) |
|
|
|
|
|
for query, emb in zip(new_queries, new_embeddings): |
|
self.query_embeddings_cache[query] = emb |
|
|
|
def data_generator(self, dialogues: List[dict]) -> Generator[Tuple[str, str, List[str]], None, None]: |
|
""" |
|
Generates training examples: (query, positive, hard_negatives). |
|
Wrapped the outer loop with tqdm for progress tracking. |
|
""" |
|
total_dialogues = len(dialogues) |
|
logger.debug(f"Total dialogues to process: {total_dialogues}") |
|
|
|
|
|
with tqdm(total=total_dialogues, desc="Processing Dialogues", unit="dialogue") as pbar: |
|
for dialogue in dialogues: |
|
pairs = self._extract_pairs_from_dialogue(dialogue) |
|
for query, positive in pairs: |
|
|
|
self._compute_embeddings([query]) |
|
hard_negatives = self._find_hard_negatives_batch([query], [positive])[0] |
|
yield (query, positive, hard_negatives) |
|
pbar.update(1) |
|
|
|
def _prepare_batch(self, queries: tf.Tensor, positives: tf.Tensor, negatives: tf.Tensor) -> Tuple[tf.Tensor, tf.Tensor]: |
|
"""Prepares a batch of data for training.""" |
|
|
|
|
|
queries_list = [query.decode("utf-8") for query in queries.numpy()] |
|
positives_list = [pos.decode("utf-8") for pos in positives.numpy()] |
|
|
|
|
|
q_tokens = self.tokenizer(queries_list, padding='max_length', truncation=True, max_length=self.max_length, return_tensors='tf') |
|
p_tokens = self.tokenizer(positives_list, padding='max_length', truncation=True, max_length=self.max_length, return_tensors='tf') |
|
|
|
|
|
negatives_list = [] |
|
for neg_list in negatives.numpy(): |
|
decoded_negs = [neg.decode("utf-8") for neg in neg_list if neg] |
|
negatives_list.append(decoded_negs) |
|
|
|
|
|
flattened_negatives = [neg for sublist in negatives_list for neg in sublist if neg] |
|
|
|
|
|
n_tokens_reshaped = None |
|
if flattened_negatives: |
|
n_tokens = self.tokenizer(flattened_negatives, padding='max_length', truncation=True, max_length=self.max_length, return_tensors='tf') |
|
|
|
|
|
|
|
n_tokens_reshaped = tf.reshape(n_tokens['input_ids'], [len(queries_list), -1, self.max_length]) |
|
else: |
|
|
|
n_tokens_reshaped = tf.zeros([len(queries_list), 0, self.max_length], dtype=tf.int32) |
|
|
|
|
|
|
|
if n_tokens_reshaped.shape[1] != self.neg_samples: |
|
|
|
padding = tf.zeros([len(queries_list), tf.maximum(0, self.neg_samples - n_tokens_reshaped.shape[1]), self.max_length], dtype=tf.int32) |
|
n_tokens_reshaped = tf.concat([n_tokens_reshaped, padding], axis=1) |
|
n_tokens_reshaped = n_tokens_reshaped[:, :self.neg_samples, :] |
|
|
|
|
|
combined_p_n_tokens = tf.concat([tf.expand_dims(p_tokens['input_ids'], axis=1), n_tokens_reshaped], axis=1) |
|
|
|
return q_tokens['input_ids'], combined_p_n_tokens |
|
|
|
def get_tf_dataset(self, dialogues: List[dict], batch_size: int) -> tf.data.Dataset: |
|
""" |
|
Creates a tf.data.Dataset for streaming training that yields |
|
(input_ids_query, input_ids_positive, input_ids_negatives). |
|
""" |
|
|
|
dataset = tf.data.Dataset.from_generator( |
|
lambda: self.data_generator(dialogues), |
|
output_signature=( |
|
tf.TensorSpec(shape=(), dtype=tf.string), |
|
tf.TensorSpec(shape=(), dtype=tf.string), |
|
tf.TensorSpec(shape=(None,), dtype=tf.string) |
|
) |
|
) |
|
|
|
|
|
dataset = dataset.batch(batch_size) |
|
|
|
|
|
dataset = dataset.map( |
|
lambda q, p, n: self._tokenize_triple(q, p, n), |
|
num_parallel_calls=1 |
|
) |
|
|
|
dataset = dataset.prefetch(tf.data.AUTOTUNE) |
|
return dataset |
|
|
|
def _tokenize_triple( |
|
self, |
|
q: tf.Tensor, |
|
p: tf.Tensor, |
|
n: tf.Tensor |
|
) -> Tuple[tf.Tensor, tf.Tensor, tf.Tensor]: |
|
""" |
|
Wraps a Python function via tf.py_function to convert tf.Tensors of strings |
|
-> Python lists of strings -> HF tokenizer -> Tensors of IDs. |
|
|
|
q is shape [batch_size], p is shape [batch_size], |
|
n is shape [batch_size, neg_samples] (i.e., each row is a list of negatives). |
|
""" |
|
|
|
q_ids, p_ids, n_ids = tf.py_function( |
|
func=self._tokenize_triple_py, |
|
inp=[q, p, n, tf.constant(self.max_length), tf.constant(self.neg_samples)], |
|
Tout=[tf.int32, tf.int32, tf.int32] |
|
) |
|
|
|
|
|
q_ids.set_shape([None, self.max_length]) |
|
p_ids.set_shape([None, self.max_length]) |
|
n_ids.set_shape([None, self.neg_samples, self.max_length]) |
|
|
|
return q_ids, p_ids, n_ids |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _tokenize_triple_py( |
|
self, |
|
q: tf.Tensor, |
|
p: tf.Tensor, |
|
n: tf.Tensor, |
|
max_len: tf.Tensor, |
|
neg_samples: tf.Tensor |
|
) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: |
|
""" |
|
Python function that: |
|
- Decodes each tf.string Tensor to a Python list of strings |
|
- Calls the HF tokenizer |
|
- Reshapes negatives |
|
- Returns np.array of int32s for (q_ids, p_ids, n_ids). |
|
|
|
q: shape [batch_size], p: shape [batch_size] |
|
n: shape [batch_size, neg_samples] |
|
max_len: scalar int |
|
neg_samples: scalar int |
|
""" |
|
max_len = int(max_len.numpy()) |
|
neg_samples = int(neg_samples.numpy()) |
|
|
|
|
|
q_list = [q_i.decode("utf-8") for q_i in q.numpy()] |
|
p_list = [p_i.decode("utf-8") for p_i in p.numpy()] |
|
|
|
|
|
n_list = [] |
|
for row in n.numpy(): |
|
|
|
decoded = [neg.decode("utf-8") for neg in row] |
|
n_list.append(decoded) |
|
|
|
|
|
q_enc = self.tokenizer( |
|
q_list, |
|
padding="max_length", |
|
truncation=True, |
|
max_length=max_len, |
|
return_tensors="np" |
|
) |
|
p_enc = self.tokenizer( |
|
p_list, |
|
padding="max_length", |
|
truncation=True, |
|
max_length=max_len, |
|
return_tensors="np" |
|
) |
|
|
|
|
|
|
|
flattened_negatives = [neg for row in n_list for neg in row] |
|
if len(flattened_negatives) == 0: |
|
|
|
n_ids = np.zeros((len(q_list), neg_samples, max_len), dtype=np.int32) |
|
else: |
|
n_enc = self.tokenizer( |
|
flattened_negatives, |
|
padding="max_length", |
|
truncation=True, |
|
max_length=max_len, |
|
return_tensors="np" |
|
) |
|
|
|
n_input_ids = n_enc["input_ids"] |
|
|
|
|
|
|
|
batch_size = len(q_list) |
|
n_ids_list = [] |
|
for i in range(batch_size): |
|
start_idx = i * neg_samples |
|
end_idx = start_idx + neg_samples |
|
row_negs = n_input_ids[start_idx:end_idx] |
|
|
|
|
|
if row_negs.shape[0] < neg_samples: |
|
deficit = neg_samples - row_negs.shape[0] |
|
pad_arr = np.zeros((deficit, max_len), dtype=np.int32) |
|
row_negs = np.concatenate([row_negs, pad_arr], axis=0) |
|
|
|
n_ids_list.append(row_negs) |
|
|
|
|
|
n_ids = np.stack(n_ids_list, axis=0) |
|
|
|
|
|
q_ids = q_enc["input_ids"].astype(np.int32) |
|
p_ids = p_enc["input_ids"].astype(np.int32) |
|
n_ids = n_ids.astype(np.int32) |
|
|
|
return q_ids, p_ids, n_ids |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|