import tempfile from contextlib import asynccontextmanager, contextmanager from typing import Annotated, Generator from apscheduler.schedulers.background import BackgroundScheduler from fastapi import Depends, FastAPI, HTTPException, Request from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import FileResponse, RedirectResponse from fastapi.staticfiles import StaticFiles from huggingface_hub import ( OAuthInfo, attach_huggingface_oauth, create_repo, parse_huggingface_oauth, snapshot_download, upload_folder, ) from sqlalchemy.engine import Engine from sqlmodel import Session, SQLModel, create_engine from . import constants from .parquet import export_to_parquet, import_from_parquet _ENGINE_SINGLETON: Engine | None = None # OAuth utilities async def _oauth_info_optional(request: Request) -> OAuthInfo | None: return parse_huggingface_oauth(request) async def _oauth_info_required(request: Request) -> OAuthInfo: oauth_info = parse_huggingface_oauth(request) if oauth_info is None: raise HTTPException( status_code=401, detail="Unauthorized. Please Sign in with Hugging Face." ) return oauth_info OptionalOAuth = Annotated[OAuthInfo | None, Depends(_oauth_info_optional)] RequiredOAuth = Annotated[OAuthInfo, Depends(_oauth_info_required)] def get_engine() -> Engine: """Get the engine.""" global _ENGINE_SINGLETON if _ENGINE_SINGLETON is None: _ENGINE_SINGLETON = create_engine(constants.DATABASE_URL) return _ENGINE_SINGLETON @contextmanager def get_session() -> Generator[Session, None, None]: """Get a session from the engine.""" engine = get_engine() with Session(engine) as session: yield session @asynccontextmanager async def _database_lifespan(app: FastAPI): """Handle database lifespan. 1. If backup is enabled enabled, a. Try to load backup from remote dataset. If it fails, delete local database for a fresh start. b. Start back-up scheduler. 2. If disabled, create local database file or reuse existing one. 3. Initialize database. 4. Yield control to FastAPI app. 5. Close database + force push backup to remote dataset. """ scheduler = BackgroundScheduler() engine = get_engine() SQLModel.metadata.create_all(engine) if constants.BACKUP_DB: print("Back-up database is enabled") # Create remote dataset if it doesn't exist repo_url = create_repo( repo_id=constants.BACKUP_DATASET_ID, # type: ignore[arg-type] repo_type="dataset", token=constants.HF_TOKEN, private=True, exist_ok=True, ) print(f"Backup dataset: {repo_url}") repo_id = repo_url.repo_id # Try to load backup from remote dataset print("Trying to load backup from remote dataset...") try: backup_dir = snapshot_download( repo_id=repo_id, repo_type="dataset", token=constants.HF_TOKEN, allow_patterns="*.parquet", ) except Exception: # If backup is enabled but no backup is found, delete local database to prevent confusion. print("Couldn't find backup in remote dataset.") print("Deleting local database for a fresh start.") engine = get_engine() SQLModel.metadata.drop_all(engine) SQLModel.metadata.create_all(engine) # Import parquet files to database import_from_parquet(get_engine(), backup_dir) # Start back-up scheduler scheduler.add_job(_backup_to_hub, args=[repo_id], trigger="interval", minutes=5) scheduler.start() yield print("Closing database...") global _ENGINE_SINGLETON if _ENGINE_SINGLETON is not None: _ENGINE_SINGLETON.dispose() _ENGINE_SINGLETON = None if constants.BACKUP_DB: print("Pushing backup to remote dataset...") _backup_to_hub(repo_id) def _backup_to_hub(repo_id: str) -> None: """Export backup to remote dataset as parquet files.""" with tempfile.TemporaryDirectory() as tmp_dir: export_to_parquet(get_engine(), tmp_dir) upload_folder( repo_id=repo_id, folder_path=tmp_dir, token=constants.HF_TOKEN, repo_type="dataset", allow_patterns="*.parquet", commit_message="Backup database as parquet", delete_patterns=["*.parquet"], ) def create_app() -> FastAPI: # FastAPI app app = FastAPI(lifespan=_database_lifespan) # Set CORS headers app.add_middleware( CORSMiddleware, allow_origins=[ # Can't use "*" because frontend doesn't like it with "credentials: true" "http://localhost:5173", "http://0.0.0.0:9481", "http://localhost:9481", "http://127.0.0.1:9481", ], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) # Mount frontend from dist directory (if configured) if constants.SERVE_FRONTEND: # Production => Serve frontend from dist directory app.mount( "/assets", StaticFiles(directory=constants.FRONTEND_ASSETS_PATH), # type: ignore[invalid-argument-type] name="assets", ) @app.get("/") async def serve_frontend(): return FileResponse(constants.FRONTEND_INDEX_PATH) # type: ignore[invalid-argument-type] else: # Development => Redirect to dev frontend @app.get("/") async def redirect_to_frontend(): return RedirectResponse("http://localhost:5173/") # Set up Hugging Face OAuth # To get OAuthInfo in an endpoint attach_huggingface_oauth(app) return app