File size: 5,657 Bytes
4aa3246 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 |
import os
import re
import faiss
import numpy as np
import pandas as pd
from torch.utils.data import DataLoader
from sklearn.model_selection import train_test_split
from sentence_transformers import SentenceTransformer, InputExample, losses
from FlagEmbedding import FlagReranker
from nltk.corpus import stopwords
from nltk.tokenize import word_tokenize
from datetime import datetime, timedelta
class BookSearch:
def __init__(
self,
file_path="../data/book_data.csv",
bert_model="paraphrase-MiniLM-L6-v2",
index_path="vectors/fine_tuned_faiss_index_t.index",
model_path="out/fine_tuned_sbert_model_test",
retrain_period_days=30,
):
self.file_path = file_path
self.bert_model = bert_model
self.index_path = index_path
self.model_path = model_path
self.model = None
self.index = None
self.retrain_period_days = retrain_period_days
def preprocess_text(self, text):
text = text.lower()
text = re.sub(r"[^a-zA-Z\s]", "", text)
tokens = word_tokenize(text)
stop_words = set(stopwords.words("english"))
tokens = [token for token in tokens if token not in stop_words]
text = " ".join(tokens)
return text
def train_model(self):
df = pd.read_csv(self.file_path)
df = df.dropna(subset=["Title", "Description", "Genres"])
df["Title"] = df["Title"].apply(self.preprocess_text)
df["Description"] = df["Description"].apply(self.preprocess_text)
df["Genres"] = df["Genres"].apply(self.preprocess_text)
train, _ = train_test_split(df, test_size=0.2, random_state=42)
train_examples = [
InputExample(
texts=[row["Title"], row["Genres"], row["Description"]], label=1.0
)
for _, row in train.iterrows()
]
self.model = SentenceTransformer(self.bert_model)
train_loader = DataLoader(train_examples, shuffle=True, batch_size=32)
train_loss = losses.CosineSimilarityLoss(model=self.model)
self.model.fit(train_objectives=[(train_loader, train_loss)], epochs=1)
self.model.save(self.model_path)
def load_model(self):
if not os.path.exists(self.model_path) or self._is_file_older_than(
self.model_path, self.retrain_period_days
):
self.train_model()
else:
self.model = SentenceTransformer(self.model_path)
def create_index(self):
if not os.path.exists(self.index_path) or self._is_file_older_than(
self.index_path, self.retrain_period_days
):
df = pd.read_csv(self.file_path)
documents = df["Description"].apply(self.preprocess_text).tolist()
if not self.model:
self.load_model()
document_embeddings = self.model.encode(documents, convert_to_tensor=False)
self.index = faiss.IndexFlatL2(document_embeddings.size(1))
self.index.add(document_embeddings)
faiss.write_index(self.index, self.index_path)
else:
self.index = faiss.read_index(self.index_path)
def semantic_search(self, query, k=5, rerank_k=3, flag_threshold=0):
if not self.model:
self.load_model()
if not self.index:
self.create_index()
query_embedding = self.model.encode([query], convert_to_tensor=False)
distances, indices = self.index.search(query_embedding, k + rerank_k)
initial_indices = indices[0][:k]
df = pd.read_csv(self.file_path)
initial_documents = df.iloc[initial_indices][["Title", "Description", "Genres"]]
genres_text = "".join(initial_documents["Genres"].to_list())
initial_documents["Text"] = (
initial_documents["Title"].str.lower()
+ " "
+ initial_documents["Description"].str.lower()
+ genres_text
)
initial_distances = distances[0][:k]
initial_results = list(
zip(
initial_documents["Title"], initial_documents["Text"], initial_distances
)
)
if flag_threshold:
flag_reranker = FlagReranker("BAAI/bge-small-en-v1.5", use_fp16=True)
flag_scores = [
flag_reranker.compute_score([query, text])
for _, text, _ in initial_results
]
reranked_results = [
(title, text, dist + flag_score)
for title, text, dist, flag_score in zip(
initial_documents["Title"],
initial_documents["Text"],
initial_distances,
flag_scores,
)
if abs(flag_score) > flag_threshold
]
reranked_results = sorted(
reranked_results, key=lambda x: x[2], reverse=True
)[:rerank_k]
else:
reranked_results = initial_results[:rerank_k]
return reranked_results
def _is_file_older_than(self, file_path, days):
if os.path.exists(file_path):
modification_time = os.path.getmtime(file_path)
modification_datetime = datetime.fromtimestamp(modification_time)
current_datetime = datetime.now()
return (current_datetime - modification_datetime).days > days
return True
book_search = BookSearch()
query = "Love and Fiction"
results = book_search.semantic_search(query)
for rank, (title, text, score) in enumerate(results, start=1):
print(f"Rank {rank}: {title} (Score: {score})")
|