Spaces:
Sleeping
Sleeping
| # -*- coding: utf-8 -*- | |
| """create_faiss_index.py | |
| """ | |
| import pandas as pd | |
| import numpy as np | |
| import faiss | |
| from sentence_transformers import InputExample, SentenceTransformer | |
| DATA_FILE_PATH = "omdena_qna_dataset/omdena_faq_training_data.csv" | |
| TRANSFORMER_MODEL_NAME = "all-distilroberta-v1" | |
| CACHE_DIR_PATH = "../working/cache/" | |
| MODEL_SAVE_PATH = "all-distilroberta-v1-model.pkl" | |
| FAISS_INDEX_FILE_PATH = "index.faiss" | |
| def load_data(file_path): | |
| qna_dataset = pd.read_csv(file_path) | |
| qna_dataset["id"] = qna_dataset.index | |
| return qna_dataset.dropna(subset=['Answers']).copy() | |
| def create_input_examples(qna_dataset): | |
| qna_dataset['QNA'] = qna_dataset.apply(lambda row: f"Question: {row['Questions']}, Answer: {row['Answers']}", axis=1) | |
| return qna_dataset.apply(lambda x: InputExample(texts=[x["QNA"]]), axis=1).tolist() | |
| def load_transformer_model(model_name, cache_folder): | |
| transformer_model = SentenceTransformer(model_name, cache_folder=cache_folder) | |
| return transformer_model | |
| def save_transformer_model(transformer_model, model_file): | |
| transformer_model.save(model_file) | |
| def create_faiss_index(transformer_model, qna_dataset): | |
| faiss_embeddings = transformer_model.encode(qna_dataset.Answers.values.tolist()) | |
| qna_dataset_indexed = qna_dataset.set_index(["id"], drop=False) | |
| id_index_array = np.array(qna_dataset_indexed.id.values).flatten().astype("int") | |
| normalized_embeddings = faiss_embeddings.copy() | |
| faiss.normalize_L2(normalized_embeddings) | |
| faiss_index = faiss.IndexIDMap(faiss.IndexFlatIP(len(faiss_embeddings[0]))) | |
| faiss_index.add_with_ids(normalized_embeddings, id_index_array) | |
| return faiss_index | |
| def save_faiss_index(faiss_index, filename): | |
| faiss.write_index(faiss_index, filename) | |
| def load_faiss_index(filename): | |
| return faiss.read_index(filename) | |
| def main(): | |
| qna_dataset = load_data(DATA_FILE_PATH) | |
| input_examples = create_input_examples(qna_dataset) | |
| transformer_model = load_transformer_model(TRANSFORMER_MODEL_NAME, CACHE_DIR_PATH) | |
| save_transformer_model(transformer_model, MODEL_SAVE_PATH) | |
| faiss_index = create_faiss_index(transformer_model, qna_dataset) | |
| save_faiss_index(faiss_index, FAISS_INDEX_FILE_PATH) | |
| faiss_index = load_faiss_index(FAISS_INDEX_FILE_PATH) | |
| if __name__ == "__main__": | |
| main() |