Spaces:
Sleeping
Sleeping
import json | |
import os | |
import re | |
from dotenv import load_dotenv | |
from langchain_core.messages import (AIMessage, HumanMessage, SystemMessage, | |
ToolMessage) | |
from langchain_huggingface import ChatHuggingFace, HuggingFaceEndpoint | |
from langgraph.graph import START, MessagesState, StateGraph | |
from langgraph.prebuilt import ToolNode, tools_condition | |
from tools import (absolute, add, analyze_csv_file, analyze_excel_file, | |
arvix_search, audio_transcription, compound_interest, | |
convert_temperature, divide, download_file, exponential, | |
extract_text_from_image, factorial, floor_divide, | |
get_current_time_in_timezone, greatest_common_divisor, | |
is_prime, least_common_multiple, logarithm, modulus, | |
multiply, percentage_calculator, power, python_code_parser, | |
reverse_sentence, roman_calculator_converter, square_root, | |
subtract, web_content_extract, web_search, wiki_search) | |
# Load Constants | |
load_dotenv() | |
HUGGINGFACEHUB_API_TOKEN = os.getenv("HUGGINGFACEHUB_API_TOKEN") | |
GEMINI_API_KEY = os.getenv("GEMINI_API_KEY") | |
tools = [ | |
multiply, add, subtract, power, divide, modulus, | |
square_root, floor_divide, absolute, logarithm, | |
exponential, web_search, roman_calculator_converter, | |
get_current_time_in_timezone, compound_interest, | |
convert_temperature, factorial, greatest_common_divisor, | |
is_prime, least_common_multiple, percentage_calculator, | |
wiki_search, analyze_excel_file, arvix_search, | |
audio_transcription, python_code_parser, analyze_csv_file, | |
extract_text_from_image, reverse_sentence, web_content_extract, | |
download_file, | |
] | |
# Updated system prompt for cleaner output | |
system_prompt = """ | |
You are a helpful AI assistant. When asked a question, think through it step by step and provide only the final answer. | |
CRITICAL INSTRUCTIONS: | |
- If the question mentions attachments, files, images, documents, or URLs, use the download_file tool FIRST to download them | |
- Use available tools when needed to gather information or perform calculations | |
- For file analysis, use appropriate tools (analyze_csv_file, analyze_excel_file, extract_text_from_image, etc.) | |
- After using tools and analyzing the information, provide ONLY the final answer | |
- Do not include explanations, reasoning, or extra text in your final response | |
- If the answer is a number, provide just the number (no units unless specifically requested) | |
- If the answer is text, provide just the essential text (no articles or extra words unless necessary) | |
- If the answer is a list, provide it as comma-separated values | |
Your response should contain ONLY the answer - nothing else. | |
""" | |
# System message | |
sys_msg = SystemMessage(content=system_prompt) | |
def build_graph(): | |
"""Build the graph""" | |
# First create the HuggingFaceEndpoint | |
llm_endpoint = HuggingFaceEndpoint( | |
repo_id="mistralai/Mistral-7B-Instruct-v0.2", | |
huggingfacehub_api_token=HUGGINGFACEHUB_API_TOKEN, | |
temperature=0.1, | |
max_new_tokens=1024, | |
timeout=60, | |
) | |
# Then wrap it with ChatHuggingFace to get chat model functionality | |
llm = ChatHuggingFace(llm=llm_endpoint) | |
# Bind tools to LLM | |
llm_with_tools = llm.bind_tools(tools) | |
def clean_answer(text): | |
"""Extract clean answer from LLM response""" | |
if not text: | |
return "" | |
# Remove common prefixes and suffixes | |
text = text.strip() | |
# Remove common response patterns | |
patterns_to_remove = [ | |
r'^(The answer is:?\s*)', | |
r'^(Answer:?\s*)', | |
r'^(Final answer:?\s*)', | |
r'^(Result:?\s*)', | |
r'(\s*is the answer\.?)$', | |
r'(\s*\.)$' | |
] | |
for pattern in patterns_to_remove: | |
text = re.sub(pattern, '', text, flags=re.IGNORECASE) | |
# Take only the first line if multiple lines | |
first_line = text.split('\n')[0].strip() | |
return first_line | |
def assistant(state: MessagesState): | |
messages_with_system_prompt = [sys_msg] + state["messages"] | |
llm_response = llm_with_tools.invoke(messages_with_system_prompt) | |
# Clean the answer | |
clean_text = clean_answer(llm_response.content) | |
# Format the response properly | |
task_id = str(state.get("task_id", "1")) | |
formatted_response = [{"task_id": task_id, "submitted_answer": clean_text}] | |
return {"messages": [AIMessage(content=json.dumps(formatted_response, ensure_ascii=False))]} | |
# --- Graph Definition --- | |
builder = StateGraph(MessagesState) | |
builder.add_node("assistant", assistant) | |
builder.add_node("tools", ToolNode(tools)) | |
builder.add_edge(START, "assistant") | |
builder.add_conditional_edges("assistant", tools_condition) | |
builder.add_edge("tools", "assistant") | |
# Compile graph | |
return builder.compile() | |
def is_valid_agent_output(output): | |
""" | |
Checks if the output matches the required format: | |
[{"task_id": ..., "submitted_answer": ...}] | |
""" | |
try: | |
parsed = json.loads(output.strip()) | |
if not isinstance(parsed, list): | |
return False | |
for item in parsed: | |
if not isinstance(item, dict): | |
return False | |
if "task_id" not in item or "submitted_answer" not in item: | |
return False | |
return True | |
except: | |
return False | |
def extract_flat_answer(output): | |
"""Extract properly formatted answer from output""" | |
try: | |
# Try to parse as JSON first | |
parsed = json.loads(output.strip()) | |
if isinstance(parsed, list) and len(parsed) > 0: | |
first_item = parsed[0] | |
if isinstance(first_item, dict) and "task_id" in first_item and "submitted_answer" in first_item: | |
return output # Already properly formatted | |
except: | |
pass | |
# If not properly formatted, return as-is (fallback) | |
return output | |
# test | |
if __name__ == "__main__": | |
question = "What is 2 + 2?" | |
# Build the graph | |
graph = build_graph() | |
# Run the graph | |
messages = [HumanMessage(content=question)] | |
# The initial state for the graph | |
initial_state = {"messages": messages, "task_id": "test123"} | |
# Invoke the graph stream to see the steps | |
for s in graph.stream(initial_state, stream_mode="values"): | |
message = s["messages"][-1] | |
if isinstance(message, ToolMessage): | |
print("---RETRIEVED CONTEXT---") | |
print(message.content) | |
print("-----------------------") | |
else: | |
output = message.content # This is a string | |
print(f"Raw output: {output}") | |
try: | |
parsed = json.loads(output) | |
if isinstance(parsed, list) and "task_id" in parsed[0] and "submitted_answer" in parsed[0]: | |
print("✅ Output is in the correct format!") | |
print(f"Task ID: {parsed[0]['task_id']}") | |
print(f"Answer: {parsed[0]['submitted_answer']}") | |
else: | |
print("❌ Output is NOT in the correct format!") | |
except Exception as e: | |
print("❌ Output is NOT in the correct format!", e) |