JoeArmani
sentence transformer
64e7c31
import os
import json
import pickle
import faiss
from tqdm.auto import tqdm
from pathlib import Path
from sentence_transformers import SentenceTransformer
from tf_data_pipeline import TFDataPipeline
from chatbot_config import ChatbotConfig
from logger_config import config_logger
logger = config_logger(__name__)
os.environ["TOKENIZERS_PARALLELISM"] = "false"
def main():
MODELS_DIR = 'models'
PROCESSED_DATA_DIR = 'processed_outputs'
CACHE_DIR = os.path.join(MODELS_DIR, 'query_embeddings_cache')
TOKENIZER_DIR = os.path.join(MODELS_DIR, 'tokenizer')
FAISS_INDICES_DIR = os.path.join(MODELS_DIR, 'faiss_indices')
TF_RECORD_DIR = 'training_data'
FAISS_INDEX_PRODUCTION_PATH = os.path.join(FAISS_INDICES_DIR, 'faiss_index_production.index')
JSON_TRAINING_DATA_PATH = os.path.join(PROCESSED_DATA_DIR, 'taskmaster_only.json')
CACHE_FILE = os.path.join(CACHE_DIR, 'query_embeddings_cache.pkl')
TF_RECORD_PATH = os.path.join(TF_RECORD_DIR, 'training_data_3.tfrecord')
# Ensure output directories exist
os.makedirs(MODELS_DIR, exist_ok=True)
os.makedirs(PROCESSED_DATA_DIR, exist_ok=True)
os.makedirs(CACHE_DIR, exist_ok=True)
os.makedirs(TOKENIZER_DIR, exist_ok=True)
os.makedirs(FAISS_INDICES_DIR, exist_ok=True)
os.makedirs(TF_RECORD_DIR, exist_ok=True)
# Load ChatbotConfig
config_json = Path(MODELS_DIR) / "config.json"
if config_json.exists():
with open(config_json, "r", encoding="utf-8") as f:
config_dict = json.load(f)
config = ChatbotConfig.from_dict(config_dict)
logger.info(f"Loaded ChatbotConfig from {config_json}")
else:
config = ChatbotConfig()
logger.warning("No config.json found. Using default ChatbotConfig.")
try:
with open(config_json, "w", encoding="utf-8") as f:
json.dump(config.to_dict(), f, indent=2)
logger.info(f"Default ChatbotConfig saved to {config_json}")
except Exception as e:
logger.error(f"Failed to save default ChatbotConfig: {e}")
raise
# Init SentenceTransformer
encoder = SentenceTransformer(config.pretrained_model)
logger.info(f"Initialized SentenceTransformer model: {config.pretrained_model}")
# Load dialogues
if Path(JSON_TRAINING_DATA_PATH).exists():
dialogues = TFDataPipeline.load_json_training_data(JSON_TRAINING_DATA_PATH)
logger.info(f"Loaded {len(dialogues)} dialogues.")
else:
logger.warning(f"No dialogues found at {JSON_TRAINING_DATA_PATH}.")
dialogues = []
# Load or init query embeddings cache
query_embeddings_cache = {}
if os.path.exists(CACHE_FILE):
with open(CACHE_FILE, 'rb') as f:
query_embeddings_cache = pickle.load(f)
logger.info(f"Loaded query embeddings cache with {len(query_embeddings_cache)} entries.")
else:
logger.info("No existing query embeddings cache found. Starting fresh.")
# Init FAISS index
dimension = encoder.get_sentence_embedding_dimension()
if Path(FAISS_INDEX_PRODUCTION_PATH).exists():
faiss_index = faiss.read_index(FAISS_INDEX_PRODUCTION_PATH)
logger.info(f"Loaded FAISS index from {FAISS_INDEX_PRODUCTION_PATH}.")
else:
faiss_index = faiss.IndexFlatIP(dimension)
logger.info(f"Initialized new FAISS index with dimension {dimension}.")
# Init TFDataPipeline
data_pipeline = TFDataPipeline(
config=config,
tokenizer=encoder.tokenizer,
encoder=encoder,
response_pool=[],
query_embeddings_cache=query_embeddings_cache,
index_type='IndexFlatIP',
faiss_index_file_path=FAISS_INDEX_PRODUCTION_PATH
)
# Collect and embed responses
if dialogues:
response_pool = data_pipeline.collect_responses_with_domain(dialogues)
data_pipeline.response_pool = response_pool
# Save the response pool
response_pool_path = FAISS_INDEX_PRODUCTION_PATH.replace('.index', '_responses.json')
with open(response_pool_path, 'w', encoding='utf-8') as f:
json.dump(response_pool, f, indent=2)
logger.info(f"Response pool saved to {response_pool_path}.")
data_pipeline.compute_and_index_response_embeddings()
data_pipeline.save_faiss_index(FAISS_INDEX_PRODUCTION_PATH)
logger.info(f"FAISS index saved at {FAISS_INDEX_PRODUCTION_PATH}.")
else:
logger.warning("No responses to embed. Skipping FAISS indexing.")
# Save query embeddings cache
with open(CACHE_FILE, 'wb') as f:
pickle.dump(query_embeddings_cache, f)
logger.info(f"Query embeddings cache saved at {CACHE_FILE}.")
logger.info("Pipeline completed successfully.")
if __name__ == "__main__":
main()