File size: 3,571 Bytes
e92d5c0
 
 
0b85fad
e92d5c0
51e559a
e92d5c0
 
51e559a
 
e92d5c0
 
 
 
 
 
 
 
 
 
 
 
 
0b85fad
 
2367cc3
0b85fad
 
e92d5c0
2367cc3
e92d5c0
2367cc3
 
 
e92d5c0
0b85fad
 
 
 
 
e92d5c0
51e559a
e92d5c0
51e559a
 
 
 
 
 
 
 
 
 
e92d5c0
 
 
 
 
 
 
 
 
 
 
 
 
 
0b85fad
e92d5c0
 
 
 
 
 
 
 
 
 
 
0b85fad
 
e92d5c0
0b85fad
e92d5c0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
from contextlib import asynccontextmanager

from fastapi import Depends, FastAPI, HTTPException, Request
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import FileResponse, RedirectResponse
from fastapi.staticfiles import StaticFiles
from huggingface_hub import OAuthInfo, attach_huggingface_oauth, parse_huggingface_oauth
from sqlmodel import select

from . import constants
from .database import get_session, init_db
from .schemas import UserCount


# Initialize database on startup
@asynccontextmanager
async def lifespan(app: FastAPI):
    init_db()
    yield


# FastAPI app
app = FastAPI(lifespan=lifespan)


# Set CORS headers
app.add_middleware(
    CORSMiddleware,
    allow_origins=[
        # Can't use "*" because frontend doesn't like it with "credentials: true"
        "http://localhost:5173",
        "http://0.0.0.0:9481",
        "http://localhost:9481",
        "http://127.0.0.1:9481",
    ],
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)

# Mount frontend from dist directory (if configured)
if constants.SERVE_FRONTEND:
    # Production => Serve frontend from dist directory
    app.mount(
        "/assets",
        StaticFiles(directory=constants.FRONTEND_ASSETS_PATH),
        name="assets",
    )

    @app.get("/")
    async def serve_frontend():
        return FileResponse(constants.FRONTEND_INDEX_PATH)

else:
    # Development => Redirect to dev frontend
    @app.get("/")
    async def redirect_to_frontend():
        return RedirectResponse("http://localhost:5173/")


# Set up Hugging Face OAuth
# To get OAuthInfo in an endpoint
attach_huggingface_oauth(app)


async def oauth_info_optional(request: Request) -> OAuthInfo | None:
    return parse_huggingface_oauth(request)


async def oauth_info_required(request: Request) -> OAuthInfo:
    oauth_info = parse_huggingface_oauth(request)
    if oauth_info is None:
        raise HTTPException(
            status_code=401, detail="Unauthorized. Please Sign in with Hugging Face."
        )
    return oauth_info


# Health check endpoint
@app.get("/api/health")
async def health():
    """Health check endpoint."""
    return {"status": "ok"}


# User endpoints
@app.get("/api/user")
async def get_user(oauth_info: OAuthInfo | None = Depends(oauth_info_optional)):
    """Get user information."""
    return {
        "connected": oauth_info is not None,
        "username": oauth_info.user_info.preferred_username if oauth_info else None,
    }


@app.get("/api/user/count")
async def get_user_count(
    oauth_info: OAuthInfo = Depends(oauth_info_required),
) -> UserCount:
    """Get user count."""
    with get_session() as session:
        statement = select(UserCount).where(UserCount.name == oauth_info.user_info.name)
        user_count = session.exec(statement).first()
        if user_count is None:
            user_count = UserCount(name=oauth_info.user_info.name, count=0)
        return user_count


@app.post("/api/user/count/increment")
async def increment_user_count(
    oauth_info: OAuthInfo = Depends(oauth_info_required),
) -> UserCount:
    """Increment user count."""
    with get_session() as session:
        statement = select(UserCount).where(UserCount.name == oauth_info.user_info.name)
        user_count = session.exec(statement).first()
        if user_count is None:
            user_count = UserCount(name=oauth_info.user_info.name, count=0)

        user_count.count += 1
        session.add(user_count)
        session.commit()
        session.refresh(user_count)
        return user_count