File size: 1,212 Bytes
7af929b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from typing import List
from src.search_engine import PromptSearchEngine
from src.prompt_loader import PromptLoader

# Constants
SEED = 42
DATA_SIZE = 100

# Initialize the prompt loader and search engine
prompts = PromptLoader(seed=SEED).load_data(size=DATA_SIZE)
engine = PromptSearchEngine(prompts)

# Initialize FastAPI
app = FastAPI()


# Request and Response Models
class QueryRequest(BaseModel):
    query: str
    n: int = 5


class SimilarPrompt(BaseModel):
    score: float
    prompt: str


class QueryResponse(BaseModel):
    similar_prompts: List[SimilarPrompt]


# API endpoint
@app.post("/most_similar", response_model=QueryResponse)
async def get_most_similar(query_request: QueryRequest):
    try:
        similar_prompts = engine.most_similar(
            query=query_request.query, n=query_request.n
        )
        response = QueryResponse(
            similar_prompts=[
                SimilarPrompt(score=score, prompt=prompt)
                for score, prompt in similar_prompts
            ]
        )
        return response
    except Exception as e:
        raise HTTPException(status_code=500, detail=str(e))