|
import time |
|
from transformers import TFAutoModel, AutoTokenizer |
|
import tensorflow as tf |
|
import numpy as np |
|
import threading |
|
from queue import Queue, Empty |
|
from typing import Generator, List, Tuple, Dict, Optional, Union, Any |
|
import math |
|
from dataclasses import dataclass |
|
import json |
|
from pathlib import Path |
|
import datetime |
|
import faiss |
|
import gc |
|
import random |
|
from response_quality_checker import ResponseQualityChecker |
|
from cross_encoder_reranker import CrossEncoderReranker |
|
from conversation_summarizer import DeviceAwareModel, Summarizer |
|
from gpu_monitor import GPUMemoryMonitor |
|
import absl.logging |
|
from logger_config import config_logger |
|
from tqdm.auto import tqdm |
|
|
|
absl.logging.set_verbosity(absl.logging.WARNING) |
|
logger = config_logger(__name__) |
|
|
|
@dataclass |
|
class ChatbotConfig: |
|
"""Configuration for the RetrievalChatbot.""" |
|
vocab_size: int = 30526 |
|
max_context_token_limit: int = 512 |
|
embedding_dim: int = 512 |
|
encoder_units: int = 256 |
|
num_attention_heads: int = 8 |
|
dropout_rate: float = 0.2 |
|
l2_reg_weight: float = 0.001 |
|
margin: float = 0.3 |
|
learning_rate: float = 0.001 |
|
min_text_length: int = 3 |
|
max_context_turns: int = 5 |
|
warmup_steps: int = 200 |
|
pretrained_model: str = 'distilbert-base-uncased' |
|
dtype: str = 'float32' |
|
freeze_embeddings: bool = False |
|
|
|
|
|
def to_dict(self) -> dict: |
|
"""Convert config to dictionary.""" |
|
return {k: str(v) if isinstance(v, Path) else v |
|
for k, v in self.__dict__.items()} |
|
|
|
@classmethod |
|
def from_dict(cls, config_dict: dict) -> 'ChatbotConfig': |
|
"""Create config from dictionary.""" |
|
return cls(**{k: v for k, v in config_dict.items() |
|
if k in cls.__dataclass_fields__}) |
|
|
|
class EncoderModel(tf.keras.Model): |
|
"""Dual encoder model with pretrained embeddings.""" |
|
def __init__( |
|
self, |
|
config: ChatbotConfig, |
|
name: str = "encoder", |
|
shared_weights: bool = False, |
|
**kwargs |
|
): |
|
super().__init__(name=name, **kwargs) |
|
self.config = config |
|
self.shared_weights = shared_weights |
|
|
|
|
|
self.pretrained = TFAutoModel.from_pretrained(config.pretrained_model) |
|
|
|
|
|
self.pretrained.distilbert.embeddings.trainable = False |
|
for i, layer_module in enumerate(self.pretrained.distilbert.transformer.layer): |
|
if i < 1: |
|
layer_module.trainable = False |
|
else: |
|
layer_module.trainable = True |
|
|
|
|
|
self.pooler = tf.keras.layers.GlobalAveragePooling1D() |
|
|
|
|
|
self.projection = tf.keras.layers.Dense( |
|
config.embedding_dim, |
|
activation='tanh', |
|
name="projection" |
|
) |
|
|
|
|
|
self.dropout = tf.keras.layers.Dropout(config.dropout_rate) |
|
self.normalize = tf.keras.layers.Lambda( |
|
lambda x: tf.nn.l2_normalize(x, axis=1) |
|
) |
|
|
|
def call(self, inputs: tf.Tensor, training: bool = False) -> tf.Tensor: |
|
"""Forward pass.""" |
|
|
|
pretrained_outputs = self.pretrained(inputs, training=training) |
|
x = pretrained_outputs.last_hidden_state |
|
|
|
|
|
x = self.pooler(x) |
|
x = self.projection(x) |
|
x = self.dropout(x, training=training) |
|
x = self.normalize(x) |
|
|
|
return x |
|
|
|
def get_config(self) -> dict: |
|
"""Return the config of the model.""" |
|
config = super().get_config() |
|
config.update({ |
|
"config": self.config.to_dict(), |
|
"shared_weights": self.shared_weights, |
|
"name": self.name |
|
}) |
|
return config |
|
|
|
class RetrievalChatbot(DeviceAwareModel): |
|
"""Retrieval-based chatbot using pretrained embeddings and FAISS for similarity search.""" |
|
def __init__(self, config: ChatbotConfig, dialogues: List[dict] = [], device: str = None, |
|
strategy=None, reranker: Optional[CrossEncoderReranker] = None, |
|
summarizer: Optional[Summarizer] = None |
|
): |
|
self.config = config |
|
self.strategy = strategy |
|
self.setup_device(device) |
|
|
|
if reranker is None: |
|
logger.info("Creating default CrossEncoderReranker...") |
|
reranker = CrossEncoderReranker(model_name="cross-encoder/ms-marco-MiniLM-L-12-v2") |
|
self.reranker = reranker |
|
|
|
if summarizer is None: |
|
logger.info("Creating default Summarizer...") |
|
summarizer = Summarizer(device=self.device) |
|
self.summarizer = summarizer |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
self.special_tokens = { |
|
"user": "<USER>", |
|
"assistant": "<ASSISTANT>", |
|
"context": "<CONTEXT>", |
|
"sep": "<SEP>" |
|
} |
|
|
|
|
|
self.tokenizer = AutoTokenizer.from_pretrained(config.pretrained_model) |
|
self.tokenizer.add_special_tokens( |
|
{'additional_special_tokens': list(self.special_tokens.values())} |
|
) |
|
|
|
self.memory_monitor = GPUMemoryMonitor() |
|
self.min_batch_size = 8 |
|
self.max_batch_size = 128 |
|
self.current_batch_size = 32 |
|
|
|
|
|
self.response_pool, self.unique_responses = self._collect_responses(dialogues) |
|
|
|
|
|
self.history = { |
|
"train_loss": [], |
|
"val_loss": [], |
|
"train_metrics": {}, |
|
"val_metrics": {} |
|
} |
|
|
|
def build_models(self): |
|
"""Initialize the shared encoder.""" |
|
logger.info("Building encoder model...") |
|
tf.keras.backend.clear_session() |
|
|
|
|
|
self.encoder = EncoderModel( |
|
self.config, |
|
name="shared_encoder", |
|
) |
|
|
|
|
|
new_vocab_size = len(self.tokenizer) |
|
self.encoder.pretrained.resize_token_embeddings(new_vocab_size) |
|
logger.info(f"Token embeddings resized to: {new_vocab_size}") |
|
|
|
|
|
self._initialize_faiss() |
|
|
|
self._compute_and_index_embeddings() |
|
|
|
|
|
try: |
|
|
|
embedding_dim = self.encoder.pretrained.config.dim |
|
logger.info("Got embedding dim from config") |
|
except AttributeError: |
|
try: |
|
|
|
embedding_dim = self.encoder.pretrained.distilbert.embeddings.word_embeddings.embedding_dim |
|
logger.info("Got embedding dim from word embeddings") |
|
except AttributeError: |
|
try: |
|
|
|
embedding_dim = self.encoder.pretrained.distilbert.embeddings.embedding_dim |
|
logger.info("Got embedding dim from embeddings module") |
|
except AttributeError: |
|
|
|
embedding_dim = self.config.embedding_dim |
|
logger.info("Using config embedding dim") |
|
|
|
vocab_size = len(self.tokenizer) |
|
|
|
logger.info(f"Encoder Embedding Dimension: {embedding_dim}") |
|
logger.info(f"Encoder Embedding Vocabulary Size: {vocab_size}") |
|
if vocab_size >= embedding_dim: |
|
logger.info("Encoder model built and embeddings resized successfully.") |
|
else: |
|
logger.error("Vocabulary size is less than embedding dimension.") |
|
raise ValueError("Vocabulary size is less than embedding dimension.") |
|
|
|
def _collect_responses(self, dialogues: List[dict]) -> Tuple[List[str], List[str]]: |
|
"""Collect all unique responses from dialogues.""" |
|
logger.info("Collecting responses from dialogues...") |
|
|
|
responses = [] |
|
try: |
|
progress_bar = tqdm(dialogues, desc="Collecting assistant responses") |
|
except ImportError: |
|
progress_bar = dialogues |
|
logger.info("Progress bar disabled - continuing without visual progress") |
|
|
|
for dialogue in progress_bar: |
|
turns = dialogue.get('turns', []) |
|
for turn in turns: |
|
if turn.get('speaker') == 'assistant' and 'text' in turn: |
|
responses.append(turn['text'].strip()) |
|
|
|
|
|
unique_responses = list(set(responses)) |
|
logger.info(f"Found {len(unique_responses)} unique responses.") |
|
|
|
return responses, unique_responses |
|
|
|
def _adjust_batch_size(self) -> None: |
|
"""Dynamically adjust batch size based on GPU memory usage.""" |
|
if self.memory_monitor.should_reduce_batch_size(): |
|
new_size = max(self.min_batch_size, self.current_batch_size // 2) |
|
if new_size != self.current_batch_size: |
|
logger.info(f"Reducing batch size to {new_size} due to high memory usage") |
|
self.current_batch_size = new_size |
|
gc.collect() |
|
if tf.config.list_physical_devices('GPU'): |
|
tf.keras.backend.clear_session() |
|
elif self.memory_monitor.can_increase_batch_size(): |
|
new_size = min(self.max_batch_size, self.current_batch_size * 2) |
|
if new_size != self.current_batch_size: |
|
logger.info(f"Increasing batch size to {new_size}") |
|
self.current_batch_size = new_size |
|
|
|
def _initialize_faiss(self): |
|
"""Initialize FAISS with safer GPU handling and memory monitoring.""" |
|
logger.info("Initializing FAISS index...") |
|
|
|
|
|
self.faiss_gpu = False |
|
self.gpu_resources = [] |
|
|
|
try: |
|
if hasattr(faiss, 'get_num_gpus'): |
|
ngpus = faiss.get_num_gpus() |
|
if ngpus > 0: |
|
|
|
for i in range(ngpus): |
|
res = faiss.StandardGpuResources() |
|
|
|
if self.memory_monitor.has_gpu: |
|
stats = self.memory_monitor.get_memory_stats() |
|
if stats: |
|
temp_memory = int(stats.total * 0.25) |
|
res.setTempMemory(temp_memory) |
|
self.gpu_resources.append(res) |
|
self.faiss_gpu = True |
|
logger.info(f"FAISS GPU resources initialized on {ngpus} GPUs") |
|
else: |
|
logger.info("Using CPU-only FAISS build") |
|
|
|
except Exception as e: |
|
logger.warning(f"Using CPU due to GPU initialization error: {e}") |
|
|
|
|
|
try: |
|
|
|
if len(self.unique_responses) < 1000: |
|
logger.info("Small dataset detected, using simple FlatIP index") |
|
self.index = faiss.IndexFlatIP(self.config.embedding_dim) |
|
else: |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
self.index = faiss.IndexFlatIP(self.config.embedding_dim) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
except Exception as e: |
|
logger.error(f"Error initializing FAISS: {e}") |
|
raise |
|
|
|
def encode_responses( |
|
self, |
|
responses: List[str], |
|
batch_size: int = 64 |
|
) -> tf.Tensor: |
|
""" |
|
Encodes responses with more conservative memory management. |
|
""" |
|
all_embeddings = [] |
|
self.current_batch_size = batch_size |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
total_processed = 0 |
|
|
|
with tqdm(total=len(responses), desc="Encoding responses") as pbar: |
|
while total_processed < len(responses): |
|
|
|
if self.memory_monitor.has_gpu: |
|
gpu_usage = self.memory_monitor.get_memory_usage() |
|
if gpu_usage > 0.8: |
|
self.current_batch_size = max(128, self.current_batch_size // 2) |
|
logger.info(f"High GPU memory usage ({gpu_usage:.1%}), reducing batch size to {self.current_batch_size}") |
|
gc.collect() |
|
tf.keras.backend.clear_session() |
|
|
|
|
|
end_idx = min(total_processed + self.current_batch_size, len(responses)) |
|
batch_texts = responses[total_processed:end_idx] |
|
|
|
try: |
|
|
|
encodings = self.tokenizer( |
|
batch_texts, |
|
padding='max_length', |
|
truncation=True, |
|
max_length=self.config.max_context_token_limit, |
|
return_tensors='tf' |
|
) |
|
|
|
|
|
embeddings_batch = self.encoder(encodings['input_ids'], training=False) |
|
|
|
|
|
if embeddings_batch.dtype != tf.float32: |
|
embeddings_batch = tf.cast(embeddings_batch, tf.float32) |
|
|
|
|
|
all_embeddings.append(embeddings_batch) |
|
|
|
|
|
batch_processed = len(batch_texts) |
|
total_processed += batch_processed |
|
|
|
|
|
if self.memory_monitor.has_gpu: |
|
gpu_usage = self.memory_monitor.get_memory_usage() |
|
pbar.set_postfix({ |
|
'GPU mem': f'{gpu_usage:.1%}', |
|
'batch_size': self.current_batch_size |
|
}) |
|
pbar.update(batch_processed) |
|
|
|
|
|
if total_processed % 1000 == 0: |
|
gc.collect() |
|
if tf.config.list_physical_devices('GPU'): |
|
tf.keras.backend.clear_session() |
|
|
|
except tf.errors.ResourceExhaustedError: |
|
logger.warning("GPU memory exhausted during encoding, reducing batch size") |
|
self.current_batch_size = max(8, self.current_batch_size // 2) |
|
continue |
|
|
|
except Exception as e: |
|
logger.error(f"Error during encoding: {str(e)}") |
|
raise |
|
|
|
|
|
|
|
if len(all_embeddings) == 1: |
|
final_embeddings = all_embeddings[0] |
|
else: |
|
final_embeddings = tf.concat(all_embeddings, axis=0) |
|
|
|
return final_embeddings |
|
|
|
def _train_faiss_index(self, response_embeddings: np.ndarray) -> None: |
|
"""Train FAISS index with better memory management and robust fallback mechanisms.""" |
|
if self.index.is_trained: |
|
logger.info("Index already trained, skipping training phase") |
|
return |
|
|
|
logger.info("Starting FAISS index training...") |
|
|
|
try: |
|
|
|
subset_size = min(5000, len(response_embeddings)) |
|
logger.info(f"Using {subset_size} samples for initial training attempt") |
|
subset_idx = np.random.choice(len(response_embeddings), subset_size, replace=False) |
|
training_embeddings = response_embeddings[subset_idx].copy() |
|
|
|
|
|
training_embeddings = np.ascontiguousarray(training_embeddings) |
|
|
|
|
|
gc.collect() |
|
if tf.config.list_physical_devices('GPU'): |
|
tf.keras.backend.clear_session() |
|
|
|
|
|
logger.info(f"FAISS training data shape: {training_embeddings.shape}") |
|
logger.info(f"FAISS training data dtype: {training_embeddings.dtype}") |
|
|
|
logger.info("Starting initial training attempt...") |
|
self.index.train(training_embeddings) |
|
logger.info("Training completed successfully") |
|
|
|
except (RuntimeError, Exception) as e: |
|
logger.warning(f"Initial training attempt failed: {str(e)}") |
|
logger.info("Attempting fallback strategy...") |
|
|
|
try: |
|
|
|
if self.faiss_gpu: |
|
logger.info("Moving index to CPU for fallback training") |
|
cpu_index = faiss.index_gpu_to_cpu(self.index) |
|
else: |
|
cpu_index = self.index |
|
|
|
|
|
if isinstance(cpu_index, faiss.IndexIVFFlat): |
|
logger.info("Creating simpler FlatL2 index for fallback") |
|
cpu_index = faiss.IndexFlatL2(self.config.embedding_dim) |
|
|
|
|
|
subset_size = min(2000, len(response_embeddings)) |
|
subset_idx = np.random.choice(len(response_embeddings), subset_size, replace=False) |
|
fallback_embeddings = response_embeddings[subset_idx].copy() |
|
|
|
|
|
if not fallback_embeddings.flags['C_CONTIGUOUS']: |
|
fallback_embeddings = np.ascontiguousarray(fallback_embeddings) |
|
if fallback_embeddings.dtype != np.float32: |
|
fallback_embeddings = fallback_embeddings.astype(np.float32) |
|
|
|
|
|
logger.info("Training fallback index on CPU...") |
|
cpu_index.train(fallback_embeddings) |
|
|
|
|
|
if self.faiss_gpu: |
|
logger.info("Moving trained index back to GPU...") |
|
if len(self.gpu_resources) > 1: |
|
self.index = faiss.index_cpu_to_gpus_list(cpu_index, self.gpu_resources) |
|
else: |
|
self.index = faiss.index_cpu_to_gpu(self.gpu_resources[0], 0, cpu_index) |
|
else: |
|
self.index = cpu_index |
|
|
|
logger.info("Fallback training completed successfully") |
|
|
|
except Exception as e2: |
|
logger.error(f"Fallback training also failed: {str(e2)}") |
|
logger.warning("Creating basic brute-force index as last resort") |
|
|
|
try: |
|
|
|
dim = response_embeddings.shape[1] |
|
basic_index = faiss.IndexFlatL2(dim) |
|
|
|
if self.faiss_gpu: |
|
if len(self.gpu_resources) > 1: |
|
self.index = faiss.index_cpu_to_gpus_list(basic_index, self.gpu_resources) |
|
else: |
|
self.index = faiss.index_cpu_to_gpu(self.gpu_resources[0], 0, basic_index) |
|
else: |
|
self.index = basic_index |
|
|
|
logger.info("Basic index created as fallback") |
|
|
|
except Exception as e3: |
|
logger.error(f"All training attempts failed: {str(e3)}") |
|
raise RuntimeError("Unable to create working FAISS index") |
|
|
|
def _add_vectors_to_index(self, response_embeddings: np.ndarray) -> None: |
|
"""Add vectors to FAISS index with enhanced memory management.""" |
|
logger.info("Starting vector addition process...") |
|
|
|
|
|
initial_batch_size = 50 |
|
min_batch_size = 10 |
|
max_batch_size = 500 |
|
|
|
total_added = 0 |
|
retry_count = 0 |
|
max_retries = 5 |
|
|
|
while total_added < len(response_embeddings): |
|
try: |
|
|
|
if self.memory_monitor.has_gpu: |
|
gpu_usage = self.memory_monitor.get_memory_usage() |
|
|
|
|
|
|
|
if gpu_usage > 0.7: |
|
logger.info("High memory usage detected, forcing cleanup") |
|
gc.collect() |
|
tf.keras.backend.clear_session() |
|
|
|
|
|
end_idx = min(total_added + initial_batch_size, len(response_embeddings)) |
|
batch = response_embeddings[total_added:end_idx] |
|
|
|
|
|
self.index.add(batch) |
|
|
|
|
|
batch_size = len(batch) |
|
total_added += batch_size |
|
|
|
|
|
|
|
if total_added % (initial_batch_size * 5) == 0: |
|
gc.collect() |
|
if tf.config.list_physical_devices('GPU'): |
|
tf.keras.backend.clear_session() |
|
|
|
|
|
if initial_batch_size < max_batch_size: |
|
initial_batch_size = min(initial_batch_size + 25, max_batch_size) |
|
|
|
except Exception as e: |
|
logger.warning(f"Error adding batch: {str(e)}") |
|
retry_count += 1 |
|
|
|
if retry_count > max_retries: |
|
logger.error("Max retries exceeded.") |
|
raise |
|
|
|
|
|
initial_batch_size = max(min_batch_size, initial_batch_size // 2) |
|
logger.info(f"Reducing batch size to {initial_batch_size} and retrying...") |
|
|
|
|
|
gc.collect() |
|
if tf.config.list_physical_devices('GPU'): |
|
tf.keras.backend.clear_session() |
|
|
|
time.sleep(1) |
|
|
|
logger.info(f"Successfully added all {total_added} vectors to index") |
|
|
|
def _add_vectors_cpu_fallback(self, remaining_embeddings: np.ndarray, already_added: int = 0) -> None: |
|
"""CPU fallback with extra safeguards and progress tracking.""" |
|
logger.info(f"CPU Fallback: Adding {len(remaining_embeddings)} remaining vectors...") |
|
|
|
try: |
|
|
|
if self.faiss_gpu: |
|
logger.info("Moving index to CPU...") |
|
cpu_index = faiss.index_gpu_to_cpu(self.index) |
|
else: |
|
cpu_index = self.index |
|
|
|
|
|
batch_size = 50 |
|
total_added = already_added |
|
|
|
for i in range(0, len(remaining_embeddings), batch_size): |
|
end_idx = min(i + batch_size, len(remaining_embeddings)) |
|
batch = remaining_embeddings[i:end_idx] |
|
|
|
|
|
cpu_index.add(batch) |
|
|
|
|
|
total_added += len(batch) |
|
if i % (batch_size * 10) == 0: |
|
logger.info(f"Added {total_added} vectors total " |
|
f"({i}/{len(remaining_embeddings)} in current phase)") |
|
|
|
|
|
if i % (batch_size * 20) == 0: |
|
gc.collect() |
|
|
|
|
|
if self.faiss_gpu: |
|
logger.info("Moving index back to GPU...") |
|
if len(self.gpu_resources) > 1: |
|
self.index = faiss.index_cpu_to_gpus_list(cpu_index, self.gpu_resources) |
|
else: |
|
self.index = faiss.index_cpu_to_gpu(self.gpu_resources[0], 0, cpu_index) |
|
else: |
|
self.index = cpu_index |
|
|
|
logger.info("CPU fallback completed successfully") |
|
|
|
except Exception as e: |
|
logger.error(f"Error during CPU fallback: {str(e)}") |
|
raise |
|
|
|
def _compute_and_index_embeddings(self): |
|
"""Compute embeddings and build FAISS index with simpler handling.""" |
|
logger.info("Computing embeddings and indexing with FAISS...") |
|
|
|
try: |
|
|
|
logger.info("Encoding unique responses") |
|
response_embeddings = self.encode_responses(self.unique_responses) |
|
response_embeddings = response_embeddings.numpy() |
|
|
|
|
|
gc.collect() |
|
if tf.config.list_physical_devices('GPU'): |
|
tf.keras.backend.clear_session() |
|
|
|
|
|
response_embeddings = response_embeddings.astype('float32') |
|
response_embeddings = np.ascontiguousarray(response_embeddings) |
|
|
|
|
|
if self.memory_monitor.has_gpu: |
|
stats = self.memory_monitor.get_memory_stats() |
|
if stats: |
|
logger.info(f"GPU memory before normalization: {stats.used/1e9:.2f}GB used") |
|
|
|
|
|
logger.info("Normalizing embeddings with FAISS") |
|
faiss.normalize_L2(response_embeddings) |
|
|
|
|
|
dim = response_embeddings.shape[1] |
|
if self.faiss_gpu: |
|
cpu_index = faiss.IndexFlatIP(dim) |
|
if len(self.gpu_resources) > 1: |
|
self.index = faiss.index_cpu_to_gpus_list(cpu_index, self.gpu_resources) |
|
else: |
|
self.index = faiss.index_cpu_to_gpu(self.gpu_resources[0], 0, cpu_index) |
|
else: |
|
self.index = faiss.IndexFlatIP(dim) |
|
|
|
|
|
self._add_vectors_to_index(response_embeddings) |
|
|
|
|
|
self.response_pool = self.unique_responses |
|
self.response_embeddings = response_embeddings |
|
|
|
|
|
gc.collect() |
|
if tf.config.list_physical_devices('GPU'): |
|
tf.keras.backend.clear_session() |
|
|
|
|
|
logger.info(f"Successfully indexed {self.index.ntotal} responses") |
|
if self.memory_monitor.has_gpu: |
|
stats = self.memory_monitor.get_memory_stats() |
|
if stats: |
|
logger.info(f"Final GPU memory usage: {stats.used/1e9:.2f}GB used") |
|
|
|
logger.info("Indexing completed successfully") |
|
|
|
except Exception as e: |
|
logger.error(f"Error during indexing: {e}") |
|
|
|
gc.collect() |
|
if tf.config.list_physical_devices('GPU'): |
|
tf.keras.backend.clear_session() |
|
raise |
|
|
|
def verify_faiss_index(self): |
|
"""Verify that FAISS index matches the response pool.""" |
|
indexed_size = self.index.ntotal |
|
pool_size = len(self.response_pool) |
|
logger.info(f"FAISS index size: {indexed_size}") |
|
logger.info(f"Response pool size: {pool_size}") |
|
if indexed_size != pool_size: |
|
logger.warning("Mismatch between FAISS index size and response pool size.") |
|
else: |
|
logger.info("FAISS index correctly matches the response pool.") |
|
|
|
def encode_query(self, query: str, context: Optional[List[Tuple[str, str]]] = None) -> tf.Tensor: |
|
"""Encode a query with optional conversation context.""" |
|
|
|
if context: |
|
context_str = ' '.join([ |
|
f"{self.special_tokens['user']} {q} " |
|
f"{self.special_tokens['assistant']} {r}" |
|
for q, r in context[-self.config.max_context_turns:] |
|
]) |
|
query = f"{context_str} {self.special_tokens['user']} {query}" |
|
else: |
|
query = f"{self.special_tokens['user']} {query}" |
|
|
|
|
|
encodings = self.tokenizer( |
|
[query], |
|
padding='max_length', |
|
truncation=True, |
|
max_length=self.config.max_context_token_limit, |
|
return_tensors='tf' |
|
) |
|
input_ids = encodings['input_ids'] |
|
|
|
|
|
max_id = tf.reduce_max(input_ids).numpy() |
|
new_vocab_size = len(self.tokenizer) |
|
|
|
if max_id >= new_vocab_size: |
|
logger.error(f"Token ID {max_id} exceeds the vocabulary size {new_vocab_size}.") |
|
raise ValueError("Token ID exceeds vocabulary size.") |
|
|
|
|
|
return self.encoder(input_ids, training=False) |
|
|
|
def retrieve_responses_cross_encoder( |
|
self, |
|
query: str, |
|
top_k: int, |
|
reranker: Optional[CrossEncoderReranker] = None, |
|
summarizer: Optional[Summarizer] = None, |
|
summarize_threshold: int = 512 |
|
) -> List[Tuple[str, float]]: |
|
""" |
|
Retrieve top-k from FAISS, then re-rank them with a cross-encoder. |
|
Optionally summarize the user query if it's too long. |
|
""" |
|
if reranker is None: |
|
reranker = self.reranker |
|
if summarizer is None: |
|
summarizer = self.summarizer |
|
|
|
|
|
if summarizer and len(query.split()) > summarize_threshold: |
|
logger.info(f"Query is long. Summarizing before cross-encoder. Original length: {len(query.split())}") |
|
query = summarizer.summarize_text(query) |
|
logger.info(f"Summarized query: {query}") |
|
|
|
|
|
dense_topk = self.retrieve_responses_faiss(query, top_k=top_k) |
|
|
|
if not dense_topk: |
|
return [] |
|
|
|
|
|
candidate_texts = [pair[0] for pair in dense_topk] |
|
cross_scores = reranker.rerank(query, candidate_texts, max_length=256) |
|
|
|
|
|
combined = [(text, score) for (text, _), score in zip(dense_topk, cross_scores)] |
|
|
|
combined.sort(key=lambda x: x[1], reverse=True) |
|
|
|
return combined |
|
|
|
def retrieve_responses_faiss(self, query: str, top_k: int = 5) -> List[Tuple[str, float]]: |
|
"""Retrieve top-k responses using FAISS.""" |
|
|
|
q_emb = self.encode_query(query) |
|
q_emb_np = q_emb.numpy().astype('float32') |
|
|
|
|
|
faiss.normalize_L2(q_emb_np) |
|
|
|
|
|
distances, indices = self.index.search(q_emb_np, top_k) |
|
|
|
|
|
top_responses = [] |
|
for i, idx in enumerate(indices[0]): |
|
if idx < len(self.response_pool): |
|
top_responses.append((self.response_pool[idx], float(distances[0][i]))) |
|
else: |
|
logger.warning(f"FAISS returned invalid index {idx}. Skipping.") |
|
|
|
return top_responses |
|
|
|
def save_models(self, save_dir: Union[str, Path]): |
|
"""Save models and configuration.""" |
|
save_dir = Path(save_dir) |
|
save_dir.mkdir(parents=True, exist_ok=True) |
|
|
|
|
|
with open(save_dir / "config.json", "w") as f: |
|
json.dump(self.config.to_dict(), f, indent=2) |
|
|
|
|
|
self.encoder.pretrained.save_pretrained(save_dir / "shared_encoder") |
|
|
|
|
|
self.tokenizer.save_pretrained(save_dir / "tokenizer") |
|
|
|
logger.info(f"Models and tokenizer saved to {save_dir}.") |
|
|
|
@classmethod |
|
def load_models(cls, load_dir: Union[str, Path]) -> 'RetrievalChatbot': |
|
"""Load saved models and configuration.""" |
|
load_dir = Path(load_dir) |
|
|
|
|
|
with open(load_dir / "config.json", "r") as f: |
|
config = ChatbotConfig.from_dict(json.load(f)) |
|
|
|
|
|
chatbot = cls(config) |
|
|
|
|
|
chatbot.encoder.pretrained = TFAutoModel.from_pretrained( |
|
load_dir / "shared_encoder", |
|
config=config |
|
) |
|
|
|
|
|
chatbot.tokenizer = AutoTokenizer.from_pretrained(load_dir / "tokenizer") |
|
|
|
logger.info(f"Models and tokenizer loaded from {load_dir}.") |
|
return chatbot |
|
|
|
@staticmethod |
|
def load_training_data(data_path: Union[str, Path], debug_samples: Optional[int] = None) -> List[dict]: |
|
""" |
|
Load training data from a JSON file. |
|
|
|
Args: |
|
data_path (Union[str, Path]): Path to the JSON file containing dialogues. |
|
debug_samples (Optional[int]): Number of samples to load for debugging. |
|
|
|
Returns: |
|
List[dict]: List of dialogue dictionaries. |
|
""" |
|
logger.info(f"Loading training data from {data_path}...") |
|
data_path = Path(data_path) |
|
if not data_path.exists(): |
|
logger.error(f"Data file {data_path} does not exist.") |
|
return [] |
|
|
|
with open(data_path, 'r', encoding='utf-8') as f: |
|
dialogues = json.load(f) |
|
|
|
if debug_samples is not None: |
|
dialogues = dialogues[:debug_samples] |
|
logger.info(f"Debug mode: Limited to {debug_samples} dialogues") |
|
|
|
logger.info(f"Loaded {len(dialogues)} dialogues.") |
|
return dialogues |
|
|
|
def train_streaming( |
|
self, |
|
dialogues: List[dict], |
|
epochs: int = 20, |
|
batch_size: int = 16, |
|
validation_split: float = 0.2, |
|
checkpoint_dir: str = "checkpoints/", |
|
use_lr_schedule: bool = True, |
|
peak_lr: float = 2e-5, |
|
warmup_steps_ratio: float = 0.1, |
|
early_stopping_patience: int = 3, |
|
min_delta: float = 1e-4, |
|
buffer_size: int = 10, |
|
neg_samples: int = 1 |
|
) -> None: |
|
""" |
|
Streaming version of training that interleaves training/val batches by |
|
giving priority to training until we meet `steps_per_epoch`, then |
|
sending leftover batches to validation. |
|
""" |
|
logger.info("Starting streaming training pipeline...") |
|
|
|
|
|
dataset_preparer = StreamingDataPipeline( |
|
tokenizer=self.tokenizer, |
|
encoder=self.encoder, |
|
index=self.index, |
|
response_pool=self.response_pool, |
|
max_length=self.config.max_context_token_limit, |
|
batch_size=batch_size, |
|
neg_samples=neg_samples |
|
) |
|
|
|
|
|
total_pairs = dataset_preparer.estimate_total_pairs(dialogues) |
|
train_size = total_pairs * (1 - validation_split) |
|
steps_per_epoch = int(math.ceil(train_size / batch_size)) |
|
val_steps = int(math.ceil((total_pairs * validation_split) / batch_size)) |
|
total_steps = steps_per_epoch * epochs |
|
|
|
logger.info(f"Total pairs: {total_pairs}") |
|
logger.info(f"Training pairs: {train_size}") |
|
logger.info(f"Steps per epoch: {steps_per_epoch}") |
|
logger.info(f"Validation steps: {val_steps}") |
|
logger.info(f"Total steps: {total_steps}") |
|
|
|
|
|
if use_lr_schedule: |
|
warmup_steps = int(total_steps * warmup_steps_ratio) |
|
lr_schedule = self._get_lr_schedule( |
|
total_steps=total_steps, |
|
peak_lr=peak_lr, |
|
warmup_steps=warmup_steps |
|
) |
|
self.optimizer = tf.keras.optimizers.Adam(learning_rate=lr_schedule) |
|
logger.info("Using custom learning rate schedule.") |
|
else: |
|
self.optimizer = tf.keras.optimizers.Adam(learning_rate=peak_lr) |
|
logger.info("Using fixed learning rate.") |
|
|
|
|
|
checkpoint = tf.train.Checkpoint(optimizer=self.optimizer, model=self.encoder) |
|
manager = tf.train.CheckpointManager(checkpoint, checkpoint_dir, max_to_keep=3) |
|
|
|
|
|
log_dir = Path(checkpoint_dir) / "tensorboard_logs" |
|
log_dir.mkdir(parents=True, exist_ok=True) |
|
current_time = datetime.datetime.now().strftime("%Y%m%d-%H%M%S") |
|
train_log_dir = str(log_dir / f"train_{current_time}") |
|
val_log_dir = str(log_dir / f"val_{current_time}") |
|
train_summary_writer = tf.summary.create_file_writer(train_log_dir) |
|
val_summary_writer = tf.summary.create_file_writer(val_log_dir) |
|
|
|
logger.info(f"TensorBoard logs will be saved in {log_dir}") |
|
|
|
|
|
best_val_loss = float("inf") |
|
epochs_no_improve = 0 |
|
|
|
try: |
|
epoch_pbar = tqdm(range(1, epochs + 1), desc="Training", unit="epoch") |
|
is_tqdm_epoch = True |
|
except ImportError: |
|
epoch_pbar = range(1, epochs + 1) |
|
is_tqdm_epoch = False |
|
logger.info("Epoch progress bar disabled - continuing without visual progress") |
|
|
|
for epoch in epoch_pbar: |
|
|
|
train_queue = Queue(maxsize=buffer_size) |
|
val_queue = Queue(maxsize=buffer_size) |
|
stop_flag = threading.Event() |
|
|
|
def data_pipeline_worker(): |
|
"""Thread function that processes dialogues and sends batches to train or val.""" |
|
try: |
|
train_batches_needed = steps_per_epoch |
|
val_batches_needed = val_steps |
|
train_batches_sent = 0 |
|
val_batches_sent = 0 |
|
|
|
logger.info(f"Pipeline starting: need {train_batches_needed} train batches, {val_batches_needed} val batches") |
|
|
|
|
|
|
|
random.shuffle(dataset_preparer.processed_pairs) |
|
|
|
while (train_batches_sent < train_batches_needed or |
|
val_batches_sent < val_batches_needed): |
|
|
|
|
|
for batch in dataset_preparer.process_dialogues(dialogues): |
|
if stop_flag.is_set(): |
|
logger.warning("Pipeline stopped early") |
|
break |
|
|
|
if train_batches_sent < train_batches_needed: |
|
train_queue.put(batch) |
|
train_batches_sent += 1 |
|
elif val_batches_sent < val_batches_needed: |
|
val_queue.put(batch) |
|
val_batches_sent += 1 |
|
else: |
|
|
|
break |
|
|
|
|
|
if train_batches_sent < train_batches_needed or val_batches_sent < val_batches_needed: |
|
logger.info("Data exhausted, repeating since we still need more batches...") |
|
|
|
random.shuffle(dataset_preparer.processed_pairs) |
|
else: |
|
|
|
break |
|
|
|
logger.info( |
|
f"Pipeline complete: sent {train_batches_sent}/{train_batches_needed} train batches, " |
|
f"{val_batches_sent}/{val_batches_needed} val batches" |
|
) |
|
|
|
except Exception as e: |
|
logger.error(f"Error in pipeline worker: {str(e)}") |
|
raise e |
|
finally: |
|
train_queue.put(None) |
|
val_queue.put(None) |
|
|
|
|
|
pipeline_thread = threading.Thread(target=data_pipeline_worker) |
|
pipeline_thread.start() |
|
|
|
try: |
|
|
|
epoch_loss_avg = tf.keras.metrics.Mean() |
|
batches_processed = 0 |
|
|
|
try: |
|
train_pbar = tqdm(total=steps_per_epoch, desc=f"Training Epoch {epoch}") |
|
is_tqdm_train = True |
|
except ImportError: |
|
train_pbar = None |
|
is_tqdm_train = False |
|
logger.info("Training progress bar disabled") |
|
|
|
while batches_processed < steps_per_epoch: |
|
try: |
|
batch = train_queue.get(timeout=1200) |
|
if batch is None: |
|
logger.warning(f"Received end signal after only {batches_processed}/{steps_per_epoch} batches") |
|
break |
|
|
|
q_batch, p_batch = batch[0], batch[1] |
|
attention_mask = batch[2] if len(batch) > 2 else None |
|
|
|
loss = self.train_step(q_batch, p_batch, attention_mask) |
|
epoch_loss_avg(loss) |
|
batches_processed += 1 |
|
|
|
|
|
with train_summary_writer.as_default(): |
|
tf.summary.scalar("loss", loss, step=epoch) |
|
|
|
|
|
if use_lr_schedule: |
|
current_lr = float(lr_schedule(self.optimizer.iterations)) |
|
else: |
|
current_lr = float(self.optimizer.learning_rate.numpy()) |
|
|
|
if is_tqdm_train: |
|
train_pbar.update(1) |
|
train_pbar.set_postfix({ |
|
"loss": f"{loss.numpy():.4f}", |
|
"lr": f"{current_lr:.2e}", |
|
"batches": f"{batches_processed}/{steps_per_epoch}" |
|
}) |
|
|
|
except Empty: |
|
logger.warning(f"Queue timeout after {batches_processed}/{steps_per_epoch} batches") |
|
break |
|
|
|
if is_tqdm_train and train_pbar: |
|
train_pbar.close() |
|
|
|
|
|
val_loss_avg = tf.keras.metrics.Mean() |
|
val_batches_processed = 0 |
|
|
|
try: |
|
val_pbar = tqdm(total=val_steps, desc="Validation") |
|
is_tqdm_val = True |
|
except ImportError: |
|
val_pbar = None |
|
is_tqdm_val = False |
|
logger.info("Validation progress bar disabled") |
|
|
|
while val_batches_processed < val_steps: |
|
try: |
|
batch = val_queue.get(timeout=30) |
|
if batch is None: |
|
logger.warning( |
|
f"Received end signal after {val_batches_processed}/{val_steps} validation batches" |
|
) |
|
break |
|
|
|
q_batch, p_batch = batch[0], batch[1] |
|
attention_mask = batch[2] if len(batch) > 2 else None |
|
|
|
val_loss = self.validation_step(q_batch, p_batch, attention_mask) |
|
val_loss_avg(val_loss) |
|
val_batches_processed += 1 |
|
|
|
if is_tqdm_val: |
|
val_pbar.update(1) |
|
val_pbar.set_postfix({ |
|
"val_loss": f"{val_loss.numpy():.4f}", |
|
"batches": f"{val_batches_processed}/{val_steps}" |
|
}) |
|
|
|
except Empty: |
|
logger.warning( |
|
f"Validation queue timeout after {val_batches_processed}/{val_steps} batches" |
|
) |
|
break |
|
|
|
if is_tqdm_val and val_pbar: |
|
val_pbar.close() |
|
|
|
|
|
train_loss = epoch_loss_avg.result().numpy() |
|
val_loss = val_loss_avg.result().numpy() |
|
logger.info(f"Epoch {epoch} Complete: Train Loss={train_loss:.4f}, Val Loss={val_loss:.4f}") |
|
|
|
|
|
with val_summary_writer.as_default(): |
|
tf.summary.scalar("val_loss", val_loss, step=epoch) |
|
|
|
|
|
manager.save() |
|
|
|
|
|
self.history['train_loss'].append(train_loss) |
|
self.history['val_loss'].append(val_loss) |
|
|
|
if use_lr_schedule: |
|
current_lr = float(lr_schedule(self.optimizer.iterations)) |
|
else: |
|
current_lr = float(self.optimizer.learning_rate.numpy()) |
|
|
|
self.history.setdefault('learning_rate', []).append(current_lr) |
|
|
|
|
|
if val_loss < best_val_loss - min_delta: |
|
best_val_loss = val_loss |
|
epochs_no_improve = 0 |
|
logger.info(f"Validation loss improved to {val_loss:.4f}. Reset patience.") |
|
else: |
|
epochs_no_improve += 1 |
|
logger.info(f"No improvement this epoch. Patience: {epochs_no_improve}/{early_stopping_patience}") |
|
if epochs_no_improve >= early_stopping_patience: |
|
logger.info("Early stopping triggered.") |
|
break |
|
|
|
except Exception as e: |
|
logger.error(f"Error during training: {str(e)}") |
|
stop_flag.set() |
|
raise e |
|
finally: |
|
|
|
stop_flag.set() |
|
pipeline_thread.join() |
|
|
|
logger.info("Streaming training completed!") |
|
|
|
|
|
@tf.function |
|
def train_step(self, q_batch: tf.Tensor, p_batch: tf.Tensor, attention_mask: Optional[tf.Tensor] = None) -> tf.Tensor: |
|
"""Single training step with tf.function optimization and partial batch handling.""" |
|
with tf.GradientTape() as tape: |
|
q_enc = self.encoder(q_batch, training=True) |
|
p_enc = self.encoder(p_batch, training=True) |
|
|
|
sim_matrix = tf.matmul(q_enc, p_enc, transpose_b=True) |
|
|
|
|
|
batch_size = tf.shape(q_enc)[0] |
|
labels = tf.range(batch_size, dtype=tf.int32) |
|
|
|
loss = tf.nn.sparse_softmax_cross_entropy_with_logits( |
|
labels=labels, logits=sim_matrix |
|
) |
|
|
|
|
|
if attention_mask is not None: |
|
loss = loss * attention_mask |
|
|
|
loss = tf.reduce_sum(loss) / tf.reduce_sum(attention_mask) |
|
else: |
|
loss = tf.reduce_mean(loss) |
|
|
|
gradients = tape.gradient(loss, self.encoder.trainable_variables) |
|
self.optimizer.apply_gradients(zip(gradients, self.encoder.trainable_variables)) |
|
return loss |
|
|
|
@tf.function |
|
def validation_step(self, q_batch: tf.Tensor, p_batch: tf.Tensor, attention_mask: Optional[tf.Tensor] = None) -> tf.Tensor: |
|
"""Single validation step with partial batch handling.""" |
|
q_enc = self.encoder(q_batch, training=False) |
|
p_enc = self.encoder(p_batch, training=False) |
|
|
|
sim_matrix = tf.matmul(q_enc, p_enc, transpose_b=True) |
|
batch_size = tf.shape(q_enc)[0] |
|
labels = tf.range(batch_size, dtype=tf.int32) |
|
|
|
loss = tf.nn.sparse_softmax_cross_entropy_with_logits( |
|
labels=labels, logits=sim_matrix |
|
) |
|
|
|
if attention_mask is not None: |
|
loss = loss * attention_mask |
|
loss = tf.reduce_sum(loss) / tf.reduce_sum(attention_mask) |
|
else: |
|
loss = tf.reduce_mean(loss) |
|
|
|
return loss |
|
|
|
def _get_lr_schedule( |
|
self, |
|
total_steps: int, |
|
peak_lr: float, |
|
warmup_steps: int |
|
) -> tf.keras.optimizers.schedules.LearningRateSchedule: |
|
"""Create a custom learning rate schedule with warmup and cosine decay.""" |
|
class CustomSchedule(tf.keras.optimizers.schedules.LearningRateSchedule): |
|
def __init__( |
|
self, |
|
total_steps: int, |
|
peak_lr: float, |
|
warmup_steps: int |
|
): |
|
super().__init__() |
|
self.total_steps = tf.cast(total_steps, tf.float32) |
|
self.peak_lr = tf.cast(peak_lr, tf.float32) |
|
|
|
|
|
adjusted_warmup_steps = min(warmup_steps, max(1, total_steps // 10)) |
|
self.warmup_steps = tf.cast(adjusted_warmup_steps, tf.float32) |
|
|
|
|
|
self.initial_lr = self.peak_lr * 0.1 |
|
self.min_lr = self.peak_lr * 0.01 |
|
|
|
logger.info(f"Learning rate schedule initialized:") |
|
logger.info(f" Initial LR: {float(self.initial_lr):.6f}") |
|
logger.info(f" Peak LR: {float(self.peak_lr):.6f}") |
|
logger.info(f" Min LR: {float(self.min_lr):.6f}") |
|
logger.info(f" Warmup steps: {int(self.warmup_steps)}") |
|
logger.info(f" Total steps: {int(self.total_steps)}") |
|
|
|
def __call__(self, step): |
|
step = tf.cast(step, tf.float32) |
|
|
|
|
|
warmup_factor = tf.minimum(1.0, step / self.warmup_steps) |
|
warmup_lr = self.initial_lr + (self.peak_lr - self.initial_lr) * warmup_factor |
|
|
|
|
|
decay_steps = tf.maximum(1.0, self.total_steps - self.warmup_steps) |
|
decay_factor = (step - self.warmup_steps) / decay_steps |
|
decay_factor = tf.minimum(tf.maximum(0.0, decay_factor), 1.0) |
|
|
|
cosine_decay = 0.5 * (1.0 + tf.cos(tf.constant(math.pi) * decay_factor)) |
|
decay_lr = self.min_lr + (self.peak_lr - self.min_lr) * cosine_decay |
|
|
|
|
|
final_lr = tf.where(step < self.warmup_steps, warmup_lr, decay_lr) |
|
|
|
|
|
final_lr = tf.maximum(self.min_lr, final_lr) |
|
final_lr = tf.where(tf.math.is_finite(final_lr), final_lr, self.min_lr) |
|
|
|
return final_lr |
|
|
|
def get_config(self): |
|
return { |
|
"total_steps": self.total_steps, |
|
"peak_lr": self.peak_lr, |
|
"warmup_steps": self.warmup_steps, |
|
} |
|
|
|
return CustomSchedule(total_steps, peak_lr, warmup_steps) |
|
|
|
def _cosine_similarity(self, emb1: np.ndarray, emb2: np.ndarray) -> np.ndarray: |
|
"""Compute cosine similarity between two numpy arrays.""" |
|
normalized_emb1 = emb1 / np.linalg.norm(emb1, axis=1, keepdims=True) |
|
normalized_emb2 = emb2 / np.linalg.norm(emb2, axis=1, keepdims=True) |
|
return np.dot(normalized_emb1, normalized_emb2.T) |
|
|
|
def chat( |
|
self, |
|
query: str, |
|
conversation_history: Optional[List[Tuple[str, str]]] = None, |
|
quality_checker: Optional['ResponseQualityChecker'] = None, |
|
top_k: int = 5, |
|
) -> Tuple[str, List[Tuple[str, float]], Dict[str, Any]]: |
|
""" |
|
Example chat method that always uses cross-encoder re-ranking |
|
if self.reranker is available. |
|
""" |
|
@self.run_on_device |
|
def get_response(self_arg, query_arg): |
|
|
|
conversation_str = self_arg._build_conversation_context(query_arg, conversation_history) |
|
|
|
|
|
results = self_arg.retrieve_responses_cross_encoder( |
|
query=conversation_str, |
|
top_k=top_k, |
|
reranker=self_arg.reranker, |
|
summarizer=self_arg.summarizer, |
|
summarize_threshold=512 |
|
) |
|
|
|
|
|
if not results: |
|
return ( |
|
"I'm sorry, but I couldn't find a relevant response.", |
|
[], |
|
{} |
|
) |
|
|
|
if quality_checker: |
|
metrics = quality_checker.check_response_quality(query_arg, results) |
|
if not metrics.get('is_confident', False): |
|
return ( |
|
"I need more information to provide a good answer. Could you please clarify?", |
|
results, |
|
metrics |
|
) |
|
return results[0][0], results, metrics |
|
|
|
return results[0][0], results, {} |
|
|
|
return get_response(self, query) |
|
|
|
def _build_conversation_context( |
|
self, |
|
query: str, |
|
conversation_history: Optional[List[Tuple[str, str]]] |
|
) -> str: |
|
"""Build conversation context with better memory management.""" |
|
if not conversation_history: |
|
return f"{self.special_tokens['user']} {query}" |
|
|
|
conversation_parts = [] |
|
for user_txt, assistant_txt in conversation_history: |
|
conversation_parts.extend([ |
|
f"{self.special_tokens['user']} {user_txt}", |
|
f"{self.special_tokens['assistant']} {assistant_txt}" |
|
]) |
|
|
|
conversation_parts.append(f"{self.special_tokens['user']} {query}") |
|
return "\n".join(conversation_parts) |
|
|
|
class StreamingDataPipeline: |
|
"""Helper class to manage the streaming data preparation pipeline with optimized caching and GPU usage.""" |
|
def __init__( |
|
self, |
|
tokenizer, |
|
encoder, |
|
index, |
|
response_pool, |
|
max_length: int, |
|
batch_size: int, |
|
neg_samples: int |
|
): |
|
self.tokenizer = tokenizer |
|
self.encoder = encoder |
|
self.index = index |
|
self.response_pool = response_pool |
|
self.max_length = max_length |
|
self.base_batch_size = batch_size |
|
self.neg_samples = neg_samples |
|
self.memory_monitor = GPUMemoryMonitor() |
|
|
|
|
|
self.hard_negatives_cache = {} |
|
self.processed_pairs = [] |
|
self.query_embeddings_cache = {} |
|
|
|
|
|
self.error_count = 0 |
|
self.max_retries = 3 |
|
|
|
|
|
self.current_batch_size = batch_size |
|
self.batch_increase_factor = 1.25 |
|
|
|
|
|
if len(response_pool) < 100: |
|
self.embedding_batch_size = 16 |
|
self.search_batch_size = 8 |
|
self.max_batch_size = 32 |
|
self.min_batch_size = 4 |
|
else: |
|
self.embedding_batch_size = 64 |
|
self.search_batch_size = 32 |
|
self.min_batch_size = max(8, batch_size // 4) |
|
self.max_batch_size = 64 |
|
|
|
def save_cache(self, cache_dir: Path) -> None: |
|
"""Save all cached data for future runs.""" |
|
cache_dir = Path(cache_dir) |
|
cache_dir.mkdir(parents=True, exist_ok=True) |
|
|
|
logger.info(f"Saving cache to {cache_dir}") |
|
|
|
|
|
embeddings_path = cache_dir / "query_embeddings.npy" |
|
np.save( |
|
embeddings_path, |
|
{k: v.numpy() if hasattr(v, 'numpy') else v |
|
for k, v in self.query_embeddings_cache.items()} |
|
) |
|
|
|
|
|
with open(cache_dir / "hard_negatives.json", 'w') as f: |
|
json.dump(self.hard_negatives_cache, f) |
|
|
|
with open(cache_dir / "processed_pairs.json", 'w') as f: |
|
json.dump(self.processed_pairs, f) |
|
|
|
logger.info("Cache saved successfully") |
|
|
|
def load_cache(self, cache_dir: Path) -> bool: |
|
"""Load cached data if available.""" |
|
cache_dir = Path(cache_dir) |
|
required_files = [ |
|
"query_embeddings.npy", |
|
"hard_negatives.json", |
|
"processed_pairs.json" |
|
] |
|
|
|
if not all((cache_dir / f).exists() for f in required_files): |
|
logger.info("Cache files not found") |
|
return False |
|
|
|
try: |
|
logger.info("Loading cache...") |
|
|
|
|
|
self.query_embeddings_cache = np.load( |
|
cache_dir / "query_embeddings.npy", |
|
allow_pickle=True |
|
).item() |
|
|
|
|
|
with open(cache_dir / "hard_negatives.json", 'r') as f: |
|
self.hard_negatives_cache = json.load(f) |
|
|
|
with open(cache_dir / "processed_pairs.json", 'r') as f: |
|
self.processed_pairs = json.load(f) |
|
|
|
logger.info(f"Cache loaded successfully: {len(self.processed_pairs)} pairs") |
|
return True |
|
|
|
except Exception as e: |
|
logger.error(f"Error loading cache: {e}") |
|
return False |
|
|
|
def _adjust_batch_size(self) -> None: |
|
"""Dynamically adjust batch size based on GPU memory usage.""" |
|
if self.memory_monitor: |
|
if self.memory_monitor.should_reduce_batch_size(): |
|
new_size = max(self.min_batch_size, self.current_batch_size // 2) |
|
if new_size != self.current_batch_size: |
|
if new_size < self.min_batch_size: |
|
logger.info(f"Reducing batch size to {new_size} due to high memory usage") |
|
self.current_batch_size = new_size |
|
gc.collect() |
|
if tf.config.list_physical_devices('GPU'): |
|
tf.keras.backend.clear_session() |
|
|
|
elif self.memory_monitor.can_increase_batch_size(): |
|
new_size = min(self.max_batch_size, int(self.current_batch_size * self.batch_increase_factor)) |
|
if new_size != self.current_batch_size: |
|
if new_size > self.max_batch_size: |
|
logger.info(f"Increasing batch size to {new_size}") |
|
self.current_batch_size = new_size |
|
|
|
def _add_progress_metrics(self, pbar, **metrics) -> None: |
|
"""Add memory and batch size metrics to progress bars.""" |
|
if self.memory_monitor: |
|
gpu_usage = self.memory_monitor.get_memory_usage() |
|
metrics['gpu_mem'] = f"{gpu_usage:.1%}" |
|
metrics['batch_size'] = self.current_batch_size |
|
pbar.set_postfix(**metrics) |
|
|
|
def preprocess_dialogues(self, dialogues: List[dict]) -> None: |
|
"""Preprocess all dialogues with error recovery and caching.""" |
|
retry_count = 0 |
|
|
|
while retry_count < self.max_retries: |
|
try: |
|
self._preprocess_dialogues_internal(dialogues) |
|
break |
|
except Exception as e: |
|
retry_count += 1 |
|
logger.warning(f"Preprocessing attempt {retry_count} failed: {e}") |
|
if retry_count == self.max_retries: |
|
logger.error("Max retries reached. Falling back to CPU processing") |
|
self._fallback_to_cpu_processing(dialogues) |
|
|
|
def _preprocess_dialogues_internal(self, dialogues: List[dict]) -> None: |
|
"""Internal preprocessing implementation with progress tracking.""" |
|
logger.info("Starting dialogue preprocessing...") |
|
|
|
|
|
unique_queries = set() |
|
query_positive_pairs = [] |
|
|
|
with tqdm(total=len(dialogues), desc="Collecting dialogue pairs") as pbar: |
|
for dialogue in dialogues: |
|
pairs = self._extract_pairs_from_dialogue(dialogue) |
|
for query, positive in pairs: |
|
unique_queries.add(query) |
|
query_positive_pairs.append((query, positive)) |
|
pbar.update(1) |
|
self._add_progress_metrics(pbar, pairs=len(query_positive_pairs)) |
|
|
|
|
|
logger.info("Precomputing query embeddings...") |
|
self.precompute_query_embeddings(list(unique_queries)) |
|
|
|
|
|
logger.info("Finding hard negatives for all pairs...") |
|
self._find_hard_negatives_for_pairs(query_positive_pairs) |
|
|
|
def precompute_query_embeddings(self, queries: List[str]) -> None: |
|
"""Precompute embeddings for all unique queries in batches.""" |
|
unique_queries = list(set(queries)) |
|
|
|
with tqdm(total=len(unique_queries), desc="Precomputing query embeddings") as pbar: |
|
for i in range(0, len(unique_queries), self.embedding_batch_size): |
|
|
|
self._adjust_batch_size() |
|
batch_size = min(self.embedding_batch_size, len(unique_queries) - i) |
|
|
|
|
|
batch_queries = unique_queries[i:i + batch_size] |
|
|
|
try: |
|
|
|
encoded = self.tokenizer( |
|
batch_queries, |
|
padding=True, |
|
truncation=True, |
|
max_length=self.max_length, |
|
return_tensors='tf' |
|
) |
|
|
|
|
|
embeddings = self.encoder(encoded['input_ids'], training=False) |
|
embeddings_np = embeddings.numpy().astype('float32') |
|
|
|
|
|
faiss.normalize_L2(embeddings_np) |
|
|
|
|
|
for query, emb in zip(batch_queries, embeddings_np): |
|
self.query_embeddings_cache[query] = emb |
|
|
|
pbar.update(len(batch_queries)) |
|
self._add_progress_metrics( |
|
pbar, |
|
cached=len(self.query_embeddings_cache), |
|
batch_size=batch_size |
|
) |
|
|
|
except Exception as e: |
|
logger.warning(f"Error processing batch: {e}") |
|
|
|
self.embedding_batch_size = max(self.min_batch_size, self.embedding_batch_size // 2) |
|
continue |
|
|
|
|
|
if i % (self.embedding_batch_size * 10) == 0: |
|
gc.collect() |
|
if tf.config.list_physical_devices('GPU'): |
|
tf.keras.backend.clear_session() |
|
|
|
logger.info(f"Cached embeddings for {len(self.query_embeddings_cache)} unique queries") |
|
|
|
def _extract_pairs_from_dialogue(self, dialogue: dict) -> List[Tuple[str, str]]: |
|
"""Extract query-response pairs from a dialogue.""" |
|
pairs = [] |
|
turns = dialogue.get('turns', []) |
|
|
|
for i in range(len(turns) - 1): |
|
current_turn = turns[i] |
|
next_turn = turns[i+1] |
|
|
|
if (current_turn.get('speaker') == 'user' and |
|
next_turn.get('speaker') == 'assistant' and |
|
'text' in current_turn and |
|
'text' in next_turn): |
|
|
|
query = current_turn['text'].strip() |
|
positive = next_turn['text'].strip() |
|
pairs.append((query, positive)) |
|
|
|
return pairs |
|
|
|
def _fallback_to_cpu_processing(self, dialogues: List[dict]) -> None: |
|
"""Fallback processing method using CPU only.""" |
|
logger.info("Falling back to CPU-only processing") |
|
|
|
self.current_batch_size = self.min_batch_size |
|
self.embedding_batch_size = 32 |
|
self.search_batch_size = 16 |
|
|
|
|
|
self._preprocess_dialogues_internal(dialogues) |
|
|
|
def process_dialogues(self, dialogues: List[dict]) -> Generator[Tuple[tf.Tensor, tf.Tensor, Optional[tf.Tensor]], None, None]: |
|
""" |
|
Process dialogues using cached data with dynamic batch sizing. |
|
Yields (q_tokens['input_ids'], p_tokens['input_ids'], attention_mask) tuples. |
|
""" |
|
|
|
if not self.processed_pairs: |
|
self.preprocess_dialogues(dialogues) |
|
|
|
|
|
current_queries = [] |
|
current_positives = [] |
|
|
|
|
|
total_examples_yielded = 0 |
|
total_batches_yielded = 0 |
|
|
|
with tqdm(total=len(self.processed_pairs), desc="Generating training batches", leave=False) as pbar: |
|
for i, (query, positive) in enumerate(self.processed_pairs): |
|
|
|
if i % 10 == 0: |
|
self._adjust_batch_size() |
|
|
|
|
|
current_queries.append(query) |
|
current_positives.append(positive) |
|
|
|
|
|
hard_negatives = self.hard_negatives_cache.get((query, positive), []) |
|
for neg_text in hard_negatives: |
|
current_queries.append(query) |
|
current_positives.append(neg_text) |
|
|
|
|
|
while len(current_queries) >= self.current_batch_size: |
|
batch_queries = current_queries[:self.current_batch_size] |
|
batch_positives = current_positives[:self.current_batch_size] |
|
|
|
|
|
batch_size_to_yield = len(batch_queries) |
|
total_examples_yielded += batch_size_to_yield |
|
total_batches_yielded += 1 |
|
|
|
yield self._prepare_batch(batch_queries, batch_positives, pad_to_batch_size=False) |
|
|
|
|
|
current_queries = current_queries[self.current_batch_size:] |
|
current_positives = current_positives[self.current_batch_size:] |
|
|
|
|
|
pbar.update(1) |
|
self._add_progress_metrics( |
|
pbar, |
|
pairs_processed=pbar.n, |
|
pending_pairs=len(current_queries) |
|
) |
|
|
|
|
|
if current_queries: |
|
leftover_size = len(current_queries) |
|
total_examples_yielded += leftover_size |
|
total_batches_yielded += 1 |
|
|
|
yield self._prepare_batch( |
|
current_queries, |
|
current_positives, |
|
pad_to_batch_size=True |
|
) |
|
|
|
def _find_hard_negatives_for_pairs(self, query_positive_pairs: List[Tuple[str, str]]) -> None: |
|
"""Process pairs in batches to find hard negatives with GPU acceleration.""" |
|
total_pairs = len(query_positive_pairs) |
|
|
|
|
|
if len(self.response_pool) < 1000: |
|
batch_size = min(8, self.search_batch_size) |
|
else: |
|
batch_size = self.search_batch_size |
|
|
|
try: |
|
pbar = tqdm(total=total_pairs, desc="Finding hard negatives") |
|
is_tqdm = True |
|
except ImportError: |
|
pbar = None |
|
is_tqdm = False |
|
logger.info("Progress bar disabled - continuing without visual progress") |
|
|
|
for i in range(0, total_pairs, batch_size): |
|
self._adjust_batch_size() |
|
|
|
batch_pairs = query_positive_pairs[i:i + batch_size] |
|
batch_queries, batch_positives = zip(*batch_pairs) |
|
|
|
batch_negatives = self._find_hard_negatives_batch( |
|
list(batch_queries), |
|
list(batch_positives) |
|
) |
|
|
|
for query, positive, negatives in zip(batch_queries, batch_positives, batch_negatives): |
|
self.hard_negatives_cache[(query, positive)] = negatives |
|
self.processed_pairs.append((query, positive)) |
|
|
|
if is_tqdm: |
|
pbar.update(len(batch_pairs)) |
|
self._add_progress_metrics( |
|
pbar, |
|
cached=len(self.processed_pairs), |
|
progress=f"{i+len(batch_pairs)}/{total_pairs}" |
|
) |
|
|
|
if is_tqdm: |
|
pbar.close() |
|
|
|
def _find_hard_negatives_batch(self, queries: List[str], positives: List[str]) -> List[List[str]]: |
|
"""Find hard negatives for a batch of queries with error handling and retries.""" |
|
retry_count = 0 |
|
total_responses = len(self.response_pool) |
|
|
|
|
|
if total_responses < 100: |
|
all_negatives = [] |
|
for positive in positives: |
|
available = [r for r in self.response_pool if r.strip() != positive.strip()] |
|
if available: |
|
negatives = list(np.random.choice( |
|
available, |
|
size=min(self.neg_samples, len(available)), |
|
replace=False |
|
)) |
|
else: |
|
negatives = [] |
|
|
|
while len(negatives) < self.neg_samples: |
|
negatives.append("") |
|
all_negatives.append(negatives) |
|
return all_negatives |
|
|
|
while retry_count < self.max_retries: |
|
try: |
|
|
|
query_embeddings = np.vstack([ |
|
self.query_embeddings_cache[q] for q in queries |
|
]).astype(np.float32) |
|
|
|
if not query_embeddings.flags['C_CONTIGUOUS']: |
|
query_embeddings = np.ascontiguousarray(query_embeddings) |
|
|
|
|
|
faiss.normalize_L2(query_embeddings) |
|
|
|
k = 1 |
|
|
|
|
|
assert query_embeddings.dtype == np.float32, f"Embeddings are not float32: {query_embeddings.dtype}" |
|
|
|
try: |
|
distances, indices = self.index.search(query_embeddings, k) |
|
except RuntimeError as e: |
|
logger.error(f"FAISS search failed: {e}") |
|
return self._fallback_random_negatives(queries, positives) |
|
|
|
|
|
all_negatives = [] |
|
for i, (query_indices, query, positive) in enumerate(zip(indices, queries, positives)): |
|
negatives = [] |
|
positive_strip = positive.strip() |
|
|
|
|
|
seen = {positive_strip} |
|
for idx in query_indices: |
|
if idx >= 0 and idx < total_responses: |
|
candidate = self.response_pool[idx].strip() |
|
if candidate and candidate not in seen: |
|
seen.add(candidate) |
|
negatives.append(candidate) |
|
if len(negatives) >= self.neg_samples: |
|
break |
|
|
|
|
|
if len(negatives) < self.neg_samples: |
|
available = [r for r in self.response_pool if r.strip() not in seen and r.strip()] |
|
if available: |
|
additional = np.random.choice( |
|
available, |
|
size=min(self.neg_samples - len(negatives), len(available)), |
|
replace=False |
|
) |
|
negatives.extend(additional) |
|
|
|
|
|
while len(negatives) < self.neg_samples: |
|
negatives.append("") |
|
|
|
all_negatives.append(negatives) |
|
|
|
return all_negatives |
|
|
|
except Exception as e: |
|
retry_count += 1 |
|
logger.warning(f"Hard negative search attempt {retry_count} failed: {e}") |
|
if retry_count == self.max_retries: |
|
logger.error("Max retries reached for hard negative search") |
|
return [[] for _ in queries] |
|
gc.collect() |
|
if tf.config.list_physical_devices('GPU'): |
|
tf.keras.backend.clear_session() |
|
|
|
def _fallback_random_negatives(self, queries: List[str], positives: List[str]) -> List[List[str]]: |
|
"""Fallback to random sampling when similarity search fails.""" |
|
logger.warning("Falling back to random negative sampling") |
|
all_negatives = [] |
|
for positive in positives: |
|
available = [r for r in self.response_pool if r.strip() != positive.strip()] |
|
negatives = list(np.random.choice( |
|
available, |
|
size=min(self.neg_samples, len(available)), |
|
replace=False |
|
)) if available else [] |
|
while len(negatives) < self.neg_samples: |
|
negatives.append("") |
|
all_negatives.append(negatives) |
|
return all_negatives |
|
|
|
def _prepare_batch( |
|
self, |
|
queries: List[str], |
|
positives: List[str], |
|
pad_to_batch_size: bool = False |
|
) -> Tuple[tf.Tensor, tf.Tensor, Optional[tf.Tensor]]: |
|
"""Prepare a batch with dynamic padding and memory optimization.""" |
|
actual_size = len(queries) |
|
|
|
|
|
if pad_to_batch_size and actual_size < self.current_batch_size: |
|
padding_needed = self.current_batch_size - actual_size |
|
queries.extend([queries[0]] * padding_needed) |
|
positives.extend([positives[0]] * padding_needed) |
|
|
|
attention_mask = tf.concat([ |
|
tf.ones((actual_size,), dtype=tf.float32), |
|
tf.zeros((padding_needed,), dtype=tf.float32) |
|
], axis=0) |
|
else: |
|
attention_mask = None |
|
|
|
try: |
|
|
|
q_tokens = self.tokenizer( |
|
queries, |
|
padding='max_length', |
|
truncation=True, |
|
max_length=self.max_length, |
|
return_tensors='tf' |
|
) |
|
p_tokens = self.tokenizer( |
|
positives, |
|
padding='max_length', |
|
truncation=True, |
|
max_length=self.max_length, |
|
return_tensors='tf' |
|
) |
|
|
|
return q_tokens['input_ids'], p_tokens['input_ids'], attention_mask |
|
|
|
except Exception as e: |
|
logger.error(f"Error preparing batch: {e}") |
|
|
|
gc.collect() |
|
if tf.config.list_physical_devices('GPU'): |
|
tf.keras.backend.clear_session() |
|
raise |
|
|
|
def estimate_total_pairs(self, dialogues: List[dict]) -> int: |
|
"""Estimate total number of training pairs including hard negatives.""" |
|
base_pairs = sum( |
|
len([ |
|
1 for i in range(len(d.get('turns', [])) - 1) |
|
if (d['turns'][i].get('speaker') == 'user' and |
|
d['turns'][i+1].get('speaker') == 'assistant') |
|
]) |
|
for d in dialogues |
|
) |
|
|
|
return base_pairs * (1 + self.neg_samples) |
|
|
|
def cleanup(self): |
|
"""Cleanup resources and memory.""" |
|
self.query_embeddings_cache.clear() |
|
gc.collect() |
|
if tf.config.list_physical_devices('GPU'): |
|
tf.keras.backend.clear_session() |
|
|