ALLOUNE commited on
Commit
62151ed
·
1 Parent(s): 290a90d

add purpose search

Browse files
Files changed (2) hide show
  1. main.py +3 -2
  2. src/processor.py +9 -4
main.py CHANGED
@@ -28,6 +28,7 @@ dataset = load_dataset("heymenn/Technologies", streaming=True, split="train")
28
 
29
  class SearchInput(BaseModel):
30
  title: str
 
31
 
32
  class SearchOutput(BaseModel):
33
  title: str
@@ -53,7 +54,7 @@ def post_search(payload: SearchInput):
53
  """
54
  Endpoint that returns a search result.
55
  """
56
- config = {"dataset": dataset, "model": model}
57
  res = search_and_retrieve(payload.title, config)
58
  return res
59
 
@@ -63,7 +64,7 @@ def post_generate_and_push(payload: GenerateInput):
63
  Endpoint to generate a technology and push it to the dataset
64
  """
65
 
66
- config = {"dataset": dataset, "model": model}
67
  res = search_and_retrieve(payload.title, config)
68
  if res["score"] >= 0.7 and not payload.force:
69
  raise HTTPException(status_code=500, detail=f"Cannot generate the technology a high score of {res['score']} have been found for the technology : {res['title']}")
 
28
 
29
  class SearchInput(BaseModel):
30
  title: str
31
+ type: str = "title"
32
 
33
  class SearchOutput(BaseModel):
34
  title: str
 
54
  """
55
  Endpoint that returns a search result.
56
  """
57
+ config = {"dataset": dataset, "model": model, "type": payload.type}
58
  res = search_and_retrieve(payload.title, config)
59
  return res
60
 
 
64
  Endpoint to generate a technology and push it to the dataset
65
  """
66
 
67
+ config = {"dataset": dataset, "model": model, "type": "title"}
68
  res = search_and_retrieve(payload.title, config)
69
  if res["score"] >= 0.7 and not payload.force:
70
  raise HTTPException(status_code=500, detail=f"Cannot generate the technology a high score of {res['score']} have been found for the technology : {res['title']}")
src/processor.py CHANGED
@@ -18,8 +18,11 @@ def search_and_retrieve(user_input, config):
18
  purpose = row["purpose"]
19
 
20
  cosim = model.similarity(row["embeddings"], user_embedding)
21
- token_set_ratio = fuzz.token_set_ratio(user_input, name)
22
-
 
 
 
23
  fuzzy_score = token_set_ratio / 100
24
  alpha = 0.6
25
  combined_score = alpha * cosim + (1 - alpha) * fuzzy_score
@@ -77,9 +80,11 @@ def generate_tech(user_input, user_instructions):
77
 
78
  <USER_INPUT>
79
  {user_input}
80
- </USER_INPUT>
81
  """
82
 
 
 
83
  client = Client(api_key=os.getenv("GEMINI_API_KEY"))
84
 
85
  # Define the grounding tool
@@ -111,4 +116,4 @@ def send_to_dataset(data, model):
111
 
112
  dataset = load_dataset("heymenn/Technologies", split="train")
113
  updated_dataset = dataset.add_item(data)
114
- updated_dataset.push_to_hub("heymenn/Technologies")
 
18
  purpose = row["purpose"]
19
 
20
  cosim = model.similarity(row["embeddings"], user_embedding)
21
+ if config["type"] == "purpose":
22
+ token_set_ratio = fuzz.token_set_ratio(user_input, purpose)
23
+ else:
24
+ token_set_ratio = fuzz.token_set_ratio(user_input, name)
25
+
26
  fuzzy_score = token_set_ratio / 100
27
  alpha = 0.6
28
  combined_score = alpha * cosim + (1 - alpha) * fuzzy_score
 
80
 
81
  <USER_INPUT>
82
  {user_input}
83
+ </USER_INPUT>
84
  """
85
 
86
+ client = Client(api_key=os.getenv("GEMINI_API_KEY"))
87
+
88
  client = Client(api_key=os.getenv("GEMINI_API_KEY"))
89
 
90
  # Define the grounding tool
 
116
 
117
  dataset = load_dataset("heymenn/Technologies", split="train")
118
  updated_dataset = dataset.add_item(data)
119
+ updated_dataset.push_to_hub("heymenn/Technologies")