|
import os |
|
import re |
|
import faiss |
|
import numpy as np |
|
import pandas as pd |
|
from torch.utils.data import DataLoader |
|
from sklearn.model_selection import train_test_split |
|
from sentence_transformers import SentenceTransformer, InputExample, losses |
|
from FlagEmbedding import FlagReranker |
|
from nltk.corpus import stopwords |
|
from nltk.tokenize import word_tokenize |
|
from datetime import datetime, timedelta |
|
|
|
|
|
class BookSearch: |
|
def __init__( |
|
self, |
|
file_path="../data/book_data.csv", |
|
bert_model="paraphrase-MiniLM-L6-v2", |
|
index_path="vectors/fine_tuned_faiss_index_t.index", |
|
model_path="out/fine_tuned_sbert_model_test", |
|
retrain_period_days=30, |
|
): |
|
self.file_path = file_path |
|
self.bert_model = bert_model |
|
self.index_path = index_path |
|
self.model_path = model_path |
|
self.model = None |
|
self.index = None |
|
self.retrain_period_days = retrain_period_days |
|
|
|
def preprocess_text(self, text): |
|
text = text.lower() |
|
text = re.sub(r"[^a-zA-Z\s]", "", text) |
|
tokens = word_tokenize(text) |
|
stop_words = set(stopwords.words("english")) |
|
tokens = [token for token in tokens if token not in stop_words] |
|
text = " ".join(tokens) |
|
return text |
|
|
|
def train_model(self): |
|
df = pd.read_csv(self.file_path) |
|
df = df.dropna(subset=["Title", "Description", "Genres"]) |
|
|
|
df["Title"] = df["Title"].apply(self.preprocess_text) |
|
df["Description"] = df["Description"].apply(self.preprocess_text) |
|
df["Genres"] = df["Genres"].apply(self.preprocess_text) |
|
|
|
train, _ = train_test_split(df, test_size=0.2, random_state=42) |
|
|
|
train_examples = [ |
|
InputExample( |
|
texts=[row["Title"], row["Genres"], row["Description"]], label=1.0 |
|
) |
|
for _, row in train.iterrows() |
|
] |
|
|
|
self.model = SentenceTransformer(self.bert_model) |
|
train_loader = DataLoader(train_examples, shuffle=True, batch_size=32) |
|
train_loss = losses.CosineSimilarityLoss(model=self.model) |
|
self.model.fit(train_objectives=[(train_loader, train_loss)], epochs=1) |
|
self.model.save(self.model_path) |
|
|
|
def load_model(self): |
|
if not os.path.exists(self.model_path) or self._is_file_older_than( |
|
self.model_path, self.retrain_period_days |
|
): |
|
self.train_model() |
|
else: |
|
self.model = SentenceTransformer(self.model_path) |
|
|
|
def create_index(self): |
|
if not os.path.exists(self.index_path) or self._is_file_older_than( |
|
self.index_path, self.retrain_period_days |
|
): |
|
df = pd.read_csv(self.file_path) |
|
documents = df["Description"].apply(self.preprocess_text).tolist() |
|
if not self.model: |
|
self.load_model() |
|
document_embeddings = self.model.encode(documents, convert_to_tensor=False) |
|
self.index = faiss.IndexFlatL2(document_embeddings.size(1)) |
|
self.index.add(document_embeddings) |
|
faiss.write_index(self.index, self.index_path) |
|
else: |
|
self.index = faiss.read_index(self.index_path) |
|
|
|
def semantic_search(self, query, k=5, rerank_k=3, flag_threshold=0): |
|
if not self.model: |
|
self.load_model() |
|
if not self.index: |
|
self.create_index() |
|
|
|
query_embedding = self.model.encode([query], convert_to_tensor=False) |
|
distances, indices = self.index.search(query_embedding, k + rerank_k) |
|
|
|
initial_indices = indices[0][:k] |
|
|
|
df = pd.read_csv(self.file_path) |
|
initial_documents = df.iloc[initial_indices][["Title", "Description", "Genres"]] |
|
genres_text = "".join(initial_documents["Genres"].to_list()) |
|
|
|
initial_documents["Text"] = ( |
|
initial_documents["Title"].str.lower() |
|
+ " " |
|
+ initial_documents["Description"].str.lower() |
|
+ genres_text |
|
) |
|
initial_distances = distances[0][:k] |
|
|
|
initial_results = list( |
|
zip( |
|
initial_documents["Title"], initial_documents["Text"], initial_distances |
|
) |
|
) |
|
|
|
if flag_threshold: |
|
flag_reranker = FlagReranker("BAAI/bge-small-en-v1.5", use_fp16=True) |
|
flag_scores = [ |
|
flag_reranker.compute_score([query, text]) |
|
for _, text, _ in initial_results |
|
] |
|
reranked_results = [ |
|
(title, text, dist + flag_score) |
|
for title, text, dist, flag_score in zip( |
|
initial_documents["Title"], |
|
initial_documents["Text"], |
|
initial_distances, |
|
flag_scores, |
|
) |
|
if abs(flag_score) > flag_threshold |
|
] |
|
reranked_results = sorted( |
|
reranked_results, key=lambda x: x[2], reverse=True |
|
)[:rerank_k] |
|
else: |
|
reranked_results = initial_results[:rerank_k] |
|
|
|
return reranked_results |
|
|
|
def _is_file_older_than(self, file_path, days): |
|
if os.path.exists(file_path): |
|
modification_time = os.path.getmtime(file_path) |
|
modification_datetime = datetime.fromtimestamp(modification_time) |
|
current_datetime = datetime.now() |
|
return (current_datetime - modification_datetime).days > days |
|
return True |
|
|
|
|
|
book_search = BookSearch() |
|
query = "Love and Fiction" |
|
results = book_search.semantic_search(query) |
|
for rank, (title, text, score) in enumerate(results, start=1): |
|
print(f"Rank {rank}: {title} (Score: {score})") |
|
|