Wauplin's picture
Wauplin HF Staff
Database backup
3ca3e6a verified
raw
history blame
5.09 kB
import os
from contextlib import asynccontextmanager, contextmanager
from typing import Annotated, Generator
from fastapi import Depends, FastAPI, HTTPException, Request
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import FileResponse, RedirectResponse
from fastapi.staticfiles import StaticFiles
from huggingface_hub import OAuthInfo, attach_huggingface_oauth, parse_huggingface_oauth
from sqlalchemy.engine import Engine
from sqlmodel import Session, SQLModel, create_engine
from . import constants
_ENGINE_SINGLETON: Engine | None = None
# OAuth utilities
async def _oauth_info_optional(request: Request) -> OAuthInfo | None:
return parse_huggingface_oauth(request)
async def _oauth_info_required(request: Request) -> OAuthInfo:
oauth_info = parse_huggingface_oauth(request)
if oauth_info is None:
raise HTTPException(
status_code=401, detail="Unauthorized. Please Sign in with Hugging Face."
)
return oauth_info
OptionalOAuth = Annotated[OAuthInfo | None, Depends(_oauth_info_optional)]
RequiredOAuth = Annotated[OAuthInfo, Depends(_oauth_info_required)]
# Database utilities
def _get_engine() -> Engine:
"""Get the engine."""
global _ENGINE_SINGLETON
if _ENGINE_SINGLETON is None:
_ENGINE_SINGLETON = create_engine(constants.DATABASE_URL)
return _ENGINE_SINGLETON
@contextmanager
def get_session() -> Generator[Session, None, None]:
"""Get a session from the engine."""
engine = _get_engine()
with Session(engine) as session:
yield session
@asynccontextmanager
async def _database_lifespan(app: FastAPI):
"""Handle database lifespan.
1. If backup is enabled enabled,
a. Try to load backup from remote dataset. If it fails, delete local database for a fresh start.
b. Start back-up scheduler.
2. If disabled, create local database file or reuse existing one.
3. Initialize database.
4. Yield control to FastAPI app.
5. Close database + force push backup to remote dataset.
"""
if constants.BACKUP_DB:
from huggingface_hub import CommitScheduler, hf_hub_download
print("Back-up database is enabled")
# Try to load backup from remote dataset
print("Trying to load backup from remote dataset...")
try:
hf_hub_download(
repo_id=constants.BACKUP_DATASET_ID, # type: ignore[arg-type]
repo_type="dataset",
filename="database.db",
token=constants.HF_TOKEN,
local_dir=constants.DATABASE_PATH,
force_download=True,
)
except Exception:
# If backup is enabled but no backup is found, delete local database to prevent confusion.
print("Couldn't find backup in remote dataset.")
print("Deleting local database for a fresh start.")
os.remove(constants.DATABASE_FILE)
# Start back-up scheduler
print("Starting back-up scheduler (save every 5 minutes)...")
scheduler = CommitScheduler(
repo_id=constants.BACKUP_DATASET_ID, # type: ignore[arg-type]
folder_path=constants.DATABASE_PATH,
allow_patterns="database.db",
token=constants.HF_TOKEN,
private=True,
repo_type="dataset",
every=5,
)
engine = _get_engine()
SQLModel.metadata.create_all(engine)
yield
print("Closing database...")
global _ENGINE_SINGLETON
if _ENGINE_SINGLETON is not None:
_ENGINE_SINGLETON.dispose()
_ENGINE_SINGLETON = None
if constants.BACKUP_DB:
print("Pushing backup to remote dataset...")
scheduler.push_to_hub()
def create_app() -> FastAPI:
# FastAPI app
app = FastAPI(lifespan=_database_lifespan)
# Set CORS headers
app.add_middleware(
CORSMiddleware,
allow_origins=[
# Can't use "*" because frontend doesn't like it with "credentials: true"
"http://localhost:5173",
"http://0.0.0.0:9481",
"http://localhost:9481",
"http://127.0.0.1:9481",
],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Mount frontend from dist directory (if configured)
if constants.SERVE_FRONTEND:
# Production => Serve frontend from dist directory
app.mount(
"/assets",
StaticFiles(directory=constants.FRONTEND_ASSETS_PATH), # type: ignore[invalid-argument-type]
name="assets",
)
@app.get("/")
async def serve_frontend():
return FileResponse(constants.FRONTEND_INDEX_PATH) # type: ignore[invalid-argument-type]
else:
# Development => Redirect to dev frontend
@app.get("/")
async def redirect_to_frontend():
return RedirectResponse("http://localhost:5173/")
# Set up Hugging Face OAuth
# To get OAuthInfo in an endpoint
attach_huggingface_oauth(app)
return app