File size: 2,498 Bytes
67c6de9
d6d38f9
e94f0c8
b32034a
2daff71
8349717
a05af2d
e94f0c8
ca55ba5
 
758cb32
b38a8b5
a05af2d
e94f0c8
a861628
 
67c6de9
ca55ba5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a05af2d
 
 
 
a67ef47
 
 
 
 
 
 
 
a861628
67c6de9
8349717
 
 
 
 
 
 
 
 
 
67c6de9
 
8349717
b32034a
67c6de9
 
4dbec95
 
 
 
 
 
 
 
02e95e3
9f3c913
02e95e3
4dbec95
758cb32
 
 
 
 
 
 
 
 
 
 
 
8349717
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
import gradio as gr
from typing import TypedDict, Annotated
from huggingface_hub import InferenceClient, login
import random
from langchain_huggingface import ChatHuggingFace, HuggingFaceEndpoint, HuggingFacePipeline
from langchain.schema import AIMessage, HumanMessage
from langchain.tools import Tool
import os
import datasets
from langchain.docstore.document import Document
from langgraph.graph import START, StateGraph
from langchain_community.retrievers import BM25Retriever
from retriever import extract_text

HUGGINGFACEHUB_API_TOKEN = os.environ["HUGGINGFACEHUB_API_TOKEN"]
login(token=HUGGINGFACEHUB_API_TOKEN)

# Load the dataset
guest_dataset = datasets.load_dataset("agents-course/unit3-invitees", split="train")

# Convert dataset entries into Document objects
docs = [
    Document(
        page_content="\n".join([
            f"Name: {guest['name']}",
            f"Relation: {guest['relation']}",
            f"Description: {guest['description']}",
            f"Email: {guest['email']}"
        ]),
        metadata={"name": guest["name"]}
    )
    for guest in guest_dataset
]


bm25_retriever = BM25Retriever.from_documents(docs)


llm = HuggingFaceEndpoint(
    repo_id="HuggingFaceH4/zephyr-7b-beta",
    task="text-generation",
    max_new_tokens=512,
    do_sample=False,
    repetition_penalty=1.03,
)

model = ChatHuggingFace(llm=llm, verbose=True)

def predict(message, history):
    history_langchain_format = []
    for msg in history:
        if msg['role'] == "user":
            history_langchain_format.append(HumanMessage(content=msg['content']))
        elif msg['role'] == "assistant":
            history_langchain_format.append(AIMessage(content=msg['content']))
    history_langchain_format.append(HumanMessage(content=message))
    gpt_response = model.invoke(history_langchain_format)
    return gpt_response.content

demo = gr.ChatInterface(
    predict,
    type="messages"
)

# setup agents

guest_info_tool = Tool(
    name="guest_info_retriever",
    func=extract_text,
    description="Retrieves detailed information about gala guests based on their name or relation."
)

tools = [guest_info_tool]
chat_with_tools = model.bind_tools(tools)


# Generate the AgentState and Agent graph
class AgentState(TypedDict):
    messages: Annotated[list[AnyMessage], add_messages]

def assistant(state: AgentState):
    return {
        "messages": [chat_with_tools.invoke(state["messages"])],
    }

## The graph
builder = StateGraph(AgentState)

demo.launch()