CHI-tography / main_fast.py
ocantocarlos's picture
wip3
54d5b9f
import os
from contextlib import asynccontextmanager
from logging import getLogger
from fastapi import FastAPI
from fastapi.responses import FileResponse
from fastapi.staticfiles import StaticFiles
from sqlalchemy import create_engine
from sqlalchemy.orm import Session
from sqlalchemy.orm import load_only
from models_semantic_search import MetadataDB
from models_semantic_search import MetadataFull
from models_semantic_search import MetadataPosition
from models_semantic_search import SearchRequest
from models_semantic_search import SemanticSearchResults
from settings import COLLECTION_HYBRID_NAME
from settings import SQL_BASE
from utils_semantic_search import HybridSearcher
from utils_semantic_search import build_year_filter
logger = getLogger()
client_resources = {}
persistent_data = {}
@asynccontextmanager
async def lifespan(app: FastAPI):
# Load the ML model
client_resources["hybrid_searcher"] = HybridSearcher(collection_name=COLLECTION_HYBRID_NAME)
# 2. Create a database engine (SQLite in this example)
engine = create_engine(SQL_BASE)
# 1. Create session
client_resources["engine_sql"] = engine
session = Session(engine)
results = session.query(MetadataDB).options(
load_only(
MetadataDB.x, MetadataDB.y, MetadataDB.cluster, MetadataDB.doi, MetadataDB.title, MetadataDB.year,
),
).all()
cluster_positions = [
MetadataPosition(
doi=row.doi,
cluster=row.cluster,
x=row.x,
y=row.y,
title=row.title,
year=row.year,
)
for row in results
]
persistent_data["cluster_positions"] = cluster_positions
session.close()
yield
# Clean up the ML models and release the resources
client_resources.clear()
app = FastAPI(lifespan=lifespan)
# Serve static files for CSS, JS, etc.
app.mount("/static", StaticFiles(directory="static"), name="static")
# Serve index.html at root URL
@app.get("/")
def read_index():
return FileResponse(os.path.join("static", "public", "index.html"))
# get data to build graph
@app.get("/api/cartography-data")
# Use the model for prediction
async def dummy():
return persistent_data["cluster_positions"]
@app.get("/cluster_positions", response_model=list[MetadataPosition])
async def get_cluster_positions():
# Use the model for prediction
return persistent_data["cluster_positions"]
@app.get("/records_by_ids", response_model=list[MetadataFull])
async def get_records_by_ids(input_ids: list[str]):
session = Session(client_resources["engine_sql"])
results = (
session.query(MetadataDB)
.filter(MetadataDB.doi.in_(input_ids))
.all()
)
session.close()
# Convert SQLAlchemy objects to list of dicts
# Convert SQLAlchemy objects to dicts while keeping types
response = [
{c.name: getattr(row, c.name) for c in MetadataDB.__table__.columns}
for row in results
]
return response
@app.post("/semantic_search", response_model=list[SemanticSearchResults])
async def semantic_search(request: SearchRequest):
hybrid_searcher = client_resources["hybrid_searcher"]
results = hybrid_searcher.search(
documents=request.input_text,
limit=request.limit,
limit_dense=request.limit_dense,
limit_sparse=request.limit_sparse,
score_threshold_dense=request.score_threshold_dense,
query_filter=build_year_filter(year_ge=request.min_year, year_le=request.max_year)
)
# Format results for JSON response
return [SemanticSearchResults(doi=point.payload["doi"], score=point.score) for point in results]
@app.post("/semantic_search_with_metadata", response_model=list[MetadataFull])
async def semantic_search_with_metadata(request: SearchRequest):
search_results = await semantic_search(request)
dois = [result.doi for result in search_results]
# TODO: do something with the scores
metadata_records = await get_records_by_ids(dois)
return metadata_records
@app.get("/health")
async def health_check() -> dict[str, str]:
"""
Health check endpoint.
"""
return {"status": "healthy"}
# launch through python
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=7860)