Lazar Radojevic commited on
Commit
1cd5053
·
1 Parent(s): 3556e6f

refactor everything

Browse files
.env ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ API_URL="https://lazarr19-prompt-engine.hf.space"
2
+ SEED=42
3
+ DATASET_SIZE=1000
README.md CHANGED
@@ -43,6 +43,10 @@ Before you start, ensure you have the following tools installed:
43
  poetry shell
44
  ```
45
 
 
 
 
 
46
  ### Backend Details
47
 
48
  The backend API provides the following endpoint:
@@ -76,6 +80,8 @@ In case you want to set up your own HuggingFace Space, you must create a HF toke
76
  git remote set-url origin https://USERNAME:[email protected]/spaces/USERNAME/REPO_NAME.git
77
  ```
78
 
 
 
79
  ### User Interface
80
 
81
  The frontend UI is simple and includes:
@@ -86,7 +92,7 @@ The frontend UI is simple and includes:
86
  To start only the UI service you can run:
87
 
88
  ```bash
89
- poe frontend --api_url http://localhost:8000
90
  ```
91
 
92
- The default api points to HF Space of this repository.
 
43
  poetry shell
44
  ```
45
 
46
+ ### Environment Variables
47
+
48
+ Environment variables are set in the .env file in the root of the repo.
49
+
50
  ### Backend Details
51
 
52
  The backend API provides the following endpoint:
 
80
  git remote set-url origin https://USERNAME:[email protected]/spaces/USERNAME/REPO_NAME.git
81
  ```
82
 
83
+ Also, pay attention to the HF Space parameters at the beginning of this README.md.
84
+
85
  ### User Interface
86
 
87
  The frontend UI is simple and includes:
 
92
  To start only the UI service you can run:
93
 
94
  ```bash
95
+ poe frontend
96
  ```
97
 
98
+ Or you can do it with `https://lazarr19-prompt-engine.hf.space` if you prefer to hit the HF Space endpoint.
backend/__init__.py ADDED
File without changes
backend/main.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI
2
+
3
+ from backend.routes import router
4
+
5
+ # Initialize FastAPI
6
+ app = FastAPI()
7
+
8
+ # Include routes from the routes module
9
+ app.include_router(router)
backend/models.py ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List
2
+
3
+ from pydantic import BaseModel
4
+
5
+
6
+ class QueryRequest(BaseModel):
7
+ """
8
+ Represents the request model for querying similar prompts.
9
+ """
10
+
11
+ query: str
12
+ n: int = 5
13
+
14
+
15
+ class SimilarPrompt(BaseModel):
16
+ """
17
+ Represents a single similar prompt with its similarity score.
18
+ """
19
+
20
+ score: float
21
+ prompt: str
22
+
23
+
24
+ class QueryResponse(BaseModel):
25
+ """
26
+ Represents the response model containing a list of similar prompts.
27
+ """
28
+
29
+ similar_prompts: List[SimilarPrompt]
main.py → backend/routes.py RENAMED
@@ -1,42 +1,38 @@
1
- from typing import List
2
 
3
- from fastapi import FastAPI, HTTPException
4
  from fastapi.responses import HTMLResponse
5
- from pydantic import BaseModel
6
 
 
7
  from src.prompt_loader import PromptLoader
8
  from src.search_engine import PromptSearchEngine
9
 
10
  # Constants
11
- SEED = 42
12
- DATA_SIZE = 100
13
 
14
  # Initialize the prompt loader and search engine
15
- prompts = PromptLoader(seed=SEED).load_data(size=DATA_SIZE)
16
  engine = PromptSearchEngine(prompts)
17
 
18
- # Initialize FastAPI
19
- app = FastAPI()
20
 
21
 
22
- # Request and Response Models
23
- class QueryRequest(BaseModel):
24
- query: str
25
- n: int = 5
26
 
 
 
27
 
28
- class SimilarPrompt(BaseModel):
29
- score: float
30
- prompt: str
31
 
32
-
33
- class QueryResponse(BaseModel):
34
- similar_prompts: List[SimilarPrompt]
35
-
36
-
37
- # API endpoint
38
- @app.post("/most_similar", response_model=QueryResponse)
39
- async def get_most_similar(query_request: QueryRequest):
40
  try:
41
  similar_prompts = engine.most_similar(
42
  query=query_request.query, n=query_request.n
@@ -52,8 +48,14 @@ async def get_most_similar(query_request: QueryRequest):
52
  raise HTTPException(status_code=500, detail=str(e))
53
 
54
 
55
- @app.get("/", response_class=HTMLResponse)
56
- async def home_page():
 
 
 
 
 
 
57
  return HTMLResponse(
58
  """
59
  <!DOCTYPE html>
@@ -77,7 +79,6 @@ async def home_page():
77
  <h2>POST /most_similar</h2>
78
  <p><strong>Request:</strong> <code>{"query": "string", "n": 5}</code></p>
79
  <p><strong>Response:</strong> <code>{"similar_prompts": [{"score": 0.95, "prompt": "Example prompt 1"}]}</code></p>
80
- <p>For more info, visit <a href="https://github.com/your-repository">GitHub</a>.</p>
81
  </div>
82
  </body>
83
  </html>
 
1
+ import os
2
 
3
+ from fastapi import APIRouter, HTTPException
4
  from fastapi.responses import HTMLResponse
 
5
 
6
+ from backend.models import QueryRequest, QueryResponse, SimilarPrompt
7
  from src.prompt_loader import PromptLoader
8
  from src.search_engine import PromptSearchEngine
9
 
10
  # Constants
11
+ SEED = int(os.getenv("SEED", 42))
12
+ DATASET_SIZE = int(os.getenv("DATASET_SIZE", 1000))
13
 
14
  # Initialize the prompt loader and search engine
15
+ prompts = PromptLoader(seed=SEED).load_data(size=DATASET_SIZE)
16
  engine = PromptSearchEngine(prompts)
17
 
18
+ # Initialize the API router
19
+ router = APIRouter()
20
 
21
 
22
+ @router.post("/most_similar", response_model=QueryResponse)
23
+ async def get_most_similar(query_request: QueryRequest) -> QueryResponse:
24
+ """
25
+ Endpoint to retrieve the most similar prompts based on a user query.
26
 
27
+ Args:
28
+ query_request (QueryRequest): The request payload containing the user query and the number of similar prompts to retrieve.
29
 
30
+ Returns:
31
+ QueryResponse: A response containing a list of similar prompts and their similarity scores.
 
32
 
33
+ Raises:
34
+ HTTPException: If an internal server error occurs while processing the request.
35
+ """
 
 
 
 
 
36
  try:
37
  similar_prompts = engine.most_similar(
38
  query=query_request.query, n=query_request.n
 
48
  raise HTTPException(status_code=500, detail=str(e))
49
 
50
 
51
+ @router.get("/", response_class=HTMLResponse)
52
+ async def home_page() -> HTMLResponse:
53
+ """
54
+ Endpoint to serve a simple HTML page with information about the API.
55
+
56
+ Returns:
57
+ HTMLResponse: An HTML page providing an overview of the API and how to use it.
58
+ """
59
  return HTMLResponse(
60
  """
61
  <!DOCTYPE html>
 
79
  <h2>POST /most_similar</h2>
80
  <p><strong>Request:</strong> <code>{"query": "string", "n": 5}</code></p>
81
  <p><strong>Response:</strong> <code>{"similar_prompts": [{"score": 0.95, "prompt": "Example prompt 1"}]}</code></p>
 
82
  </div>
83
  </body>
84
  </html>
frontend/__init__.py ADDED
File without changes
frontend/app_ui.py CHANGED
@@ -1,26 +1,29 @@
1
- import argparse
2
 
3
  import requests
4
  import streamlit as st
5
 
 
 
6
 
7
- def parse_arguments():
8
- """Parse command-line arguments."""
9
- parser = argparse.ArgumentParser(description="Prompt Similarity Finder")
10
- parser.add_argument(
11
- "--api_url",
12
- type=str,
13
- default="https://lazarr19-prompt-engine.hf.space",
14
- help="The URL of the FastAPI service",
15
- )
16
- return parser.parse_args()
17
 
 
 
 
 
 
 
 
 
 
 
18
 
19
- def get_similar_prompts(api_url, query, n):
20
- """Fetch similar prompts from the FastAPI service."""
 
21
  try:
22
  response = requests.post(
23
- f"{api_url}/most_similar", json={"query": query, "n": n}
24
  )
25
  response.raise_for_status() # Raise an exception for HTTP errors
26
  return response.json()
@@ -29,8 +32,16 @@ def get_similar_prompts(api_url, query, n):
29
  return None
30
 
31
 
32
- def get_color(score):
33
- """Determine the color based on the score."""
 
 
 
 
 
 
 
 
34
  if score >= 0.8:
35
  return "green"
36
  elif score >= 0.5:
@@ -39,11 +50,13 @@ def get_color(score):
39
  return "red"
40
 
41
 
42
- def main(api_url):
43
- """Main function to run the Streamlit app."""
 
 
 
44
  st.title("Prompt Similarity Finder")
45
 
46
- # User input for query
47
  query = st.text_input("Enter your query:", "")
48
  n = st.slider(
49
  "Number of similar prompts to retrieve:", min_value=1, max_value=40, value=5
@@ -52,7 +65,7 @@ def main(api_url):
52
  if st.button("Find Similar Prompts"):
53
  if query:
54
  with st.spinner("Fetching similar prompts..."):
55
- result = get_similar_prompts(api_url, query, n)
56
  if result:
57
  similar_prompts = result.get("similar_prompts", [])
58
  if similar_prompts:
@@ -60,7 +73,6 @@ def main(api_url):
60
  for item in similar_prompts:
61
  score = item["score"]
62
  color = get_color(score)
63
- # Apply color only to the score part
64
  st.markdown(
65
  f"<p><strong>Score:</strong> <span style='color:{color};'>{score:.2f}</span> <br> <strong>Prompt:</strong> {item['prompt']}</p>",
66
  unsafe_allow_html=True,
@@ -73,5 +85,4 @@ def main(api_url):
73
 
74
 
75
  if __name__ == "__main__":
76
- args = parse_arguments()
77
- main(args.api_url)
 
1
+ import os
2
 
3
  import requests
4
  import streamlit as st
5
 
6
+ # Read API URL from environment variable
7
+ API_URL = os.getenv("API_URL", "http://localhost:8000")
8
 
 
 
 
 
 
 
 
 
 
 
9
 
10
+ def get_similar_prompts(query: str, n: int) -> dict:
11
+ """
12
+ Fetches similar prompts from the API based on the user query.
13
+
14
+ Args:
15
+ query (str): The user query for which similar prompts are to be retrieved.
16
+ n (int): The number of similar prompts to return.
17
+
18
+ Returns:
19
+ dict: A dictionary containing similar prompts, or None if there was an error.
20
 
21
+ Raises:
22
+ requests.RequestException: If an HTTP error occurs during the request.
23
+ """
24
  try:
25
  response = requests.post(
26
+ f"{API_URL}/most_similar", json={"query": query, "n": n}
27
  )
28
  response.raise_for_status() # Raise an exception for HTTP errors
29
  return response.json()
 
32
  return None
33
 
34
 
35
+ def get_color(score: float) -> str:
36
+ """
37
+ Determines the color based on the similarity score.
38
+
39
+ Args:
40
+ score (float): The similarity score of a prompt.
41
+
42
+ Returns:
43
+ str: The color representing the score, which could be "green", "orange", or "red".
44
+ """
45
  if score >= 0.8:
46
  return "green"
47
  elif score >= 0.5:
 
50
  return "red"
51
 
52
 
53
+ def main():
54
+ """
55
+ The main function for running the Streamlit app.
56
+ Sets up the UI for entering queries and retrieving similar prompts.
57
+ """
58
  st.title("Prompt Similarity Finder")
59
 
 
60
  query = st.text_input("Enter your query:", "")
61
  n = st.slider(
62
  "Number of similar prompts to retrieve:", min_value=1, max_value=40, value=5
 
65
  if st.button("Find Similar Prompts"):
66
  if query:
67
  with st.spinner("Fetching similar prompts..."):
68
+ result = get_similar_prompts(query, n)
69
  if result:
70
  similar_prompts = result.get("similar_prompts", [])
71
  if similar_prompts:
 
73
  for item in similar_prompts:
74
  score = item["score"]
75
  color = get_color(score)
 
76
  st.markdown(
77
  f"<p><strong>Score:</strong> <span style='color:{color};'>{score:.2f}</span> <br> <strong>Prompt:</strong> {item['prompt']}</p>",
78
  unsafe_allow_html=True,
 
85
 
86
 
87
  if __name__ == "__main__":
88
+ main()
 
run.py CHANGED
@@ -3,7 +3,7 @@ import uvicorn
3
 
4
  def run_fastapi_app():
5
  uvicorn.run(
6
- "main:app", # Module name and app instance
7
  host="0.0.0.0",
8
  port=8000,
9
  reload=True, # Enable auto-reload for development
 
3
 
4
  def run_fastapi_app():
5
  uvicorn.run(
6
+ "backend.main:app", # Module name and app instance
7
  host="0.0.0.0",
8
  port=8000,
9
  reload=True, # Enable auto-reload for development
src/prompt_loader.py CHANGED
@@ -5,16 +5,46 @@ from datasets import load_dataset
5
 
6
 
7
  class PromptLoader:
 
 
 
 
8
  def __init__(self, seed: int = 42) -> None:
 
 
 
 
 
 
9
  self.randomizer = random.Random(seed)
10
  self.data: Optional[List[str]] = None
11
 
12
  def _load_data(self) -> None:
 
 
 
 
 
13
  self.data = load_dataset("daspartho/stable-diffusion-prompts")["train"][
14
  "prompt"
15
  ]
16
 
17
  def load_data(self, size: Optional[int] = None) -> List[str]:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18
  if not self.data:
19
  self._load_data()
20
 
 
5
 
6
 
7
  class PromptLoader:
8
+ """
9
+ A class for loading and sampling prompts from a dataset.
10
+ """
11
+
12
  def __init__(self, seed: int = 42) -> None:
13
+ """
14
+ Initializes the PromptLoader with a specified seed for random sampling.
15
+
16
+ Args:
17
+ seed (int): The seed value for the random number generator. Default is 42.
18
+ """
19
  self.randomizer = random.Random(seed)
20
  self.data: Optional[List[str]] = None
21
 
22
  def _load_data(self) -> None:
23
+ """
24
+ Loads the dataset of prompts and stores them in the `data` attribute.
25
+
26
+ This method uses the `datasets` library to load the dataset and extract prompts from the "train" split.
27
+ """
28
  self.data = load_dataset("daspartho/stable-diffusion-prompts")["train"][
29
  "prompt"
30
  ]
31
 
32
  def load_data(self, size: Optional[int] = None) -> List[str]:
33
+ """
34
+ Loads and samples prompts from the dataset.
35
+
36
+ If the dataset is not already loaded, it calls `_load_data()` to load it.
37
+
38
+ Args:
39
+ size (Optional[int]): The number of prompts to sample. If not specified, all loaded prompts are returned.
40
+
41
+ Returns:
42
+ List[str]: A list of sampled prompts. If `size` is specified, returns a random sample of the specified size.
43
+ If `size` is not specified, returns all loaded prompts.
44
+
45
+ Raises:
46
+ ValueError: If `size` is specified and is greater than the number of available prompts.
47
+ """
48
  if not self.data:
49
  self._load_data()
50