Spaces:
Running
Running
File size: 6,416 Bytes
d2b1491 5e18ceb 9e2a8ba 5e18ceb 9e2a8ba 5e18ceb 9e2a8ba d2b1491 9e2a8ba d2b1491 9e2a8ba d2b1491 9e2a8ba 723b4bb d2b1491 723b4bb 9e2a8ba c09c66e 9e2a8ba c09c66e 9e2a8ba 5e18ceb 54d5b9f 9e2a8ba d2b1491 5e18ceb 26c372a 723b4bb 9e2a8ba d2b1491 9e2a8ba d2b1491 71d5692 d2b1491 71d5692 d2b1491 9e2a8ba 35a33f4 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 |
import json
import os
from contextlib import asynccontextmanager
from logging import getLogger
from fastapi import FastAPI
from fastapi.responses import FileResponse
from fastapi.staticfiles import StaticFiles
from sqlalchemy import create_engine
from sqlalchemy.orm import Session
from sqlalchemy.orm import load_only
from models_semantic_search import MetadataDB
from models_semantic_search import MetadataFull
from models_semantic_search import MetadataPosition
from models_semantic_search import SearchRequestVector
from models_semantic_search import SearchRequestHybrid
from models_semantic_search import SemanticSearchResults
from models_semantic_search import VectorType
from settings import COLLECTION_HYBRID_NAME
from settings import SQL_BASE
from utils_semantic_search import HybridSearcher
from utils_semantic_search import build_year_filter
from qdrant_client import models
logger = getLogger()
client_resources = {}
persistent_data = {}
@asynccontextmanager
async def lifespan(app: FastAPI):
# Load the ML model
client_resources["hybrid_searcher"] = HybridSearcher(collection_name=COLLECTION_HYBRID_NAME)
# 2. Create a database engine (SQLite in this example)
engine = create_engine(SQL_BASE)
# TODO. make async
with open("static/public/js/cartography/clusterInfo.json", "r") as f:
client_resources["cluster_metadata"] = json.load(f)
client_resources["engine_sql"] = engine
session = Session(engine)
results = session.query(MetadataDB).options(
load_only(
MetadataDB.x,
MetadataDB.y,
MetadataDB.cluster,
MetadataDB.doi,
MetadataDB.title,
MetadataDB.year,
MetadataDB.abstract,
),
).all()
cluster_positions = [
MetadataPosition(
doi=row.doi,
cluster=row.cluster,
x=row.x,
y=row.y,
title=row.title,
year=row.year,
abstract=row.abstract,
)
for row in results
]
persistent_data["cluster_positions"] = cluster_positions
session.close()
yield
# Clean up the ML models and release the resources
client_resources.clear()
app = FastAPI(lifespan=lifespan)
# Serve static files for CSS, JS, etc.
app.mount("/static", StaticFiles(directory="static"), name="static")
# Serve index.html at root URL
@app.get("/")
def read_index():
return FileResponse(os.path.join("static", "public", "index.html"))
# get data to build graph
@app.get("/get_cluster_metadata")
# Use the model for prediction
async def get_cluster_metadata():
return client_resources["cluster_metadata"]
@app.get("/cluster_positions", response_model=list[MetadataPosition])
async def get_cluster_positions():
# Use the model for prediction
return persistent_data["cluster_positions"]
@app.get("/records_by_ids", response_model=list[MetadataFull])
async def get_records_by_ids(input_ids: list[str]):
session = Session(client_resources["engine_sql"])
results = (
session.query(MetadataDB)
.filter(MetadataDB.doi.in_(input_ids))
.all()
)
session.close()
# Convert SQLAlchemy objects to list of dicts
# Convert SQLAlchemy objects to dicts while keeping types
response = [
{c.name: getattr(row, c.name) for c in MetadataDB.__table__.columns}
for row in results
]
return response
@app.post("/semantic_search", response_model=list[SemanticSearchResults])
async def semantic_search_hybrid(request: SearchRequestHybrid):
hybrid_searcher = client_resources["hybrid_searcher"]
results = hybrid_searcher.search(
documents=request.input_text,
limit=request.limit,
limit_dense=request.limit_dense,
limit_sparse=request.limit_sparse,
score_threshold_dense=request.score_threshold_dense,
query_filter=build_year_filter(year_ge=request.min_year, year_le=request.max_year)
)
# Format results for JSON response
return [SemanticSearchResults(doi=point.payload["doi"], score=point.score) for point in results]
@app.post("/semantic_search_vector", response_model=list[SemanticSearchResults])
async def semantic_search_vector(request: SearchRequestVector):
score_th = request.score_threshold_dense if request.vector_type==VectorType.dense else None
hybrid_searcher = client_resources["hybrid_searcher"]
match request.vector_type:
case VectorType.sparse:
sparse = hybrid_searcher.embed_sparse(request.input_text)
query = models.SparseVector(
indices=sparse.indices.tolist(),
values=sparse.values.tolist(),
)
case VectorType.dense:
query = hybrid_searcher.embed_dense(request.input_text)
case _:
raise ValueError(f"Unsupported embedding type {request.vector_type}")
results = hybrid_searcher.client.query_points(
collection_name=COLLECTION_HYBRID_NAME,
query=query,
limit=request.limit,
using=request.vector_type,
score_threshold=score_th,
query_filter=build_year_filter(year_ge=request.min_year, year_le=request.max_year)
)
# Format results for JSON response
return [SemanticSearchResults(doi=point.payload["doi"], score=point.score) for point in results.points]
@app.post("/semantic_search_with_metadata", response_model=list[MetadataFull])
async def semantic_search_with_metadata_dense(request: SearchRequestVector):
search_results = await semantic_search_vector(request)
dois = [result.doi for result in search_results]
# TODO: do something with the scores
metadata_records = await get_records_by_ids(dois)
return metadata_records
@app.post("/semantic_search_with_metadata_hybrid", response_model=list[MetadataFull])
async def semantic_search_with_metadata_hybrid(request: SearchRequestHybrid):
search_results = await semantic_search_hybrid(request)
dois = [result.doi for result in search_results]
# TODO: do something with the scores
metadata_records = await get_records_by_ids(dois)
return metadata_records
@app.get("/health")
async def health_check() -> dict[str, str]:
"""
Health check endpoint.
"""
return {"status": "healthy"}
# launch through python
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=7860)
|