Spaces:
Sleeping
Sleeping
Lazar Radojevic
commited on
Commit
·
1cd5053
1
Parent(s):
3556e6f
refactor everything
Browse files- .env +3 -0
- README.md +8 -2
- backend/__init__.py +0 -0
- backend/main.py +9 -0
- backend/models.py +29 -0
- main.py → backend/routes.py +27 -26
- frontend/__init__.py +0 -0
- frontend/app_ui.py +34 -23
- run.py +1 -1
- src/prompt_loader.py +30 -0
.env
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
API_URL="https://lazarr19-prompt-engine.hf.space"
|
2 |
+
SEED=42
|
3 |
+
DATASET_SIZE=1000
|
README.md
CHANGED
@@ -43,6 +43,10 @@ Before you start, ensure you have the following tools installed:
|
|
43 |
poetry shell
|
44 |
```
|
45 |
|
|
|
|
|
|
|
|
|
46 |
### Backend Details
|
47 |
|
48 |
The backend API provides the following endpoint:
|
@@ -76,6 +80,8 @@ In case you want to set up your own HuggingFace Space, you must create a HF toke
|
|
76 |
git remote set-url origin https://USERNAME:[email protected]/spaces/USERNAME/REPO_NAME.git
|
77 |
```
|
78 |
|
|
|
|
|
79 |
### User Interface
|
80 |
|
81 |
The frontend UI is simple and includes:
|
@@ -86,7 +92,7 @@ The frontend UI is simple and includes:
|
|
86 |
To start only the UI service you can run:
|
87 |
|
88 |
```bash
|
89 |
-
poe frontend
|
90 |
```
|
91 |
|
92 |
-
|
|
|
43 |
poetry shell
|
44 |
```
|
45 |
|
46 |
+
### Environment Variables
|
47 |
+
|
48 |
+
Environment variables are set in the .env file in the root of the repo.
|
49 |
+
|
50 |
### Backend Details
|
51 |
|
52 |
The backend API provides the following endpoint:
|
|
|
80 |
git remote set-url origin https://USERNAME:[email protected]/spaces/USERNAME/REPO_NAME.git
|
81 |
```
|
82 |
|
83 |
+
Also, pay attention to the HF Space parameters at the beginning of this README.md.
|
84 |
+
|
85 |
### User Interface
|
86 |
|
87 |
The frontend UI is simple and includes:
|
|
|
92 |
To start only the UI service you can run:
|
93 |
|
94 |
```bash
|
95 |
+
poe frontend
|
96 |
```
|
97 |
|
98 |
+
Or you can do it with `https://lazarr19-prompt-engine.hf.space` if you prefer to hit the HF Space endpoint.
|
backend/__init__.py
ADDED
File without changes
|
backend/main.py
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from fastapi import FastAPI
|
2 |
+
|
3 |
+
from backend.routes import router
|
4 |
+
|
5 |
+
# Initialize FastAPI
|
6 |
+
app = FastAPI()
|
7 |
+
|
8 |
+
# Include routes from the routes module
|
9 |
+
app.include_router(router)
|
backend/models.py
ADDED
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import List
|
2 |
+
|
3 |
+
from pydantic import BaseModel
|
4 |
+
|
5 |
+
|
6 |
+
class QueryRequest(BaseModel):
|
7 |
+
"""
|
8 |
+
Represents the request model for querying similar prompts.
|
9 |
+
"""
|
10 |
+
|
11 |
+
query: str
|
12 |
+
n: int = 5
|
13 |
+
|
14 |
+
|
15 |
+
class SimilarPrompt(BaseModel):
|
16 |
+
"""
|
17 |
+
Represents a single similar prompt with its similarity score.
|
18 |
+
"""
|
19 |
+
|
20 |
+
score: float
|
21 |
+
prompt: str
|
22 |
+
|
23 |
+
|
24 |
+
class QueryResponse(BaseModel):
|
25 |
+
"""
|
26 |
+
Represents the response model containing a list of similar prompts.
|
27 |
+
"""
|
28 |
+
|
29 |
+
similar_prompts: List[SimilarPrompt]
|
main.py → backend/routes.py
RENAMED
@@ -1,42 +1,38 @@
|
|
1 |
-
|
2 |
|
3 |
-
from fastapi import
|
4 |
from fastapi.responses import HTMLResponse
|
5 |
-
from pydantic import BaseModel
|
6 |
|
|
|
7 |
from src.prompt_loader import PromptLoader
|
8 |
from src.search_engine import PromptSearchEngine
|
9 |
|
10 |
# Constants
|
11 |
-
SEED = 42
|
12 |
-
|
13 |
|
14 |
# Initialize the prompt loader and search engine
|
15 |
-
prompts = PromptLoader(seed=SEED).load_data(size=
|
16 |
engine = PromptSearchEngine(prompts)
|
17 |
|
18 |
-
# Initialize
|
19 |
-
|
20 |
|
21 |
|
22 |
-
|
23 |
-
|
24 |
-
|
25 |
-
|
26 |
|
|
|
|
|
27 |
|
28 |
-
|
29 |
-
|
30 |
-
prompt: str
|
31 |
|
32 |
-
|
33 |
-
|
34 |
-
|
35 |
-
|
36 |
-
|
37 |
-
# API endpoint
|
38 |
-
@app.post("/most_similar", response_model=QueryResponse)
|
39 |
-
async def get_most_similar(query_request: QueryRequest):
|
40 |
try:
|
41 |
similar_prompts = engine.most_similar(
|
42 |
query=query_request.query, n=query_request.n
|
@@ -52,8 +48,14 @@ async def get_most_similar(query_request: QueryRequest):
|
|
52 |
raise HTTPException(status_code=500, detail=str(e))
|
53 |
|
54 |
|
55 |
-
@
|
56 |
-
async def home_page():
|
|
|
|
|
|
|
|
|
|
|
|
|
57 |
return HTMLResponse(
|
58 |
"""
|
59 |
<!DOCTYPE html>
|
@@ -77,7 +79,6 @@ async def home_page():
|
|
77 |
<h2>POST /most_similar</h2>
|
78 |
<p><strong>Request:</strong> <code>{"query": "string", "n": 5}</code></p>
|
79 |
<p><strong>Response:</strong> <code>{"similar_prompts": [{"score": 0.95, "prompt": "Example prompt 1"}]}</code></p>
|
80 |
-
<p>For more info, visit <a href="https://github.com/your-repository">GitHub</a>.</p>
|
81 |
</div>
|
82 |
</body>
|
83 |
</html>
|
|
|
1 |
+
import os
|
2 |
|
3 |
+
from fastapi import APIRouter, HTTPException
|
4 |
from fastapi.responses import HTMLResponse
|
|
|
5 |
|
6 |
+
from backend.models import QueryRequest, QueryResponse, SimilarPrompt
|
7 |
from src.prompt_loader import PromptLoader
|
8 |
from src.search_engine import PromptSearchEngine
|
9 |
|
10 |
# Constants
|
11 |
+
SEED = int(os.getenv("SEED", 42))
|
12 |
+
DATASET_SIZE = int(os.getenv("DATASET_SIZE", 1000))
|
13 |
|
14 |
# Initialize the prompt loader and search engine
|
15 |
+
prompts = PromptLoader(seed=SEED).load_data(size=DATASET_SIZE)
|
16 |
engine = PromptSearchEngine(prompts)
|
17 |
|
18 |
+
# Initialize the API router
|
19 |
+
router = APIRouter()
|
20 |
|
21 |
|
22 |
+
@router.post("/most_similar", response_model=QueryResponse)
|
23 |
+
async def get_most_similar(query_request: QueryRequest) -> QueryResponse:
|
24 |
+
"""
|
25 |
+
Endpoint to retrieve the most similar prompts based on a user query.
|
26 |
|
27 |
+
Args:
|
28 |
+
query_request (QueryRequest): The request payload containing the user query and the number of similar prompts to retrieve.
|
29 |
|
30 |
+
Returns:
|
31 |
+
QueryResponse: A response containing a list of similar prompts and their similarity scores.
|
|
|
32 |
|
33 |
+
Raises:
|
34 |
+
HTTPException: If an internal server error occurs while processing the request.
|
35 |
+
"""
|
|
|
|
|
|
|
|
|
|
|
36 |
try:
|
37 |
similar_prompts = engine.most_similar(
|
38 |
query=query_request.query, n=query_request.n
|
|
|
48 |
raise HTTPException(status_code=500, detail=str(e))
|
49 |
|
50 |
|
51 |
+
@router.get("/", response_class=HTMLResponse)
|
52 |
+
async def home_page() -> HTMLResponse:
|
53 |
+
"""
|
54 |
+
Endpoint to serve a simple HTML page with information about the API.
|
55 |
+
|
56 |
+
Returns:
|
57 |
+
HTMLResponse: An HTML page providing an overview of the API and how to use it.
|
58 |
+
"""
|
59 |
return HTMLResponse(
|
60 |
"""
|
61 |
<!DOCTYPE html>
|
|
|
79 |
<h2>POST /most_similar</h2>
|
80 |
<p><strong>Request:</strong> <code>{"query": "string", "n": 5}</code></p>
|
81 |
<p><strong>Response:</strong> <code>{"similar_prompts": [{"score": 0.95, "prompt": "Example prompt 1"}]}</code></p>
|
|
|
82 |
</div>
|
83 |
</body>
|
84 |
</html>
|
frontend/__init__.py
ADDED
File without changes
|
frontend/app_ui.py
CHANGED
@@ -1,26 +1,29 @@
|
|
1 |
-
import
|
2 |
|
3 |
import requests
|
4 |
import streamlit as st
|
5 |
|
|
|
|
|
6 |
|
7 |
-
def parse_arguments():
|
8 |
-
"""Parse command-line arguments."""
|
9 |
-
parser = argparse.ArgumentParser(description="Prompt Similarity Finder")
|
10 |
-
parser.add_argument(
|
11 |
-
"--api_url",
|
12 |
-
type=str,
|
13 |
-
default="https://lazarr19-prompt-engine.hf.space",
|
14 |
-
help="The URL of the FastAPI service",
|
15 |
-
)
|
16 |
-
return parser.parse_args()
|
17 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
18 |
|
19 |
-
|
20 |
-
|
|
|
21 |
try:
|
22 |
response = requests.post(
|
23 |
-
f"{
|
24 |
)
|
25 |
response.raise_for_status() # Raise an exception for HTTP errors
|
26 |
return response.json()
|
@@ -29,8 +32,16 @@ def get_similar_prompts(api_url, query, n):
|
|
29 |
return None
|
30 |
|
31 |
|
32 |
-
def get_color(score):
|
33 |
-
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
34 |
if score >= 0.8:
|
35 |
return "green"
|
36 |
elif score >= 0.5:
|
@@ -39,11 +50,13 @@ def get_color(score):
|
|
39 |
return "red"
|
40 |
|
41 |
|
42 |
-
def main(
|
43 |
-
"""
|
|
|
|
|
|
|
44 |
st.title("Prompt Similarity Finder")
|
45 |
|
46 |
-
# User input for query
|
47 |
query = st.text_input("Enter your query:", "")
|
48 |
n = st.slider(
|
49 |
"Number of similar prompts to retrieve:", min_value=1, max_value=40, value=5
|
@@ -52,7 +65,7 @@ def main(api_url):
|
|
52 |
if st.button("Find Similar Prompts"):
|
53 |
if query:
|
54 |
with st.spinner("Fetching similar prompts..."):
|
55 |
-
result = get_similar_prompts(
|
56 |
if result:
|
57 |
similar_prompts = result.get("similar_prompts", [])
|
58 |
if similar_prompts:
|
@@ -60,7 +73,6 @@ def main(api_url):
|
|
60 |
for item in similar_prompts:
|
61 |
score = item["score"]
|
62 |
color = get_color(score)
|
63 |
-
# Apply color only to the score part
|
64 |
st.markdown(
|
65 |
f"<p><strong>Score:</strong> <span style='color:{color};'>{score:.2f}</span> <br> <strong>Prompt:</strong> {item['prompt']}</p>",
|
66 |
unsafe_allow_html=True,
|
@@ -73,5 +85,4 @@ def main(api_url):
|
|
73 |
|
74 |
|
75 |
if __name__ == "__main__":
|
76 |
-
|
77 |
-
main(args.api_url)
|
|
|
1 |
+
import os
|
2 |
|
3 |
import requests
|
4 |
import streamlit as st
|
5 |
|
6 |
+
# Read API URL from environment variable
|
7 |
+
API_URL = os.getenv("API_URL", "http://localhost:8000")
|
8 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
9 |
|
10 |
+
def get_similar_prompts(query: str, n: int) -> dict:
|
11 |
+
"""
|
12 |
+
Fetches similar prompts from the API based on the user query.
|
13 |
+
|
14 |
+
Args:
|
15 |
+
query (str): The user query for which similar prompts are to be retrieved.
|
16 |
+
n (int): The number of similar prompts to return.
|
17 |
+
|
18 |
+
Returns:
|
19 |
+
dict: A dictionary containing similar prompts, or None if there was an error.
|
20 |
|
21 |
+
Raises:
|
22 |
+
requests.RequestException: If an HTTP error occurs during the request.
|
23 |
+
"""
|
24 |
try:
|
25 |
response = requests.post(
|
26 |
+
f"{API_URL}/most_similar", json={"query": query, "n": n}
|
27 |
)
|
28 |
response.raise_for_status() # Raise an exception for HTTP errors
|
29 |
return response.json()
|
|
|
32 |
return None
|
33 |
|
34 |
|
35 |
+
def get_color(score: float) -> str:
|
36 |
+
"""
|
37 |
+
Determines the color based on the similarity score.
|
38 |
+
|
39 |
+
Args:
|
40 |
+
score (float): The similarity score of a prompt.
|
41 |
+
|
42 |
+
Returns:
|
43 |
+
str: The color representing the score, which could be "green", "orange", or "red".
|
44 |
+
"""
|
45 |
if score >= 0.8:
|
46 |
return "green"
|
47 |
elif score >= 0.5:
|
|
|
50 |
return "red"
|
51 |
|
52 |
|
53 |
+
def main():
|
54 |
+
"""
|
55 |
+
The main function for running the Streamlit app.
|
56 |
+
Sets up the UI for entering queries and retrieving similar prompts.
|
57 |
+
"""
|
58 |
st.title("Prompt Similarity Finder")
|
59 |
|
|
|
60 |
query = st.text_input("Enter your query:", "")
|
61 |
n = st.slider(
|
62 |
"Number of similar prompts to retrieve:", min_value=1, max_value=40, value=5
|
|
|
65 |
if st.button("Find Similar Prompts"):
|
66 |
if query:
|
67 |
with st.spinner("Fetching similar prompts..."):
|
68 |
+
result = get_similar_prompts(query, n)
|
69 |
if result:
|
70 |
similar_prompts = result.get("similar_prompts", [])
|
71 |
if similar_prompts:
|
|
|
73 |
for item in similar_prompts:
|
74 |
score = item["score"]
|
75 |
color = get_color(score)
|
|
|
76 |
st.markdown(
|
77 |
f"<p><strong>Score:</strong> <span style='color:{color};'>{score:.2f}</span> <br> <strong>Prompt:</strong> {item['prompt']}</p>",
|
78 |
unsafe_allow_html=True,
|
|
|
85 |
|
86 |
|
87 |
if __name__ == "__main__":
|
88 |
+
main()
|
|
run.py
CHANGED
@@ -3,7 +3,7 @@ import uvicorn
|
|
3 |
|
4 |
def run_fastapi_app():
|
5 |
uvicorn.run(
|
6 |
-
"main:app", # Module name and app instance
|
7 |
host="0.0.0.0",
|
8 |
port=8000,
|
9 |
reload=True, # Enable auto-reload for development
|
|
|
3 |
|
4 |
def run_fastapi_app():
|
5 |
uvicorn.run(
|
6 |
+
"backend.main:app", # Module name and app instance
|
7 |
host="0.0.0.0",
|
8 |
port=8000,
|
9 |
reload=True, # Enable auto-reload for development
|
src/prompt_loader.py
CHANGED
@@ -5,16 +5,46 @@ from datasets import load_dataset
|
|
5 |
|
6 |
|
7 |
class PromptLoader:
|
|
|
|
|
|
|
|
|
8 |
def __init__(self, seed: int = 42) -> None:
|
|
|
|
|
|
|
|
|
|
|
|
|
9 |
self.randomizer = random.Random(seed)
|
10 |
self.data: Optional[List[str]] = None
|
11 |
|
12 |
def _load_data(self) -> None:
|
|
|
|
|
|
|
|
|
|
|
13 |
self.data = load_dataset("daspartho/stable-diffusion-prompts")["train"][
|
14 |
"prompt"
|
15 |
]
|
16 |
|
17 |
def load_data(self, size: Optional[int] = None) -> List[str]:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
18 |
if not self.data:
|
19 |
self._load_data()
|
20 |
|
|
|
5 |
|
6 |
|
7 |
class PromptLoader:
|
8 |
+
"""
|
9 |
+
A class for loading and sampling prompts from a dataset.
|
10 |
+
"""
|
11 |
+
|
12 |
def __init__(self, seed: int = 42) -> None:
|
13 |
+
"""
|
14 |
+
Initializes the PromptLoader with a specified seed for random sampling.
|
15 |
+
|
16 |
+
Args:
|
17 |
+
seed (int): The seed value for the random number generator. Default is 42.
|
18 |
+
"""
|
19 |
self.randomizer = random.Random(seed)
|
20 |
self.data: Optional[List[str]] = None
|
21 |
|
22 |
def _load_data(self) -> None:
|
23 |
+
"""
|
24 |
+
Loads the dataset of prompts and stores them in the `data` attribute.
|
25 |
+
|
26 |
+
This method uses the `datasets` library to load the dataset and extract prompts from the "train" split.
|
27 |
+
"""
|
28 |
self.data = load_dataset("daspartho/stable-diffusion-prompts")["train"][
|
29 |
"prompt"
|
30 |
]
|
31 |
|
32 |
def load_data(self, size: Optional[int] = None) -> List[str]:
|
33 |
+
"""
|
34 |
+
Loads and samples prompts from the dataset.
|
35 |
+
|
36 |
+
If the dataset is not already loaded, it calls `_load_data()` to load it.
|
37 |
+
|
38 |
+
Args:
|
39 |
+
size (Optional[int]): The number of prompts to sample. If not specified, all loaded prompts are returned.
|
40 |
+
|
41 |
+
Returns:
|
42 |
+
List[str]: A list of sampled prompts. If `size` is specified, returns a random sample of the specified size.
|
43 |
+
If `size` is not specified, returns all loaded prompts.
|
44 |
+
|
45 |
+
Raises:
|
46 |
+
ValueError: If `size` is specified and is greater than the number of available prompts.
|
47 |
+
"""
|
48 |
if not self.data:
|
49 |
self._load_data()
|
50 |
|