|
import os |
|
import sys |
|
import faiss |
|
import json |
|
import pickle |
|
from transformers import AutoTokenizer |
|
from tqdm.auto import tqdm |
|
from chatbot_model import ChatbotConfig, EncoderModel |
|
from environment_setup import EnvironmentSetup |
|
from tf_data_pipeline import TFDataPipeline |
|
from logger_config import config_logger |
|
|
|
logger = config_logger(__name__) |
|
|
|
os.environ["TOKENIZERS_PARALLELISM"] = "false" |
|
|
|
def cleanup_test_indices(faiss_dir, test_prefix='test_'): |
|
test_files = [f for f in os.listdir(faiss_dir) if f.startswith(test_prefix)] |
|
for file in test_files: |
|
file_path = os.path.join(faiss_dir, file) |
|
os.remove(file_path) |
|
logger.info(f"Removed test FAISS index file: {file_path}") |
|
|
|
def main(): |
|
|
|
MODELS_DIR = 'models' |
|
PROCESSED_DATA_DIR = 'processed_outputs' |
|
CACHE_DIR = 'cache' |
|
TOKENIZER_DIR = os.path.join(MODELS_DIR, 'tokenizer') |
|
FAISS_INDICES_DIR = os.path.join(MODELS_DIR, 'faiss_indices') |
|
TF_RECORD_DIR = 'training_data' |
|
FAISS_INDEX_PRODUCTION_PATH = os.path.join(FAISS_INDICES_DIR, 'faiss_index_production.index') |
|
FAISS_INDEX_TEST_PATH = os.path.join(FAISS_INDICES_DIR, 'faiss_index_test.index') |
|
ENVIRONMENT = 'production' |
|
if ENVIRONMENT == 'test': |
|
FAISS_INDEX_PATH = FAISS_INDEX_TEST_PATH |
|
else: |
|
FAISS_INDEX_PATH = FAISS_INDEX_PRODUCTION_PATH |
|
JSON_TRAINING_DATA_PATH = os.path.join(PROCESSED_DATA_DIR, 'augmented_dialogues.json') |
|
CACHE_FILE = os.path.join(CACHE_DIR, 'query_embeddings_cache.pkl') |
|
TF_RECORD_PATH = os.path.join(TF_RECORD_DIR, 'training_data.tfrecord') |
|
DEBUG_SAMPLES = None |
|
|
|
|
|
os.makedirs(MODELS_DIR, exist_ok=True) |
|
os.makedirs(PROCESSED_DATA_DIR, exist_ok=True) |
|
os.makedirs(CACHE_DIR, exist_ok=True) |
|
os.makedirs(TOKENIZER_DIR, exist_ok=True) |
|
os.makedirs(FAISS_INDICES_DIR, exist_ok=True) |
|
os.makedirs(TF_RECORD_DIR, exist_ok=True) |
|
|
|
|
|
config = ChatbotConfig() |
|
logger.info(f"Chatbot Configuration: {config}") |
|
|
|
|
|
try: |
|
tokenizer = AutoTokenizer.from_pretrained(config.pretrained_model) |
|
logger.info(f"Tokenizer '{config.pretrained_model}' loaded successfully.") |
|
tokenizer.add_special_tokens({'additional_special_tokens': ['<EMPTY_NEGATIVE>']}) |
|
logger.info("Added special tokens to tokenizer.") |
|
except Exception as e: |
|
logger.error(f"Failed to load tokenizer: {e}") |
|
sys.exit(1) |
|
|
|
|
|
try: |
|
encoder = EncoderModel(config=config) |
|
logger.info("EncoderModel initialized successfully.") |
|
encoder.pretrained.resize_token_embeddings(len(tokenizer)) |
|
logger.info(f"Token embeddings resized to: {len(tokenizer)}") |
|
except Exception as e: |
|
logger.error(f"Failed to initialize EncoderModel: {e}") |
|
sys.exit(1) |
|
|
|
|
|
try: |
|
dialogues = TFDataPipeline.load_json_training_data(JSON_TRAINING_DATA_PATH, DEBUG_SAMPLES) |
|
logger.info(f"Loaded {len(dialogues)} dialogues from {JSON_TRAINING_DATA_PATH}.") |
|
except Exception as e: |
|
logger.error(f"Failed to load dialogues: {e}") |
|
sys.exit(1) |
|
|
|
|
|
try: |
|
if os.path.exists(CACHE_FILE): |
|
with open(CACHE_FILE, 'rb') as f: |
|
query_embeddings_cache = pickle.load(f) |
|
logger.info(f"Loaded {len(query_embeddings_cache)} query embeddings from {CACHE_FILE}.") |
|
else: |
|
query_embeddings_cache = {} |
|
logger.info("Initialized empty query embeddings cache.") |
|
except Exception as e: |
|
logger.error(f"Failed to load or initialize query embeddings cache: {e}") |
|
sys.exit(1) |
|
|
|
|
|
try: |
|
data_pipeline = TFDataPipeline( |
|
config=config, |
|
tokenizer=tokenizer, |
|
encoder=encoder, |
|
index_file_path=FAISS_INDEX_PATH, |
|
response_pool=[], |
|
max_length=config.max_context_token_limit, |
|
neg_samples=config.neg_samples, |
|
query_embeddings_cache=query_embeddings_cache, |
|
index_type='IndexFlatIP', |
|
nlist=100, |
|
max_retries=config.max_retries |
|
) |
|
logger.info("TFDataPipeline initialized successfully.") |
|
except Exception as e: |
|
logger.error(f"Failed to initialize TFDataPipeline: {e}") |
|
sys.exit(1) |
|
|
|
|
|
try: |
|
response_pool = data_pipeline.collect_responses(dialogues) |
|
data_pipeline.response_pool = response_pool |
|
logger.info(f"Collected {len(response_pool)} unique assistant responses from dialogues.") |
|
except Exception as e: |
|
logger.error(f"Failed to collect responses: {e}") |
|
sys.exit(1) |
|
|
|
|
|
try: |
|
logger.info("Computing and adding response embeddings to FAISS index...") |
|
data_pipeline.compute_and_index_response_embeddings() |
|
logger.info("Response embeddings computed and added to FAISS index.") |
|
except Exception as e: |
|
logger.error(f"Failed to compute or add response embeddings: {e}") |
|
sys.exit(1) |
|
|
|
|
|
try: |
|
logger.info(f"Saving FAISS index to {FAISS_INDEX_PATH}...") |
|
faiss.write_index(data_pipeline.index, FAISS_INDEX_PATH) |
|
logger.info("FAISS index saved successfully.") |
|
|
|
response_pool_path = FAISS_INDEX_PATH.replace('.index', '_responses.json') |
|
with open(response_pool_path, 'w', encoding='utf-8') as f: |
|
json.dump(data_pipeline.response_pool, f, indent=2) |
|
logger.info(f"Response pool saved to {response_pool_path}.") |
|
except Exception as e: |
|
logger.error(f"Failed to save FAISS index: {e}") |
|
sys.exit(1) |
|
|
|
|
|
try: |
|
logger.info("Starting data preparation and saving as TFRecord...") |
|
data_pipeline.prepare_and_save_data(dialogues, TF_RECORD_PATH) |
|
logger.info(f"Data saved as TFRecord at {TF_RECORD_PATH}.") |
|
except Exception as e: |
|
logger.error(f"Failed during data preparation and saving: {e}") |
|
sys.exit(1) |
|
|
|
|
|
try: |
|
with open(CACHE_FILE, 'wb') as f: |
|
pickle.dump(data_pipeline.query_embeddings_cache, f) |
|
logger.info(f"Saved {len(data_pipeline.query_embeddings_cache)} query embeddings to {CACHE_FILE}.") |
|
except Exception as e: |
|
logger.error(f"Failed to save query embeddings cache: {e}") |
|
sys.exit(1) |
|
|
|
|
|
try: |
|
tokenizer.save_pretrained(TOKENIZER_DIR) |
|
logger.info(f"Tokenizer saved to {TOKENIZER_DIR}.") |
|
except Exception as e: |
|
logger.error(f"Failed to save tokenizer: {e}") |
|
sys.exit(1) |
|
|
|
logger.info("Data preparation pipeline completed successfully.") |
|
|
|
if __name__ == "__main__": |
|
main() |