Lazar Radojevic commited on
Commit
d765d3d
·
1 Parent(s): 4095eb7

split main and run

Browse files
Files changed (2) hide show
  1. main.py +49 -0
  2. run.py +1 -54
main.py ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI, HTTPException
2
+ from pydantic import BaseModel
3
+ from typing import List
4
+ from src.search_engine import PromptSearchEngine
5
+ from src.prompt_loader import PromptLoader
6
+
7
+ # Constants
8
+ SEED = 42
9
+ DATA_SIZE = 100
10
+
11
+ # Initialize the prompt loader and search engine
12
+ prompts = PromptLoader(seed=SEED).load_data(size=DATA_SIZE)
13
+ engine = PromptSearchEngine(prompts)
14
+
15
+ # Initialize FastAPI
16
+ app = FastAPI()
17
+
18
+
19
+ # Request and Response Models
20
+ class QueryRequest(BaseModel):
21
+ query: str
22
+ n: int = 5
23
+
24
+
25
+ class SimilarPrompt(BaseModel):
26
+ score: float
27
+ prompt: str
28
+
29
+
30
+ class QueryResponse(BaseModel):
31
+ similar_prompts: List[SimilarPrompt]
32
+
33
+
34
+ # API endpoint
35
+ @app.post("/most_similar", response_model=QueryResponse)
36
+ async def get_most_similar(query_request: QueryRequest):
37
+ try:
38
+ similar_prompts = engine.most_similar(
39
+ query=query_request.query, n=query_request.n
40
+ )
41
+ response = QueryResponse(
42
+ similar_prompts=[
43
+ SimilarPrompt(score=score, prompt=prompt)
44
+ for score, prompt in similar_prompts
45
+ ]
46
+ )
47
+ return response
48
+ except Exception as e:
49
+ raise HTTPException(status_code=500, detail=str(e))
run.py CHANGED
@@ -1,57 +1,4 @@
1
- from fastapi import FastAPI, HTTPException
2
- from pydantic import BaseModel
3
- from typing import List
4
- from src.search_engine import PromptSearchEngine
5
- from src.prompt_loader import PromptLoader
6
- import os
7
 
8
- # Constants
9
- SEED = 42
10
- DATA_SIZE = 100
11
-
12
- # Initialize the prompt loader and search engine
13
- prompts = PromptLoader(seed=SEED).load_data(size=DATA_SIZE)
14
- engine = PromptSearchEngine(prompts)
15
-
16
- # Initialize FastAPI
17
- app = FastAPI()
18
-
19
-
20
- # Request and Response Models
21
- class QueryRequest(BaseModel):
22
- query: str
23
- n: int = 5
24
-
25
-
26
- class SimilarPrompt(BaseModel):
27
- score: float
28
- prompt: str
29
-
30
-
31
- class QueryResponse(BaseModel):
32
- similar_prompts: List[SimilarPrompt]
33
-
34
-
35
- # API endpoint
36
- @app.post("/most_similar", response_model=QueryResponse)
37
- async def get_most_similar(query_request: QueryRequest):
38
- try:
39
- similar_prompts = engine.most_similar(
40
- query=query_request.query, n=query_request.n
41
- )
42
- response = QueryResponse(
43
- similar_prompts=[
44
- SimilarPrompt(score=score, prompt=prompt)
45
- for score, prompt in similar_prompts
46
- ]
47
- )
48
- return response
49
- except Exception as e:
50
- raise HTTPException(status_code=500, detail=str(e))
51
-
52
-
53
- # Run the server with: uvicorn main:app --reload
54
  if __name__ == "__main__":
55
- import uvicorn
56
-
57
  uvicorn.run("main:app", host="0.0.0.0", port=7860, reload=True)
 
1
+ import uvicorn
 
 
 
 
 
2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3
  if __name__ == "__main__":
 
 
4
  uvicorn.run("main:app", host="0.0.0.0", port=7860, reload=True)