sid385 commited on
Commit
329acf2
·
verified ·
1 Parent(s): d8680d8

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +68 -42
app.py CHANGED
@@ -5,61 +5,87 @@ import inspect
5
  import pandas as pd
6
 
7
 
 
8
  from langgraph.graph import StateGraph, END
9
- from langchain_core.runnables import Runnable
10
- from langchain_core.agents import AgentFinish
11
- from langchain.agents import create_react_agent
12
- from langchain.agents.agent import RunnableAgent
13
- from langchain.tools import Tool
14
- from langchain_core.messages import HumanMessage
15
  from langchain_google_genai import ChatGoogleGenerativeAI
 
 
 
 
 
 
 
16
 
17
 
18
- # (Keep Constants as is)
19
- # --- Constants ---
20
- DEFAULT_API_URL = "https://agents-course-unit4-scoring.hf.space"
 
 
21
 
22
- GOOGLE_API_KEY = os.getenv("GOOGLE_API_KEY")
 
 
 
23
 
 
 
 
 
24
 
25
- # --- Basic Agent Definition ---
26
- # ----- THIS IS WERE YOU CAN BUILD WHAT YOU WANT ------
27
- class BasicAgent:
28
- def __init__(self):
29
- print("LangGraph ReAct Agent initialized.")
30
- if not GOOGLE_API_KEY:
31
- raise ValueError("GOOGLE_API_KEY environment variable not found.")
32
-
33
- self.llm = ChatGoogleGenerativeAI(model="gemini-2.0-flash")
34
 
 
 
35
 
36
- # Example tools
37
- self.tools = [
38
- Tool(
39
- name="DuckDuckGo Search",
40
- func=lambda q: f"Pretend search result for: {q}",
41
- description="Search the web using DuckDuckGo. Useful for recent events."
42
- )
43
- ]
44
-
45
- react_agent = create_react_agent(self.llm, self.tools)
46
 
47
- workflow = StateGraph({"agent_outcome": Runnable})
48
- workflow.add_node("agent", react_agent)
49
- workflow.set_entry_point("agent")
50
- workflow.add_edge("agent", END)
51
 
52
- app = workflow.compile()
53
- self.agent = RunnableAgent(app)
54
 
55
- def __call__(self, question: str) -> str:
56
- result = self.agent.invoke({"input": HumanMessage(content=question)})
57
- if isinstance(result, AgentFinish):
58
- return result.return_values["output"]
59
- elif isinstance(result, dict) and "output" in result:
60
- return result["output"]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
61
  else:
62
- return "No valid output returned."
 
 
 
 
63
 
64
  def run_and_submit_all( profile: gr.OAuthProfile | None):
65
  """
 
5
  import pandas as pd
6
 
7
 
8
+ from typing import Sequence, Annotated, TypedDict, Union
9
  from langgraph.graph import StateGraph, END
10
+ from langgraph.graph.message import add_messages
11
+ from langgraph.prebuilt import ToolNode
12
+ from langchain_core.messages import BaseMessage, SystemMessage, HumanMessage
 
 
 
13
  from langchain_google_genai import ChatGoogleGenerativeAI
14
+ from langchain_community.tools import DuckDuckGoSearchRun
15
+ from langchain.tools import tool
16
+ from dotenv import load_dotenv
17
+ load_dotenv()
18
+ # ----------- Agent State -----------
19
+ class AgentState(TypedDict):
20
+ messages: Annotated[Sequence[BaseMessage], add_messages]
21
 
22
 
23
+ # ----------- Math Tools ------------
24
+ @tool
25
+ def add(a: int, b: int):
26
+ """Adds two numbers."""
27
+ return a + b
28
 
29
+ @tool
30
+ def subtract(a: int, b: int):
31
+ """Subtracts two numbers."""
32
+ return a - b
33
 
34
+ @tool
35
+ def multiply(a: int, b: int):
36
+ """Multiplies two numbers."""
37
+ return a * b
38
 
39
+ # ----------- DuckDuckGo Tool (LangChain built-in) -----------
40
+ ddg_tool = DuckDuckGoSearchRun(name="duckduckgo_search")
 
 
 
 
 
 
 
41
 
42
+ # ----------- Combine all tools -----------
43
+ tools = [add, subtract, multiply, ddg_tool]
44
 
 
 
 
 
 
 
 
 
 
 
45
 
46
+ # ----------- BasicAgent Class -----------
47
+ class BasicAgent:
48
+ def __init__(self):
49
+ print("LangGraph Gemini Agent with LangChain DuckDuckGo Tool Initialized.")
50
 
51
+ self.model = ChatGoogleGenerativeAI(model="gemini-2.0-flash").bind_tools(tools)
 
52
 
53
+ def model_call(state: AgentState) -> AgentState:
54
+ system_prompt = SystemMessage(
55
+ content="You are an AI assistant. Use tools like math functions and web search (DuckDuckGo) to answer queries."
56
+ )
57
+ response = self.model.invoke([system_prompt] + state["messages"])
58
+ return {"messages": [response]}
59
+
60
+ def should_continue(state: AgentState):
61
+ last_message = state["messages"][-1]
62
+ if not getattr(last_message, "tool_calls", None):
63
+ return "end"
64
+ return "continue"
65
+
66
+ graph = StateGraph(AgentState)
67
+ graph.add_node("our_agent", model_call)
68
+ graph.add_node("tools", ToolNode(tools=tools))
69
+ graph.set_entry_point("our_agent")
70
+ graph.add_conditional_edges("our_agent", should_continue, {
71
+ "continue": "tools",
72
+ "end": END
73
+ })
74
+ graph.add_edge("tools", "our_agent")
75
+
76
+ self.agent = graph.compile()
77
+
78
+ def __call__(self, query: Union[str, Sequence[BaseMessage]]) -> str:
79
+ if isinstance(query, str):
80
+ messages = [HumanMessage(content=query)]
81
+ elif isinstance(query, list):
82
+ messages = query
83
  else:
84
+ raise ValueError("Invalid input: Must be a string or list of messages.")
85
+
86
+ result = self.agent.invoke({"messages": messages})
87
+ last_message = result["messages"][-1]
88
+ return last_message.content if hasattr(last_message, "content") else str(last_message)
89
 
90
  def run_and_submit_all( profile: gr.OAuthProfile | None):
91
  """