File size: 3,571 Bytes
e92d5c0 0b85fad e92d5c0 51e559a e92d5c0 51e559a e92d5c0 0b85fad 2367cc3 0b85fad e92d5c0 2367cc3 e92d5c0 2367cc3 e92d5c0 0b85fad e92d5c0 51e559a e92d5c0 51e559a e92d5c0 0b85fad e92d5c0 0b85fad e92d5c0 0b85fad e92d5c0 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 |
from contextlib import asynccontextmanager
from fastapi import Depends, FastAPI, HTTPException, Request
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import FileResponse, RedirectResponse
from fastapi.staticfiles import StaticFiles
from huggingface_hub import OAuthInfo, attach_huggingface_oauth, parse_huggingface_oauth
from sqlmodel import select
from . import constants
from .database import get_session, init_db
from .schemas import UserCount
# Initialize database on startup
@asynccontextmanager
async def lifespan(app: FastAPI):
init_db()
yield
# FastAPI app
app = FastAPI(lifespan=lifespan)
# Set CORS headers
app.add_middleware(
CORSMiddleware,
allow_origins=[
# Can't use "*" because frontend doesn't like it with "credentials: true"
"http://localhost:5173",
"http://0.0.0.0:9481",
"http://localhost:9481",
"http://127.0.0.1:9481",
],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Mount frontend from dist directory (if configured)
if constants.SERVE_FRONTEND:
# Production => Serve frontend from dist directory
app.mount(
"/assets",
StaticFiles(directory=constants.FRONTEND_ASSETS_PATH),
name="assets",
)
@app.get("/")
async def serve_frontend():
return FileResponse(constants.FRONTEND_INDEX_PATH)
else:
# Development => Redirect to dev frontend
@app.get("/")
async def redirect_to_frontend():
return RedirectResponse("http://localhost:5173/")
# Set up Hugging Face OAuth
# To get OAuthInfo in an endpoint
attach_huggingface_oauth(app)
async def oauth_info_optional(request: Request) -> OAuthInfo | None:
return parse_huggingface_oauth(request)
async def oauth_info_required(request: Request) -> OAuthInfo:
oauth_info = parse_huggingface_oauth(request)
if oauth_info is None:
raise HTTPException(
status_code=401, detail="Unauthorized. Please Sign in with Hugging Face."
)
return oauth_info
# Health check endpoint
@app.get("/api/health")
async def health():
"""Health check endpoint."""
return {"status": "ok"}
# User endpoints
@app.get("/api/user")
async def get_user(oauth_info: OAuthInfo | None = Depends(oauth_info_optional)):
"""Get user information."""
return {
"connected": oauth_info is not None,
"username": oauth_info.user_info.preferred_username if oauth_info else None,
}
@app.get("/api/user/count")
async def get_user_count(
oauth_info: OAuthInfo = Depends(oauth_info_required),
) -> UserCount:
"""Get user count."""
with get_session() as session:
statement = select(UserCount).where(UserCount.name == oauth_info.user_info.name)
user_count = session.exec(statement).first()
if user_count is None:
user_count = UserCount(name=oauth_info.user_info.name, count=0)
return user_count
@app.post("/api/user/count/increment")
async def increment_user_count(
oauth_info: OAuthInfo = Depends(oauth_info_required),
) -> UserCount:
"""Increment user count."""
with get_session() as session:
statement = select(UserCount).where(UserCount.name == oauth_info.user_info.name)
user_count = session.exec(statement).first()
if user_count is None:
user_count = UserCount(name=oauth_info.user_info.name, count=0)
user_count.count += 1
session.add(user_count)
session.commit()
session.refresh(user_count)
return user_count
|