EtienneB
updates
30b4543
raw
history blame
7.39 kB
import json
import os
import re
from dotenv import load_dotenv
from langchain_core.messages import (AIMessage, HumanMessage, SystemMessage,
ToolMessage)
from langchain_huggingface import ChatHuggingFace, HuggingFaceEndpoint
from langgraph.graph import START, MessagesState, StateGraph
from langgraph.prebuilt import ToolNode, tools_condition
from tools import (absolute, add, analyze_csv_file, analyze_excel_file,
arvix_search, audio_transcription, compound_interest,
convert_temperature, divide, download_file, exponential,
extract_text_from_image, factorial, floor_divide,
get_current_time_in_timezone, greatest_common_divisor,
is_prime, least_common_multiple, logarithm, modulus,
multiply, percentage_calculator, power, python_code_parser,
reverse_sentence, roman_calculator_converter, square_root,
subtract, web_content_extract, web_search, wiki_search)
# Load Constants
load_dotenv()
HUGGINGFACEHUB_API_TOKEN = os.getenv("HUGGINGFACEHUB_API_TOKEN")
GEMINI_API_KEY = os.getenv("GEMINI_API_KEY")
tools = [
multiply, add, subtract, power, divide, modulus,
square_root, floor_divide, absolute, logarithm,
exponential, web_search, roman_calculator_converter,
get_current_time_in_timezone, compound_interest,
convert_temperature, factorial, greatest_common_divisor,
is_prime, least_common_multiple, percentage_calculator,
wiki_search, analyze_excel_file, arvix_search,
audio_transcription, python_code_parser, analyze_csv_file,
extract_text_from_image, reverse_sentence, web_content_extract,
download_file,
]
# Updated system prompt for cleaner output
system_prompt = """
You are a helpful AI assistant. When asked a question, think through it step by step and provide only the final answer.
CRITICAL INSTRUCTIONS:
- If the question mentions attachments, files, images, documents, or URLs, use the download_file tool FIRST to download them
- Use available tools when needed to gather information or perform calculations
- For file analysis, use appropriate tools (analyze_csv_file, analyze_excel_file, extract_text_from_image, etc.)
- After using tools and analyzing the information, provide ONLY the final answer
- Do not include explanations, reasoning, or extra text in your final response
- If the answer is a number, provide just the number (no units unless specifically requested)
- If the answer is text, provide just the essential text (no articles or extra words unless necessary)
- If the answer is a list, provide it as comma-separated values
Your response should contain ONLY the answer - nothing else.
"""
# System message
sys_msg = SystemMessage(content=system_prompt)
def build_graph():
"""Build the graph"""
# First create the HuggingFaceEndpoint
llm_endpoint = HuggingFaceEndpoint(
repo_id="mistralai/Mistral-7B-Instruct-v0.2",
huggingfacehub_api_token=HUGGINGFACEHUB_API_TOKEN,
temperature=0.1,
max_new_tokens=1024,
timeout=60,
)
# Then wrap it with ChatHuggingFace to get chat model functionality
llm = ChatHuggingFace(llm=llm_endpoint)
# Bind tools to LLM
llm_with_tools = llm.bind_tools(tools)
def clean_answer(text):
"""Extract clean answer from LLM response"""
if not text:
return ""
# Remove common prefixes and suffixes
text = text.strip()
# Remove common response patterns
patterns_to_remove = [
r'^(The answer is:?\s*)',
r'^(Answer:?\s*)',
r'^(Final answer:?\s*)',
r'^(Result:?\s*)',
r'(\s*is the answer\.?)$',
r'(\s*\.)$'
]
for pattern in patterns_to_remove:
text = re.sub(pattern, '', text, flags=re.IGNORECASE)
# Take only the first line if multiple lines
first_line = text.split('\n')[0].strip()
return first_line
def assistant(state: MessagesState):
messages_with_system_prompt = [sys_msg] + state["messages"]
llm_response = llm_with_tools.invoke(messages_with_system_prompt)
# Clean the answer
clean_text = clean_answer(llm_response.content)
# Format the response properly
task_id = str(state.get("task_id", "1"))
formatted_response = [{"task_id": task_id, "submitted_answer": clean_text}]
return {"messages": [AIMessage(content=json.dumps(formatted_response, ensure_ascii=False))]}
# --- Graph Definition ---
builder = StateGraph(MessagesState)
builder.add_node("assistant", assistant)
builder.add_node("tools", ToolNode(tools))
builder.add_edge(START, "assistant")
builder.add_conditional_edges("assistant", tools_condition)
builder.add_edge("tools", "assistant")
# Compile graph
return builder.compile()
def is_valid_agent_output(output):
"""
Checks if the output matches the required format:
[{"task_id": ..., "submitted_answer": ...}]
"""
try:
parsed = json.loads(output.strip())
if not isinstance(parsed, list):
return False
for item in parsed:
if not isinstance(item, dict):
return False
if "task_id" not in item or "submitted_answer" not in item:
return False
return True
except:
return False
def extract_flat_answer(output):
"""Extract properly formatted answer from output"""
try:
# Try to parse as JSON first
parsed = json.loads(output.strip())
if isinstance(parsed, list) and len(parsed) > 0:
first_item = parsed[0]
if isinstance(first_item, dict) and "task_id" in first_item and "submitted_answer" in first_item:
return output # Already properly formatted
except:
pass
# If not properly formatted, return as-is (fallback)
return output
# test
if __name__ == "__main__":
question = "What is 2 + 2?"
# Build the graph
graph = build_graph()
# Run the graph
messages = [HumanMessage(content=question)]
# The initial state for the graph
initial_state = {"messages": messages, "task_id": "test123"}
# Invoke the graph stream to see the steps
for s in graph.stream(initial_state, stream_mode="values"):
message = s["messages"][-1]
if isinstance(message, ToolMessage):
print("---RETRIEVED CONTEXT---")
print(message.content)
print("-----------------------")
else:
output = message.content # This is a string
print(f"Raw output: {output}")
try:
parsed = json.loads(output)
if isinstance(parsed, list) and "task_id" in parsed[0] and "submitted_answer" in parsed[0]:
print("✅ Output is in the correct format!")
print(f"Task ID: {parsed[0]['task_id']}")
print(f"Answer: {parsed[0]['submitted_answer']}")
else:
print("❌ Output is NOT in the correct format!")
except Exception as e:
print("❌ Output is NOT in the correct format!", e)