CHI-tography / main_fast.py
ocantocarlos's picture
feat: make dense search the default one
71d5692
import json
import os
from contextlib import asynccontextmanager
from logging import getLogger
from fastapi import FastAPI
from fastapi.responses import FileResponse
from fastapi.staticfiles import StaticFiles
from sqlalchemy import create_engine
from sqlalchemy.orm import Session
from sqlalchemy.orm import load_only
from models_semantic_search import MetadataDB
from models_semantic_search import MetadataFull
from models_semantic_search import MetadataPosition
from models_semantic_search import SearchRequestVector
from models_semantic_search import SearchRequestHybrid
from models_semantic_search import SemanticSearchResults
from models_semantic_search import VectorType
from settings import COLLECTION_HYBRID_NAME
from settings import SQL_BASE
from utils_semantic_search import HybridSearcher
from utils_semantic_search import build_year_filter
from qdrant_client import models
logger = getLogger()
client_resources = {}
persistent_data = {}
@asynccontextmanager
async def lifespan(app: FastAPI):
# Load the ML model
client_resources["hybrid_searcher"] = HybridSearcher(collection_name=COLLECTION_HYBRID_NAME)
# 2. Create a database engine (SQLite in this example)
engine = create_engine(SQL_BASE)
# TODO. make async
with open("static/public/js/cartography/clusterInfo.json", "r") as f:
client_resources["cluster_metadata"] = json.load(f)
client_resources["engine_sql"] = engine
session = Session(engine)
results = session.query(MetadataDB).options(
load_only(
MetadataDB.x,
MetadataDB.y,
MetadataDB.cluster,
MetadataDB.doi,
MetadataDB.title,
MetadataDB.year,
MetadataDB.abstract,
),
).all()
cluster_positions = [
MetadataPosition(
doi=row.doi,
cluster=row.cluster,
x=row.x,
y=row.y,
title=row.title,
year=row.year,
abstract=row.abstract,
)
for row in results
]
persistent_data["cluster_positions"] = cluster_positions
session.close()
yield
# Clean up the ML models and release the resources
client_resources.clear()
app = FastAPI(lifespan=lifespan)
# Serve static files for CSS, JS, etc.
app.mount("/static", StaticFiles(directory="static"), name="static")
# Serve index.html at root URL
@app.get("/")
def read_index():
return FileResponse(os.path.join("static", "public", "index.html"))
# get data to build graph
@app.get("/get_cluster_metadata")
# Use the model for prediction
async def get_cluster_metadata():
return client_resources["cluster_metadata"]
@app.get("/cluster_positions", response_model=list[MetadataPosition])
async def get_cluster_positions():
# Use the model for prediction
return persistent_data["cluster_positions"]
@app.get("/records_by_ids", response_model=list[MetadataFull])
async def get_records_by_ids(input_ids: list[str]):
session = Session(client_resources["engine_sql"])
results = (
session.query(MetadataDB)
.filter(MetadataDB.doi.in_(input_ids))
.all()
)
session.close()
# Convert SQLAlchemy objects to list of dicts
# Convert SQLAlchemy objects to dicts while keeping types
response = [
{c.name: getattr(row, c.name) for c in MetadataDB.__table__.columns}
for row in results
]
return response
@app.post("/semantic_search", response_model=list[SemanticSearchResults])
async def semantic_search_hybrid(request: SearchRequestHybrid):
hybrid_searcher = client_resources["hybrid_searcher"]
results = hybrid_searcher.search(
documents=request.input_text,
limit=request.limit,
limit_dense=request.limit_dense,
limit_sparse=request.limit_sparse,
score_threshold_dense=request.score_threshold_dense,
query_filter=build_year_filter(year_ge=request.min_year, year_le=request.max_year)
)
# Format results for JSON response
return [SemanticSearchResults(doi=point.payload["doi"], score=point.score) for point in results]
@app.post("/semantic_search_vector", response_model=list[SemanticSearchResults])
async def semantic_search_vector(request: SearchRequestVector):
score_th = request.score_threshold_dense if request.vector_type==VectorType.dense else None
hybrid_searcher = client_resources["hybrid_searcher"]
match request.vector_type:
case VectorType.sparse:
sparse = hybrid_searcher.embed_sparse(request.input_text)
query = models.SparseVector(
indices=sparse.indices.tolist(),
values=sparse.values.tolist(),
)
case VectorType.dense:
query = hybrid_searcher.embed_dense(request.input_text)
case _:
raise ValueError(f"Unsupported embedding type {request.vector_type}")
results = hybrid_searcher.client.query_points(
collection_name=COLLECTION_HYBRID_NAME,
query=query,
limit=request.limit,
using=request.vector_type,
score_threshold=score_th,
query_filter=build_year_filter(year_ge=request.min_year, year_le=request.max_year)
)
# Format results for JSON response
return [SemanticSearchResults(doi=point.payload["doi"], score=point.score) for point in results.points]
@app.post("/semantic_search_with_metadata", response_model=list[MetadataFull])
async def semantic_search_with_metadata_dense(request: SearchRequestVector):
search_results = await semantic_search_vector(request)
dois = [result.doi for result in search_results]
# TODO: do something with the scores
metadata_records = await get_records_by_ids(dois)
return metadata_records
@app.post("/semantic_search_with_metadata_hybrid", response_model=list[MetadataFull])
async def semantic_search_with_metadata_hybrid(request: SearchRequestHybrid):
search_results = await semantic_search_hybrid(request)
dois = [result.doi for result in search_results]
# TODO: do something with the scores
metadata_records = await get_records_by_ids(dois)
return metadata_records
@app.get("/health")
async def health_check() -> dict[str, str]:
"""
Health check endpoint.
"""
return {"status": "healthy"}
# launch through python
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=7860)