File size: 4,842 Bytes
74af405 5b413d1 74af405 64e7c31 74af405 7a0020b 64e7c31 74af405 64e7c31 74af405 4aec49f 74af405 4aec49f 74af405 64e7c31 74af405 7a0020b 74af405 64e7c31 7a0020b 64e7c31 7a0020b 64e7c31 74af405 64e7c31 74af405 64e7c31 7a0020b 64e7c31 7a0020b 74af405 64e7c31 74af405 71ca212 64e7c31 7a0020b 64e7c31 7a0020b 74af405 7a0020b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 |
import os
import json
import pickle
import faiss
from tqdm.auto import tqdm
from pathlib import Path
from sentence_transformers import SentenceTransformer
from tf_data_pipeline import TFDataPipeline
from chatbot_config import ChatbotConfig
from logger_config import config_logger
logger = config_logger(__name__)
os.environ["TOKENIZERS_PARALLELISM"] = "false"
def main():
MODELS_DIR = 'models'
PROCESSED_DATA_DIR = 'processed_outputs'
CACHE_DIR = os.path.join(MODELS_DIR, 'query_embeddings_cache')
TOKENIZER_DIR = os.path.join(MODELS_DIR, 'tokenizer')
FAISS_INDICES_DIR = os.path.join(MODELS_DIR, 'faiss_indices')
TF_RECORD_DIR = 'training_data'
FAISS_INDEX_PRODUCTION_PATH = os.path.join(FAISS_INDICES_DIR, 'faiss_index_production.index')
JSON_TRAINING_DATA_PATH = os.path.join(PROCESSED_DATA_DIR, 'taskmaster_only.json')
CACHE_FILE = os.path.join(CACHE_DIR, 'query_embeddings_cache.pkl')
TF_RECORD_PATH = os.path.join(TF_RECORD_DIR, 'training_data_3.tfrecord')
# Ensure output directories exist
os.makedirs(MODELS_DIR, exist_ok=True)
os.makedirs(PROCESSED_DATA_DIR, exist_ok=True)
os.makedirs(CACHE_DIR, exist_ok=True)
os.makedirs(TOKENIZER_DIR, exist_ok=True)
os.makedirs(FAISS_INDICES_DIR, exist_ok=True)
os.makedirs(TF_RECORD_DIR, exist_ok=True)
# Load ChatbotConfig
config_json = Path(MODELS_DIR) / "config.json"
if config_json.exists():
with open(config_json, "r", encoding="utf-8") as f:
config_dict = json.load(f)
config = ChatbotConfig.from_dict(config_dict)
logger.info(f"Loaded ChatbotConfig from {config_json}")
else:
config = ChatbotConfig()
logger.warning("No config.json found. Using default ChatbotConfig.")
try:
with open(config_json, "w", encoding="utf-8") as f:
json.dump(config.to_dict(), f, indent=2)
logger.info(f"Default ChatbotConfig saved to {config_json}")
except Exception as e:
logger.error(f"Failed to save default ChatbotConfig: {e}")
raise
# Init SentenceTransformer
encoder = SentenceTransformer(config.pretrained_model)
logger.info(f"Initialized SentenceTransformer model: {config.pretrained_model}")
# Load dialogues
if Path(JSON_TRAINING_DATA_PATH).exists():
dialogues = TFDataPipeline.load_json_training_data(JSON_TRAINING_DATA_PATH)
logger.info(f"Loaded {len(dialogues)} dialogues.")
else:
logger.warning(f"No dialogues found at {JSON_TRAINING_DATA_PATH}.")
dialogues = []
# Load or init query embeddings cache
query_embeddings_cache = {}
if os.path.exists(CACHE_FILE):
with open(CACHE_FILE, 'rb') as f:
query_embeddings_cache = pickle.load(f)
logger.info(f"Loaded query embeddings cache with {len(query_embeddings_cache)} entries.")
else:
logger.info("No existing query embeddings cache found. Starting fresh.")
# Init FAISS index
dimension = encoder.get_sentence_embedding_dimension()
if Path(FAISS_INDEX_PRODUCTION_PATH).exists():
faiss_index = faiss.read_index(FAISS_INDEX_PRODUCTION_PATH)
logger.info(f"Loaded FAISS index from {FAISS_INDEX_PRODUCTION_PATH}.")
else:
faiss_index = faiss.IndexFlatIP(dimension)
logger.info(f"Initialized new FAISS index with dimension {dimension}.")
# Init TFDataPipeline
data_pipeline = TFDataPipeline(
config=config,
tokenizer=encoder.tokenizer,
encoder=encoder,
response_pool=[],
query_embeddings_cache=query_embeddings_cache,
index_type='IndexFlatIP',
faiss_index_file_path=FAISS_INDEX_PRODUCTION_PATH
)
# Collect and embed responses
if dialogues:
response_pool = data_pipeline.collect_responses_with_domain(dialogues)
data_pipeline.response_pool = response_pool
# Save the response pool
response_pool_path = FAISS_INDEX_PRODUCTION_PATH.replace('.index', '_responses.json')
with open(response_pool_path, 'w', encoding='utf-8') as f:
json.dump(response_pool, f, indent=2)
logger.info(f"Response pool saved to {response_pool_path}.")
data_pipeline.compute_and_index_response_embeddings()
data_pipeline.save_faiss_index(FAISS_INDEX_PRODUCTION_PATH)
logger.info(f"FAISS index saved at {FAISS_INDEX_PRODUCTION_PATH}.")
else:
logger.warning("No responses to embed. Skipping FAISS indexing.")
# Save query embeddings cache
with open(CACHE_FILE, 'wb') as f:
pickle.dump(query_embeddings_cache, f)
logger.info(f"Query embeddings cache saved at {CACHE_FILE}.")
logger.info("Pipeline completed successfully.")
if __name__ == "__main__":
main()
|