File size: 1,266 Bytes
baa44d9
48af48e
baa44d9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
from langchain_community.retrievers import BM25Retriever


import datasets
from langchain.docstore.document import Document

class GuestInfoRetriever:
    """A class to retrieve information about gala guests."""

    def __init__(self, docs):
        self.docs = docs
        self.dataset = BM25Retriever.from_documents(docs)

    def retrieve(self, query: str):
        """Retrieves detailed information about gala guests based on their name or relation."""
        results = self.dataset.invoke(query)
        if results:
            return "\n\n".join([doc.page_content for doc in results[:3]])
        else:
            return "No matching guest information found."

# Load the dataset
def load_guest_dataset():
    guest_dataset = datasets.load_dataset("agents-course/unit3-invitees", split="train")
    # Convert dataset entries into Document objects
    docs = [
        Document(
            page_content="\n".join([
                f"Name: {guest['name']}",
                f"Relation: {guest['relation']}",
                f"Description: {guest['description']}",
                f"Email: {guest['email']}"
            ]),
            metadata={"name": guest["name"]}
        )
        for guest in guest_dataset
    ]

    return GuestInfoRetriever(docs)