yonnel commited on
Commit
0236fb6
·
1 Parent(s): 7bec29d

Add adult content filtering option and update related functionality in TMDBClient and settings

Browse files
Files changed (4) hide show
  1. .env.example +4 -1
  2. app/build_index.py +30 -6
  3. app/main.py +58 -4
  4. app/settings.py +3 -0
.env.example CHANGED
@@ -11,4 +11,7 @@ API_TOKEN=your_api_token_here
11
  ENV=dev
12
 
13
  # Logging level
14
- LOG_LEVEL=INFO
 
 
 
 
11
  ENV=dev
12
 
13
  # Logging level
14
+ LOG_LEVEL=INFO
15
+
16
+ # Remove adult content from TMDB results
17
+ FILTER_ADULT_CONTENT=true # Set to true to filter out adult content
app/build_index.py CHANGED
@@ -110,7 +110,7 @@ class TMDBClient:
110
 
111
  return None
112
 
113
- def get_popular_movies(self, max_pages: int = 100) -> List[int]:
114
  """Get movie IDs from popular movies pagination"""
115
  movie_ids = []
116
 
@@ -127,14 +127,18 @@ class TMDBClient:
127
  logger.info(f"Reached last page ({data.get('total_pages')})")
128
  break
129
 
130
- # Extract movie IDs
131
  for movie in data.get('results', []):
 
 
 
 
132
  movie_ids.append(movie['id'])
133
 
134
  # Rate limiting
135
  time.sleep(0.25) # 4 requests per second max
136
 
137
- logger.info(f"Collected {len(movie_ids)} movie IDs from {page} pages")
138
  return movie_ids
139
 
140
  def get_movie_details(self, movie_id: int) -> Optional[dict]:
@@ -285,10 +289,16 @@ def get_embeddings_batch(texts: List[str], client: OpenAI, model: str = "text-em
285
  else:
286
  raise
287
 
288
- def build_index(max_pages: int = 10, model: str = "text-embedding-3-small", use_faiss: bool = True):
289
  """Main function to build the FAISS index and data files"""
290
  settings = get_settings()
291
 
 
 
 
 
 
 
292
  # Initialize clients
293
  tmdb_client = TMDBClient(settings.tmdb_api_key)
294
  openai_client = OpenAI(api_key=settings.openai_api_key)
@@ -321,7 +331,10 @@ def build_index(max_pages: int = 10, model: str = "text-embedding-3-small", use_
321
  else:
322
  # Step 1: Get movie IDs
323
  logger.info(f"Fetching movie IDs from TMDB (max {max_pages} pages)...")
324
- movie_ids = tmdb_client.get_popular_movies(max_pages=max_pages)
 
 
 
325
 
326
  if not movie_ids:
327
  logger.error("❌ No movie IDs retrieved from TMDB")
@@ -335,6 +348,14 @@ def build_index(max_pages: int = 10, model: str = "text-embedding-3-small", use_
335
  logger.error("❌ No movie data retrieved")
336
  return
337
 
 
 
 
 
 
 
 
 
338
  # Save movie data checkpoint
339
  save_checkpoint(movies_data, MOVIE_DATA_CHECKPOINT)
340
 
@@ -530,11 +551,14 @@ if __name__ == "__main__":
530
  help="OpenAI embedding model to use (default: text-embedding-3-small)")
531
  parser.add_argument("--no-faiss", action="store_true",
532
  help="Skip building FAISS index")
 
 
533
 
534
  args = parser.parse_args()
535
 
536
  build_index(
537
  max_pages=args.max_pages,
538
  model=args.model,
539
- use_faiss=not args.no_faiss
 
540
  )
 
110
 
111
  return None
112
 
113
+ def get_popular_movies(self, max_pages: int = 100, filter_adult: bool = True) -> List[int]:
114
  """Get movie IDs from popular movies pagination"""
115
  movie_ids = []
116
 
 
127
  logger.info(f"Reached last page ({data.get('total_pages')})")
128
  break
129
 
130
+ # Extract movie IDs, filtering adult content if requested
131
  for movie in data.get('results', []):
132
+ # Skip adult movies if filtering is enabled
133
+ if filter_adult and movie.get('adult', False):
134
+ logger.debug(f"Skipping adult movie: {movie.get('title', 'Unknown')} (ID: {movie.get('id')})")
135
+ continue
136
  movie_ids.append(movie['id'])
137
 
138
  # Rate limiting
139
  time.sleep(0.25) # 4 requests per second max
140
 
141
+ logger.info(f"Collected {len(movie_ids)} movie IDs from {page} pages (adult filter: {'ON' if filter_adult else 'OFF'})")
142
  return movie_ids
143
 
144
  def get_movie_details(self, movie_id: int) -> Optional[dict]:
 
289
  else:
290
  raise
291
 
292
+ def build_index(max_pages: int = 10, model: str = "text-embedding-3-small", use_faiss: bool = True, override_adult_filter: bool = None):
293
  """Main function to build the FAISS index and data files"""
294
  settings = get_settings()
295
 
296
+ # Determine adult filtering setting
297
+ filter_adult = settings.filter_adult_content
298
+ if override_adult_filter is not None:
299
+ filter_adult = not override_adult_filter # --include-adult means don't filter
300
+ logger.info(f"Adult filter override: {'DISABLED' if override_adult_filter else 'ENABLED'}")
301
+
302
  # Initialize clients
303
  tmdb_client = TMDBClient(settings.tmdb_api_key)
304
  openai_client = OpenAI(api_key=settings.openai_api_key)
 
331
  else:
332
  # Step 1: Get movie IDs
333
  logger.info(f"Fetching movie IDs from TMDB (max {max_pages} pages)...")
334
+ movie_ids = tmdb_client.get_popular_movies(
335
+ max_pages=max_pages,
336
+ filter_adult=filter_adult
337
+ )
338
 
339
  if not movie_ids:
340
  logger.error("❌ No movie IDs retrieved from TMDB")
 
348
  logger.error("❌ No movie data retrieved")
349
  return
350
 
351
+ # Additional filtering at the detail level (double-check)
352
+ if filter_adult:
353
+ original_count = len(movies_data)
354
+ movies_data = {k: v for k, v in movies_data.items() if not v.get('adult', False)}
355
+ filtered_count = original_count - len(movies_data)
356
+ if filtered_count > 0:
357
+ logger.info(f"Filtered out {filtered_count} adult movies at detail level")
358
+
359
  # Save movie data checkpoint
360
  save_checkpoint(movies_data, MOVIE_DATA_CHECKPOINT)
361
 
 
551
  help="OpenAI embedding model to use (default: text-embedding-3-small)")
552
  parser.add_argument("--no-faiss", action="store_true",
553
  help="Skip building FAISS index")
554
+ parser.add_argument("--include-adult", action="store_true",
555
+ help="Include adult movies (overrides FILTER_ADULT_CONTENT setting)")
556
 
557
  args = parser.parse_args()
558
 
559
  build_index(
560
  max_pages=args.max_pages,
561
  model=args.model,
562
+ use_faiss=not args.no_faiss,
563
+ override_adult_filter=args.include_adult
564
  )
app/main.py CHANGED
@@ -10,6 +10,15 @@ from typing import List, Optional
10
  import logging
11
  import time
12
 
 
 
 
 
 
 
 
 
 
13
  # Configure logging
14
  logging.basicConfig(level=os.getenv("LOG_LEVEL", "INFO").upper())
15
  logger = logging.getLogger(__name__)
@@ -208,6 +217,25 @@ async def health_check():
208
  """Health check endpoint"""
209
  return {"status": "healthy", "vectors_loaded": vectors is not None}
210
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
211
  @app.post("/explore", response_model=ExploreResponse)
212
  async def explore(
213
  request: ExploreRequest,
@@ -219,15 +247,32 @@ async def explore(
219
  start_time = time.time()
220
 
221
  try:
 
 
 
 
 
 
 
222
  # Convert TMDB IDs to internal indices
223
  liked_indices = []
224
  disliked_indices = []
 
225
 
226
  for tmdb_id in request.liked_ids:
227
  if str(tmdb_id) in id_map:
228
  liked_indices.append(id_map[str(tmdb_id)])
229
  else:
230
  logger.warning(f"TMDB ID {tmdb_id} not found in index")
 
 
 
 
 
 
 
 
 
231
 
232
  for tmdb_id in request.disliked_ids:
233
  if str(tmdb_id) in id_map:
@@ -235,6 +280,10 @@ async def explore(
235
  else:
236
  logger.warning(f"TMDB ID {tmdb_id} not found in index")
237
 
 
 
 
 
238
  # Get embedding vectors
239
  liked_vectors = vectors[liked_indices] if liked_indices else None
240
  disliked_vectors = vectors[disliked_indices] if disliked_indices else None
@@ -251,9 +300,14 @@ async def explore(
251
  # Compute distances to subspace (residuals)
252
  residuals = np.linalg.norm(vectors - reconstructed, axis=1)
253
 
254
- # Get top-k closest movies
255
- top_k_indices = np.argpartition(residuals, min(request.top_k, len(residuals)))[:request.top_k]
256
- top_k_indices = top_k_indices[np.argsort(residuals[top_k_indices])]
 
 
 
 
 
257
 
258
  # Assign spiral coordinates
259
  spiral_coords = assign_spiral_coords(len(top_k_indices))
@@ -290,7 +344,7 @@ async def explore(
290
  )
291
 
292
  elapsed = time.time() - start_time
293
- logger.info(f"Explore request processed in {elapsed:.3f}s - {len(request.liked_ids)} likes, {len(request.disliked_ids)} dislikes, {len(movies)} results")
294
 
295
  return response
296
 
 
10
  import logging
11
  import time
12
 
13
+ # Try different import patterns to handle both direct execution and module execution
14
+ try:
15
+ from .settings import get_settings
16
+ except ImportError:
17
+ try:
18
+ from app.settings import get_settings
19
+ except ImportError:
20
+ from settings import get_settings
21
+
22
  # Configure logging
23
  logging.basicConfig(level=os.getenv("LOG_LEVEL", "INFO").upper())
24
  logger = logging.getLogger(__name__)
 
217
  """Health check endpoint"""
218
  return {"status": "healthy", "vectors_loaded": vectors is not None}
219
 
220
+ async def get_movie_from_tmdb(tmdb_id: int):
221
+ """Fetch a single movie from TMDB API if not in local index"""
222
+ try:
223
+ settings = get_settings()
224
+ import requests
225
+
226
+ url = f"https://api.themoviedb.org/3/movie/{tmdb_id}"
227
+ params = {"api_key": settings.tmdb_api_key}
228
+
229
+ response = requests.get(url, params=params, timeout=10)
230
+ if response.status_code == 200:
231
+ return response.json()
232
+ else:
233
+ logger.warning(f"TMDB API returned {response.status_code} for movie {tmdb_id}")
234
+ return None
235
+ except Exception as e:
236
+ logger.error(f"Error fetching movie {tmdb_id} from TMDB: {e}")
237
+ return None
238
+
239
  @app.post("/explore", response_model=ExploreResponse)
240
  async def explore(
241
  request: ExploreRequest,
 
247
  start_time = time.time()
248
 
249
  try:
250
+ # Ensure top_k doesn't exceed available movies
251
+ total_movies = len(vectors) if vectors is not None else 0
252
+ actual_top_k = min(request.top_k, total_movies)
253
+
254
+ if actual_top_k <= 0:
255
+ raise HTTPException(status_code=400, detail="No movies available")
256
+
257
  # Convert TMDB IDs to internal indices
258
  liked_indices = []
259
  disliked_indices = []
260
+ missing_movies = []
261
 
262
  for tmdb_id in request.liked_ids:
263
  if str(tmdb_id) in id_map:
264
  liked_indices.append(id_map[str(tmdb_id)])
265
  else:
266
  logger.warning(f"TMDB ID {tmdb_id} not found in index")
267
+ # Optionally fetch movie info for debugging
268
+ movie_info = await get_movie_from_tmdb(tmdb_id)
269
+ if movie_info:
270
+ missing_movies.append({
271
+ "id": tmdb_id,
272
+ "title": movie_info.get("title", "Unknown"),
273
+ "release_date": movie_info.get("release_date", "Unknown")
274
+ })
275
+ logger.info(f"Missing movie: {movie_info.get('title')} ({movie_info.get('release_date', 'Unknown')})")
276
 
277
  for tmdb_id in request.disliked_ids:
278
  if str(tmdb_id) in id_map:
 
280
  else:
281
  logger.warning(f"TMDB ID {tmdb_id} not found in index")
282
 
283
+ # Log missing movies for debugging
284
+ if missing_movies:
285
+ logger.info(f"Missing {len(missing_movies)} movies from index: {[m['title'] for m in missing_movies]}")
286
+
287
  # Get embedding vectors
288
  liked_vectors = vectors[liked_indices] if liked_indices else None
289
  disliked_vectors = vectors[disliked_indices] if disliked_indices else None
 
300
  # Compute distances to subspace (residuals)
301
  residuals = np.linalg.norm(vectors - reconstructed, axis=1)
302
 
303
+ # Get top-k closest movies - use proper bounds checking
304
+ if actual_top_k >= len(residuals):
305
+ # If we want all movies, just sort them
306
+ top_k_indices = np.argsort(residuals)
307
+ else:
308
+ # Use argpartition for efficiency when we want a subset
309
+ top_k_indices = np.argpartition(residuals, actual_top_k-1)[:actual_top_k]
310
+ top_k_indices = top_k_indices[np.argsort(residuals[top_k_indices])]
311
 
312
  # Assign spiral coordinates
313
  spiral_coords = assign_spiral_coords(len(top_k_indices))
 
344
  )
345
 
346
  elapsed = time.time() - start_time
347
+ logger.info(f"Explore request processed in {elapsed:.3f}s - {len(request.liked_ids)} likes ({len(liked_indices)} found), {len(request.disliked_ids)} dislikes ({len(disliked_indices)} found), {len(movies)} results")
348
 
349
  return response
350
 
app/settings.py CHANGED
@@ -24,6 +24,9 @@ class Settings(BaseSettings):
24
  # Logging level
25
  log_level: str = "INFO"
26
 
 
 
 
27
  class Config:
28
  env_file = ".env"
29
  env_file_encoding = "utf-8"
 
24
  # Logging level
25
  log_level: str = "INFO"
26
 
27
+ # Filter adult content (True = exclude adult films, False = include all)
28
+ filter_adult_content: bool = True
29
+
30
  class Config:
31
  env_file = ".env"
32
  env_file_encoding = "utf-8"