File size: 6,416 Bytes
d2b1491
5e18ceb
9e2a8ba
5e18ceb
9e2a8ba
 
5e18ceb
 
9e2a8ba
 
 
 
 
 
 
d2b1491
 
9e2a8ba
d2b1491
9e2a8ba
 
 
 
d2b1491
 
9e2a8ba
 
 
 
 
 
 
 
 
 
 
 
 
723b4bb
d2b1491
723b4bb
9e2a8ba
 
 
 
 
c09c66e
 
 
 
 
 
 
9e2a8ba
 
 
 
 
 
 
 
 
 
c09c66e
9e2a8ba
 
 
 
 
 
 
 
 
 
5e18ceb
54d5b9f
9e2a8ba
d2b1491
5e18ceb
 
 
26c372a
 
723b4bb
 
 
 
 
 
 
 
9e2a8ba
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d2b1491
9e2a8ba
 
 
 
 
 
 
 
 
 
 
 
 
d2b1491
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
71d5692
d2b1491
 
 
 
 
 
 
 
71d5692
d2b1491
 
9e2a8ba
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
35a33f4
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
import json
import os
from contextlib import asynccontextmanager
from logging import getLogger

from fastapi import FastAPI
from fastapi.responses import FileResponse
from fastapi.staticfiles import StaticFiles
from sqlalchemy import create_engine
from sqlalchemy.orm import Session
from sqlalchemy.orm import load_only

from models_semantic_search import MetadataDB
from models_semantic_search import MetadataFull
from models_semantic_search import MetadataPosition
from models_semantic_search import SearchRequestVector
from models_semantic_search import SearchRequestHybrid
from models_semantic_search import SemanticSearchResults
from models_semantic_search import VectorType
from settings import COLLECTION_HYBRID_NAME
from settings import SQL_BASE
from utils_semantic_search import HybridSearcher
from utils_semantic_search import build_year_filter
from qdrant_client import models

logger = getLogger()

client_resources = {}
persistent_data = {}


@asynccontextmanager
async def lifespan(app: FastAPI):
    # Load the ML model
    client_resources["hybrid_searcher"] = HybridSearcher(collection_name=COLLECTION_HYBRID_NAME)
    # 2. Create a database engine (SQLite in this example)
    engine = create_engine(SQL_BASE)

    # TODO. make async
    with open("static/public/js/cartography/clusterInfo.json", "r") as f:
        client_resources["cluster_metadata"] = json.load(f)

    client_resources["engine_sql"] = engine
    session = Session(engine)
    results = session.query(MetadataDB).options(
        load_only(
            MetadataDB.x,
            MetadataDB.y,
            MetadataDB.cluster,
            MetadataDB.doi,
            MetadataDB.title,
            MetadataDB.year,
            MetadataDB.abstract,
        ),
    ).all()
    cluster_positions = [
        MetadataPosition(
            doi=row.doi,
            cluster=row.cluster,
            x=row.x,
            y=row.y,
            title=row.title,
            year=row.year,
            abstract=row.abstract,
        )
        for row in results
    ]
    persistent_data["cluster_positions"] = cluster_positions
    session.close()
    yield
    # Clean up the ML models and release the resources
    client_resources.clear()

app = FastAPI(lifespan=lifespan)
# Serve static files for CSS, JS, etc.
app.mount("/static", StaticFiles(directory="static"), name="static")


# Serve index.html at root URL
@app.get("/")
def read_index():
    return FileResponse(os.path.join("static", "public", "index.html"))


# get data to build graph
@app.get("/get_cluster_metadata")
# Use the model for prediction
async def get_cluster_metadata():
    return client_resources["cluster_metadata"]


@app.get("/cluster_positions", response_model=list[MetadataPosition])
async def get_cluster_positions():
    # Use the model for prediction
    return persistent_data["cluster_positions"]


@app.get("/records_by_ids", response_model=list[MetadataFull])
async def get_records_by_ids(input_ids: list[str]):
    session = Session(client_resources["engine_sql"])
    results = (
        session.query(MetadataDB)
        .filter(MetadataDB.doi.in_(input_ids))
        .all()
    )
    session.close()
    # Convert SQLAlchemy objects to list of dicts
    # Convert SQLAlchemy objects to dicts while keeping types
    response = [
        {c.name: getattr(row, c.name) for c in MetadataDB.__table__.columns}
        for row in results
    ]
    return response


@app.post("/semantic_search", response_model=list[SemanticSearchResults])
async def semantic_search_hybrid(request: SearchRequestHybrid):
    hybrid_searcher = client_resources["hybrid_searcher"]
    results = hybrid_searcher.search(
        documents=request.input_text,
        limit=request.limit,
        limit_dense=request.limit_dense,
        limit_sparse=request.limit_sparse,
        score_threshold_dense=request.score_threshold_dense,
        query_filter=build_year_filter(year_ge=request.min_year, year_le=request.max_year)
    )
    # Format results for JSON response
    return [SemanticSearchResults(doi=point.payload["doi"], score=point.score) for point in results]


@app.post("/semantic_search_vector", response_model=list[SemanticSearchResults])
async def semantic_search_vector(request: SearchRequestVector):
    score_th = request.score_threshold_dense if request.vector_type==VectorType.dense else None
    hybrid_searcher = client_resources["hybrid_searcher"]
    match request.vector_type:
        case VectorType.sparse:
            sparse = hybrid_searcher.embed_sparse(request.input_text)
            query = models.SparseVector(
                indices=sparse.indices.tolist(),
                values=sparse.values.tolist(),
            )
        case VectorType.dense:
            query = hybrid_searcher.embed_dense(request.input_text)
        case _:
            raise ValueError(f"Unsupported embedding type {request.vector_type}")

    results = hybrid_searcher.client.query_points(
        collection_name=COLLECTION_HYBRID_NAME,
        query=query,
        limit=request.limit,
        using=request.vector_type,
        score_threshold=score_th,
        query_filter=build_year_filter(year_ge=request.min_year, year_le=request.max_year)
    )
    # Format results for JSON response
    return [SemanticSearchResults(doi=point.payload["doi"], score=point.score) for point in results.points]


@app.post("/semantic_search_with_metadata", response_model=list[MetadataFull])
async def semantic_search_with_metadata_dense(request: SearchRequestVector):
    search_results = await semantic_search_vector(request)
    dois = [result.doi for result in search_results]
    # TODO: do something with the scores
    metadata_records = await get_records_by_ids(dois)
    return metadata_records


@app.post("/semantic_search_with_metadata_hybrid", response_model=list[MetadataFull])
async def semantic_search_with_metadata_hybrid(request: SearchRequestHybrid):
    search_results = await semantic_search_hybrid(request)
    dois = [result.doi for result in search_results]
    # TODO: do something with the scores
    metadata_records = await get_records_by_ids(dois)
    return metadata_records


@app.get("/health")
async def health_check() -> dict[str, str]:
    """
    Health check endpoint.
    """
    return {"status": "healthy"}


# launch through python
if __name__ == "__main__":
    import uvicorn

    uvicorn.run(app, host="0.0.0.0", port=7860)