File size: 4,842 Bytes
74af405
5b413d1
74af405
64e7c31
74af405
7a0020b
64e7c31
74af405
64e7c31
74af405
 
 
 
 
 
 
4aec49f
 
 
74af405
 
4aec49f
74af405
64e7c31
74af405
7a0020b
 
74af405
 
 
 
 
 
 
 
64e7c31
7a0020b
 
 
 
 
 
 
 
 
64e7c31
 
 
 
 
 
 
7a0020b
64e7c31
 
 
74af405
64e7c31
 
 
 
 
 
 
74af405
64e7c31
7a0020b
 
64e7c31
 
 
7a0020b
 
74af405
64e7c31
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
74af405
71ca212
64e7c31
 
 
7a0020b
64e7c31
7a0020b
74af405
7a0020b
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
import os
import json
import pickle
import faiss
from tqdm.auto import tqdm
from pathlib import Path
from sentence_transformers import SentenceTransformer
from tf_data_pipeline import TFDataPipeline
from chatbot_config import ChatbotConfig
from logger_config import config_logger

logger = config_logger(__name__)

os.environ["TOKENIZERS_PARALLELISM"] = "false"

def main():
    MODELS_DIR = 'models'
    PROCESSED_DATA_DIR = 'processed_outputs'
    CACHE_DIR = os.path.join(MODELS_DIR, 'query_embeddings_cache')
    TOKENIZER_DIR = os.path.join(MODELS_DIR, 'tokenizer')
    FAISS_INDICES_DIR = os.path.join(MODELS_DIR, 'faiss_indices')
    TF_RECORD_DIR = 'training_data'
    FAISS_INDEX_PRODUCTION_PATH = os.path.join(FAISS_INDICES_DIR, 'faiss_index_production.index')
    JSON_TRAINING_DATA_PATH = os.path.join(PROCESSED_DATA_DIR, 'taskmaster_only.json')
    CACHE_FILE = os.path.join(CACHE_DIR, 'query_embeddings_cache.pkl')
    TF_RECORD_PATH = os.path.join(TF_RECORD_DIR, 'training_data_3.tfrecord')
    
    # Ensure output directories exist
    os.makedirs(MODELS_DIR, exist_ok=True)
    os.makedirs(PROCESSED_DATA_DIR, exist_ok=True)
    os.makedirs(CACHE_DIR, exist_ok=True)
    os.makedirs(TOKENIZER_DIR, exist_ok=True)
    os.makedirs(FAISS_INDICES_DIR, exist_ok=True)
    os.makedirs(TF_RECORD_DIR, exist_ok=True)
    
    # Load ChatbotConfig
    config_json = Path(MODELS_DIR) / "config.json"
    if config_json.exists():
        with open(config_json, "r", encoding="utf-8") as f:
            config_dict = json.load(f)
        config = ChatbotConfig.from_dict(config_dict)
        logger.info(f"Loaded ChatbotConfig from {config_json}")
    else:
        config = ChatbotConfig()
        logger.warning("No config.json found. Using default ChatbotConfig.")
        try:
            with open(config_json, "w", encoding="utf-8") as f:
                json.dump(config.to_dict(), f, indent=2)
            logger.info(f"Default ChatbotConfig saved to {config_json}")
        except Exception as e:
            logger.error(f"Failed to save default ChatbotConfig: {e}")
            raise
    
    # Init SentenceTransformer
    encoder = SentenceTransformer(config.pretrained_model)
    logger.info(f"Initialized SentenceTransformer model: {config.pretrained_model}")
    
    # Load dialogues
    if Path(JSON_TRAINING_DATA_PATH).exists():
        dialogues = TFDataPipeline.load_json_training_data(JSON_TRAINING_DATA_PATH)
        logger.info(f"Loaded {len(dialogues)} dialogues.")
    else:
        logger.warning(f"No dialogues found at {JSON_TRAINING_DATA_PATH}.")
        dialogues = []
    
    # Load or init query embeddings cache
    query_embeddings_cache = {}
    if os.path.exists(CACHE_FILE):
        with open(CACHE_FILE, 'rb') as f:
            query_embeddings_cache = pickle.load(f)
        logger.info(f"Loaded query embeddings cache with {len(query_embeddings_cache)} entries.")
    else:
        logger.info("No existing query embeddings cache found. Starting fresh.")
    
    # Init FAISS index
    dimension = encoder.get_sentence_embedding_dimension()
    if Path(FAISS_INDEX_PRODUCTION_PATH).exists():
        faiss_index = faiss.read_index(FAISS_INDEX_PRODUCTION_PATH)
        logger.info(f"Loaded FAISS index from {FAISS_INDEX_PRODUCTION_PATH}.")
    else:
        faiss_index = faiss.IndexFlatIP(dimension)
        logger.info(f"Initialized new FAISS index with dimension {dimension}.")
    
    # Init TFDataPipeline
    data_pipeline = TFDataPipeline(
        config=config,
        tokenizer=encoder.tokenizer,
        encoder=encoder,
        response_pool=[],
        query_embeddings_cache=query_embeddings_cache,
        index_type='IndexFlatIP',
        faiss_index_file_path=FAISS_INDEX_PRODUCTION_PATH
    )
    
    # Collect and embed responses
    if dialogues:
        response_pool = data_pipeline.collect_responses_with_domain(dialogues)
        data_pipeline.response_pool = response_pool
        
        # Save the response pool
        response_pool_path = FAISS_INDEX_PRODUCTION_PATH.replace('.index', '_responses.json')
        with open(response_pool_path, 'w', encoding='utf-8') as f:
            json.dump(response_pool, f, indent=2)
        logger.info(f"Response pool saved to {response_pool_path}.")
        data_pipeline.compute_and_index_response_embeddings()
        data_pipeline.save_faiss_index(FAISS_INDEX_PRODUCTION_PATH)
        logger.info(f"FAISS index saved at {FAISS_INDEX_PRODUCTION_PATH}.")
    else:
        logger.warning("No responses to embed. Skipping FAISS indexing.")
    
    # Save query embeddings cache
    with open(CACHE_FILE, 'wb') as f:
        pickle.dump(query_embeddings_cache, f)
    logger.info(f"Query embeddings cache saved at {CACHE_FILE}.")

    logger.info("Pipeline completed successfully.")

if __name__ == "__main__":
    main()