wt002 commited on
Commit
1d71658
·
verified ·
1 Parent(s): 3ddca4e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +31 -52
app.py CHANGED
@@ -152,12 +152,12 @@ import os
152
  import time
153
  import json
154
  from typing import TypedDict, List, Union, Any, Dict
155
- from langchain_google_genai import ChatGoogleGenerativeAI
 
156
  from langchain.schema import HumanMessage, AIMessage, SystemMessage
157
  from langchain.prompts import ChatPromptTemplate
158
  from langgraph.graph import StateGraph, END
159
- from google.api_core.exceptions import ResourceExhausted
160
- from langchain.tools import Tool # Import Tool for consistent tool definitions
161
 
162
  # Assume these tools are defined elsewhere and imported
163
  # Placeholder for your actual tool implementations
@@ -248,7 +248,7 @@ def should_continue(state: AgentState) -> str:
248
  print(f"DEBUG: Entering should_continue. Current context: {state.get('context', {})}")
249
 
250
  # End if agent has produced a final answer
251
- if state.get("final_answer"):
252
  print("DEBUG: should_continue -> END (Final Answer set in state)")
253
  return "end"
254
 
@@ -270,29 +270,31 @@ def reasoning_node(state: AgentState) -> AgentState:
270
  print(f"DEBUG: Entering reasoning_node. Iteration: {state['iterations']}")
271
  print(f"DEBUG: Current history length: {len(state.get('history', []))}")
272
 
273
- # Load API key
274
- GOOGLE_API_KEY = os.getenv("GOOGLE_API_KEY")
275
- if not GOOGLE_API_KEY:
276
- raise ValueError("GOOGLE_API_KEY not set in environment variables.")
 
277
 
278
- # Initialize/update state fields
279
  state.setdefault("context", {})
280
  state.setdefault("reasoning", "")
281
  state.setdefault("iterations", 0)
282
  state.setdefault("current_task", "Understand the question and plan the next step.")
283
  state.setdefault("current_thoughts", "")
284
 
285
- # Clear any old pending action from context before generating a new one
286
  state["context"].pop("pending_action", None)
287
 
288
- # Create Gemini model wrapper
289
- llm = ChatGoogleGenerativeAI(
290
- model="gemini-1.5-flash",
291
- temperature=0.1,
292
- google_api_key=GOOGLE_API_KEY
 
 
 
 
293
  )
294
 
295
- # Dynamically generate tool descriptions for the prompt
296
  tool_descriptions = "\n".join([
297
  f"- **{t.name}**: {t.description}" for t in state.get("tools", [])
298
  ])
@@ -329,34 +331,30 @@ def reasoning_node(state: AgentState) -> AgentState:
329
 
330
  prompt = ChatPromptTemplate.from_messages([
331
  SystemMessage(content=system_prompt),
332
- *state["history"] # Include full history for conversational context
333
  ])
334
 
335
  chain = prompt | llm
336
 
337
- def call_with_retry(inputs, retries=3, delay=60):
338
  for attempt in range(retries):
339
  try:
340
  response = chain.invoke(inputs)
341
- # Attempt to parse immediately to catch bad JSON before returning
342
- json.loads(response.content) # Validate JSON structure
343
  return response
344
- except ResourceExhausted as e:
345
- print(f"[Retry {attempt+1}/{retries}] Gemini rate limit hit. Waiting {delay}s...")
346
- time.sleep(delay)
347
  except json.JSONDecodeError as e:
348
  print(f"[Retry {attempt+1}/{retries}] LLM returned invalid JSON. Retrying...")
349
  print(f"Invalid JSON content: {response.content[:200]}...")
350
- time.sleep(5) # Shorter delay for parsing errors
351
- except Exception as e:
352
- print(f"[Retry {attempt+1}/{retries}] An unexpected error occurred during LLM call: {e}. Retrying...")
353
  time.sleep(delay)
354
- raise RuntimeError("Failed after multiple retries due to Gemini quota limit or invalid JSON.")
355
 
356
  response = call_with_retry({
357
  "context": state["context"],
358
  "reasoning": state["reasoning"],
359
- "question": state["question"], # Redundant as it's in history, but keeps prompt consistent
360
  "current_task": state["current_task"],
361
  "current_thoughts": state["current_thoughts"]
362
  })
@@ -367,22 +365,18 @@ def reasoning_node(state: AgentState) -> AgentState:
367
  print(f"DEBUG: LLM Raw Response Content: {content[:200]}...")
368
  print(f"DEBUG: Parsed Action: '{action}', Action Input: '{action_input[:100]}...'")
369
 
370
- # Update state
371
- state["history"].append(AIMessage(content=content)) # Store the raw LLM response
372
  state["reasoning"] += f"\nStep {state['iterations'] + 1}: {reasoning}"
373
  state["iterations"] += 1
374
- state["current_thoughts"] = reasoning # Update current thoughts for next iteration
375
 
376
  if "final answer" in action.lower():
377
- state["final_answer"] = action_input # Set final answer directly in state
378
- # The should_continue check will handle ending the graph based on final_answer presence
379
  else:
380
- # Store the action request in context, not in history
381
  state["context"]["pending_action"] = {
382
  "tool": action,
383
  "input": action_input
384
  }
385
- # Add a message to history to indicate the agent's intent for the LLM
386
  state["history"].append(AIMessage(content=f"Agent decided to use tool: {action} with input: {action_input}"))
387
 
388
 
@@ -396,12 +390,9 @@ def tool_node(state: AgentState) -> AgentState:
396
  """
397
  print(f"DEBUG: Entering tool_node. Iteration: {state['iterations']}")
398
 
399
- # Get the pending action from context
400
  tool_call_dict = state["context"].pop("pending_action", None)
401
 
402
  if not tool_call_dict:
403
- # This case should ideally not be reached if should_continue is robust,
404
- # but provides a fallback.
405
  error_message = "[Tool Error] No pending_action found in context. This indicates an issue with graph flow."
406
  print(f"ERROR: {error_message}")
407
  state["history"].append(AIMessage(content=error_message))
@@ -410,34 +401,29 @@ def tool_node(state: AgentState) -> AgentState:
410
  tool_name = tool_call_dict.get("tool")
411
  tool_input = tool_call_dict.get("input")
412
 
413
- # Defensive check for empty tool name or input (still needed as LLM might generate empty strings)
414
  if not tool_name or tool_input is None:
415
  error_message = f"[Tool Error] Invalid action request from LLM: Tool name '{tool_name}' or input '{tool_input}' was empty. LLM needs to provide valid 'Action' and 'Action Input'."
416
  print(f"ERROR: {error_message}")
417
  state["history"].append(AIMessage(content=error_message))
418
- # Clear any problematic pending action
419
  state["context"].pop("pending_action", None)
420
  return state
421
 
422
- # Look up and invoke the tool from the state's tool list
423
  available_tools = state.get("tools", [])
424
  tool_fn = next((t for t in available_tools if t.name == tool_name), None)
425
 
426
  if tool_fn is None:
427
- # Fallback for unrecognized tool - feedback to LLM
428
  tool_output = f"[Tool Error] Tool '{tool_name}' not found or not available. Please choose from: {', '.join([t.name for t in available_tools])}"
429
  print(f"ERROR: {tool_output}")
430
  else:
431
  try:
432
  print(f"DEBUG: Invoking tool '{tool_name}' with input: '{tool_input[:100]}...'")
433
  tool_output = tool_fn.run(tool_input)
434
- if not tool_output and tool_output is not False: # Ensure 'False' is not treated as empty
435
  tool_output = f"[{tool_name} output] No specific result found for '{tool_input}'. The tool might have returned an empty response."
436
  except Exception as e:
437
  tool_output = f"[Tool Error] An error occurred while running '{tool_name}': {str(e)}"
438
  print(f"ERROR: {tool_output}")
439
 
440
- # Add tool output to history as an AIMessage for the LLM to process next
441
  state["history"].append(AIMessage(content=f"[{tool_name} output]\n{tool_output}"))
442
 
443
  print(f"DEBUG: Exiting tool_node. Tool output added to history. New history length: {len(state['history'])}")
@@ -445,17 +431,14 @@ def tool_node(state: AgentState) -> AgentState:
445
 
446
 
447
  # ====== Agent Graph ======
448
- def create_agent_workflow(tools: List[Tool]): # tools are passed in now
449
  workflow = StateGraph(AgentState)
450
 
451
- # Define nodes
452
  workflow.add_node("reason", reasoning_node)
453
  workflow.add_node("action", tool_node)
454
 
455
- # Set entry point
456
  workflow.set_entry_point("reason")
457
 
458
- # Define edges
459
  workflow.add_conditional_edges(
460
  "reason",
461
  should_continue,
@@ -500,8 +483,6 @@ class BasicAgent:
500
  "tools": self.tools
501
  }
502
 
503
- # The invoke method will now return the final state, or raise an error if it hits a dead end
504
- # LangGraph runs are synchronous by default here.
505
  final_state = self.workflow.invoke(state)
506
 
507
  if final_state.get("final_answer") is not None:
@@ -509,8 +490,6 @@ class BasicAgent:
509
  print(f"--- Agent returning FINAL ANSWER: {answer} ---")
510
  return answer
511
  else:
512
- # This should ideally not happen if the agent is designed to always provide a final answer
513
- # or a specific "cannot answer" message.
514
  print(f"--- ERROR: Agent finished without setting 'final_answer' for question: {question} ---")
515
  raise ValueError("Agent finished without providing a final answer.")
516
 
 
152
  import time
153
  import json
154
  from typing import TypedDict, List, Union, Any, Dict
155
+ from langchain_huggingface import ChatHuggingFace
156
+ from langchain_huggingface.llms import HuggingFaceEndpoint
157
  from langchain.schema import HumanMessage, AIMessage, SystemMessage
158
  from langchain.prompts import ChatPromptTemplate
159
  from langgraph.graph import StateGraph, END
160
+ from langchain.tools import Tool
 
161
 
162
  # Assume these tools are defined elsewhere and imported
163
  # Placeholder for your actual tool implementations
 
248
  print(f"DEBUG: Entering should_continue. Current context: {state.get('context', {})}")
249
 
250
  # End if agent has produced a final answer
251
+ if state.get("final_answer") is not None: # Check for None explicitly
252
  print("DEBUG: should_continue -> END (Final Answer set in state)")
253
  return "end"
254
 
 
270
  print(f"DEBUG: Entering reasoning_node. Iteration: {state['iterations']}")
271
  print(f"DEBUG: Current history length: {len(state.get('history', []))}")
272
 
273
+ # --- CHANGE: Use HF_TOKEN environment variable ---
274
+ HF_TOKEN = os.getenv("HF_TOKEN")
275
+ if not HF_TOKEN:
276
+ raise ValueError("HF_TOKEN not set in environment variables.")
277
+ # --- END CHANGE ---
278
 
 
279
  state.setdefault("context", {})
280
  state.setdefault("reasoning", "")
281
  state.setdefault("iterations", 0)
282
  state.setdefault("current_task", "Understand the question and plan the next step.")
283
  state.setdefault("current_thoughts", "")
284
 
 
285
  state["context"].pop("pending_action", None)
286
 
287
+ model_id = "meta-llama/Llama-2-7b-chat-hf"
288
+
289
+ llm = ChatHuggingFace(
290
+ llm=HuggingFaceEndpoint(
291
+ repo_id=model_id,
292
+ max_new_tokens=512,
293
+ temperature=0.1,
294
+ huggingfacehub_api_token=HF_TOKEN, # --- CHANGE: Pass HF_TOKEN here ---
295
+ )
296
  )
297
 
 
298
  tool_descriptions = "\n".join([
299
  f"- **{t.name}**: {t.description}" for t in state.get("tools", [])
300
  ])
 
331
 
332
  prompt = ChatPromptTemplate.from_messages([
333
  SystemMessage(content=system_prompt),
334
+ *state["history"]
335
  ])
336
 
337
  chain = prompt | llm
338
 
339
+ def call_with_retry(inputs, retries=3, delay=30):
340
  for attempt in range(retries):
341
  try:
342
  response = chain.invoke(inputs)
343
+ json.loads(response.content)
 
344
  return response
 
 
 
345
  except json.JSONDecodeError as e:
346
  print(f"[Retry {attempt+1}/{retries}] LLM returned invalid JSON. Retrying...")
347
  print(f"Invalid JSON content: {response.content[:200]}...")
348
+ time.sleep(5)
349
+ except Exception as e:
350
+ print(f"[Retry {attempt+1}/{retries}] An unexpected error occurred during LLM call: {e}. Waiting {delay}s...")
351
  time.sleep(delay)
352
+ raise RuntimeError("Failed after multiple retries due to Hugging Face API issues or invalid JSON.")
353
 
354
  response = call_with_retry({
355
  "context": state["context"],
356
  "reasoning": state["reasoning"],
357
+ "question": state["question"],
358
  "current_task": state["current_task"],
359
  "current_thoughts": state["current_thoughts"]
360
  })
 
365
  print(f"DEBUG: LLM Raw Response Content: {content[:200]}...")
366
  print(f"DEBUG: Parsed Action: '{action}', Action Input: '{action_input[:100]}...'")
367
 
368
+ state["history"].append(AIMessage(content=content))
 
369
  state["reasoning"] += f"\nStep {state['iterations'] + 1}: {reasoning}"
370
  state["iterations"] += 1
371
+ state["current_thoughts"] = reasoning
372
 
373
  if "final answer" in action.lower():
374
+ state["final_answer"] = action_input
 
375
  else:
 
376
  state["context"]["pending_action"] = {
377
  "tool": action,
378
  "input": action_input
379
  }
 
380
  state["history"].append(AIMessage(content=f"Agent decided to use tool: {action} with input: {action_input}"))
381
 
382
 
 
390
  """
391
  print(f"DEBUG: Entering tool_node. Iteration: {state['iterations']}")
392
 
 
393
  tool_call_dict = state["context"].pop("pending_action", None)
394
 
395
  if not tool_call_dict:
 
 
396
  error_message = "[Tool Error] No pending_action found in context. This indicates an issue with graph flow."
397
  print(f"ERROR: {error_message}")
398
  state["history"].append(AIMessage(content=error_message))
 
401
  tool_name = tool_call_dict.get("tool")
402
  tool_input = tool_call_dict.get("input")
403
 
 
404
  if not tool_name or tool_input is None:
405
  error_message = f"[Tool Error] Invalid action request from LLM: Tool name '{tool_name}' or input '{tool_input}' was empty. LLM needs to provide valid 'Action' and 'Action Input'."
406
  print(f"ERROR: {error_message}")
407
  state["history"].append(AIMessage(content=error_message))
 
408
  state["context"].pop("pending_action", None)
409
  return state
410
 
 
411
  available_tools = state.get("tools", [])
412
  tool_fn = next((t for t in available_tools if t.name == tool_name), None)
413
 
414
  if tool_fn is None:
 
415
  tool_output = f"[Tool Error] Tool '{tool_name}' not found or not available. Please choose from: {', '.join([t.name for t in available_tools])}"
416
  print(f"ERROR: {tool_output}")
417
  else:
418
  try:
419
  print(f"DEBUG: Invoking tool '{tool_name}' with input: '{tool_input[:100]}...'")
420
  tool_output = tool_fn.run(tool_input)
421
+ if not tool_output and tool_output is not False:
422
  tool_output = f"[{tool_name} output] No specific result found for '{tool_input}'. The tool might have returned an empty response."
423
  except Exception as e:
424
  tool_output = f"[Tool Error] An error occurred while running '{tool_name}': {str(e)}"
425
  print(f"ERROR: {tool_output}")
426
 
 
427
  state["history"].append(AIMessage(content=f"[{tool_name} output]\n{tool_output}"))
428
 
429
  print(f"DEBUG: Exiting tool_node. Tool output added to history. New history length: {len(state['history'])}")
 
431
 
432
 
433
  # ====== Agent Graph ======
434
+ def create_agent_workflow(tools: List[Tool]):
435
  workflow = StateGraph(AgentState)
436
 
 
437
  workflow.add_node("reason", reasoning_node)
438
  workflow.add_node("action", tool_node)
439
 
 
440
  workflow.set_entry_point("reason")
441
 
 
442
  workflow.add_conditional_edges(
443
  "reason",
444
  should_continue,
 
483
  "tools": self.tools
484
  }
485
 
 
 
486
  final_state = self.workflow.invoke(state)
487
 
488
  if final_state.get("final_answer") is not None:
 
490
  print(f"--- Agent returning FINAL ANSWER: {answer} ---")
491
  return answer
492
  else:
 
 
493
  print(f"--- ERROR: Agent finished without setting 'final_answer' for question: {question} ---")
494
  raise ValueError("Agent finished without providing a final answer.")
495