Spaces:
Sleeping
Sleeping
Vela
commited on
Commit
·
540db73
1
Parent(s):
75115cd
modified functions
Browse files- application/agents/scraper_agent.py +8 -32
- application/services/gemini_api_service.py +1 -7
- main.py +12 -12
application/agents/scraper_agent.py
CHANGED
@@ -42,57 +42,42 @@ model_with_tools = model.bind_tools(tools)
|
|
42 |
def invoke_model(state: AgentState) -> dict:
|
43 |
"""Invokes the LLM with the current conversation history."""
|
44 |
logger.info("--- Invoking Model ---")
|
45 |
-
# LangGraph automatically passes the entire state
|
46 |
-
# The model_with_tools expects a list of BaseMessages
|
47 |
response = model_with_tools.invoke(state['messages'])
|
48 |
-
|
49 |
-
# We return a dictionary with the key corresponding to the state field name
|
50 |
-
return {"messages": [response]} # The response is already an AIMessage
|
51 |
|
52 |
def invoke_tools(state: AgentState) -> dict:
|
53 |
"""Invokes the necessary tools based on the last AI message."""
|
54 |
logger.info("--- Invoking Tools ---")
|
55 |
-
# The state contains the history, the last message is the AI's request
|
56 |
last_message = state['messages'][-1]
|
57 |
|
58 |
-
# Check if the last message is an AIMessage with tool_calls
|
59 |
if not hasattr(last_message, 'tool_calls') or not last_message.tool_calls:
|
60 |
logger.info("No tool calls found in the last message.")
|
61 |
-
# This scenario might indicate the conversation should end or requires clarification
|
62 |
-
# For now, return an empty dict, which won't update the state significantly.
|
63 |
-
# Consider adding a message indicating no tools were called if needed.
|
64 |
return {}
|
65 |
-
# Alternative: return {"messages": [SystemMessage(content="No tool calls requested.")]}
|
66 |
|
67 |
tool_invocation_messages = []
|
68 |
|
69 |
-
# Find the tool object by name
|
70 |
tool_map = {tool.name: tool for tool in tools}
|
71 |
|
72 |
for tool_call in last_message.tool_calls:
|
73 |
tool_name = tool_call['name']
|
74 |
tool_args = tool_call['args']
|
75 |
-
tool_call_id = tool_call['id']
|
76 |
|
77 |
logger.info(f"Executing tool: {tool_name} with args: {tool_args}")
|
78 |
|
79 |
if tool_name in tool_map:
|
80 |
selected_tool = tool_map[tool_name]
|
81 |
try:
|
82 |
-
# Use the tool's invoke method, passing the arguments dictionary
|
83 |
result = selected_tool.invoke(tool_args)
|
84 |
|
85 |
-
# IMPORTANT: Convert the result to a string or a JSON serializable format
|
86 |
-
# if it's a complex object. ToolMessage content should be simple.
|
87 |
-
# Adjust this based on what your tools actually return.
|
88 |
if isinstance(result, list) or isinstance(result, dict):
|
89 |
-
result_content = json.dumps(result)
|
90 |
-
elif hasattr(result, 'companies') and isinstance(result.companies, list):
|
91 |
result_content = f"Companies found: {', '.join(result.companies)}"
|
92 |
elif result is None:
|
93 |
result_content = "Tool executed successfully, but returned no specific data (None)."
|
94 |
else:
|
95 |
-
result_content = str(result)
|
96 |
|
97 |
logger.info(f"Tool {tool_name} result: {result_content}")
|
98 |
tool_invocation_messages.append(
|
@@ -100,7 +85,6 @@ def invoke_tools(state: AgentState) -> dict:
|
|
100 |
)
|
101 |
except Exception as e:
|
102 |
logger.error(f"Error executing tool {tool_name}: {e}")
|
103 |
-
# Return an error message in the ToolMessage
|
104 |
tool_invocation_messages.append(
|
105 |
ToolMessage(content=f"Error executing tool {tool_name}: {str(e)}", tool_call_id=tool_call_id)
|
106 |
)
|
@@ -110,29 +94,22 @@ def invoke_tools(state: AgentState) -> dict:
|
|
110 |
ToolMessage(content=f"Error: Tool '{tool_name}' not found.", tool_call_id=tool_call_id)
|
111 |
)
|
112 |
|
113 |
-
# Return the collected ToolMessages to be added to the state
|
114 |
return {"messages": tool_invocation_messages}
|
115 |
|
116 |
-
# --- Graph Definition ---
|
117 |
graph_builder = StateGraph(AgentState)
|
118 |
|
119 |
-
# Add nodes
|
120 |
graph_builder.add_node("scraper_agent", invoke_model)
|
121 |
-
graph_builder.add_node("tools", invoke_tools)
|
122 |
|
123 |
-
# Define edges
|
124 |
graph_builder.set_entry_point("scraper_agent")
|
125 |
|
126 |
-
# Conditional edge: After the agent runs, decide whether to call tools or end.
|
127 |
def router(state: AgentState) -> str:
|
128 |
"""Determines the next step based on the last message."""
|
129 |
last_message = state['messages'][-1]
|
130 |
if hasattr(last_message, 'tool_calls') and last_message.tool_calls:
|
131 |
-
# If the AI message has tool calls, invoke the tools node
|
132 |
logger.info("--- Routing to Tools ---")
|
133 |
return "tools"
|
134 |
else:
|
135 |
-
# Otherwise, the conversation can end
|
136 |
logger.info("--- Routing to End ---")
|
137 |
return END
|
138 |
|
@@ -140,12 +117,11 @@ graph_builder.add_conditional_edges(
|
|
140 |
"scraper_agent",
|
141 |
router,
|
142 |
{
|
143 |
-
"tools": "tools",
|
144 |
-
END: END,
|
145 |
}
|
146 |
)
|
147 |
|
148 |
-
# After tools are invoked, their results (ToolMessages) should go back to the agent
|
149 |
graph_builder.add_edge("tools", "scraper_agent")
|
150 |
|
151 |
# Compile the graph
|
|
|
42 |
def invoke_model(state: AgentState) -> dict:
|
43 |
"""Invokes the LLM with the current conversation history."""
|
44 |
logger.info("--- Invoking Model ---")
|
|
|
|
|
45 |
response = model_with_tools.invoke(state['messages'])
|
46 |
+
return {"messages": [response]}
|
|
|
|
|
47 |
|
48 |
def invoke_tools(state: AgentState) -> dict:
|
49 |
"""Invokes the necessary tools based on the last AI message."""
|
50 |
logger.info("--- Invoking Tools ---")
|
|
|
51 |
last_message = state['messages'][-1]
|
52 |
|
|
|
53 |
if not hasattr(last_message, 'tool_calls') or not last_message.tool_calls:
|
54 |
logger.info("No tool calls found in the last message.")
|
|
|
|
|
|
|
55 |
return {}
|
|
|
56 |
|
57 |
tool_invocation_messages = []
|
58 |
|
|
|
59 |
tool_map = {tool.name: tool for tool in tools}
|
60 |
|
61 |
for tool_call in last_message.tool_calls:
|
62 |
tool_name = tool_call['name']
|
63 |
tool_args = tool_call['args']
|
64 |
+
tool_call_id = tool_call['id']
|
65 |
|
66 |
logger.info(f"Executing tool: {tool_name} with args: {tool_args}")
|
67 |
|
68 |
if tool_name in tool_map:
|
69 |
selected_tool = tool_map[tool_name]
|
70 |
try:
|
|
|
71 |
result = selected_tool.invoke(tool_args)
|
72 |
|
|
|
|
|
|
|
73 |
if isinstance(result, list) or isinstance(result, dict):
|
74 |
+
result_content = json.dumps(result)
|
75 |
+
elif hasattr(result, 'companies') and isinstance(result.companies, list):
|
76 |
result_content = f"Companies found: {', '.join(result.companies)}"
|
77 |
elif result is None:
|
78 |
result_content = "Tool executed successfully, but returned no specific data (None)."
|
79 |
else:
|
80 |
+
result_content = str(result)
|
81 |
|
82 |
logger.info(f"Tool {tool_name} result: {result_content}")
|
83 |
tool_invocation_messages.append(
|
|
|
85 |
)
|
86 |
except Exception as e:
|
87 |
logger.error(f"Error executing tool {tool_name}: {e}")
|
|
|
88 |
tool_invocation_messages.append(
|
89 |
ToolMessage(content=f"Error executing tool {tool_name}: {str(e)}", tool_call_id=tool_call_id)
|
90 |
)
|
|
|
94 |
ToolMessage(content=f"Error: Tool '{tool_name}' not found.", tool_call_id=tool_call_id)
|
95 |
)
|
96 |
|
|
|
97 |
return {"messages": tool_invocation_messages}
|
98 |
|
|
|
99 |
graph_builder = StateGraph(AgentState)
|
100 |
|
|
|
101 |
graph_builder.add_node("scraper_agent", invoke_model)
|
102 |
+
graph_builder.add_node("tools", invoke_tools)
|
103 |
|
|
|
104 |
graph_builder.set_entry_point("scraper_agent")
|
105 |
|
|
|
106 |
def router(state: AgentState) -> str:
|
107 |
"""Determines the next step based on the last message."""
|
108 |
last_message = state['messages'][-1]
|
109 |
if hasattr(last_message, 'tool_calls') and last_message.tool_calls:
|
|
|
110 |
logger.info("--- Routing to Tools ---")
|
111 |
return "tools"
|
112 |
else:
|
|
|
113 |
logger.info("--- Routing to End ---")
|
114 |
return END
|
115 |
|
|
|
117 |
"scraper_agent",
|
118 |
router,
|
119 |
{
|
120 |
+
"tools": "tools",
|
121 |
+
END: END,
|
122 |
}
|
123 |
)
|
124 |
|
|
|
125 |
graph_builder.add_edge("tools", "scraper_agent")
|
126 |
|
127 |
# Compile the graph
|
application/services/gemini_api_service.py
CHANGED
@@ -152,13 +152,11 @@ def upload_file(
|
|
152 |
Exception: If upload fails.
|
153 |
"""
|
154 |
try:
|
155 |
-
# Determine if input is a URL
|
156 |
is_url = isinstance(file, str) and file.startswith(('http://', 'https://'))
|
157 |
|
158 |
-
# Determine file name if not provided
|
159 |
if not file_name:
|
160 |
if is_url:
|
161 |
-
file_name = os.path.basename(file.split("?")[0])
|
162 |
elif isinstance(file, str):
|
163 |
file_name = os.path.basename(file)
|
164 |
elif hasattr(file, "name"):
|
@@ -172,14 +170,12 @@ def upload_file(
|
|
172 |
config.update({"name": sanitized_name, "mime_type": mime_type})
|
173 |
gemini_file_key = f"files/{sanitized_name}"
|
174 |
|
175 |
-
# Check if file already exists
|
176 |
if gemini_file_key in get_files():
|
177 |
logger.info(f"File already exists on Gemini: {gemini_file_key}")
|
178 |
return client.files.get(name=gemini_file_key)
|
179 |
|
180 |
logger.info(f"Uploading file to Gemini: {gemini_file_key}")
|
181 |
|
182 |
-
# Handle URL
|
183 |
if is_url:
|
184 |
headers = {
|
185 |
"User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/120.0.0.0 Safari/537.36"
|
@@ -189,14 +185,12 @@ def upload_file(
|
|
189 |
file_content = io.BytesIO(response.content)
|
190 |
return client.files.upload(file=file_content, config=config)
|
191 |
|
192 |
-
# Handle local file path
|
193 |
if isinstance(file, str):
|
194 |
if not os.path.isfile(file):
|
195 |
raise FileNotFoundError(f"Local file '{file}' does not exist.")
|
196 |
with open(file, "rb") as f:
|
197 |
return client.files.upload(file=f, config=config)
|
198 |
|
199 |
-
# Handle already opened binary file object
|
200 |
return client.files.upload(file=file, config=config)
|
201 |
|
202 |
except Exception as e:
|
|
|
152 |
Exception: If upload fails.
|
153 |
"""
|
154 |
try:
|
|
|
155 |
is_url = isinstance(file, str) and file.startswith(('http://', 'https://'))
|
156 |
|
|
|
157 |
if not file_name:
|
158 |
if is_url:
|
159 |
+
file_name = os.path.basename(file.split("?")[0])
|
160 |
elif isinstance(file, str):
|
161 |
file_name = os.path.basename(file)
|
162 |
elif hasattr(file, "name"):
|
|
|
170 |
config.update({"name": sanitized_name, "mime_type": mime_type})
|
171 |
gemini_file_key = f"files/{sanitized_name}"
|
172 |
|
|
|
173 |
if gemini_file_key in get_files():
|
174 |
logger.info(f"File already exists on Gemini: {gemini_file_key}")
|
175 |
return client.files.get(name=gemini_file_key)
|
176 |
|
177 |
logger.info(f"Uploading file to Gemini: {gemini_file_key}")
|
178 |
|
|
|
179 |
if is_url:
|
180 |
headers = {
|
181 |
"User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/120.0.0.0 Safari/537.36"
|
|
|
185 |
file_content = io.BytesIO(response.content)
|
186 |
return client.files.upload(file=file_content, config=config)
|
187 |
|
|
|
188 |
if isinstance(file, str):
|
189 |
if not os.path.isfile(file):
|
190 |
raise FileNotFoundError(f"Local file '{file}' does not exist.")
|
191 |
with open(file, "rb") as f:
|
192 |
return client.files.upload(file=f, config=config)
|
193 |
|
|
|
194 |
return client.files.upload(file=file, config=config)
|
195 |
|
196 |
except Exception as e:
|
main.py
CHANGED
@@ -147,15 +147,15 @@ workflow.set_entry_point("supervisor")
|
|
147 |
graph = workflow.compile()
|
148 |
|
149 |
# # === Example Run ===
|
150 |
-
if __name__ == "__main__":
|
151 |
-
|
152 |
-
|
153 |
-
|
154 |
-
|
155 |
-
|
156 |
-
|
157 |
-
|
158 |
-
|
159 |
-
|
160 |
-
|
161 |
-
|
|
|
147 |
graph = workflow.compile()
|
148 |
|
149 |
# # === Example Run ===
|
150 |
+
# if __name__ == "__main__":
|
151 |
+
# logger.info("Starting the graph execution...")
|
152 |
+
# initial_message = HumanMessage(content="Can you get zalando pdf link")
|
153 |
+
# input_state = {"messages": [initial_message]}
|
154 |
+
|
155 |
+
# for step in graph.stream(input_state):
|
156 |
+
# if "__end__" not in step:
|
157 |
+
# logger.info(f"Graph Step Output: {step}")
|
158 |
+
# print(step)
|
159 |
+
# print("----")
|
160 |
+
|
161 |
+
# logger.info("Graph execution completed.")
|