Spaces:
Sleeping
Sleeping
yonnel
commited on
Commit
·
0236fb6
1
Parent(s):
7bec29d
Add adult content filtering option and update related functionality in TMDBClient and settings
Browse files- .env.example +4 -1
- app/build_index.py +30 -6
- app/main.py +58 -4
- app/settings.py +3 -0
.env.example
CHANGED
@@ -11,4 +11,7 @@ API_TOKEN=your_api_token_here
|
|
11 |
ENV=dev
|
12 |
|
13 |
# Logging level
|
14 |
-
LOG_LEVEL=INFO
|
|
|
|
|
|
|
|
11 |
ENV=dev
|
12 |
|
13 |
# Logging level
|
14 |
+
LOG_LEVEL=INFO
|
15 |
+
|
16 |
+
# Remove adult content from TMDB results
|
17 |
+
FILTER_ADULT_CONTENT=true # Set to true to filter out adult content
|
app/build_index.py
CHANGED
@@ -110,7 +110,7 @@ class TMDBClient:
|
|
110 |
|
111 |
return None
|
112 |
|
113 |
-
def get_popular_movies(self, max_pages: int = 100) -> List[int]:
|
114 |
"""Get movie IDs from popular movies pagination"""
|
115 |
movie_ids = []
|
116 |
|
@@ -127,14 +127,18 @@ class TMDBClient:
|
|
127 |
logger.info(f"Reached last page ({data.get('total_pages')})")
|
128 |
break
|
129 |
|
130 |
-
# Extract movie IDs
|
131 |
for movie in data.get('results', []):
|
|
|
|
|
|
|
|
|
132 |
movie_ids.append(movie['id'])
|
133 |
|
134 |
# Rate limiting
|
135 |
time.sleep(0.25) # 4 requests per second max
|
136 |
|
137 |
-
logger.info(f"Collected {len(movie_ids)} movie IDs from {page} pages")
|
138 |
return movie_ids
|
139 |
|
140 |
def get_movie_details(self, movie_id: int) -> Optional[dict]:
|
@@ -285,10 +289,16 @@ def get_embeddings_batch(texts: List[str], client: OpenAI, model: str = "text-em
|
|
285 |
else:
|
286 |
raise
|
287 |
|
288 |
-
def build_index(max_pages: int = 10, model: str = "text-embedding-3-small", use_faiss: bool = True):
|
289 |
"""Main function to build the FAISS index and data files"""
|
290 |
settings = get_settings()
|
291 |
|
|
|
|
|
|
|
|
|
|
|
|
|
292 |
# Initialize clients
|
293 |
tmdb_client = TMDBClient(settings.tmdb_api_key)
|
294 |
openai_client = OpenAI(api_key=settings.openai_api_key)
|
@@ -321,7 +331,10 @@ def build_index(max_pages: int = 10, model: str = "text-embedding-3-small", use_
|
|
321 |
else:
|
322 |
# Step 1: Get movie IDs
|
323 |
logger.info(f"Fetching movie IDs from TMDB (max {max_pages} pages)...")
|
324 |
-
movie_ids = tmdb_client.get_popular_movies(
|
|
|
|
|
|
|
325 |
|
326 |
if not movie_ids:
|
327 |
logger.error("❌ No movie IDs retrieved from TMDB")
|
@@ -335,6 +348,14 @@ def build_index(max_pages: int = 10, model: str = "text-embedding-3-small", use_
|
|
335 |
logger.error("❌ No movie data retrieved")
|
336 |
return
|
337 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
338 |
# Save movie data checkpoint
|
339 |
save_checkpoint(movies_data, MOVIE_DATA_CHECKPOINT)
|
340 |
|
@@ -530,11 +551,14 @@ if __name__ == "__main__":
|
|
530 |
help="OpenAI embedding model to use (default: text-embedding-3-small)")
|
531 |
parser.add_argument("--no-faiss", action="store_true",
|
532 |
help="Skip building FAISS index")
|
|
|
|
|
533 |
|
534 |
args = parser.parse_args()
|
535 |
|
536 |
build_index(
|
537 |
max_pages=args.max_pages,
|
538 |
model=args.model,
|
539 |
-
use_faiss=not args.no_faiss
|
|
|
540 |
)
|
|
|
110 |
|
111 |
return None
|
112 |
|
113 |
+
def get_popular_movies(self, max_pages: int = 100, filter_adult: bool = True) -> List[int]:
|
114 |
"""Get movie IDs from popular movies pagination"""
|
115 |
movie_ids = []
|
116 |
|
|
|
127 |
logger.info(f"Reached last page ({data.get('total_pages')})")
|
128 |
break
|
129 |
|
130 |
+
# Extract movie IDs, filtering adult content if requested
|
131 |
for movie in data.get('results', []):
|
132 |
+
# Skip adult movies if filtering is enabled
|
133 |
+
if filter_adult and movie.get('adult', False):
|
134 |
+
logger.debug(f"Skipping adult movie: {movie.get('title', 'Unknown')} (ID: {movie.get('id')})")
|
135 |
+
continue
|
136 |
movie_ids.append(movie['id'])
|
137 |
|
138 |
# Rate limiting
|
139 |
time.sleep(0.25) # 4 requests per second max
|
140 |
|
141 |
+
logger.info(f"Collected {len(movie_ids)} movie IDs from {page} pages (adult filter: {'ON' if filter_adult else 'OFF'})")
|
142 |
return movie_ids
|
143 |
|
144 |
def get_movie_details(self, movie_id: int) -> Optional[dict]:
|
|
|
289 |
else:
|
290 |
raise
|
291 |
|
292 |
+
def build_index(max_pages: int = 10, model: str = "text-embedding-3-small", use_faiss: bool = True, override_adult_filter: bool = None):
|
293 |
"""Main function to build the FAISS index and data files"""
|
294 |
settings = get_settings()
|
295 |
|
296 |
+
# Determine adult filtering setting
|
297 |
+
filter_adult = settings.filter_adult_content
|
298 |
+
if override_adult_filter is not None:
|
299 |
+
filter_adult = not override_adult_filter # --include-adult means don't filter
|
300 |
+
logger.info(f"Adult filter override: {'DISABLED' if override_adult_filter else 'ENABLED'}")
|
301 |
+
|
302 |
# Initialize clients
|
303 |
tmdb_client = TMDBClient(settings.tmdb_api_key)
|
304 |
openai_client = OpenAI(api_key=settings.openai_api_key)
|
|
|
331 |
else:
|
332 |
# Step 1: Get movie IDs
|
333 |
logger.info(f"Fetching movie IDs from TMDB (max {max_pages} pages)...")
|
334 |
+
movie_ids = tmdb_client.get_popular_movies(
|
335 |
+
max_pages=max_pages,
|
336 |
+
filter_adult=filter_adult
|
337 |
+
)
|
338 |
|
339 |
if not movie_ids:
|
340 |
logger.error("❌ No movie IDs retrieved from TMDB")
|
|
|
348 |
logger.error("❌ No movie data retrieved")
|
349 |
return
|
350 |
|
351 |
+
# Additional filtering at the detail level (double-check)
|
352 |
+
if filter_adult:
|
353 |
+
original_count = len(movies_data)
|
354 |
+
movies_data = {k: v for k, v in movies_data.items() if not v.get('adult', False)}
|
355 |
+
filtered_count = original_count - len(movies_data)
|
356 |
+
if filtered_count > 0:
|
357 |
+
logger.info(f"Filtered out {filtered_count} adult movies at detail level")
|
358 |
+
|
359 |
# Save movie data checkpoint
|
360 |
save_checkpoint(movies_data, MOVIE_DATA_CHECKPOINT)
|
361 |
|
|
|
551 |
help="OpenAI embedding model to use (default: text-embedding-3-small)")
|
552 |
parser.add_argument("--no-faiss", action="store_true",
|
553 |
help="Skip building FAISS index")
|
554 |
+
parser.add_argument("--include-adult", action="store_true",
|
555 |
+
help="Include adult movies (overrides FILTER_ADULT_CONTENT setting)")
|
556 |
|
557 |
args = parser.parse_args()
|
558 |
|
559 |
build_index(
|
560 |
max_pages=args.max_pages,
|
561 |
model=args.model,
|
562 |
+
use_faiss=not args.no_faiss,
|
563 |
+
override_adult_filter=args.include_adult
|
564 |
)
|
app/main.py
CHANGED
@@ -10,6 +10,15 @@ from typing import List, Optional
|
|
10 |
import logging
|
11 |
import time
|
12 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
13 |
# Configure logging
|
14 |
logging.basicConfig(level=os.getenv("LOG_LEVEL", "INFO").upper())
|
15 |
logger = logging.getLogger(__name__)
|
@@ -208,6 +217,25 @@ async def health_check():
|
|
208 |
"""Health check endpoint"""
|
209 |
return {"status": "healthy", "vectors_loaded": vectors is not None}
|
210 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
211 |
@app.post("/explore", response_model=ExploreResponse)
|
212 |
async def explore(
|
213 |
request: ExploreRequest,
|
@@ -219,15 +247,32 @@ async def explore(
|
|
219 |
start_time = time.time()
|
220 |
|
221 |
try:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
222 |
# Convert TMDB IDs to internal indices
|
223 |
liked_indices = []
|
224 |
disliked_indices = []
|
|
|
225 |
|
226 |
for tmdb_id in request.liked_ids:
|
227 |
if str(tmdb_id) in id_map:
|
228 |
liked_indices.append(id_map[str(tmdb_id)])
|
229 |
else:
|
230 |
logger.warning(f"TMDB ID {tmdb_id} not found in index")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
231 |
|
232 |
for tmdb_id in request.disliked_ids:
|
233 |
if str(tmdb_id) in id_map:
|
@@ -235,6 +280,10 @@ async def explore(
|
|
235 |
else:
|
236 |
logger.warning(f"TMDB ID {tmdb_id} not found in index")
|
237 |
|
|
|
|
|
|
|
|
|
238 |
# Get embedding vectors
|
239 |
liked_vectors = vectors[liked_indices] if liked_indices else None
|
240 |
disliked_vectors = vectors[disliked_indices] if disliked_indices else None
|
@@ -251,9 +300,14 @@ async def explore(
|
|
251 |
# Compute distances to subspace (residuals)
|
252 |
residuals = np.linalg.norm(vectors - reconstructed, axis=1)
|
253 |
|
254 |
-
# Get top-k closest movies
|
255 |
-
|
256 |
-
|
|
|
|
|
|
|
|
|
|
|
257 |
|
258 |
# Assign spiral coordinates
|
259 |
spiral_coords = assign_spiral_coords(len(top_k_indices))
|
@@ -290,7 +344,7 @@ async def explore(
|
|
290 |
)
|
291 |
|
292 |
elapsed = time.time() - start_time
|
293 |
-
logger.info(f"Explore request processed in {elapsed:.3f}s - {len(request.liked_ids)} likes, {len(request.disliked_ids)} dislikes, {len(movies)} results")
|
294 |
|
295 |
return response
|
296 |
|
|
|
10 |
import logging
|
11 |
import time
|
12 |
|
13 |
+
# Try different import patterns to handle both direct execution and module execution
|
14 |
+
try:
|
15 |
+
from .settings import get_settings
|
16 |
+
except ImportError:
|
17 |
+
try:
|
18 |
+
from app.settings import get_settings
|
19 |
+
except ImportError:
|
20 |
+
from settings import get_settings
|
21 |
+
|
22 |
# Configure logging
|
23 |
logging.basicConfig(level=os.getenv("LOG_LEVEL", "INFO").upper())
|
24 |
logger = logging.getLogger(__name__)
|
|
|
217 |
"""Health check endpoint"""
|
218 |
return {"status": "healthy", "vectors_loaded": vectors is not None}
|
219 |
|
220 |
+
async def get_movie_from_tmdb(tmdb_id: int):
|
221 |
+
"""Fetch a single movie from TMDB API if not in local index"""
|
222 |
+
try:
|
223 |
+
settings = get_settings()
|
224 |
+
import requests
|
225 |
+
|
226 |
+
url = f"https://api.themoviedb.org/3/movie/{tmdb_id}"
|
227 |
+
params = {"api_key": settings.tmdb_api_key}
|
228 |
+
|
229 |
+
response = requests.get(url, params=params, timeout=10)
|
230 |
+
if response.status_code == 200:
|
231 |
+
return response.json()
|
232 |
+
else:
|
233 |
+
logger.warning(f"TMDB API returned {response.status_code} for movie {tmdb_id}")
|
234 |
+
return None
|
235 |
+
except Exception as e:
|
236 |
+
logger.error(f"Error fetching movie {tmdb_id} from TMDB: {e}")
|
237 |
+
return None
|
238 |
+
|
239 |
@app.post("/explore", response_model=ExploreResponse)
|
240 |
async def explore(
|
241 |
request: ExploreRequest,
|
|
|
247 |
start_time = time.time()
|
248 |
|
249 |
try:
|
250 |
+
# Ensure top_k doesn't exceed available movies
|
251 |
+
total_movies = len(vectors) if vectors is not None else 0
|
252 |
+
actual_top_k = min(request.top_k, total_movies)
|
253 |
+
|
254 |
+
if actual_top_k <= 0:
|
255 |
+
raise HTTPException(status_code=400, detail="No movies available")
|
256 |
+
|
257 |
# Convert TMDB IDs to internal indices
|
258 |
liked_indices = []
|
259 |
disliked_indices = []
|
260 |
+
missing_movies = []
|
261 |
|
262 |
for tmdb_id in request.liked_ids:
|
263 |
if str(tmdb_id) in id_map:
|
264 |
liked_indices.append(id_map[str(tmdb_id)])
|
265 |
else:
|
266 |
logger.warning(f"TMDB ID {tmdb_id} not found in index")
|
267 |
+
# Optionally fetch movie info for debugging
|
268 |
+
movie_info = await get_movie_from_tmdb(tmdb_id)
|
269 |
+
if movie_info:
|
270 |
+
missing_movies.append({
|
271 |
+
"id": tmdb_id,
|
272 |
+
"title": movie_info.get("title", "Unknown"),
|
273 |
+
"release_date": movie_info.get("release_date", "Unknown")
|
274 |
+
})
|
275 |
+
logger.info(f"Missing movie: {movie_info.get('title')} ({movie_info.get('release_date', 'Unknown')})")
|
276 |
|
277 |
for tmdb_id in request.disliked_ids:
|
278 |
if str(tmdb_id) in id_map:
|
|
|
280 |
else:
|
281 |
logger.warning(f"TMDB ID {tmdb_id} not found in index")
|
282 |
|
283 |
+
# Log missing movies for debugging
|
284 |
+
if missing_movies:
|
285 |
+
logger.info(f"Missing {len(missing_movies)} movies from index: {[m['title'] for m in missing_movies]}")
|
286 |
+
|
287 |
# Get embedding vectors
|
288 |
liked_vectors = vectors[liked_indices] if liked_indices else None
|
289 |
disliked_vectors = vectors[disliked_indices] if disliked_indices else None
|
|
|
300 |
# Compute distances to subspace (residuals)
|
301 |
residuals = np.linalg.norm(vectors - reconstructed, axis=1)
|
302 |
|
303 |
+
# Get top-k closest movies - use proper bounds checking
|
304 |
+
if actual_top_k >= len(residuals):
|
305 |
+
# If we want all movies, just sort them
|
306 |
+
top_k_indices = np.argsort(residuals)
|
307 |
+
else:
|
308 |
+
# Use argpartition for efficiency when we want a subset
|
309 |
+
top_k_indices = np.argpartition(residuals, actual_top_k-1)[:actual_top_k]
|
310 |
+
top_k_indices = top_k_indices[np.argsort(residuals[top_k_indices])]
|
311 |
|
312 |
# Assign spiral coordinates
|
313 |
spiral_coords = assign_spiral_coords(len(top_k_indices))
|
|
|
344 |
)
|
345 |
|
346 |
elapsed = time.time() - start_time
|
347 |
+
logger.info(f"Explore request processed in {elapsed:.3f}s - {len(request.liked_ids)} likes ({len(liked_indices)} found), {len(request.disliked_ids)} dislikes ({len(disliked_indices)} found), {len(movies)} results")
|
348 |
|
349 |
return response
|
350 |
|
app/settings.py
CHANGED
@@ -24,6 +24,9 @@ class Settings(BaseSettings):
|
|
24 |
# Logging level
|
25 |
log_level: str = "INFO"
|
26 |
|
|
|
|
|
|
|
27 |
class Config:
|
28 |
env_file = ".env"
|
29 |
env_file_encoding = "utf-8"
|
|
|
24 |
# Logging level
|
25 |
log_level: str = "INFO"
|
26 |
|
27 |
+
# Filter adult content (True = exclude adult films, False = include all)
|
28 |
+
filter_adult_content: bool = True
|
29 |
+
|
30 |
class Config:
|
31 |
env_file = ".env"
|
32 |
env_file_encoding = "utf-8"
|