File size: 7,387 Bytes
1e05108
1669f2b
1e05108
1669f2b
 
fc5e0c3
 
30b4543
1669f2b
 
 
042d1d5
 
30b4543
de96b54
ce842f8
 
 
 
 
1669f2b
 
 
 
042d1d5
1669f2b
 
 
 
 
 
 
8ca5d55
042d1d5
 
9945183
30b4543
1669f2b
 
374dd02
85ecabb
374dd02
 
 
30b4543
374dd02
30b4543
374dd02
 
 
 
 
 
 
85ecabb
 
f00550f
 
 
64434a5
1669f2b
64434a5
1669f2b
 
7faf23e
61cae63
6044144
7faf23e
64434a5
 
 
1669f2b
 
 
 
 
 
374dd02
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fc5e0c3
1669f2b
64434a5
c4e5a43
374dd02
 
 
 
 
 
 
 
 
64434a5
 
f00550f
1669f2b
 
64434a5
7fe8e5c
85ecabb
1669f2b
64434a5
 
1669f2b
 
64434a5
1e05108
 
 
374dd02
1e05108
 
374dd02
 
 
 
 
 
1e05108
374dd02
1e05108
 
374dd02
1e05108
 
9945183
 
374dd02
 
 
 
 
 
 
 
 
 
 
 
 
 
9945183
1669f2b
 
374dd02
1669f2b
 
 
 
64434a5
1e05108
 
64434a5
 
 
 
 
 
 
 
fc5e0c3
374dd02
fc5e0c3
 
 
 
374dd02
 
fc5e0c3
 
 
30b4543
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
import json
import os
import re

from dotenv import load_dotenv
from langchain_core.messages import (AIMessage, HumanMessage, SystemMessage,
                                     ToolMessage)
from langchain_huggingface import ChatHuggingFace, HuggingFaceEndpoint
from langgraph.graph import START, MessagesState, StateGraph
from langgraph.prebuilt import ToolNode, tools_condition

from tools import (absolute, add, analyze_csv_file, analyze_excel_file,
                   arvix_search, audio_transcription, compound_interest,
                   convert_temperature, divide, download_file, exponential,
                   extract_text_from_image, factorial, floor_divide,
                   get_current_time_in_timezone, greatest_common_divisor,
                   is_prime, least_common_multiple, logarithm, modulus,
                   multiply, percentage_calculator, power, python_code_parser,
                   reverse_sentence, roman_calculator_converter, square_root,
                   subtract, web_content_extract, web_search, wiki_search)

# Load Constants
load_dotenv()
HUGGINGFACEHUB_API_TOKEN = os.getenv("HUGGINGFACEHUB_API_TOKEN")
GEMINI_API_KEY = os.getenv("GEMINI_API_KEY")

tools = [
    multiply, add, subtract, power, divide, modulus,
    square_root, floor_divide, absolute, logarithm,
    exponential, web_search, roman_calculator_converter,
    get_current_time_in_timezone, compound_interest,
    convert_temperature, factorial, greatest_common_divisor,
    is_prime, least_common_multiple, percentage_calculator,
    wiki_search, analyze_excel_file, arvix_search,
    audio_transcription, python_code_parser, analyze_csv_file,
    extract_text_from_image, reverse_sentence, web_content_extract,
    download_file,
]

# Updated system prompt for cleaner output
system_prompt = """
You are a helpful AI assistant. When asked a question, think through it step by step and provide only the final answer.

CRITICAL INSTRUCTIONS:
- If the question mentions attachments, files, images, documents, or URLs, use the download_file tool FIRST to download them
- Use available tools when needed to gather information or perform calculations
- For file analysis, use appropriate tools (analyze_csv_file, analyze_excel_file, extract_text_from_image, etc.)
- After using tools and analyzing the information, provide ONLY the final answer
- Do not include explanations, reasoning, or extra text in your final response
- If the answer is a number, provide just the number (no units unless specifically requested)
- If the answer is text, provide just the essential text (no articles or extra words unless necessary)
- If the answer is a list, provide it as comma-separated values

Your response should contain ONLY the answer - nothing else.
"""

# System message
sys_msg = SystemMessage(content=system_prompt)


def build_graph():
    """Build the graph"""
    # First create the HuggingFaceEndpoint
    llm_endpoint = HuggingFaceEndpoint(
        repo_id="mistralai/Mistral-7B-Instruct-v0.2",
        huggingfacehub_api_token=HUGGINGFACEHUB_API_TOKEN,
        temperature=0.1,
        max_new_tokens=1024,
        timeout=60,
    )

    # Then wrap it with ChatHuggingFace to get chat model functionality
    llm = ChatHuggingFace(llm=llm_endpoint)

    # Bind tools to LLM
    llm_with_tools = llm.bind_tools(tools)

    def clean_answer(text):
        """Extract clean answer from LLM response"""
        if not text:
            return ""
        
        # Remove common prefixes and suffixes
        text = text.strip()
        
        # Remove common response patterns
        patterns_to_remove = [
            r'^(The answer is:?\s*)',
            r'^(Answer:?\s*)',
            r'^(Final answer:?\s*)',
            r'^(Result:?\s*)',
            r'(\s*is the answer\.?)$',
            r'(\s*\.)$'
        ]
        
        for pattern in patterns_to_remove:
            text = re.sub(pattern, '', text, flags=re.IGNORECASE)
        
        # Take only the first line if multiple lines
        first_line = text.split('\n')[0].strip()
        
        return first_line

    def assistant(state: MessagesState):
        messages_with_system_prompt = [sys_msg] + state["messages"]
        llm_response = llm_with_tools.invoke(messages_with_system_prompt)
        
        # Clean the answer
        clean_text = clean_answer(llm_response.content)
        
        # Format the response properly
        task_id = str(state.get("task_id", "1"))
        formatted_response = [{"task_id": task_id, "submitted_answer": clean_text}]
        
        return {"messages": [AIMessage(content=json.dumps(formatted_response, ensure_ascii=False))]}

    # --- Graph Definition ---
    builder = StateGraph(MessagesState)
    builder.add_node("assistant", assistant)
    builder.add_node("tools", ToolNode(tools))

    builder.add_edge(START, "assistant")
    builder.add_conditional_edges("assistant", tools_condition)
    builder.add_edge("tools", "assistant")

    # Compile graph
    return builder.compile()


def is_valid_agent_output(output):
    """
    Checks if the output matches the required format:
    [{"task_id": ..., "submitted_answer": ...}]
    """
    try:
        parsed = json.loads(output.strip())
        if not isinstance(parsed, list):
            return False
        
        for item in parsed:
            if not isinstance(item, dict):
                return False
            if "task_id" not in item or "submitted_answer" not in item:
                return False
        return True
    except:
        return False


def extract_flat_answer(output):
    """Extract properly formatted answer from output"""
    try:
        # Try to parse as JSON first
        parsed = json.loads(output.strip())
        if isinstance(parsed, list) and len(parsed) > 0:
            first_item = parsed[0]
            if isinstance(first_item, dict) and "task_id" in first_item and "submitted_answer" in first_item:
                return output  # Already properly formatted
    except:
        pass
    
    # If not properly formatted, return as-is (fallback)
    return output


# test
if __name__ == "__main__":
    question = "What is 2 + 2?"
    # Build the graph
    graph = build_graph()
    # Run the graph
    messages = [HumanMessage(content=question)]
    # The initial state for the graph
    initial_state = {"messages": messages, "task_id": "test123"}

    # Invoke the graph stream to see the steps
    for s in graph.stream(initial_state, stream_mode="values"):
        message = s["messages"][-1]
        if isinstance(message, ToolMessage):
            print("---RETRIEVED CONTEXT---")
            print(message.content)
            print("-----------------------")
        else:
            output = message.content  # This is a string
            print(f"Raw output: {output}")
            try:
                parsed = json.loads(output)
                if isinstance(parsed, list) and "task_id" in parsed[0] and "submitted_answer" in parsed[0]:
                    print("✅ Output is in the correct format!")
                    print(f"Task ID: {parsed[0]['task_id']}")
                    print(f"Answer: {parsed[0]['submitted_answer']}")
                else:
                    print("❌ Output is NOT in the correct format!")
            except Exception as e:
                print("❌ Output is NOT in the correct format!", e)