rag-movie-api / app /retrieval /media_retriever.py
JJ Tsao
Update media_type enum
1080531
from collections import Counter
from typing import Dict, List, Tuple
import threading
from nltk.corpus import stopwords
from nltk.stem import PorterStemmer
from nltk.tokenize import word_tokenize
from qdrant_client import QdrantClient
from qdrant_client.models import FieldCondition, Filter, MatchValue, Range, models
from sentence_transformers import SentenceTransformer
_stop_words_lock = threading.Lock()
class MediaRetriever:
def __init__(
self,
embed_model: SentenceTransformer,
qdrant_client: QdrantClient,
bm25_models: Dict,
bm25_vocabs: Dict,
movie_collection_name: str,
tv_collection_name: str,
dense_weight: float = 0.4, # Weight of semantic match score for reranking
sparse_weight: float = 0.1, # Weight of BM25 match score for reranking
rating_weight: float = 0.3, # Weight of rating score for reranking
popularity_weight: float = 0.2, # Weight of popularity score for reranking
semantic_retrieval_limit: int = 300, # Number of movies to retrieve for reranking
bm25_retrieval_limit: int = 20,
top_k: int = 20, # Number of post-reranking movies to send to LLM
):
self.client = qdrant_client
self.movie_collection_name = movie_collection_name
self.tv_collection_name = tv_collection_name
self.embed_model = embed_model
self.bm25_models = bm25_models
self.bm25_vocabs = bm25_vocabs
self.dense_weight = dense_weight
self.sparse_weight = sparse_weight
self.rating_weight = rating_weight
self.popularity_weight = popularity_weight
self.semantic_retrieval_limit = semantic_retrieval_limit
self.bm25_retrieval_limit = bm25_retrieval_limit
self.top_k = top_k
def embed_dense(self, query: str) -> List[float]:
return self.embed_model.encode(query).tolist()
@staticmethod
def tokenize_and_preprocess(text: str) -> List[str]:
with _stop_words_lock:
try:
stop_words = set(stopwords.words("english"))
except Exception as e:
print("⚠️ Failed to load NLTK stopwords:", e)
stop_words = set()
stemmer = PorterStemmer()
tokens = word_tokenize(text.lower())
filtered_tokens = [w for w in tokens if w not in stop_words and w.isalnum()]
processed_tokens = [stemmer.stem(w) for w in filtered_tokens]
return processed_tokens
def embed_sparse(self, query: str, media_type: str) -> Dict:
bm25_model = (
self.bm25_models["movie"]
if media_type.lower() == "movie"
else self.bm25_models["tv"]
)
bm25_vocab = (
self.bm25_vocabs["movie"]
if media_type.lower() == "movie"
else self.bm25_vocabs["tv"]
)
tokens = self.tokenize_and_preprocess(query)
term_counts = Counter(tokens)
indices, values = [], []
avg_doc_length = bm25_model.avgdl
k1, b = bm25_model.k1, bm25_model.b
for term, tf in term_counts.items():
if term in bm25_vocab:
idx = bm25_vocab[term]
idf = bm25_model.idf.get(term, 0)
numerator = idf * tf * (k1 + 1)
denominator = tf + k1 * (1 - b + b * len(tokens) / avg_doc_length)
if denominator != 0:
weight = numerator / denominator
indices.append(idx)
values.append(float(weight))
sparse_vector = {"indices": indices, "values": values}
return sparse_vector
def retrieve_and_rerank(
self,
dense_vector: List[float],
sparse_vector: Dict,
media_type: str = "movie",
genres=None,
providers=None,
year_range=None,
) -> List[dict]:
# Construct Qdrant filter based on user input
qdrant_filter = self._build_filter(genres, providers, year_range)
# Query Qdrant for semantic search with dense vector
dense_results = self._query_dense(
vector=dense_vector,
media_type=media_type,
qdrant_filter=qdrant_filter,
)
# Query Qdrant for BM25 search with sparse vector
sparse_results = self._query_sparse(
vector=sparse_vector,
media_type=media_type,
qdrant_filter=qdrant_filter,
)
if not dense_results:
return []
# Fuse dense and sparse results and rerank
fused = self.fuse_dense_sparse(dense_results, sparse_results)
reranked, scored_lookup = self.rerank_fused_results(fused)
reranked_ids = [p.id for p in reranked[:20]]
print ("\nReranked Top-30:")
for i, id_ in enumerate(reranked_ids):
f = fused[id_]
p = f["point"]
print(
f"#{i + 1} {p.payload.get('title', '')} | Score: {p.score} Dense: {f['dense_score']:.3f}, Sparse: {f['sparse_score']:.3f}, Pop: {p.payload.get('popularity', 0)}, Rating: {p.payload.get('vote_average', 0)}"
)
return reranked[: self.top_k], scored_lookup
def _build_filter(
self, genres=None, providers=None, year_range=None
) -> Filter | None:
must_clauses = []
if genres:
genre_conditions = [
FieldCondition(key="genres", match=MatchValue(value=genre))
for genre in genres
]
must_clauses.append({"should": genre_conditions})
if providers:
provider_conditions = [
FieldCondition(key="watch_providers", match=MatchValue(value=provider))
for provider in providers
]
must_clauses.append({"should": provider_conditions})
if year_range:
must_clauses.append(
FieldCondition(
key="release_year",
range=Range(gte=year_range[0], lte=year_range[1]),
)
)
return Filter(must=must_clauses) if must_clauses else None
def _query_dense(self, vector, media_type, qdrant_filter):
collection = (
self.movie_collection_name
if media_type == "movie"
else self.tv_collection_name
)
return self.client.query_points(
collection_name=collection,
query=vector,
using="dense_vector",
query_filter=qdrant_filter,
limit=self.semantic_retrieval_limit,
with_payload=["llm_context", "media_id", "title", "popularity", "vote_average"],
with_vectors=False,
)
def _query_sparse(self, vector, media_type, qdrant_filter):
collection = (
self.movie_collection_name
if media_type == "movie"
else self.tv_collection_name
)
return self.client.query_points(
collection_name=collection,
query=models.SparseVector(**vector),
using="sparse_vector",
query_filter=qdrant_filter,
limit=self.bm25_retrieval_limit,
with_payload=["llm_context", "media_id", "title", "popularity", "vote_average"],
with_vectors=False,
)
def fuse_dense_sparse(
self,
dense_results: List,
sparse_results: List,
) -> Dict[str, Dict]:
fused = {}
# Add dense results
for point in dense_results.points:
fused[point.id] = {
"point": point,
"dense_score": point.score or 0.0,
"sparse_score": 0.0,
}
max_sparse_score = max((pt.score for pt in sparse_results.points), default=1e-6)
# Add sparse scores
for point in sparse_results.points:
if point.id in fused:
fused[point.id]["sparse_score"] = (
min(point.score / max_sparse_score, 0.8) or 0.0
)
else:
fused[point.id] = {
"point": point,
"dense_score": 0.0,
"sparse_score": min(point.score / max_sparse_score, 0.8) or 0.0,
}
return fused
def rerank_fused_results(
self,
fused: Dict[str, Dict],
) -> Tuple[List, Dict]:
max_popularity = max(
(float(f["point"].payload.get("popularity", 0)) for f in fused.values()),
default=1.0,
)
scored = {}
for id_, f in fused.items():
point = f["point"]
dense_score = f["dense_score"]
sparse_score = f["sparse_score"]
popularity = float(point.payload.get("popularity", 0)) / max_popularity
vote_average = float(point.payload.get("vote_average", 0)) / 10.0
reranked_score = (
self.dense_weight * dense_score
+ self.sparse_weight * sparse_score
+ self.rating_weight * vote_average
+ self.popularity_weight * popularity
)
scored[id_] = {
"point": point,
"dense_score": dense_score,
"sparse_score": sparse_score,
"reranked_score": reranked_score,
}
sorted_ids = sorted(scored.items(), key=lambda x: x[1]["reranked_score"], reverse=True)
return [v["point"] for _, v in sorted_ids], scored
def format_context(self, movies: list[dict]) -> str:
# Formart the retrieved documents as context for LLM
return "\n\n".join(
[f" {movie.payload.get('llm_context', '')}" for movie in movies]
)