|
import tempfile |
|
from contextlib import asynccontextmanager, contextmanager |
|
from typing import Annotated, Generator |
|
|
|
from apscheduler.schedulers.background import BackgroundScheduler |
|
from fastapi import Depends, FastAPI, HTTPException, Request |
|
from fastapi.middleware.cors import CORSMiddleware |
|
from fastapi.responses import FileResponse, RedirectResponse |
|
from fastapi.staticfiles import StaticFiles |
|
from huggingface_hub import ( |
|
OAuthInfo, |
|
attach_huggingface_oauth, |
|
create_repo, |
|
parse_huggingface_oauth, |
|
snapshot_download, |
|
upload_folder, |
|
) |
|
from sqlalchemy.engine import Engine |
|
from sqlmodel import Session, SQLModel, create_engine |
|
|
|
from . import constants |
|
from .parquet import export_to_parquet, import_from_parquet |
|
|
|
_ENGINE_SINGLETON: Engine | None = None |
|
|
|
|
|
|
|
|
|
async def _oauth_info_optional(request: Request) -> OAuthInfo | None: |
|
return parse_huggingface_oauth(request) |
|
|
|
|
|
async def _oauth_info_required(request: Request) -> OAuthInfo: |
|
oauth_info = parse_huggingface_oauth(request) |
|
if oauth_info is None: |
|
raise HTTPException( |
|
status_code=401, detail="Unauthorized. Please Sign in with Hugging Face." |
|
) |
|
return oauth_info |
|
|
|
|
|
OptionalOAuth = Annotated[OAuthInfo | None, Depends(_oauth_info_optional)] |
|
RequiredOAuth = Annotated[OAuthInfo, Depends(_oauth_info_required)] |
|
|
|
|
|
def get_engine() -> Engine: |
|
"""Get the engine.""" |
|
global _ENGINE_SINGLETON |
|
if _ENGINE_SINGLETON is None: |
|
_ENGINE_SINGLETON = create_engine(constants.DATABASE_URL) |
|
return _ENGINE_SINGLETON |
|
|
|
|
|
@contextmanager |
|
def get_session() -> Generator[Session, None, None]: |
|
"""Get a session from the engine.""" |
|
engine = get_engine() |
|
with Session(engine) as session: |
|
yield session |
|
|
|
|
|
@asynccontextmanager |
|
async def _database_lifespan(app: FastAPI): |
|
"""Handle database lifespan. |
|
|
|
1. If backup is enabled enabled, |
|
a. Try to load backup from remote dataset. If it fails, delete local database for a fresh start. |
|
b. Start back-up scheduler. |
|
2. If disabled, create local database file or reuse existing one. |
|
3. Initialize database. |
|
4. Yield control to FastAPI app. |
|
5. Close database + force push backup to remote dataset. |
|
""" |
|
scheduler = BackgroundScheduler() |
|
engine = get_engine() |
|
SQLModel.metadata.create_all(engine) |
|
|
|
if constants.BACKUP_DB: |
|
print("Back-up database is enabled") |
|
|
|
|
|
repo_url = create_repo( |
|
repo_id=constants.BACKUP_DATASET_ID, |
|
repo_type="dataset", |
|
token=constants.HF_TOKEN, |
|
private=True, |
|
exist_ok=True, |
|
) |
|
print(f"Backup dataset: {repo_url}") |
|
repo_id = repo_url.repo_id |
|
|
|
|
|
print("Trying to load backup from remote dataset...") |
|
try: |
|
backup_dir = snapshot_download( |
|
repo_id=repo_id, |
|
repo_type="dataset", |
|
token=constants.HF_TOKEN, |
|
allow_patterns="*.parquet", |
|
) |
|
except Exception: |
|
|
|
print("Couldn't find backup in remote dataset.") |
|
print("Deleting local database for a fresh start.") |
|
|
|
engine = get_engine() |
|
SQLModel.metadata.drop_all(engine) |
|
SQLModel.metadata.create_all(engine) |
|
|
|
|
|
import_from_parquet(get_engine(), backup_dir) |
|
|
|
|
|
scheduler.add_job(_backup_to_hub, args=[repo_id], trigger="interval", minutes=5) |
|
scheduler.start() |
|
|
|
yield |
|
|
|
print("Closing database...") |
|
global _ENGINE_SINGLETON |
|
if _ENGINE_SINGLETON is not None: |
|
_ENGINE_SINGLETON.dispose() |
|
_ENGINE_SINGLETON = None |
|
|
|
if constants.BACKUP_DB: |
|
print("Pushing backup to remote dataset...") |
|
_backup_to_hub(repo_id) |
|
|
|
|
|
def _backup_to_hub(repo_id: str) -> None: |
|
"""Export backup to remote dataset as parquet files.""" |
|
with tempfile.TemporaryDirectory() as tmp_dir: |
|
export_to_parquet(get_engine(), tmp_dir) |
|
|
|
upload_folder( |
|
repo_id=repo_id, |
|
folder_path=tmp_dir, |
|
token=constants.HF_TOKEN, |
|
repo_type="dataset", |
|
allow_patterns="*.parquet", |
|
commit_message="Backup database as parquet", |
|
delete_patterns=["*.parquet"], |
|
) |
|
|
|
|
|
def create_app() -> FastAPI: |
|
|
|
app = FastAPI(lifespan=_database_lifespan) |
|
|
|
|
|
app.add_middleware( |
|
CORSMiddleware, |
|
allow_origins=[ |
|
|
|
"http://localhost:5173", |
|
"http://0.0.0.0:9481", |
|
"http://localhost:9481", |
|
"http://127.0.0.1:9481", |
|
], |
|
allow_credentials=True, |
|
allow_methods=["*"], |
|
allow_headers=["*"], |
|
) |
|
|
|
|
|
if constants.SERVE_FRONTEND: |
|
|
|
app.mount( |
|
"/assets", |
|
StaticFiles(directory=constants.FRONTEND_ASSETS_PATH), |
|
name="assets", |
|
) |
|
|
|
@app.get("/") |
|
async def serve_frontend(): |
|
return FileResponse(constants.FRONTEND_INDEX_PATH) |
|
|
|
else: |
|
|
|
@app.get("/") |
|
async def redirect_to_frontend(): |
|
return RedirectResponse("http://localhost:5173/") |
|
|
|
|
|
|
|
attach_huggingface_oauth(app) |
|
|
|
return app |
|
|