Spaces:
Running
Running
from collections import Counter | |
from typing import Dict, List, Tuple | |
import threading | |
from nltk.corpus import stopwords | |
from nltk.stem import PorterStemmer | |
from nltk.tokenize import word_tokenize | |
from qdrant_client import QdrantClient | |
from qdrant_client.models import FieldCondition, Filter, MatchValue, Range, models | |
from sentence_transformers import SentenceTransformer | |
_stop_words_lock = threading.Lock() | |
class MediaRetriever: | |
def __init__( | |
self, | |
embed_model: SentenceTransformer, | |
qdrant_client: QdrantClient, | |
bm25_models: Dict, | |
bm25_vocabs: Dict, | |
movie_collection_name: str, | |
tv_collection_name: str, | |
dense_weight: float = 0.4, # Weight of semantic match score for reranking | |
sparse_weight: float = 0.1, # Weight of BM25 match score for reranking | |
rating_weight: float = 0.3, # Weight of rating score for reranking | |
popularity_weight: float = 0.2, # Weight of popularity score for reranking | |
semantic_retrieval_limit: int = 300, # Number of movies to retrieve for reranking | |
bm25_retrieval_limit: int = 20, | |
top_k: int = 20, # Number of post-reranking movies to send to LLM | |
): | |
self.client = qdrant_client | |
self.movie_collection_name = movie_collection_name | |
self.tv_collection_name = tv_collection_name | |
self.embed_model = embed_model | |
self.bm25_models = bm25_models | |
self.bm25_vocabs = bm25_vocabs | |
self.dense_weight = dense_weight | |
self.sparse_weight = sparse_weight | |
self.rating_weight = rating_weight | |
self.popularity_weight = popularity_weight | |
self.semantic_retrieval_limit = semantic_retrieval_limit | |
self.bm25_retrieval_limit = bm25_retrieval_limit | |
self.top_k = top_k | |
def embed_dense(self, query: str) -> List[float]: | |
return self.embed_model.encode(query).tolist() | |
def tokenize_and_preprocess(text: str) -> List[str]: | |
with _stop_words_lock: | |
try: | |
stop_words = set(stopwords.words("english")) | |
except Exception as e: | |
print("⚠️ Failed to load NLTK stopwords:", e) | |
stop_words = set() | |
stemmer = PorterStemmer() | |
tokens = word_tokenize(text.lower()) | |
filtered_tokens = [w for w in tokens if w not in stop_words and w.isalnum()] | |
processed_tokens = [stemmer.stem(w) for w in filtered_tokens] | |
return processed_tokens | |
def embed_sparse(self, query: str, media_type: str) -> Dict: | |
bm25_model = ( | |
self.bm25_models["movie"] | |
if media_type.lower() == "movie" | |
else self.bm25_models["tv"] | |
) | |
bm25_vocab = ( | |
self.bm25_vocabs["movie"] | |
if media_type.lower() == "movie" | |
else self.bm25_vocabs["tv"] | |
) | |
tokens = self.tokenize_and_preprocess(query) | |
term_counts = Counter(tokens) | |
indices, values = [], [] | |
avg_doc_length = bm25_model.avgdl | |
k1, b = bm25_model.k1, bm25_model.b | |
for term, tf in term_counts.items(): | |
if term in bm25_vocab: | |
idx = bm25_vocab[term] | |
idf = bm25_model.idf.get(term, 0) | |
numerator = idf * tf * (k1 + 1) | |
denominator = tf + k1 * (1 - b + b * len(tokens) / avg_doc_length) | |
if denominator != 0: | |
weight = numerator / denominator | |
indices.append(idx) | |
values.append(float(weight)) | |
sparse_vector = {"indices": indices, "values": values} | |
return sparse_vector | |
def retrieve_and_rerank( | |
self, | |
dense_vector: List[float], | |
sparse_vector: Dict, | |
media_type: str = "movie", | |
genres=None, | |
providers=None, | |
year_range=None, | |
) -> List[dict]: | |
# Construct Qdrant filter based on user input | |
qdrant_filter = self._build_filter(genres, providers, year_range) | |
# Query Qdrant for semantic search with dense vector | |
dense_results = self._query_dense( | |
vector=dense_vector, | |
media_type=media_type, | |
qdrant_filter=qdrant_filter, | |
) | |
# Query Qdrant for BM25 search with sparse vector | |
sparse_results = self._query_sparse( | |
vector=sparse_vector, | |
media_type=media_type, | |
qdrant_filter=qdrant_filter, | |
) | |
if not dense_results: | |
return [] | |
# Fuse dense and sparse results and rerank | |
fused = self.fuse_dense_sparse(dense_results, sparse_results) | |
reranked, scored_lookup = self.rerank_fused_results(fused) | |
reranked_ids = [p.id for p in reranked[:20]] | |
print ("\nReranked Top-30:") | |
for i, id_ in enumerate(reranked_ids): | |
f = fused[id_] | |
p = f["point"] | |
print( | |
f"#{i + 1} {p.payload.get('title', '')} | Score: {p.score} Dense: {f['dense_score']:.3f}, Sparse: {f['sparse_score']:.3f}, Pop: {p.payload.get('popularity', 0)}, Rating: {p.payload.get('vote_average', 0)}" | |
) | |
return reranked[: self.top_k], scored_lookup | |
def _build_filter( | |
self, genres=None, providers=None, year_range=None | |
) -> Filter | None: | |
must_clauses = [] | |
if genres: | |
genre_conditions = [ | |
FieldCondition(key="genres", match=MatchValue(value=genre)) | |
for genre in genres | |
] | |
must_clauses.append({"should": genre_conditions}) | |
if providers: | |
provider_conditions = [ | |
FieldCondition(key="watch_providers", match=MatchValue(value=provider)) | |
for provider in providers | |
] | |
must_clauses.append({"should": provider_conditions}) | |
if year_range: | |
must_clauses.append( | |
FieldCondition( | |
key="release_year", | |
range=Range(gte=year_range[0], lte=year_range[1]), | |
) | |
) | |
return Filter(must=must_clauses) if must_clauses else None | |
def _query_dense(self, vector, media_type, qdrant_filter): | |
collection = ( | |
self.movie_collection_name | |
if media_type == "movie" | |
else self.tv_collection_name | |
) | |
return self.client.query_points( | |
collection_name=collection, | |
query=vector, | |
using="dense_vector", | |
query_filter=qdrant_filter, | |
limit=self.semantic_retrieval_limit, | |
with_payload=["llm_context", "media_id", "title", "popularity", "vote_average"], | |
with_vectors=False, | |
) | |
def _query_sparse(self, vector, media_type, qdrant_filter): | |
collection = ( | |
self.movie_collection_name | |
if media_type == "movie" | |
else self.tv_collection_name | |
) | |
return self.client.query_points( | |
collection_name=collection, | |
query=models.SparseVector(**vector), | |
using="sparse_vector", | |
query_filter=qdrant_filter, | |
limit=self.bm25_retrieval_limit, | |
with_payload=["llm_context", "media_id", "title", "popularity", "vote_average"], | |
with_vectors=False, | |
) | |
def fuse_dense_sparse( | |
self, | |
dense_results: List, | |
sparse_results: List, | |
) -> Dict[str, Dict]: | |
fused = {} | |
# Add dense results | |
for point in dense_results.points: | |
fused[point.id] = { | |
"point": point, | |
"dense_score": point.score or 0.0, | |
"sparse_score": 0.0, | |
} | |
max_sparse_score = max((pt.score for pt in sparse_results.points), default=1e-6) | |
# Add sparse scores | |
for point in sparse_results.points: | |
if point.id in fused: | |
fused[point.id]["sparse_score"] = ( | |
min(point.score / max_sparse_score, 0.8) or 0.0 | |
) | |
else: | |
fused[point.id] = { | |
"point": point, | |
"dense_score": 0.0, | |
"sparse_score": min(point.score / max_sparse_score, 0.8) or 0.0, | |
} | |
return fused | |
def rerank_fused_results( | |
self, | |
fused: Dict[str, Dict], | |
) -> Tuple[List, Dict]: | |
max_popularity = max( | |
(float(f["point"].payload.get("popularity", 0)) for f in fused.values()), | |
default=1.0, | |
) | |
scored = {} | |
for id_, f in fused.items(): | |
point = f["point"] | |
dense_score = f["dense_score"] | |
sparse_score = f["sparse_score"] | |
popularity = float(point.payload.get("popularity", 0)) / max_popularity | |
vote_average = float(point.payload.get("vote_average", 0)) / 10.0 | |
reranked_score = ( | |
self.dense_weight * dense_score | |
+ self.sparse_weight * sparse_score | |
+ self.rating_weight * vote_average | |
+ self.popularity_weight * popularity | |
) | |
scored[id_] = { | |
"point": point, | |
"dense_score": dense_score, | |
"sparse_score": sparse_score, | |
"reranked_score": reranked_score, | |
} | |
sorted_ids = sorted(scored.items(), key=lambda x: x[1]["reranked_score"], reverse=True) | |
return [v["point"] for _, v in sorted_ids], scored | |
def format_context(self, movies: list[dict]) -> str: | |
# Formart the retrieved documents as context for LLM | |
return "\n\n".join( | |
[f" {movie.payload.get('llm_context', '')}" for movie in movies] | |
) | |