yonnel
Enhance environment configuration; implement lazy initialization for vector updater and improve error handling in imports
b8ca8ae
import os
import json
import numpy as np
import faiss
from fastapi import FastAPI, HTTPException, Depends, status
from fastapi.middleware.cors import CORSMiddleware
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
from pydantic import BaseModel
from typing import List, Optional
import logging
import time
# Try different import patterns to handle both direct execution and module execution
try:
from .settings import get_settings
except ImportError:
try:
from app.settings import get_settings
except ImportError:
from settings import get_settings
# Configure logging
logging.basicConfig(level=os.getenv("LOG_LEVEL", "INFO").upper())
logger = logging.getLogger(__name__)
# Security
security = HTTPBearer()
def verify_token(credentials: HTTPAuthorizationCredentials = Depends(security)):
expected_token = os.getenv("API_TOKEN")
if not expected_token:
raise HTTPException(status_code=500, detail="API token not configured")
if credentials.credentials != expected_token:
raise HTTPException(status_code=401, detail="Invalid token")
return credentials.credentials
# Pydantic models
class ExploreRequest(BaseModel):
liked_ids: List[int]
disliked_ids: List[int] = []
top_k: int = 400
class MovieResult(BaseModel):
id: int
title: str
year: int
poster_path: Optional[str]
genres: List[str]
coords: List[float]
class ExploreResponse(BaseModel):
movies: List[MovieResult]
bary: List[float]
center: List[float]
# Global variables for loaded data
vectors = None
id_map = None
faiss_index = None
movie_metadata = None
def load_data():
"""Load FAISS index, vectors, and metadata on startup"""
try:
# Load vectors
vectors = np.load("app/data/movies.npy")
logger.info(f"Loaded {vectors.shape[0]} movie vectors of dimension {vectors.shape[1]}")
# Load ID mapping
with open("app/data/id_map.json", "r") as f:
id_map = json.load(f)
logger.info(f"Loaded ID mapping for {len(id_map)} movies")
# Load FAISS index
faiss_index = faiss.read_index("app/data/faiss.index")
logger.info(f"Loaded FAISS index with {faiss_index.ntotal} vectors")
# Load movie metadata
with open("app/data/movie_metadata.json", "r") as f:
movie_metadata = json.load(f)
logger.info(f"Loaded metadata for {len(movie_metadata)} movies")
return vectors, id_map, faiss_index, movie_metadata
except Exception as e:
logger.error(f"Failed to load data: {e}")
raise
def build_plane(likes: np.ndarray, dislikes: np.ndarray = None, dim: int = 2):
"""
Build user subspace from liked/disliked movies
Returns (axes, center) where axes is 2xD orthonormal matrix
"""
n_likes = likes.shape[0] if likes is not None else 0
d = vectors.shape[1]
# Compute composite vector: +liked - 0.5*disliked
if n_likes == 0:
# Cold start: use global average
center = vectors.mean(0)
# Create random orthonormal basis
axes = np.random.randn(dim, d)
axes[0] /= np.linalg.norm(axes[0])
for i in range(1, dim):
for j in range(i):
axes[i] -= np.dot(axes[i], axes[j]) * axes[j]
axes[i] /= np.linalg.norm(axes[i])
else:
# Compute composite from likes and dislikes
composite = likes.mean(0)
if dislikes is not None and dislikes.shape[0] > 0:
composite -= 0.5 * dislikes.mean(0)
if n_likes == 1:
# One like: use as center, random orthogonal axes
center = composite
axis1 = np.random.randn(d)
axis1 /= np.linalg.norm(axis1)
axis2 = np.random.randn(d)
axis2 -= np.dot(axis2, axis1) * axis1
axis2 /= np.linalg.norm(axis2)
axes = np.vstack([axis1, axis2])
elif n_likes == 2:
# Two likes: line between them
center = likes.mean(0)
axis1 = likes[1] - likes[0]
axis1 /= np.linalg.norm(axis1)
axis2 = np.random.randn(d)
axis2 -= np.dot(axis2, axis1) * axis1
axis2 /= np.linalg.norm(axis2)
axes = np.vstack([axis1, axis2])
else:
# 3+ likes: PCA plane
center = likes.mean(0)
likes_centered = likes - center
u, s, vt = np.linalg.svd(likes_centered, full_matrices=False)
axes = vt[:2] # First 2 principal components
return axes, center
def assign_spiral_coords(n_movies: int):
"""
Assign 2D grid coordinates in outward spiral pattern
Returns array of shape (n_movies, 2) with integer coordinates
"""
coords = np.zeros((n_movies, 2), dtype=int)
if n_movies == 0:
return coords
coords[0] = [0, 0] # Start at origin
if n_movies == 1:
return coords
# Spiral pattern: right, up, left, down, repeat with increasing distances
dx, dy = [1, 0, -1, 0], [0, 1, 0, -1]
direction = 0
steps = 1
x, y = 0, 0
idx = 1
while idx < n_movies:
for _ in range(2): # Each step count is used twice (except the first)
for _ in range(steps):
if idx >= n_movies:
break
x += dx[direction]
y += dy[direction]
coords[idx] = [x, y]
idx += 1
direction = (direction + 1) % 4
if idx >= n_movies:
break
steps += 1
return coords
def compute_barycenter(liked_indices: List[int], coords: np.ndarray):
"""Compute barycenter of liked movies in 2D grid"""
if not liked_indices:
return [0.0, 0.0]
liked_coords = coords[liked_indices]
bary = liked_coords.mean(0)
return bary.tolist()
# FastAPI app setup
app = FastAPI(title="Karl-Movie Vector Backend", version="1.0.0")
# Ajouter l'import du router admin
try:
from .routers import admin
except ImportError:
from app.routers import admin
# Ajouter le router admin
app.include_router(admin.router)
# CORS configuration
DEV_ORIGINS = [
"http://localhost:5173",
"http://127.0.0.1:5173",
"http://localhost:8888",
"https://*.bolt.run",
"https://*.stackblitz.io",
]
PROD_ORIGINS = ["https://karl.movie"]
origins = DEV_ORIGINS if os.getenv("ENV") != "prod" else PROD_ORIGINS
app.add_middleware(
CORSMiddleware,
allow_origins=origins,
allow_methods=["POST", "GET"],
allow_headers=["*"],
)
@app.on_event("startup")
async def startup_event():
"""Load data on startup"""
global vectors, id_map, faiss_index, movie_metadata
vectors, id_map, faiss_index, movie_metadata = load_data()
# Vérifier et mettre à jour les vecteurs si nécessaire au démarrage
if os.getenv('AUTO_UPDATE_VECTORS', 'false').lower() == 'true':
# Lancer en arrière-plan sans attendre
import asyncio
try:
from .services.vector_updater import VectorUpdater
except ImportError:
from app.services.vector_updater import VectorUpdater
vector_updater = VectorUpdater()
asyncio.create_task(vector_updater.update_vectors_if_needed())
@app.get("/health")
async def health_check():
"""Health check endpoint"""
return {"status": "healthy", "vectors_loaded": vectors is not None}
async def get_movie_from_tmdb(tmdb_id: int):
"""Fetch a single movie from TMDB API if not in local index"""
try:
settings = get_settings()
import requests
url = f"https://api.themoviedb.org/3/movie/{tmdb_id}"
params = {"api_key": settings.tmdb_api_key}
response = requests.get(url, params=params, timeout=10)
if response.status_code == 200:
return response.json()
else:
logger.warning(f"TMDB API returned {response.status_code} for movie {tmdb_id}")
return None
except Exception as e:
logger.error(f"Error fetching movie {tmdb_id} from TMDB: {e}")
return None
@app.post("/explore", response_model=ExploreResponse)
async def explore(
request: ExploreRequest,
token: str = Depends(verify_token)
):
"""
Main endpoint: find movies closest to user's preference subspace
"""
start_time = time.time()
try:
# Ensure top_k doesn't exceed available movies
total_movies = len(vectors) if vectors is not None else 0
actual_top_k = min(request.top_k, total_movies)
if actual_top_k <= 0:
raise HTTPException(status_code=400, detail="No movies available")
# Convert TMDB IDs to internal indices
liked_indices = []
disliked_indices = []
missing_movies = []
for tmdb_id in request.liked_ids:
if str(tmdb_id) in id_map:
liked_indices.append(id_map[str(tmdb_id)])
else:
logger.warning(f"TMDB ID {tmdb_id} not found in index")
# Optionally fetch movie info for debugging
movie_info = await get_movie_from_tmdb(tmdb_id)
if movie_info:
missing_movies.append({
"id": tmdb_id,
"title": movie_info.get("title", "Unknown"),
"release_date": movie_info.get("release_date", "Unknown")
})
logger.info(f"Missing movie: {movie_info.get('title')} ({movie_info.get('release_date', 'Unknown')})")
for tmdb_id in request.disliked_ids:
if str(tmdb_id) in id_map:
disliked_indices.append(id_map[str(tmdb_id)])
else:
logger.warning(f"TMDB ID {tmdb_id} not found in index")
# Log missing movies for debugging
if missing_movies:
logger.info(f"Missing {len(missing_movies)} movies from index: {[m['title'] for m in missing_movies]}")
# Get embedding vectors
liked_vectors = vectors[liked_indices] if liked_indices else None
disliked_vectors = vectors[disliked_indices] if disliked_indices else None
# Build user subspace
axes, center = build_plane(liked_vectors, disliked_vectors)
# Project all vectors onto the 2D subspace
projections = np.dot(vectors - center, axes.T) # Shape: (N, 2)
# Reconstruct vectors in original space
reconstructed = np.dot(projections, axes) + center
# Compute distances to subspace (residuals)
residuals = np.linalg.norm(vectors - reconstructed, axis=1)
# Get top-k closest movies - use proper bounds checking
if actual_top_k >= len(residuals):
# If we want all movies, just sort them
top_k_indices = np.argsort(residuals)
else:
# Use argpartition for efficiency when we want a subset
top_k_indices = np.argpartition(residuals, actual_top_k-1)[:actual_top_k]
top_k_indices = top_k_indices[np.argsort(residuals[top_k_indices])]
# Assign spiral coordinates
spiral_coords = assign_spiral_coords(len(top_k_indices))
# Compute barycenter of liked movies
liked_positions = [i for i, idx in enumerate(top_k_indices) if idx in liked_indices]
bary = compute_barycenter(liked_positions, spiral_coords)
# Translate grid so barycenter is at origin
spiral_coords = spiral_coords - np.array(bary)
# Build response
movies = []
reverse_id_map = {v: k for k, v in id_map.items()}
for i, movie_idx in enumerate(top_k_indices):
tmdb_id = int(reverse_id_map[movie_idx])
metadata = movie_metadata.get(str(tmdb_id), {})
movie = MovieResult(
id=tmdb_id,
title=metadata.get("title", f"Movie {tmdb_id}"),
year=metadata.get("year", 0),
poster_path=metadata.get("poster_path"),
genres=metadata.get("genres", []),
coords=spiral_coords[i].tolist()
)
movies.append(movie)
response = ExploreResponse(
movies=movies,
bary=[0.0, 0.0], # Always [0,0] since we translated
center=center.tolist()
)
elapsed = time.time() - start_time
logger.info(f"Explore request processed in {elapsed:.3f}s - {len(request.liked_ids)} likes ({len(liked_indices)} found), {len(request.disliked_ids)} dislikes ({len(disliked_indices)} found), {len(movies)} results")
return response
except Exception as e:
logger.error(f"Error processing explore request: {e}")
raise HTTPException(status_code=500, detail=str(e))
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=8000)