EtienneB
Update agent.py
d82a6d1
raw
history blame
6.8 kB
import json
import os
import re
from dotenv import load_dotenv
from langchain_core.messages import HumanMessage, SystemMessage, ToolMessage
from langchain_huggingface import (ChatHuggingFace, HuggingFaceEmbeddings,
HuggingFaceEndpoint)
from langgraph.graph import START, MessagesState, StateGraph
from langgraph.prebuilt import ToolNode, tools_condition
from tools import (absolute, add, analyze_csv_file, analyze_excel_file,
arvix_search, audio_transcription, compound_interest,
convert_temperature, divide, exponential,
extract_text_from_image, factorial, floor_divide,
get_current_time_in_timezone, greatest_common_divisor,
is_prime, least_common_multiple, logarithm, modulus,
multiply, percentage_calculator, power, python_code_parser,
reverse_sentence, roman_calculator_converter, square_root,
subtract, web_content_extract, web_search, wiki_search)
# Load Constants
load_dotenv()
HUGGINGFACEHUB_API_TOKEN = os.getenv("HUGGINGFACEHUB_API_TOKEN")
GEMINI_API_KEY = os.getenv("GEMINI_API_KEY")
tools = [
multiply, add, subtract, power, divide, modulus,
square_root, floor_divide, absolute, logarithm,
exponential, web_search, roman_calculator_converter,
get_current_time_in_timezone, compound_interest,
convert_temperature, factorial, greatest_common_divisor,
is_prime, least_common_multiple, percentage_calculator,
wiki_search, analyze_excel_file, arvix_search,
audio_transcription, python_code_parser, analyze_csv_file,
extract_text_from_image, reverse_sentence, web_content_extract,
]
# Load system prompt
system_prompt = """
You are a general AI assistant. I will ask you a question.
Report your thoughts, and finish your answer with only the answer, no extra text, no prefix, and no explanation.
Your answer should be a number OR as few words as possible OR a comma separated list of numbers and/or strings.
If you are asked for a number, don't use a comma to write your number, nor use units such as $ or percent sign unless specified otherwise.
If you are asked for a string, don't use articles, neither abbreviations (e.g. for cities), and write the digits in plain text unless specified otherwise.
If you are asked for a comma separated list, apply the above rules depending on whether the element to be put in the list is a number or a string.
Format your output as: [{"task_id": ..., "submitted_answer": ...}]
Do NOT include the format string or any JSON inside the submitted_answer field. Only output a single flat list as: [{"task_id": ..., "submitted_answer": ...}]
"""
# System message
sys_msg = SystemMessage(content=system_prompt)
def build_graph():
"""Build the graph"""
# First create the HuggingFaceEndpoint
llm_endpoint = HuggingFaceEndpoint(
repo_id="mistralai/Mistral-7B-Instruct-v0.2",
huggingfacehub_api_token=HUGGINGFACEHUB_API_TOKEN,
#api_key=GEMINI_API_KEY,
temperature=0.1,
max_new_tokens=1024,
timeout=60,
)
# Then wrap it with ChatHuggingFace to get chat model functionality
llm = ChatHuggingFace(llm=llm_endpoint)
# Bind tools to LLM
llm_with_tools = llm.bind_tools(tools)
# --- Nodes ---
def assistant(state: MessagesState):
"""Assistant node"""
messages_with_system_prompt = [sys_msg] + state["messages"]
llm_response = llm_with_tools.invoke(messages_with_system_prompt)
# Extract the answer text (strip any "FINAL ANSWER:" if present)
answer_text = llm_response.content
if answer_text.strip().lower().startswith("final answer:"):
answer_text = answer_text.split(":", 1)[1].strip()
# Get task_id from state or set a placeholder
task_id = state.get("task_id", "1") # Replace with actual logic if needed
formatted = f'Answers (answers): [{{"task_id": "{task_id}", "submitted_answer": "{answer_text}"}}]'
formatted = extract_flat_answer(formatted)
return {"messages": [formatted]}
# --- Graph Definition ---
builder = StateGraph(MessagesState)
builder.add_node("assistant", assistant)
builder.add_node("tools", ToolNode(tools))
builder.add_edge(START, "assistant")
builder.add_conditional_edges("assistant", tools_condition)
builder.add_edge("tools", "assistant")
# Compile graph
return builder.compile()
def is_valid_agent_output(output):
"""
Checks if the output matches the required format:
Answers (answers): [{"task_id": ..., "submitted_answer": ...}]
"""
# Basic regex to check the format
pattern = r'^Answers \(answers\): \[(\{.*\})\]$'
match = re.match(pattern, output.strip())
if not match:
return False
# Try to parse the JSON part
try:
answers_list = json.loads(f'[{match.group(1)}]')
# Check required keys
for ans in answers_list:
if not isinstance(ans, dict):
return False
if "task_id" not in ans or "submitted_answer" not in ans:
return False
return True
except Exception:
return False
def extract_flat_answer(output):
# Try to find the innermost Answers (answers): [{...}]
pattern = r'Answers \(answers\): \[(\{.*?\})\]'
matches = re.findall(pattern, output)
if matches:
# Use the last match (innermost)
try:
answers_list = json.loads(f'[{matches[-1]}]')
if isinstance(answers_list, list) and "task_id" in answers_list[0] and "submitted_answer" in answers_list[0]:
return f'Answers (answers): [{matches[-1]}]'
except Exception:
pass
return output # fallback
# test
if __name__ == "__main__":
question = "When was a picture of St. Thomas Aquinas first added to the Wikipedia page on the Principle of double effect?"
# Build the graph
graph = build_graph()
# Run the graph
messages = [HumanMessage(content=question)]
# The initial state for the graph
initial_state = {"messages": messages, "task_id": "test123"}
# Invoke the graph stream to see the steps
for s in graph.stream(initial_state, stream_mode="values"):
message = s["messages"][-1]
if isinstance(message, ToolMessage):
print("---RETRIEVED CONTEXT---")
print(message.content)
print("-----------------------")
else:
output = str(message)
print("Agent Output:", output)
if is_valid_agent_output(output):
print("✅ Output is in the correct format!")
else:
print("❌ Output is NOT in the correct format!")