applemuncy commited on
Commit
b89e612
·
1 Parent(s): 049485c

adding files

Browse files
Files changed (2) hide show
  1. requirements.txt +3 -0
  2. retriever.py +48 -0
requirements.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ datasets
2
+ smolagents
3
+ langchain-community
retriever.py ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from langchain.docstore.document import Document
2
+ import datasets
3
+
4
+ class GuestInfoRetrieverTool(Tool):
5
+ name = "guest_info_retriever"
6
+ description = "Retrieves detailed information about gala guests based on their name or relation.
7
+ "
8
+ inputs = {
9
+ "query": {
10
+ "type": "string",
11
+ "description": "The name or relation of the guest you want information about."
12
+ }
13
+ }
14
+ output_type = "string"
15
+
16
+ def __init__(self, docs):
17
+ self.is_initialized = False
18
+ self.retriever = BM25Retriever.from_documents(docs)
19
+
20
+ def forward(self, query: str):
21
+ results = self.retriever.get_relevant_documents(query)
22
+ if results:
23
+ return "\n\n".join([doc.page_content for doc in results[:3]])
24
+ else:
25
+ return "No matching guest information found."
26
+
27
+
28
+ def load_guest_dataset():
29
+ # Load the dataset
30
+ guest_dataset = datasets.load_dataset("agents-course/unit3-invitees", split="train")
31
+
32
+ # Convert dataset entries into Document objects
33
+ docs = [
34
+ Document(
35
+ page_content="\n".join([
36
+ f"Name: {guest['name']}",
37
+ f"Relation: {guest['relation']}",
38
+ f"Description: {guest['description']}",
39
+ f"Email: {guest['email']}"
40
+ ]),
41
+ metadata={"name": guest["name"]}
42
+ )
43
+ for guest in guest_dataset
44
+ ]
45
+
46
+ # Return the tool
47
+ return GuestInfoRetrieverTool(docs)
48
+