import os import re import faiss import numpy as np import pandas as pd from torch.utils.data import DataLoader from sklearn.model_selection import train_test_split from sentence_transformers import SentenceTransformer, InputExample, losses from FlagEmbedding import FlagReranker from nltk.corpus import stopwords from nltk.tokenize import word_tokenize from datetime import datetime, timedelta class BookSearch: def __init__( self, file_path="../data/book_data.csv", bert_model="paraphrase-MiniLM-L6-v2", index_path="vectors/fine_tuned_faiss_index_t.index", model_path="out/fine_tuned_sbert_model_test", retrain_period_days=30, ): self.file_path = file_path self.bert_model = bert_model self.index_path = index_path self.model_path = model_path self.model = None self.index = None self.retrain_period_days = retrain_period_days def preprocess_text(self, text): text = text.lower() text = re.sub(r"[^a-zA-Z\s]", "", text) tokens = word_tokenize(text) stop_words = set(stopwords.words("english")) tokens = [token for token in tokens if token not in stop_words] text = " ".join(tokens) return text def train_model(self): df = pd.read_csv(self.file_path) df = df.dropna(subset=["Title", "Description", "Genres"]) df["Title"] = df["Title"].apply(self.preprocess_text) df["Description"] = df["Description"].apply(self.preprocess_text) df["Genres"] = df["Genres"].apply(self.preprocess_text) train, _ = train_test_split(df, test_size=0.2, random_state=42) train_examples = [ InputExample( texts=[row["Title"], row["Genres"], row["Description"]], label=1.0 ) for _, row in train.iterrows() ] self.model = SentenceTransformer(self.bert_model) train_loader = DataLoader(train_examples, shuffle=True, batch_size=32) train_loss = losses.CosineSimilarityLoss(model=self.model) self.model.fit(train_objectives=[(train_loader, train_loss)], epochs=1) self.model.save(self.model_path) def load_model(self): if not os.path.exists(self.model_path) or self._is_file_older_than( self.model_path, self.retrain_period_days ): self.train_model() else: self.model = SentenceTransformer(self.model_path) def create_index(self): if not os.path.exists(self.index_path) or self._is_file_older_than( self.index_path, self.retrain_period_days ): df = pd.read_csv(self.file_path) documents = df["Description"].apply(self.preprocess_text).tolist() if not self.model: self.load_model() document_embeddings = self.model.encode(documents, convert_to_tensor=False) self.index = faiss.IndexFlatL2(document_embeddings.size(1)) self.index.add(document_embeddings) faiss.write_index(self.index, self.index_path) else: self.index = faiss.read_index(self.index_path) def semantic_search(self, query, k=5, rerank_k=3, flag_threshold=0): if not self.model: self.load_model() if not self.index: self.create_index() query_embedding = self.model.encode([query], convert_to_tensor=False) distances, indices = self.index.search(query_embedding, k + rerank_k) initial_indices = indices[0][:k] df = pd.read_csv(self.file_path) initial_documents = df.iloc[initial_indices][["Title", "Description", "Genres"]] genres_text = "".join(initial_documents["Genres"].to_list()) initial_documents["Text"] = ( initial_documents["Title"].str.lower() + " " + initial_documents["Description"].str.lower() + genres_text ) initial_distances = distances[0][:k] initial_results = list( zip( initial_documents["Title"], initial_documents["Text"], initial_distances ) ) if flag_threshold: flag_reranker = FlagReranker("BAAI/bge-small-en-v1.5", use_fp16=True) flag_scores = [ flag_reranker.compute_score([query, text]) for _, text, _ in initial_results ] reranked_results = [ (title, text, dist + flag_score) for title, text, dist, flag_score in zip( initial_documents["Title"], initial_documents["Text"], initial_distances, flag_scores, ) if abs(flag_score) > flag_threshold ] reranked_results = sorted( reranked_results, key=lambda x: x[2], reverse=True )[:rerank_k] else: reranked_results = initial_results[:rerank_k] return reranked_results def _is_file_older_than(self, file_path, days): if os.path.exists(file_path): modification_time = os.path.getmtime(file_path) modification_datetime = datetime.fromtimestamp(modification_time) current_datetime = datetime.now() return (current_datetime - modification_datetime).days > days return True book_search = BookSearch() query = "Love and Fiction" results = book_search.semantic_search(query) for rank, (title, text, score) in enumerate(results, start=1): print(f"Rank {rank}: {title} (Score: {score})")