Spaces:
Sleeping
Sleeping
import os | |
from contextlib import asynccontextmanager | |
from logging import getLogger | |
from fastapi import FastAPI | |
from fastapi.responses import FileResponse | |
from fastapi.staticfiles import StaticFiles | |
from sqlalchemy import create_engine | |
from sqlalchemy.orm import Session | |
from sqlalchemy.orm import load_only | |
from models_semantic_search import MetadataDB | |
from models_semantic_search import MetadataFull | |
from models_semantic_search import MetadataPosition | |
from models_semantic_search import SearchRequest | |
from models_semantic_search import SemanticSearchResults | |
from settings import COLLECTION_HYBRID_NAME | |
from settings import SQL_BASE | |
from utils_semantic_search import HybridSearcher | |
from utils_semantic_search import build_year_filter | |
logger = getLogger() | |
client_resources = {} | |
persistent_data = {} | |
async def lifespan(app: FastAPI): | |
# Load the ML model | |
client_resources["hybrid_searcher"] = HybridSearcher(collection_name=COLLECTION_HYBRID_NAME) | |
# 2. Create a database engine (SQLite in this example) | |
engine = create_engine(SQL_BASE) | |
# 1. Create session | |
client_resources["engine_sql"] = engine | |
session = Session(engine) | |
results = session.query(MetadataDB).options( | |
load_only( | |
MetadataDB.x, MetadataDB.y, MetadataDB.cluster, MetadataDB.doi, MetadataDB.title, MetadataDB.year, | |
), | |
).all() | |
cluster_positions = [ | |
MetadataPosition( | |
doi=row.doi, | |
cluster=row.cluster, | |
x=row.x, | |
y=row.y, | |
title=row.title, | |
year=row.year, | |
) | |
for row in results | |
] | |
persistent_data["cluster_positions"] = cluster_positions | |
session.close() | |
yield | |
# Clean up the ML models and release the resources | |
client_resources.clear() | |
app = FastAPI(lifespan=lifespan) | |
# Serve static files for CSS, JS, etc. | |
app.mount("/static", StaticFiles(directory="static"), name="static") | |
# Serve index.html at root URL | |
def read_index(): | |
return FileResponse(os.path.join("static", "public", "index.html")) | |
# get data to build graph | |
# Use the model for prediction | |
async def dummy(): | |
return persistent_data["cluster_positions"] | |
async def get_cluster_positions(): | |
# Use the model for prediction | |
return persistent_data["cluster_positions"] | |
async def get_records_by_ids(input_ids: list[str]): | |
session = Session(client_resources["engine_sql"]) | |
results = ( | |
session.query(MetadataDB) | |
.filter(MetadataDB.doi.in_(input_ids)) | |
.all() | |
) | |
session.close() | |
# Convert SQLAlchemy objects to list of dicts | |
# Convert SQLAlchemy objects to dicts while keeping types | |
response = [ | |
{c.name: getattr(row, c.name) for c in MetadataDB.__table__.columns} | |
for row in results | |
] | |
return response | |
async def semantic_search(request: SearchRequest): | |
hybrid_searcher = client_resources["hybrid_searcher"] | |
results = hybrid_searcher.search( | |
documents=request.input_text, | |
limit=request.limit, | |
limit_dense=request.limit_dense, | |
limit_sparse=request.limit_sparse, | |
score_threshold_dense=request.score_threshold_dense, | |
query_filter=build_year_filter(year_ge=request.min_year, year_le=request.max_year) | |
) | |
# Format results for JSON response | |
return [SemanticSearchResults(doi=point.payload["doi"], score=point.score) for point in results] | |
async def semantic_search_with_metadata(request: SearchRequest): | |
search_results = await semantic_search(request) | |
dois = [result.doi for result in search_results] | |
# TODO: do something with the scores | |
metadata_records = await get_records_by_ids(dois) | |
return metadata_records | |
async def health_check() -> dict[str, str]: | |
""" | |
Health check endpoint. | |
""" | |
return {"status": "healthy"} | |
# launch through python | |
if __name__ == "__main__": | |
import uvicorn | |
uvicorn.run(app, host="0.0.0.0", port=7860) | |