from contextlib import asynccontextmanager from fastapi import Depends, FastAPI, HTTPException, Request from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import FileResponse, RedirectResponse from fastapi.staticfiles import StaticFiles from huggingface_hub import OAuthInfo, attach_huggingface_oauth, parse_huggingface_oauth from sqlmodel import select from . import constants from .database import get_session, init_db from .schemas import UserCount # Initialize database on startup @asynccontextmanager async def lifespan(app: FastAPI): init_db() yield # FastAPI app app = FastAPI(lifespan=lifespan) # Set CORS headers app.add_middleware( CORSMiddleware, allow_origins=[ # Can't use "*" because frontend doesn't like it with "credentials: true" "http://localhost:5173", "http://0.0.0.0:9481", "http://localhost:9481", "http://127.0.0.1:9481", ], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) # Mount frontend from dist directory (if configured) if constants.SERVE_FRONTEND: # Production => Serve frontend from dist directory app.mount( "/assets", StaticFiles(directory=constants.FRONTEND_ASSETS_PATH), name="assets", ) @app.get("/") async def serve_frontend(): return FileResponse(constants.FRONTEND_INDEX_PATH) else: # Development => Redirect to dev frontend @app.get("/") async def redirect_to_frontend(): return RedirectResponse("http://localhost:5173/") # Set up Hugging Face OAuth # To get OAuthInfo in an endpoint attach_huggingface_oauth(app) async def oauth_info_optional(request: Request) -> OAuthInfo | None: return parse_huggingface_oauth(request) async def oauth_info_required(request: Request) -> OAuthInfo: oauth_info = parse_huggingface_oauth(request) if oauth_info is None: raise HTTPException( status_code=401, detail="Unauthorized. Please Sign in with Hugging Face." ) return oauth_info # Health check endpoint @app.get("/api/health") async def health(): """Health check endpoint.""" return {"status": "ok"} # User endpoints @app.get("/api/user") async def get_user(oauth_info: OAuthInfo | None = Depends(oauth_info_optional)): """Get user information.""" return { "connected": oauth_info is not None, "username": oauth_info.user_info.preferred_username if oauth_info else None, } @app.get("/api/user/count") async def get_user_count( oauth_info: OAuthInfo = Depends(oauth_info_required), ) -> UserCount: """Get user count.""" with get_session() as session: statement = select(UserCount).where(UserCount.name == oauth_info.user_info.name) user_count = session.exec(statement).first() if user_count is None: user_count = UserCount(name=oauth_info.user_info.name, count=0) return user_count @app.post("/api/user/count/increment") async def increment_user_count( oauth_info: OAuthInfo = Depends(oauth_info_required), ) -> UserCount: """Increment user count.""" with get_session() as session: statement = select(UserCount).where(UserCount.name == oauth_info.user_info.name) user_count = session.exec(statement).first() if user_count is None: user_count = UserCount(name=oauth_info.user_info.name, count=0) user_count.count += 1 session.add(user_count) session.commit() session.refresh(user_count) return user_count