Spaces:
Sleeping
Sleeping
Lazar Radojevic
commited on
Commit
·
d765d3d
1
Parent(s):
4095eb7
split main and run
Browse files
main.py
ADDED
@@ -0,0 +1,49 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from fastapi import FastAPI, HTTPException
|
2 |
+
from pydantic import BaseModel
|
3 |
+
from typing import List
|
4 |
+
from src.search_engine import PromptSearchEngine
|
5 |
+
from src.prompt_loader import PromptLoader
|
6 |
+
|
7 |
+
# Constants
|
8 |
+
SEED = 42
|
9 |
+
DATA_SIZE = 100
|
10 |
+
|
11 |
+
# Initialize the prompt loader and search engine
|
12 |
+
prompts = PromptLoader(seed=SEED).load_data(size=DATA_SIZE)
|
13 |
+
engine = PromptSearchEngine(prompts)
|
14 |
+
|
15 |
+
# Initialize FastAPI
|
16 |
+
app = FastAPI()
|
17 |
+
|
18 |
+
|
19 |
+
# Request and Response Models
|
20 |
+
class QueryRequest(BaseModel):
|
21 |
+
query: str
|
22 |
+
n: int = 5
|
23 |
+
|
24 |
+
|
25 |
+
class SimilarPrompt(BaseModel):
|
26 |
+
score: float
|
27 |
+
prompt: str
|
28 |
+
|
29 |
+
|
30 |
+
class QueryResponse(BaseModel):
|
31 |
+
similar_prompts: List[SimilarPrompt]
|
32 |
+
|
33 |
+
|
34 |
+
# API endpoint
|
35 |
+
@app.post("/most_similar", response_model=QueryResponse)
|
36 |
+
async def get_most_similar(query_request: QueryRequest):
|
37 |
+
try:
|
38 |
+
similar_prompts = engine.most_similar(
|
39 |
+
query=query_request.query, n=query_request.n
|
40 |
+
)
|
41 |
+
response = QueryResponse(
|
42 |
+
similar_prompts=[
|
43 |
+
SimilarPrompt(score=score, prompt=prompt)
|
44 |
+
for score, prompt in similar_prompts
|
45 |
+
]
|
46 |
+
)
|
47 |
+
return response
|
48 |
+
except Exception as e:
|
49 |
+
raise HTTPException(status_code=500, detail=str(e))
|
run.py
CHANGED
@@ -1,57 +1,4 @@
|
|
1 |
-
|
2 |
-
from pydantic import BaseModel
|
3 |
-
from typing import List
|
4 |
-
from src.search_engine import PromptSearchEngine
|
5 |
-
from src.prompt_loader import PromptLoader
|
6 |
-
import os
|
7 |
|
8 |
-
# Constants
|
9 |
-
SEED = 42
|
10 |
-
DATA_SIZE = 100
|
11 |
-
|
12 |
-
# Initialize the prompt loader and search engine
|
13 |
-
prompts = PromptLoader(seed=SEED).load_data(size=DATA_SIZE)
|
14 |
-
engine = PromptSearchEngine(prompts)
|
15 |
-
|
16 |
-
# Initialize FastAPI
|
17 |
-
app = FastAPI()
|
18 |
-
|
19 |
-
|
20 |
-
# Request and Response Models
|
21 |
-
class QueryRequest(BaseModel):
|
22 |
-
query: str
|
23 |
-
n: int = 5
|
24 |
-
|
25 |
-
|
26 |
-
class SimilarPrompt(BaseModel):
|
27 |
-
score: float
|
28 |
-
prompt: str
|
29 |
-
|
30 |
-
|
31 |
-
class QueryResponse(BaseModel):
|
32 |
-
similar_prompts: List[SimilarPrompt]
|
33 |
-
|
34 |
-
|
35 |
-
# API endpoint
|
36 |
-
@app.post("/most_similar", response_model=QueryResponse)
|
37 |
-
async def get_most_similar(query_request: QueryRequest):
|
38 |
-
try:
|
39 |
-
similar_prompts = engine.most_similar(
|
40 |
-
query=query_request.query, n=query_request.n
|
41 |
-
)
|
42 |
-
response = QueryResponse(
|
43 |
-
similar_prompts=[
|
44 |
-
SimilarPrompt(score=score, prompt=prompt)
|
45 |
-
for score, prompt in similar_prompts
|
46 |
-
]
|
47 |
-
)
|
48 |
-
return response
|
49 |
-
except Exception as e:
|
50 |
-
raise HTTPException(status_code=500, detail=str(e))
|
51 |
-
|
52 |
-
|
53 |
-
# Run the server with: uvicorn main:app --reload
|
54 |
if __name__ == "__main__":
|
55 |
-
import uvicorn
|
56 |
-
|
57 |
uvicorn.run("main:app", host="0.0.0.0", port=7860, reload=True)
|
|
|
1 |
+
import uvicorn
|
|
|
|
|
|
|
|
|
|
|
2 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
3 |
if __name__ == "__main__":
|
|
|
|
|
4 |
uvicorn.run("main:app", host="0.0.0.0", port=7860, reload=True)
|