File size: 4,212 Bytes
67c6de9
d6d38f9
e94f0c8
b32034a
2daff71
2b1dab2
bbf0eaf
1e21e45
2b1dab2
a05af2d
e94f0c8
ca55ba5
 
758cb32
b38a8b5
e94f0c8
a861628
3f3cb4d
67c6de9
ca55ba5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a05af2d
 
 
d5f6ff8
 
 
 
 
 
 
a05af2d
a67ef47
 
 
 
 
 
 
 
a861628
67c6de9
6a5de82
8349717
 
 
 
 
 
 
 
 
 
6a5de82
67c6de9
6a5de82
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
67c6de9
4dbec95
 
 
 
 
 
 
 
02e95e3
9f3c913
02e95e3
758cb32
 
 
 
 
 
 
 
 
 
 
 
5b7afbf
 
 
 
 
 
 
 
 
 
 
 
 
48febc6
4402fc7
6044a35
 
 
 
 
5b7afbf
8349717
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
import gradio as gr
from typing import TypedDict, Annotated
from huggingface_hub import InferenceClient, login
import random
from langchain_huggingface import ChatHuggingFace, HuggingFaceEndpoint, HuggingFacePipeline
#from langchain.schema import AIMessage, HumanMessage
from langgraph.graph.message import add_messages
from langgraph.prebuilt import ToolNode, tools_condition
from langchain_core.messages import AnyMessage, HumanMessage, AIMessage
from langchain.tools import Tool
import os
import datasets
from langchain.docstore.document import Document
from langgraph.graph import START, StateGraph
from langchain_community.retrievers import BM25Retriever

HUGGINGFACEHUB_API_TOKEN = os.environ["HUGGINGFACEHUB_API_TOKEN"]
login(token=HUGGINGFACEHUB_API_TOKEN, add_to_git_credential=True)

# Load the dataset
guest_dataset = datasets.load_dataset("agents-course/unit3-invitees", split="train")

# Convert dataset entries into Document objects
docs = [
    Document(
        page_content="\n".join([
            f"Name: {guest['name']}",
            f"Relation: {guest['relation']}",
            f"Description: {guest['description']}",
            f"Email: {guest['email']}"
        ]),
        metadata={"name": guest["name"]}
    )
    for guest in guest_dataset
]


bm25_retriever = BM25Retriever.from_documents(docs)

def extract_text(query: str) -> str:
    """Retrieves detailed information about gala guests based on their name or relation."""
    results = bm25_retriever.invoke(query)
    if results:
        return "\n\n".join([doc.page_content for doc in results[:3]])
    else:
        return "No matching guest information found."

llm = HuggingFaceEndpoint(
    repo_id="HuggingFaceH4/zephyr-7b-beta",
    task="text-generation",
    max_new_tokens=512,
    do_sample=False,
    repetition_penalty=1.03,
)

model = ChatHuggingFace(llm=llm, verbose=True)

"""
def predict(message, history):
    history_langchain_format = []
    for msg in history:
        if msg['role'] == "user":
            history_langchain_format.append(HumanMessage(content=msg['content']))
        elif msg['role'] == "assistant":
            history_langchain_format.append(AIMessage(content=msg['content']))
    history_langchain_format.append(HumanMessage(content=message))
    gpt_response = model.invoke(history_langchain_format)
    return gpt_response.content
"""

def predict(message, history):
    # Convert Gradio history to LangChain message format
    history_langchain_format = []
    for msg in history:
        if msg['role'] == "user":
            history_langchain_format.append(HumanMessage(content=msg['content']))
        elif msg['role'] == "assistant":
            history_langchain_format.append(AIMessage(content=msg['content"]))
    
    # Add new user message
    history_langchain_format.append(HumanMessage(content=message))
    
    # Invoke Alfred agent with full message history
    response = alfred.invoke(
        input={"messages": history_langchain_format},
        config={"recursion_limit": 100}
    )
    
    # Extract final assistant message
    return response["messages"][-1].content

# setup agents

guest_info_tool = Tool(
    name="guest_info_retriever",
    func=extract_text,
    description="Retrieves detailed information about gala guests based on their name or relation."
)

tools = [guest_info_tool]
chat_with_tools = model.bind_tools(tools)

# Generate the AgentState and Agent graph
class AgentState(TypedDict):
    messages: Annotated[list[AnyMessage], add_messages]

def assistant(state: AgentState):
    return {
        "messages": [chat_with_tools.invoke(state["messages"])],
    }

## The graph
builder = StateGraph(AgentState)

# Define nodes: these do the work
builder.add_node("assistant", assistant)
builder.add_node("tools", ToolNode(tools))

# Define edges: these determine how the control flow moves
builder.add_edge(START, "assistant")
builder.add_conditional_edges(
    "assistant",
    # If the latest message requires a tool, route to tools
    # Otherwise, provide a direct response
    tools_condition,
)
builder.add_edge("tools", "assistant")
alfred = builder.compile()

demo = gr.ChatInterface(
    predict,
    type="messages"
)


demo.launch()