Spaces:
Sleeping
Sleeping
yonnel
Enhance OpenAI client initialization with version compatibility handling and update openai dependency to 1.12.0
945f885
""" | |
Build FAISS index from movie embeddings | |
This script should be run once to create the data files needed by the API | |
""" | |
import os | |
import json | |
import numpy as np | |
import faiss | |
from openai import OpenAI | |
import requests | |
from typing import Dict, List, Optional | |
import time | |
import argparse | |
from concurrent.futures import ThreadPoolExecutor, as_completed | |
import logging | |
import pickle | |
# Try different import patterns to handle both direct execution and module execution | |
try: | |
from .settings import get_settings | |
except ImportError: | |
try: | |
from app.settings import get_settings | |
except ImportError: | |
from settings import get_settings | |
# Configure logging | |
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') | |
logger = logging.getLogger(__name__) | |
# Checkpoint file paths - use temp directory or disable for production | |
import tempfile | |
CHECKPOINT_DIR = os.environ.get('CHECKPOINT_DIR', tempfile.gettempdir()) | |
MOVIE_DATA_CHECKPOINT = f"{CHECKPOINT_DIR}/movie_data.pkl" | |
EMBEDDINGS_CHECKPOINT = f"{CHECKPOINT_DIR}/embeddings_progress.pkl" | |
METADATA_CHECKPOINT = f"{CHECKPOINT_DIR}/metadata_progress.pkl" | |
def save_checkpoint(data, filepath: str): | |
"""Save checkpoint data to file - skip if permissions denied""" | |
try: | |
os.makedirs(os.path.dirname(filepath), exist_ok=True) | |
with open(filepath, 'wb') as f: | |
pickle.dump(data, f) | |
logger.info(f"Checkpoint saved: {filepath}") | |
except PermissionError: | |
logger.warning(f"Cannot save checkpoint due to permissions: {filepath}") | |
except Exception as e: | |
logger.warning(f"Failed to save checkpoint {filepath}: {e}") | |
def load_checkpoint(filepath: str): | |
"""Load checkpoint data from file""" | |
try: | |
if os.path.exists(filepath): | |
with open(filepath, 'rb') as f: | |
data = pickle.load(f) | |
logger.info(f"Checkpoint loaded: {filepath}") | |
return data | |
except Exception as e: | |
logger.warning(f"Failed to load checkpoint {filepath}: {e}") | |
return None | |
def cleanup_checkpoints(): | |
"""Remove checkpoint files after successful completion""" | |
try: | |
import shutil | |
if os.path.exists(CHECKPOINT_DIR) and CHECKPOINT_DIR != tempfile.gettempdir(): | |
shutil.rmtree(CHECKPOINT_DIR) | |
logger.info("Checkpoint files cleaned up") | |
except Exception as e: | |
logger.warning(f"Failed to cleanup checkpoints: {e}") | |
class TMDBClient: | |
"""Client for TMDB API with retry and backoff""" | |
def __init__(self, api_key: str): | |
self.api_key = api_key | |
self.base_url = "https://api.themoviedb.org/3" | |
self.session = requests.Session() | |
def _make_request(self, endpoint: str, params: dict = None, max_retries: int = 3) -> Optional[dict]: | |
"""Make API request with retry and backoff""" | |
if params is None: | |
params = {} | |
params['api_key'] = self.api_key | |
url = f"{self.base_url}{endpoint}" | |
for attempt in range(max_retries): | |
try: | |
response = self.session.get(url, params=params, timeout=10) | |
if response.status_code == 200: | |
return response.json() | |
elif response.status_code == 429: | |
# Rate limit - wait and retry | |
wait_time = 2 ** attempt | |
logger.warning(f"Rate limited, waiting {wait_time}s before retry...") | |
time.sleep(wait_time) | |
continue | |
elif response.status_code == 404: | |
logger.warning(f"Resource not found: {url}") | |
return None | |
else: | |
logger.error(f"API error {response.status_code}: {response.text}") | |
except requests.exceptions.RequestException as e: | |
logger.error(f"Request failed (attempt {attempt + 1}): {e}") | |
if attempt < max_retries - 1: | |
time.sleep(2 ** attempt) | |
return None | |
def get_popular_movies(self, max_pages: int = 100, filter_adult: bool = True) -> List[int]: | |
"""Get movie IDs from popular movies pagination""" | |
movie_ids = [] | |
for page in range(1, max_pages + 1): | |
logger.info(f"Fetching popular movies page {page}/{max_pages}") | |
data = self._make_request("/movie/popular", {"page": page}) | |
if not data: | |
logger.error(f"Failed to fetch page {page}") | |
break | |
# Check if we've exceeded total pages | |
if page > data.get('total_pages', 0): | |
logger.info(f"Reached last page ({data.get('total_pages')})") | |
break | |
# Extract movie IDs, filtering adult content if requested | |
for movie in data.get('results', []): | |
# Skip adult movies if filtering is enabled | |
if filter_adult and movie.get('adult', False): | |
logger.debug(f"Skipping adult movie: {movie.get('title', 'Unknown')} (ID: {movie.get('id')})") | |
continue | |
movie_ids.append(movie['id']) | |
# Rate limiting | |
time.sleep(0.25) # 4 requests per second max | |
logger.info(f"Collected {len(movie_ids)} movie IDs from {page} pages (adult filter: {'ON' if filter_adult else 'OFF'})") | |
return movie_ids | |
def get_movie_details(self, movie_id: int) -> Optional[dict]: | |
"""Get detailed movie information""" | |
return self._make_request(f"/movie/{movie_id}") | |
def get_movie_credits(self, movie_id: int) -> Optional[dict]: | |
"""Get movie cast and crew""" | |
return self._make_request(f"/movie/{movie_id}/credits") | |
def fetch_movie_data(tmdb_client: TMDBClient, movie_ids: List[int], max_workers: int = 5) -> Dict[int, dict]: | |
"""Fetch detailed data for all movies with controlled parallelization""" | |
movies_data = {} | |
def fetch_single_movie(movie_id: int) -> tuple: | |
"""Fetch details and credits for a single movie""" | |
try: | |
# Get basic details | |
details = tmdb_client.get_movie_details(movie_id) | |
if not details: | |
return movie_id, None | |
# Get credits | |
credits = tmdb_client.get_movie_credits(movie_id) | |
if credits: | |
details['credits'] = credits | |
return movie_id, details | |
except Exception as e: | |
logger.error(f"Error fetching movie {movie_id}: {e}") | |
return movie_id, None | |
# Process movies in batches with controlled parallelization | |
batch_size = 50 | |
total_movies = len(movie_ids) | |
for i in range(0, total_movies, batch_size): | |
batch = movie_ids[i:i + batch_size] | |
logger.info(f"Processing batch {i//batch_size + 1}/{(total_movies-1)//batch_size + 1} ({len(batch)} movies)") | |
with ThreadPoolExecutor(max_workers=max_workers) as executor: | |
futures = {executor.submit(fetch_single_movie, movie_id): movie_id for movie_id in batch} | |
for future in as_completed(futures): | |
movie_id, movie_data = future.result() | |
if movie_data: | |
movies_data[movie_id] = movie_data | |
# Sleep between batches to be respectful to API | |
time.sleep(1) | |
logger.info(f"Successfully fetched data for {len(movies_data)}/{total_movies} movies") | |
return movies_data | |
def create_composite_text(movie_data: Dict) -> str: | |
"""Create composite text for embedding from movie data""" | |
parts = [] | |
# Title | |
if movie_data.get('title'): | |
parts.append(f"Title: {movie_data['title']}") | |
# Tagline | |
if movie_data.get('tagline'): | |
parts.append(f"Tagline: {movie_data['tagline']}") | |
# Overview | |
if movie_data.get('overview'): | |
parts.append(f"Overview: {movie_data['overview']}") | |
# Release date | |
if movie_data.get('release_date'): | |
parts.append(f"Release Date: {movie_data['release_date']}") | |
# Original language | |
if movie_data.get('original_language'): | |
parts.append(f"Language: {movie_data['original_language']}") | |
# Spoken languages | |
if movie_data.get('spoken_languages'): | |
languages = [lang.get('iso_639_1', '') for lang in movie_data['spoken_languages'] if lang.get('iso_639_1')] | |
if languages: | |
parts.append(f"Spoken Languages: {', '.join(languages)}") | |
# Genres | |
if movie_data.get('genres'): | |
genres = [genre['name'] for genre in movie_data['genres']] | |
parts.append(f"Genres: {', '.join(genres)}") | |
# Production companies | |
if movie_data.get('production_companies'): | |
companies = [company['name'] for company in movie_data['production_companies']] | |
if companies: | |
parts.append(f"Production Companies: {', '.join(companies)}") | |
# Production countries | |
if movie_data.get('production_countries'): | |
countries = [country['name'] for country in movie_data['production_countries']] | |
if countries: | |
parts.append(f"Production Countries: {', '.join(countries)}") | |
# Budget (only if > 0) | |
if movie_data.get('budget') and movie_data['budget'] > 0: | |
parts.append(f"Budget: ${movie_data['budget']:,}") | |
# Popularity | |
if movie_data.get('popularity'): | |
parts.append(f"Popularity: {movie_data['popularity']}") | |
# Vote average | |
if movie_data.get('vote_average'): | |
parts.append(f"Vote Average: {movie_data['vote_average']}") | |
# Vote count | |
if movie_data.get('vote_count'): | |
parts.append(f"Vote Count: {movie_data['vote_count']}") | |
# Director(s) | |
if movie_data.get('credits', {}).get('crew'): | |
directors = [person['name'] for person in movie_data['credits']['crew'] if person['job'] == 'Director'] | |
if directors: | |
parts.append(f"Director: {', '.join(directors)}") | |
# Top 5 cast | |
if movie_data.get('credits', {}).get('cast'): | |
top_cast = [person['name'] for person in movie_data['credits']['cast'][:5]] | |
if top_cast: | |
parts.append(f"Cast: {', '.join(top_cast)}") | |
return " / ".join(parts) | |
def get_embeddings_batch(texts: List[str], client: OpenAI, model: str = "text-embedding-3-small") -> List[List[float]]: | |
"""Get embeddings for a batch of texts with retry""" | |
max_retries = 3 | |
for attempt in range(max_retries): | |
try: | |
response = client.embeddings.create( | |
input=texts, | |
model=model | |
) | |
return [item.embedding for item in response.data] | |
except Exception as e: | |
logger.error(f"Error getting embeddings (attempt {attempt + 1}): {e}") | |
if attempt < max_retries - 1: | |
time.sleep(2 ** attempt) | |
else: | |
raise | |
def build_index(max_pages: int = 10, model: str = "text-embedding-3-small", use_faiss: bool = True, override_adult_filter: bool = None): | |
"""Main function to build the FAISS index and data files""" | |
settings = get_settings() | |
# Determine adult filtering setting | |
filter_adult = settings.filter_adult_content_bool if hasattr(settings, 'filter_adult_content_bool') else settings.filter_adult_content | |
if override_adult_filter is not None: | |
filter_adult = not override_adult_filter # --include-adult means don't filter | |
logger.info(f"Adult filter override: {'DISABLED' if override_adult_filter else 'ENABLED'}") | |
# Initialize clients with error handling for version compatibility | |
tmdb_client = TMDBClient(settings.tmdb_api_key) | |
try: | |
# Try to create OpenAI client with different approaches for version compatibility | |
try: | |
openai_client = OpenAI(api_key=settings.openai_api_key) | |
except TypeError as e: | |
if "proxies" in str(e): | |
# Fallback for version compatibility issues | |
logger.warning(f"OpenAI client compatibility issue: {e}") | |
logger.info("Trying alternative OpenAI client initialization...") | |
import httpx | |
# Create a basic httpx client without proxies | |
http_client = httpx.Client(timeout=60.0) | |
openai_client = OpenAI(api_key=settings.openai_api_key, http_client=http_client) | |
else: | |
raise | |
except Exception as e: | |
logger.error(f"β Failed to initialize OpenAI client: {e}") | |
logger.error("Please check your OpenAI API key and ensure compatible versions are installed") | |
return | |
# Create data directory with absolute path | |
script_dir = os.path.dirname(os.path.abspath(__file__)) | |
data_dir = os.path.join(script_dir, "data") | |
try: | |
os.makedirs(data_dir, exist_ok=True) | |
# Test write permissions | |
test_file = os.path.join(data_dir, ".write_test") | |
with open(test_file, 'w') as f: | |
f.write("test") | |
os.remove(test_file) | |
logger.info(f"Data directory ready: {data_dir}") | |
except PermissionError as e: | |
logger.error(f"β Permission denied when creating data directory: {e}") | |
logger.error("Make sure the data directory has write permissions") | |
return | |
except Exception as e: | |
logger.error(f"β Failed to create or write to data directory: {e}") | |
return | |
# Check for existing movie data checkpoint | |
movies_data = load_checkpoint(MOVIE_DATA_CHECKPOINT) | |
if movies_data is not None: | |
logger.info(f"π Resuming from checkpoint: {len(movies_data)} movies data found") | |
else: | |
# Step 1: Get movie IDs | |
logger.info(f"Fetching movie IDs from TMDB (max {max_pages} pages)...") | |
movie_ids = tmdb_client.get_popular_movies( | |
max_pages=max_pages, | |
filter_adult=filter_adult | |
) | |
if not movie_ids: | |
logger.error("β No movie IDs retrieved from TMDB") | |
return | |
# Step 2: Fetch detailed movie data | |
logger.info(f"Fetching detailed data for {len(movie_ids)} movies...") | |
movies_data = fetch_movie_data(tmdb_client, movie_ids) | |
if not movies_data: | |
logger.error("β No movie data retrieved") | |
return | |
# Additional filtering at the detail level (double-check) | |
if filter_adult: | |
original_count = len(movies_data) | |
movies_data = {k: v for k, v in movies_data.items() if not v.get('adult', False)} | |
filtered_count = original_count - len(movies_data) | |
if filtered_count > 0: | |
logger.info(f"Filtered out {filtered_count} adult movies at detail level") | |
# Save movie data checkpoint | |
save_checkpoint(movies_data, MOVIE_DATA_CHECKPOINT) | |
# Step 3: Create composite texts and process embeddings in batches | |
logger.info("Creating embeddings...") | |
embeddings = [] | |
id_map = {} | |
movie_metadata = {} | |
processed_movie_ids = set() | |
batch_size = 20 # Process 20 texts at a time | |
# Check for existing embedding progress | |
embedding_checkpoint = load_checkpoint(EMBEDDINGS_CHECKPOINT) | |
metadata_checkpoint = load_checkpoint(METADATA_CHECKPOINT) | |
if embedding_checkpoint is not None and metadata_checkpoint is not None: | |
embeddings = embedding_checkpoint['embeddings'] | |
id_map = embedding_checkpoint['id_map'] | |
processed_movie_ids = set(embedding_checkpoint['processed_movie_ids']) | |
movie_metadata = metadata_checkpoint | |
logger.info(f"π Resuming embeddings from checkpoint: {len(embeddings)} embeddings found") | |
else: | |
logger.info("Starting embeddings from scratch") | |
# Process remaining movies | |
remaining_movies = {k: v for k, v in movies_data.items() if k not in processed_movie_ids} | |
logger.info(f"Processing {len(remaining_movies)} remaining movies") | |
composite_texts = [] | |
current_movie_ids = [] | |
for movie_id, movie_data in remaining_movies.items(): | |
# Create composite text | |
composite_text = create_composite_text(movie_data) | |
composite_texts.append(composite_text) | |
current_movie_ids.append(movie_id) | |
# Store metadata | |
release_year = 0 | |
if movie_data.get("release_date"): | |
try: | |
release_year = int(movie_data["release_date"][:4]) | |
except (ValueError, IndexError): | |
release_year = 0 | |
movie_metadata[str(movie_id)] = { | |
"id": movie_id, | |
"title": movie_data.get("title", ""), | |
"year": release_year, | |
"poster_path": movie_data.get("poster_path"), | |
"release_date": movie_data.get("release_date"), | |
"genres": [g["name"] for g in movie_data.get("genres", [])] | |
} | |
# Process batch when full | |
if len(composite_texts) >= batch_size: | |
logger.info(f"Processing embedding batch ({len(embeddings)} done, {len(composite_texts)} in batch)") | |
try: | |
batch_embeddings = get_embeddings_batch(composite_texts, openai_client, model) | |
embeddings.extend(batch_embeddings) | |
# Update ID mapping and processed set | |
for i, mid in enumerate(current_movie_ids): | |
id_map[str(mid)] = len(id_map) | |
processed_movie_ids.add(mid) | |
# Save progress checkpoints | |
embedding_data = { | |
'embeddings': embeddings, | |
'id_map': id_map, | |
'processed_movie_ids': list(processed_movie_ids) | |
} | |
save_checkpoint(embedding_data, EMBEDDINGS_CHECKPOINT) | |
save_checkpoint(movie_metadata, METADATA_CHECKPOINT) | |
# Clear batch | |
composite_texts = [] | |
current_movie_ids = [] | |
# Sleep between batches | |
time.sleep(0.5) | |
except Exception as e: | |
logger.error(f"Failed to process batch: {e}") | |
logger.info("Progress has been saved, you can restart the script to resume") | |
return | |
# Process remaining texts | |
if composite_texts: | |
logger.info(f"Processing final embedding batch ({len(composite_texts)} texts)") | |
try: | |
batch_embeddings = get_embeddings_batch(composite_texts, openai_client, model) | |
embeddings.extend(batch_embeddings) | |
for i, mid in enumerate(current_movie_ids): | |
id_map[str(mid)] = len(id_map) | |
processed_movie_ids.add(mid) | |
# Save final progress | |
embedding_data = { | |
'embeddings': embeddings, | |
'id_map': id_map, | |
'processed_movie_ids': list(processed_movie_ids) | |
} | |
save_checkpoint(embedding_data, EMBEDDINGS_CHECKPOINT) | |
save_checkpoint(movie_metadata, METADATA_CHECKPOINT) | |
except Exception as e: | |
logger.error(f"Failed to process final batch: {e}") | |
logger.info("Progress has been saved, you can restart the script to resume") | |
return | |
if not embeddings: | |
logger.error("β No embeddings generated") | |
return | |
logger.info(f"Generated {len(embeddings)} embeddings") | |
# Step 4: Save embeddings as numpy array | |
embeddings_array = np.array(embeddings, dtype=np.float32) | |
embeddings_path = os.path.join(data_dir, "movies.npy") | |
try: | |
np.save(embeddings_path, embeddings_array) | |
logger.info(f"Saved embeddings matrix: {embeddings_array.shape}") | |
except Exception as e: | |
logger.error(f"β Failed to save embeddings: {e}") | |
return | |
# Step 5: Build and save FAISS index | |
if use_faiss: | |
logger.info("Building FAISS index...") | |
dimension = embeddings_array.shape[1] | |
# Choose index type based on size | |
if len(embeddings) < 10000: | |
# For smaller datasets, use flat index | |
index = faiss.IndexFlatL2(dimension) | |
else: | |
# For larger datasets, use IVF index | |
nlist = min(int(np.sqrt(len(embeddings))), 1000) | |
quantizer = faiss.IndexFlatL2(dimension) | |
index = faiss.IndexIVFFlat(quantizer, dimension, nlist) | |
# Train the index | |
index.train(embeddings_array) | |
index.add(embeddings_array) | |
index_path = os.path.join(data_dir, "faiss.index") | |
try: | |
faiss.write_index(index, index_path) | |
logger.info(f"FAISS index saved (type: {type(index).__name__}, dimension: {dimension})") | |
except Exception as e: | |
logger.error(f"β Failed to save FAISS index: {e}") | |
return | |
# Step 6: Save metadata files | |
id_map_path = os.path.join(data_dir, "id_map.json") | |
metadata_path = os.path.join(data_dir, "movie_metadata.json") | |
try: | |
with open(id_map_path, "w") as f: | |
json.dump(id_map, f) | |
with open(metadata_path, "w") as f: | |
json.dump(movie_metadata, f) | |
logger.info("β Index built successfully!") | |
logger.info(f" - {len(embeddings)} movies indexed") | |
logger.info(f" - Embedding model: {model}") | |
logger.info(f" - Files saved in {data_dir}") | |
logger.info(f" * movies.npy: embeddings matrix") | |
logger.info(f" * id_map.json: TMDB ID to matrix position mapping") | |
logger.info(f" * movie_metadata.json: movie metadata") | |
if use_faiss: | |
logger.info(f" * faiss.index: FAISS search index") | |
# Cleanup checkpoints | |
cleanup_checkpoints() | |
except Exception as e: | |
logger.error(f"β Failed to save metadata files: {e}") | |
return | |
# Remove the old functions that are no longer needed | |
# create_movie_embedding and load_movie_data are replaced by the new implementation | |
if __name__ == "__main__": | |
parser = argparse.ArgumentParser(description="Build movie embeddings index from TMDB data") | |
parser.add_argument("--max-pages", type=int, default=10, | |
help="Maximum pages to fetch from TMDB popular movies (default: 10)") | |
parser.add_argument("--model", type=str, default="text-embedding-3-small", | |
help="OpenAI embedding model to use (default: text-embedding-3-small)") | |
parser.add_argument("--no-faiss", action="store_true", | |
help="Skip building FAISS index") | |
parser.add_argument("--include-adult", action="store_true", | |
help="Include adult movies (overrides FILTER_ADULT_CONTENT setting)") | |
args = parser.parse_args() | |
build_index( | |
max_pages=args.max_pages, | |
model=args.model, | |
use_faiss=not args.no_faiss, | |
override_adult_filter=args.include_adult | |
) |