EtienneB commited on
Commit
374dd02
Β·
1 Parent(s): b565efa

trying adjusted agent.

Browse files
Files changed (2) hide show
  1. agent copy.py +174 -0
  2. agent.py +75 -55
agent copy.py ADDED
@@ -0,0 +1,174 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import json
3
+ import os
4
+ import re
5
+
6
+ from dotenv import load_dotenv
7
+ from langchain_core.messages import (AIMessage, HumanMessage, SystemMessage,
8
+ ToolMessage)
9
+ from langchain_huggingface import (ChatHuggingFace, HuggingFaceEmbeddings,
10
+ HuggingFaceEndpoint)
11
+ from langgraph.graph import START, MessagesState, StateGraph
12
+ from langgraph.prebuilt import ToolNode, tools_condition
13
+
14
+ from tools import (absolute, add, analyze_csv_file, analyze_excel_file,
15
+ arvix_search, audio_transcription, compound_interest,
16
+ convert_temperature, divide, exponential,
17
+ extract_text_from_image, factorial, floor_divide,
18
+ get_current_time_in_timezone, greatest_common_divisor,
19
+ is_prime, least_common_multiple, logarithm, modulus,
20
+ multiply, percentage_calculator, power, python_code_parser,
21
+ reverse_sentence, roman_calculator_converter, square_root,
22
+ subtract, web_content_extract, web_search, wiki_search)
23
+
24
+ # Load Constants
25
+ load_dotenv()
26
+ HUGGINGFACEHUB_API_TOKEN = os.getenv("HUGGINGFACEHUB_API_TOKEN")
27
+ GEMINI_API_KEY = os.getenv("GEMINI_API_KEY")
28
+
29
+ tools = [
30
+ multiply, add, subtract, power, divide, modulus,
31
+ square_root, floor_divide, absolute, logarithm,
32
+ exponential, web_search, roman_calculator_converter,
33
+ get_current_time_in_timezone, compound_interest,
34
+ convert_temperature, factorial, greatest_common_divisor,
35
+ is_prime, least_common_multiple, percentage_calculator,
36
+ wiki_search, analyze_excel_file, arvix_search,
37
+ audio_transcription, python_code_parser, analyze_csv_file,
38
+ extract_text_from_image, reverse_sentence, web_content_extract,
39
+ ]
40
+
41
+ # Load system prompt
42
+ system_prompt = """
43
+ You are a general AI assistant. I will ask you a question.
44
+ Report your thoughts, and finish your answer with only the answer, no extra text, no prefix, and no explanation.
45
+ Your answer should be a number OR as few words as possible OR a comma separated list of numbers and/or strings.
46
+ If you are asked for a number, don't use a comma to write your number, nor use units such as $ or percent sign unless specified otherwise.
47
+ If you are asked for a string, don't use articles, neither abbreviations (e.g. for cities), and write the digits in plain text unless specified otherwise.
48
+ If you are asked for a comma separated list, apply the above rules depending on whether the element to be put in the list is a number or a string.
49
+ Format your output as: [{"task_id": ..., "submitted_answer": ...}]
50
+ Do NOT include the format string or any JSON inside the submitted_answer field. Only output a single flat list as: [{"task_id": ..., "submitted_answer": ...}]
51
+ """
52
+
53
+ # System message
54
+ sys_msg = SystemMessage(content=system_prompt)
55
+
56
+
57
+ def build_graph():
58
+ """Build the graph"""
59
+ # First create the HuggingFaceEndpoint
60
+ llm_endpoint = HuggingFaceEndpoint(
61
+ repo_id="mistralai/Mistral-7B-Instruct-v0.2",
62
+ huggingfacehub_api_token=HUGGINGFACEHUB_API_TOKEN,
63
+ #api_key=GEMINI_API_KEY,
64
+ temperature=0.1,
65
+ max_new_tokens=1024,
66
+ timeout=60,
67
+ )
68
+
69
+ # Then wrap it with ChatHuggingFace to get chat model functionality
70
+ llm = ChatHuggingFace(llm=llm_endpoint)
71
+
72
+ # Bind tools to LLM
73
+ llm_with_tools = llm.bind_tools(tools)
74
+
75
+ # --- Nodes ---
76
+ def extract_answer(llm_output):
77
+ # Try to parse as JSON if possible
78
+ try:
79
+ # If the LLM output is a JSON list, extract the answer
80
+ parsed = json.loads(llm_output.strip().split('\n')[0])
81
+ if isinstance(parsed, list) and isinstance(parsed[0], dict) and "submitted_answer" in parsed[0]:
82
+ return parsed[0]["submitted_answer"]
83
+ except Exception:
84
+ pass
85
+ # Otherwise, just return the first line (before any explanation)
86
+ return llm_output.strip().split('\n')[0]
87
+
88
+ def assistant(state: MessagesState):
89
+ messages_with_system_prompt = [sys_msg] + state["messages"]
90
+ llm_response = llm_with_tools.invoke(messages_with_system_prompt)
91
+ answer_text = extract_answer(llm_response.content)
92
+ task_id = str(state.get("task_id", "1")) # Ensure task_id is a string
93
+ formatted = [{"task_id": task_id, "submitted_answer": answer_text}]
94
+ return {"messages": [AIMessage(content=json.dumps(formatted, ensure_ascii=False))]}
95
+
96
+ # --- Graph Definition ---
97
+ builder = StateGraph(MessagesState)
98
+ builder.add_node("assistant", assistant)
99
+ builder.add_node("tools", ToolNode(tools))
100
+
101
+ builder.add_edge(START, "assistant")
102
+ builder.add_conditional_edges("assistant", tools_condition)
103
+ builder.add_edge("tools", "assistant")
104
+
105
+ # Compile graph
106
+ return builder.compile()
107
+
108
+
109
+ def is_valid_agent_output(output):
110
+ """
111
+ Checks if the output matches the required format:
112
+ Answers (answers): [{"task_id": ..., "submitted_answer": ...}]
113
+ """
114
+ # Basic regex to check the format
115
+ pattern = r'^Answers \(answers\): \[(\{.*\})\]$'
116
+ match = re.match(pattern, output.strip())
117
+ if not match:
118
+ return False
119
+
120
+ # Try to parse the JSON part
121
+ try:
122
+ answers_list = json.loads(f'[{match.group(1)}]')
123
+ # Check required keys
124
+ for ans in answers_list:
125
+ if not isinstance(ans, dict):
126
+ return False
127
+ if "task_id" not in ans or "submitted_answer" not in ans:
128
+ return False
129
+ return True
130
+ except Exception:
131
+ return False
132
+
133
+
134
+ def extract_flat_answer(output):
135
+ # Try to find the innermost Answers (answers): [{...}]
136
+ pattern = r'Answers \(answers\): \[(\{.*?\})\]'
137
+ matches = re.findall(pattern, output)
138
+ if matches:
139
+ # Use the last match (innermost)
140
+ try:
141
+ answers_list = json.loads(f'[{matches[-1]}]')
142
+ if isinstance(answers_list, list) and "task_id" in answers_list[0] and "submitted_answer" in answers_list[0]:
143
+ return f'Answers (answers): [{matches[-1]}]'
144
+ except Exception:
145
+ pass
146
+ return output # fallback
147
+
148
+ # test
149
+ if __name__ == "__main__":
150
+ question = "When was a picture of St. Thomas Aquinas first added to the Wikipedia page on the Principle of double effect?"
151
+ # Build the graph
152
+ graph = build_graph()
153
+ # Run the graph
154
+ messages = [HumanMessage(content=question)]
155
+ # The initial state for the graph
156
+ initial_state = {"messages": messages, "task_id": "test123"}
157
+
158
+ # Invoke the graph stream to see the steps
159
+ for s in graph.stream(initial_state, stream_mode="values"):
160
+ message = s["messages"][-1]
161
+ if isinstance(message, ToolMessage):
162
+ print("---RETRIEVED CONTEXT---")
163
+ print(message.content)
164
+ print("-----------------------")
165
+ else:
166
+ output = message.content # This is a string
167
+ try:
168
+ parsed = json.loads(output)
169
+ if isinstance(parsed, list) and "task_id" in parsed[0] and "submitted_answer" in parsed[0]:
170
+ print("βœ… Output is in the correct format!")
171
+ else:
172
+ print("❌ Output is NOT in the correct format!")
173
+ except Exception as e:
174
+ print("❌ Output is NOT in the correct format!", e)
agent.py CHANGED
@@ -1,4 +1,3 @@
1
-
2
  import json
3
  import os
4
  import re
@@ -38,16 +37,19 @@ tools = [
38
  extract_text_from_image, reverse_sentence, web_content_extract,
39
  ]
40
 
41
- # Load system prompt
42
  system_prompt = """
43
- You are a general AI assistant. I will ask you a question.
44
- Report your thoughts, and finish your answer with only the answer, no extra text, no prefix, and no explanation.
45
- Your answer should be a number OR as few words as possible OR a comma separated list of numbers and/or strings.
46
- If you are asked for a number, don't use a comma to write your number, nor use units such as $ or percent sign unless specified otherwise.
47
- If you are asked for a string, don't use articles, neither abbreviations (e.g. for cities), and write the digits in plain text unless specified otherwise.
48
- If you are asked for a comma separated list, apply the above rules depending on whether the element to be put in the list is a number or a string.
49
- Format your output as: [{"task_id": ..., "submitted_answer": ...}]
50
- Do NOT include the format string or any JSON inside the submitted_answer field. Only output a single flat list as: [{"task_id": ..., "submitted_answer": ...}]
 
 
 
51
  """
52
 
53
  # System message
@@ -60,7 +62,6 @@ def build_graph():
60
  llm_endpoint = HuggingFaceEndpoint(
61
  repo_id="mistralai/Mistral-7B-Instruct-v0.2",
62
  huggingfacehub_api_token=HUGGINGFACEHUB_API_TOKEN,
63
- #api_key=GEMINI_API_KEY,
64
  temperature=0.1,
65
  max_new_tokens=1024,
66
  timeout=60,
@@ -72,26 +73,44 @@ def build_graph():
72
  # Bind tools to LLM
73
  llm_with_tools = llm.bind_tools(tools)
74
 
75
- # --- Nodes ---
76
- def extract_answer(llm_output):
77
- # Try to parse as JSON if possible
78
- try:
79
- # If the LLM output is a JSON list, extract the answer
80
- parsed = json.loads(llm_output.strip().split('\n')[0])
81
- if isinstance(parsed, list) and isinstance(parsed[0], dict) and "submitted_answer" in parsed[0]:
82
- return parsed[0]["submitted_answer"]
83
- except Exception:
84
- pass
85
- # Otherwise, just return the first line (before any explanation)
86
- return llm_output.strip().split('\n')[0]
 
 
 
 
 
 
 
 
 
 
 
 
 
87
 
88
  def assistant(state: MessagesState):
89
  messages_with_system_prompt = [sys_msg] + state["messages"]
90
  llm_response = llm_with_tools.invoke(messages_with_system_prompt)
91
- answer_text = extract_answer(llm_response.content)
92
- task_id = str(state.get("task_id", "1")) # Ensure task_id is a string
93
- formatted = [{"task_id": task_id, "submitted_answer": answer_text}]
94
- return {"messages": [AIMessage(content=json.dumps(formatted, ensure_ascii=False))]}
 
 
 
 
 
95
 
96
  # --- Graph Definition ---
97
  builder = StateGraph(MessagesState)
@@ -109,45 +128,42 @@ def build_graph():
109
  def is_valid_agent_output(output):
110
  """
111
  Checks if the output matches the required format:
112
- Answers (answers): [{"task_id": ..., "submitted_answer": ...}]
113
  """
114
- # Basic regex to check the format
115
- pattern = r'^Answers \(answers\): \[(\{.*\})\]$'
116
- match = re.match(pattern, output.strip())
117
- if not match:
118
- return False
119
-
120
- # Try to parse the JSON part
121
  try:
122
- answers_list = json.loads(f'[{match.group(1)}]')
123
- # Check required keys
124
- for ans in answers_list:
125
- if not isinstance(ans, dict):
 
 
126
  return False
127
- if "task_id" not in ans or "submitted_answer" not in ans:
128
  return False
129
  return True
130
- except Exception:
131
  return False
132
 
133
 
134
  def extract_flat_answer(output):
135
- # Try to find the innermost Answers (answers): [{...}]
136
- pattern = r'Answers \(answers\): \[(\{.*?\})\]'
137
- matches = re.findall(pattern, output)
138
- if matches:
139
- # Use the last match (innermost)
140
- try:
141
- answers_list = json.loads(f'[{matches[-1]}]')
142
- if isinstance(answers_list, list) and "task_id" in answers_list[0] and "submitted_answer" in answers_list[0]:
143
- return f'Answers (answers): [{matches[-1]}]'
144
- except Exception:
145
- pass
146
- return output # fallback
 
 
147
 
148
  # test
149
  if __name__ == "__main__":
150
- question = "When was a picture of St. Thomas Aquinas first added to the Wikipedia page on the Principle of double effect?"
151
  # Build the graph
152
  graph = build_graph()
153
  # Run the graph
@@ -164,11 +180,15 @@ if __name__ == "__main__":
164
  print("-----------------------")
165
  else:
166
  output = message.content # This is a string
 
167
  try:
168
  parsed = json.loads(output)
169
  if isinstance(parsed, list) and "task_id" in parsed[0] and "submitted_answer" in parsed[0]:
170
  print("βœ… Output is in the correct format!")
 
 
171
  else:
172
  print("❌ Output is NOT in the correct format!")
173
  except Exception as e:
174
- print("❌ Output is NOT in the correct format!", e)
 
 
 
1
  import json
2
  import os
3
  import re
 
37
  extract_text_from_image, reverse_sentence, web_content_extract,
38
  ]
39
 
40
+ # Updated system prompt for cleaner output
41
  system_prompt = """
42
+ You are a helpful AI assistant. When asked a question, think through it step by step and provide only the final answer.
43
+
44
+ CRITICAL INSTRUCTIONS:
45
+ - Use available tools when needed to gather information or perform calculations
46
+ - After using tools and analyzing the information, provide ONLY the final answer
47
+ - Do not include explanations, reasoning, or extra text in your final response
48
+ - If the answer is a number, provide just the number (no units unless specifically requested)
49
+ - If the answer is text, provide just the essential text (no articles or extra words unless necessary)
50
+ - If the answer is a list, provide it as comma-separated values
51
+
52
+ Your response should contain ONLY the answer - nothing else.
53
  """
54
 
55
  # System message
 
62
  llm_endpoint = HuggingFaceEndpoint(
63
  repo_id="mistralai/Mistral-7B-Instruct-v0.2",
64
  huggingfacehub_api_token=HUGGINGFACEHUB_API_TOKEN,
 
65
  temperature=0.1,
66
  max_new_tokens=1024,
67
  timeout=60,
 
73
  # Bind tools to LLM
74
  llm_with_tools = llm.bind_tools(tools)
75
 
76
+ def clean_answer(text):
77
+ """Extract clean answer from LLM response"""
78
+ if not text:
79
+ return ""
80
+
81
+ # Remove common prefixes and suffixes
82
+ text = text.strip()
83
+
84
+ # Remove common response patterns
85
+ patterns_to_remove = [
86
+ r'^(The answer is:?\s*)',
87
+ r'^(Answer:?\s*)',
88
+ r'^(Final answer:?\s*)',
89
+ r'^(Result:?\s*)',
90
+ r'(\s*is the answer\.?)$',
91
+ r'(\s*\.)$'
92
+ ]
93
+
94
+ for pattern in patterns_to_remove:
95
+ text = re.sub(pattern, '', text, flags=re.IGNORECASE)
96
+
97
+ # Take only the first line if multiple lines
98
+ first_line = text.split('\n')[0].strip()
99
+
100
+ return first_line
101
 
102
  def assistant(state: MessagesState):
103
  messages_with_system_prompt = [sys_msg] + state["messages"]
104
  llm_response = llm_with_tools.invoke(messages_with_system_prompt)
105
+
106
+ # Clean the answer
107
+ clean_text = clean_answer(llm_response.content)
108
+
109
+ # Format the response properly
110
+ task_id = str(state.get("task_id", "1"))
111
+ formatted_response = [{"task_id": task_id, "submitted_answer": clean_text}]
112
+
113
+ return {"messages": [AIMessage(content=json.dumps(formatted_response, ensure_ascii=False))]}
114
 
115
  # --- Graph Definition ---
116
  builder = StateGraph(MessagesState)
 
128
  def is_valid_agent_output(output):
129
  """
130
  Checks if the output matches the required format:
131
+ [{"task_id": ..., "submitted_answer": ...}]
132
  """
 
 
 
 
 
 
 
133
  try:
134
+ parsed = json.loads(output.strip())
135
+ if not isinstance(parsed, list):
136
+ return False
137
+
138
+ for item in parsed:
139
+ if not isinstance(item, dict):
140
  return False
141
+ if "task_id" not in item or "submitted_answer" not in item:
142
  return False
143
  return True
144
+ except:
145
  return False
146
 
147
 
148
  def extract_flat_answer(output):
149
+ """Extract properly formatted answer from output"""
150
+ try:
151
+ # Try to parse as JSON first
152
+ parsed = json.loads(output.strip())
153
+ if isinstance(parsed, list) and len(parsed) > 0:
154
+ first_item = parsed[0]
155
+ if isinstance(first_item, dict) and "task_id" in first_item and "submitted_answer" in first_item:
156
+ return output # Already properly formatted
157
+ except:
158
+ pass
159
+
160
+ # If not properly formatted, return as-is (fallback)
161
+ return output
162
+
163
 
164
  # test
165
  if __name__ == "__main__":
166
+ question = "What is 2 + 2?"
167
  # Build the graph
168
  graph = build_graph()
169
  # Run the graph
 
180
  print("-----------------------")
181
  else:
182
  output = message.content # This is a string
183
+ print(f"Raw output: {output}")
184
  try:
185
  parsed = json.loads(output)
186
  if isinstance(parsed, list) and "task_id" in parsed[0] and "submitted_answer" in parsed[0]:
187
  print("βœ… Output is in the correct format!")
188
+ print(f"Task ID: {parsed[0]['task_id']}")
189
+ print(f"Answer: {parsed[0]['submitted_answer']}")
190
  else:
191
  print("❌ Output is NOT in the correct format!")
192
  except Exception as e:
193
+ print("❌ Output is NOT in the correct format!", e)
194
+