ALLOUNE
commited on
Commit
·
62151ed
1
Parent(s):
290a90d
add purpose search
Browse files- main.py +3 -2
- src/processor.py +9 -4
main.py
CHANGED
@@ -28,6 +28,7 @@ dataset = load_dataset("heymenn/Technologies", streaming=True, split="train")
|
|
28 |
|
29 |
class SearchInput(BaseModel):
|
30 |
title: str
|
|
|
31 |
|
32 |
class SearchOutput(BaseModel):
|
33 |
title: str
|
@@ -53,7 +54,7 @@ def post_search(payload: SearchInput):
|
|
53 |
"""
|
54 |
Endpoint that returns a search result.
|
55 |
"""
|
56 |
-
config = {"dataset": dataset, "model": model}
|
57 |
res = search_and_retrieve(payload.title, config)
|
58 |
return res
|
59 |
|
@@ -63,7 +64,7 @@ def post_generate_and_push(payload: GenerateInput):
|
|
63 |
Endpoint to generate a technology and push it to the dataset
|
64 |
"""
|
65 |
|
66 |
-
config = {"dataset": dataset, "model": model}
|
67 |
res = search_and_retrieve(payload.title, config)
|
68 |
if res["score"] >= 0.7 and not payload.force:
|
69 |
raise HTTPException(status_code=500, detail=f"Cannot generate the technology a high score of {res['score']} have been found for the technology : {res['title']}")
|
|
|
28 |
|
29 |
class SearchInput(BaseModel):
|
30 |
title: str
|
31 |
+
type: str = "title"
|
32 |
|
33 |
class SearchOutput(BaseModel):
|
34 |
title: str
|
|
|
54 |
"""
|
55 |
Endpoint that returns a search result.
|
56 |
"""
|
57 |
+
config = {"dataset": dataset, "model": model, "type": payload.type}
|
58 |
res = search_and_retrieve(payload.title, config)
|
59 |
return res
|
60 |
|
|
|
64 |
Endpoint to generate a technology and push it to the dataset
|
65 |
"""
|
66 |
|
67 |
+
config = {"dataset": dataset, "model": model, "type": "title"}
|
68 |
res = search_and_retrieve(payload.title, config)
|
69 |
if res["score"] >= 0.7 and not payload.force:
|
70 |
raise HTTPException(status_code=500, detail=f"Cannot generate the technology a high score of {res['score']} have been found for the technology : {res['title']}")
|
src/processor.py
CHANGED
@@ -18,8 +18,11 @@ def search_and_retrieve(user_input, config):
|
|
18 |
purpose = row["purpose"]
|
19 |
|
20 |
cosim = model.similarity(row["embeddings"], user_embedding)
|
21 |
-
|
22 |
-
|
|
|
|
|
|
|
23 |
fuzzy_score = token_set_ratio / 100
|
24 |
alpha = 0.6
|
25 |
combined_score = alpha * cosim + (1 - alpha) * fuzzy_score
|
@@ -77,9 +80,11 @@ def generate_tech(user_input, user_instructions):
|
|
77 |
|
78 |
<USER_INPUT>
|
79 |
{user_input}
|
80 |
-
</USER_INPUT>
|
81 |
"""
|
82 |
|
|
|
|
|
83 |
client = Client(api_key=os.getenv("GEMINI_API_KEY"))
|
84 |
|
85 |
# Define the grounding tool
|
@@ -111,4 +116,4 @@ def send_to_dataset(data, model):
|
|
111 |
|
112 |
dataset = load_dataset("heymenn/Technologies", split="train")
|
113 |
updated_dataset = dataset.add_item(data)
|
114 |
-
updated_dataset.push_to_hub("heymenn/Technologies")
|
|
|
18 |
purpose = row["purpose"]
|
19 |
|
20 |
cosim = model.similarity(row["embeddings"], user_embedding)
|
21 |
+
if config["type"] == "purpose":
|
22 |
+
token_set_ratio = fuzz.token_set_ratio(user_input, purpose)
|
23 |
+
else:
|
24 |
+
token_set_ratio = fuzz.token_set_ratio(user_input, name)
|
25 |
+
|
26 |
fuzzy_score = token_set_ratio / 100
|
27 |
alpha = 0.6
|
28 |
combined_score = alpha * cosim + (1 - alpha) * fuzzy_score
|
|
|
80 |
|
81 |
<USER_INPUT>
|
82 |
{user_input}
|
83 |
+
</USER_INPUT>
|
84 |
"""
|
85 |
|
86 |
+
client = Client(api_key=os.getenv("GEMINI_API_KEY"))
|
87 |
+
|
88 |
client = Client(api_key=os.getenv("GEMINI_API_KEY"))
|
89 |
|
90 |
# Define the grounding tool
|
|
|
116 |
|
117 |
dataset = load_dataset("heymenn/Technologies", split="train")
|
118 |
updated_dataset = dataset.add_item(data)
|
119 |
+
updated_dataset.push_to_hub("heymenn/Technologies")
|