ladka6's picture
commit from ladka6
4aa3246
import os
import re
import faiss
import numpy as np
import pandas as pd
from torch.utils.data import DataLoader
from sklearn.model_selection import train_test_split
from sentence_transformers import SentenceTransformer, InputExample, losses
from FlagEmbedding import FlagReranker
from nltk.corpus import stopwords
from nltk.tokenize import word_tokenize
from datetime import datetime, timedelta
class BookSearch:
def __init__(
self,
file_path="../data/book_data.csv",
bert_model="paraphrase-MiniLM-L6-v2",
index_path="vectors/fine_tuned_faiss_index_t.index",
model_path="out/fine_tuned_sbert_model_test",
retrain_period_days=30,
):
self.file_path = file_path
self.bert_model = bert_model
self.index_path = index_path
self.model_path = model_path
self.model = None
self.index = None
self.retrain_period_days = retrain_period_days
def preprocess_text(self, text):
text = text.lower()
text = re.sub(r"[^a-zA-Z\s]", "", text)
tokens = word_tokenize(text)
stop_words = set(stopwords.words("english"))
tokens = [token for token in tokens if token not in stop_words]
text = " ".join(tokens)
return text
def train_model(self):
df = pd.read_csv(self.file_path)
df = df.dropna(subset=["Title", "Description", "Genres"])
df["Title"] = df["Title"].apply(self.preprocess_text)
df["Description"] = df["Description"].apply(self.preprocess_text)
df["Genres"] = df["Genres"].apply(self.preprocess_text)
train, _ = train_test_split(df, test_size=0.2, random_state=42)
train_examples = [
InputExample(
texts=[row["Title"], row["Genres"], row["Description"]], label=1.0
)
for _, row in train.iterrows()
]
self.model = SentenceTransformer(self.bert_model)
train_loader = DataLoader(train_examples, shuffle=True, batch_size=32)
train_loss = losses.CosineSimilarityLoss(model=self.model)
self.model.fit(train_objectives=[(train_loader, train_loss)], epochs=1)
self.model.save(self.model_path)
def load_model(self):
if not os.path.exists(self.model_path) or self._is_file_older_than(
self.model_path, self.retrain_period_days
):
self.train_model()
else:
self.model = SentenceTransformer(self.model_path)
def create_index(self):
if not os.path.exists(self.index_path) or self._is_file_older_than(
self.index_path, self.retrain_period_days
):
df = pd.read_csv(self.file_path)
documents = df["Description"].apply(self.preprocess_text).tolist()
if not self.model:
self.load_model()
document_embeddings = self.model.encode(documents, convert_to_tensor=False)
self.index = faiss.IndexFlatL2(document_embeddings.size(1))
self.index.add(document_embeddings)
faiss.write_index(self.index, self.index_path)
else:
self.index = faiss.read_index(self.index_path)
def semantic_search(self, query, k=5, rerank_k=3, flag_threshold=0):
if not self.model:
self.load_model()
if not self.index:
self.create_index()
query_embedding = self.model.encode([query], convert_to_tensor=False)
distances, indices = self.index.search(query_embedding, k + rerank_k)
initial_indices = indices[0][:k]
df = pd.read_csv(self.file_path)
initial_documents = df.iloc[initial_indices][["Title", "Description", "Genres"]]
genres_text = "".join(initial_documents["Genres"].to_list())
initial_documents["Text"] = (
initial_documents["Title"].str.lower()
+ " "
+ initial_documents["Description"].str.lower()
+ genres_text
)
initial_distances = distances[0][:k]
initial_results = list(
zip(
initial_documents["Title"], initial_documents["Text"], initial_distances
)
)
if flag_threshold:
flag_reranker = FlagReranker("BAAI/bge-small-en-v1.5", use_fp16=True)
flag_scores = [
flag_reranker.compute_score([query, text])
for _, text, _ in initial_results
]
reranked_results = [
(title, text, dist + flag_score)
for title, text, dist, flag_score in zip(
initial_documents["Title"],
initial_documents["Text"],
initial_distances,
flag_scores,
)
if abs(flag_score) > flag_threshold
]
reranked_results = sorted(
reranked_results, key=lambda x: x[2], reverse=True
)[:rerank_k]
else:
reranked_results = initial_results[:rerank_k]
return reranked_results
def _is_file_older_than(self, file_path, days):
if os.path.exists(file_path):
modification_time = os.path.getmtime(file_path)
modification_datetime = datetime.fromtimestamp(modification_time)
current_datetime = datetime.now()
return (current_datetime - modification_datetime).days > days
return True
book_search = BookSearch()
query = "Love and Fiction"
results = book_search.semantic_search(query)
for rank, (title, text, score) in enumerate(results, start=1):
print(f"Rank {rank}: {title} (Score: {score})")