File size: 5,908 Bytes
21a86d3
3ca3e6a
 
 
21a86d3
3ca3e6a
 
 
 
21a86d3
 
 
 
 
 
 
 
3ca3e6a
 
 
 
21a86d3
3ca3e6a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21a86d3
3ca3e6a
 
 
 
 
 
 
 
 
 
21a86d3
3ca3e6a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21a86d3
 
 
3ca3e6a
21a86d3
3ca3e6a
 
21a86d3
 
 
 
 
 
 
 
 
 
 
3ca3e6a
 
 
21a86d3
 
3ca3e6a
 
21a86d3
3ca3e6a
 
 
 
 
 
21a86d3
 
 
3ca3e6a
21a86d3
 
 
 
 
 
3ca3e6a
 
 
 
 
 
 
 
 
 
 
21a86d3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3ca3e6a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
import tempfile
from contextlib import asynccontextmanager, contextmanager
from typing import Annotated, Generator

from apscheduler.schedulers.background import BackgroundScheduler
from fastapi import Depends, FastAPI, HTTPException, Request
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import FileResponse, RedirectResponse
from fastapi.staticfiles import StaticFiles
from huggingface_hub import (
    OAuthInfo,
    attach_huggingface_oauth,
    create_repo,
    parse_huggingface_oauth,
    snapshot_download,
    upload_folder,
)
from sqlalchemy.engine import Engine
from sqlmodel import Session, SQLModel, create_engine

from . import constants
from .parquet import export_to_parquet, import_from_parquet

_ENGINE_SINGLETON: Engine | None = None

# OAuth utilities


async def _oauth_info_optional(request: Request) -> OAuthInfo | None:
    return parse_huggingface_oauth(request)


async def _oauth_info_required(request: Request) -> OAuthInfo:
    oauth_info = parse_huggingface_oauth(request)
    if oauth_info is None:
        raise HTTPException(
            status_code=401, detail="Unauthorized. Please Sign in with Hugging Face."
        )
    return oauth_info


OptionalOAuth = Annotated[OAuthInfo | None, Depends(_oauth_info_optional)]
RequiredOAuth = Annotated[OAuthInfo, Depends(_oauth_info_required)]


def get_engine() -> Engine:
    """Get the engine."""
    global _ENGINE_SINGLETON
    if _ENGINE_SINGLETON is None:
        _ENGINE_SINGLETON = create_engine(constants.DATABASE_URL)
    return _ENGINE_SINGLETON


@contextmanager
def get_session() -> Generator[Session, None, None]:
    """Get a session from the engine."""
    engine = get_engine()
    with Session(engine) as session:
        yield session


@asynccontextmanager
async def _database_lifespan(app: FastAPI):
    """Handle database lifespan.

    1. If backup is enabled enabled,
       a. Try to load backup from remote dataset. If it fails, delete local database for a fresh start.
       b. Start back-up scheduler.
    2. If disabled, create local database file or reuse existing one.
    3. Initialize database.
    4. Yield control to FastAPI app.
    5. Close database + force push backup to remote dataset.
    """
    scheduler = BackgroundScheduler()
    engine = get_engine()
    SQLModel.metadata.create_all(engine)

    if constants.BACKUP_DB:
        print("Back-up database is enabled")

        # Create remote dataset if it doesn't exist
        repo_url = create_repo(
            repo_id=constants.BACKUP_DATASET_ID,  # type: ignore[arg-type]
            repo_type="dataset",
            token=constants.HF_TOKEN,
            private=True,
            exist_ok=True,
        )
        print(f"Backup dataset: {repo_url}")
        repo_id = repo_url.repo_id

        # Try to load backup from remote dataset
        print("Trying to load backup from remote dataset...")
        try:
            backup_dir = snapshot_download(
                repo_id=repo_id,
                repo_type="dataset",
                token=constants.HF_TOKEN,
                allow_patterns="*.parquet",
            )
        except Exception:
            # If backup is enabled but no backup is found, delete local database to prevent confusion.
            print("Couldn't find backup in remote dataset.")
            print("Deleting local database for a fresh start.")

            engine = get_engine()
            SQLModel.metadata.drop_all(engine)
            SQLModel.metadata.create_all(engine)

        # Import parquet files to database
        import_from_parquet(get_engine(), backup_dir)

        # Start back-up scheduler
        scheduler.add_job(_backup_to_hub, args=[repo_id], trigger="interval", minutes=5)
        scheduler.start()

    yield

    print("Closing database...")
    global _ENGINE_SINGLETON
    if _ENGINE_SINGLETON is not None:
        _ENGINE_SINGLETON.dispose()
        _ENGINE_SINGLETON = None

    if constants.BACKUP_DB:
        print("Pushing backup to remote dataset...")
        _backup_to_hub(repo_id)


def _backup_to_hub(repo_id: str) -> None:
    """Export backup to remote dataset as parquet files."""
    with tempfile.TemporaryDirectory() as tmp_dir:
        export_to_parquet(get_engine(), tmp_dir)

        upload_folder(
            repo_id=repo_id,
            folder_path=tmp_dir,
            token=constants.HF_TOKEN,
            repo_type="dataset",
            allow_patterns="*.parquet",
            commit_message="Backup database as parquet",
            delete_patterns=["*.parquet"],
        )


def create_app() -> FastAPI:
    # FastAPI app
    app = FastAPI(lifespan=_database_lifespan)

    # Set CORS headers
    app.add_middleware(
        CORSMiddleware,
        allow_origins=[
            # Can't use "*" because frontend doesn't like it with "credentials: true"
            "http://localhost:5173",
            "http://0.0.0.0:9481",
            "http://localhost:9481",
            "http://127.0.0.1:9481",
        ],
        allow_credentials=True,
        allow_methods=["*"],
        allow_headers=["*"],
    )

    # Mount frontend from dist directory (if configured)
    if constants.SERVE_FRONTEND:
        # Production => Serve frontend from dist directory
        app.mount(
            "/assets",
            StaticFiles(directory=constants.FRONTEND_ASSETS_PATH),  # type: ignore[invalid-argument-type]
            name="assets",
        )

        @app.get("/")
        async def serve_frontend():
            return FileResponse(constants.FRONTEND_INDEX_PATH)  # type: ignore[invalid-argument-type]

    else:
        # Development => Redirect to dev frontend
        @app.get("/")
        async def redirect_to_frontend():
            return RedirectResponse("http://localhost:5173/")

    # Set up Hugging Face OAuth
    # To get OAuthInfo in an endpoint
    attach_huggingface_oauth(app)

    return app