Spaces:
Runtime error
Runtime error
"""Guest Information Retrieval Toll""" | |
from smolagents import Tool | |
from langchain_community.retrievers import BM25Retriever | |
from langchain.docstore.document import Document | |
import datasets | |
class GuestInfoRetrieverTool(Tool): | |
"""Derived Class for Guest Information Retrieval Tool""" | |
name = "guest_info_retriever" | |
description = "Retrieves detailed information about gala guests based on their name or relation." # pylint: disable=line-too-long | |
inputs = { | |
"query": { | |
"type": "string", | |
"description": "The name or relation of the guest you want information about.", | |
} | |
} | |
output_type = "string" | |
def __init__(self, docs): # pylint: disable=super-init-not-called | |
self.is_initialized = False | |
self.retriever = BM25Retriever.from_documents(docs) | |
def forward(self, query: str): # pylint: disable=arguments-differ | |
results = self.retriever.get_relevant_documents(query) | |
if results: | |
return "\n\n".join([doc.page_content for doc in results[:3]]) | |
return "No matching guest information found." | |
def load_guest_dataset(): | |
"""Load the dataset""" | |
DATASETS = [ # pylint: disable=invalid-name | |
"agents-course/unit3-invitees", | |
"ANT-TECH/Unit3-dataset3", | |
"Data-Gem/agents-course-unit3-invitees-expanded", | |
] | |
NAME = "name" # pylint: disable=invalid-name | |
RELATION = "relation" # pylint: disable=invalid-name | |
DESCRIPTION = "description" # pylint: disable=invalid-name | |
EMAIL = "email" # pylint: disable=invalid-name | |
COLUMN_LIST = [NAME, RELATION, DESCRIPTION, EMAIL] # pylint: disable=invalid-name | |
raw_guest_list = [] | |
guest_list = [] | |
for dataset_name in DATASETS: | |
guest_dataset = datasets.load_dataset(dataset_name, split="train") | |
raw_guest_list.append(guest_dataset) | |
for raw_guest in raw_guest_list: | |
guest = {} | |
for col_name in COLUMN_LIST: | |
if col_name in raw_guest: | |
guest[col_name] = raw_guest[col_name] | |
else: | |
guest[col_name] = f"{col_name} not in data".capitalize() | |
guest_list.append(guest) | |
# Convert dataset entries into Document objects | |
docs = [ | |
Document( | |
page_content="\n".join( | |
[ | |
f"Name: {guest[NAME]}", | |
f"Relation: {guest[RELATION]}", | |
f"Description: {guest[DESCRIPTION]}", | |
f"Email: {guest[EMAIL]}", | |
] | |
), | |
metadata={"name": guest[NAME]}, | |
) | |
for guest in guest_list | |
] | |
# Return the tool | |
return GuestInfoRetrieverTool(docs) | |