semioz commited on
Commit
89836e0
·
1 Parent(s): 7dfcd3a
Files changed (2) hide show
  1. agent.py +32 -25
  2. app.py +15 -10
agent.py CHANGED
@@ -1,5 +1,4 @@
1
  import logging
2
-
3
  from langchain_core.messages import HumanMessage, SystemMessage
4
  from langchain_groq import ChatGroq
5
  from langchain_huggingface import (
@@ -9,12 +8,15 @@ from langchain_huggingface import (
9
  )
10
  from langgraph.graph import START, MessagesState, StateGraph
11
  from langgraph.prebuilt import ToolNode, tools_condition
 
 
 
 
12
 
13
  from tools import tools
14
 
15
  logger = logging.getLogger(__name__)
16
 
17
-
18
  # ----- Initializing vector store and retriever tool -------
19
 
20
  with open("system_prompt.txt", encoding="utf-8") as f:
@@ -25,24 +27,7 @@ sys_msg = SystemMessage(content=system_prompt)
25
 
26
  embeddings = HuggingFaceEmbeddings(
27
  model_name="sentence-transformers/all-mpnet-base-v2"
28
- ) # dim=768
29
-
30
- '''
31
- supabase: Client = create_client(
32
- os.environ.get("SUPABASE_URL"), os.environ.get("SUPABASE_SERVICE_ROLE_KEY")
33
- )
34
- vector_store = SupabaseVectorStore(
35
- client=supabase,
36
- embedding=embeddings,
37
- table_name="documents2",
38
- query_name="match_documents_2",
39
- )
40
- create_retriever_tool = create_retriever_tool(
41
- retriever=vector_store.as_retriever(),
42
- name="Question Search",
43
- description="A tool to retrieve similar questions from a vector store.",
44
  )
45
- '''
46
 
47
  def build_graph(provider: str = "groq"):
48
  """Build the graph"""
@@ -62,22 +47,44 @@ def build_graph(provider: str = "groq"):
62
  raise ValueError("Invalid provider. Choose 'groq' or 'huggingface'.")
63
  llm_with_tools = llm.bind_tools(tools)
64
 
 
 
 
 
 
 
 
 
 
 
65
  # Node
66
  def assistant(state: MessagesState):
67
  """Assistant node"""
68
  return {"messages": [llm_with_tools.invoke(state["messages"])]}
69
 
 
70
  def retriever(state: MessagesState):
71
- """Retriever node"""
72
- similar_question = vector_store.similarity_search(state["messages"][0].content)
 
 
 
 
 
 
 
73
 
74
- if similar_question:
 
 
 
75
  example_msg = HumanMessage(
76
- content=f"Here I provide a similar question and answer for reference: \n\n{similar_question[0].page_content}",
77
  )
78
  return {"messages": [sys_msg] + state["messages"] + [example_msg]}
79
- # no similar questions are found
80
- return {"messages": [sys_msg] + state["messages"]}
 
81
 
82
  builder = StateGraph(MessagesState)
83
  builder.add_node("retriever", retriever)
 
1
  import logging
 
2
  from langchain_core.messages import HumanMessage, SystemMessage
3
  from langchain_groq import ChatGroq
4
  from langchain_huggingface import (
 
8
  )
9
  from langgraph.graph import START, MessagesState, StateGraph
10
  from langgraph.prebuilt import ToolNode, tools_condition
11
+ from langchain_community.vectorstores import DuckDB
12
+ from langchain.tools.retriever import create_retriever_tool
13
+ from langchain_core.documents import Document
14
+ import uuid
15
 
16
  from tools import tools
17
 
18
  logger = logging.getLogger(__name__)
19
 
 
20
  # ----- Initializing vector store and retriever tool -------
21
 
22
  with open("system_prompt.txt", encoding="utf-8") as f:
 
27
 
28
  embeddings = HuggingFaceEmbeddings(
29
  model_name="sentence-transformers/all-mpnet-base-v2"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
30
  )
 
31
 
32
  def build_graph(provider: str = "groq"):
33
  """Build the graph"""
 
47
  raise ValueError("Invalid provider. Choose 'groq' or 'huggingface'.")
48
  llm_with_tools = llm.bind_tools(tools)
49
 
50
+ def convert_messages_to_documents(messages: list[HumanMessage]) -> list[Document]:
51
+ return [
52
+ Document(
53
+ page_content=msg.content,
54
+ metadata={"role": "user", "index": idx, "id": str(uuid.uuid4())}
55
+ )
56
+ for idx, msg in enumerate(messages)
57
+ ]
58
+
59
+
60
  # Node
61
  def assistant(state: MessagesState):
62
  """Assistant node"""
63
  return {"messages": [llm_with_tools.invoke(state["messages"])]}
64
 
65
+
66
  def retriever(state: MessagesState):
67
+ """Retriever node using DuckDB"""
68
+ docs = convert_messages_to_documents(state["messages"])
69
+ vector_store = DuckDB.from_documents(docs, embedding=embeddings)
70
+
71
+ create_retriever_tool = create_retriever_tool(
72
+ retriever=vector_store.as_retriever(),
73
+ name="Question Search",
74
+ description="A tool to retrieve similar questions from a vector store.",
75
+ )
76
 
77
+ query = state["messages"][0].content
78
+ similar = vector_store.similarity_search(query)
79
+
80
+ if similar:
81
  example_msg = HumanMessage(
82
+ content=f"Here I provide a similar question and answer for reference: \n\n{similar[0].page_content}"
83
  )
84
  return {"messages": [sys_msg] + state["messages"] + [example_msg]}
85
+ else:
86
+ return {"messages": [sys_msg] + state["messages"]}
87
+
88
 
89
  builder = StateGraph(MessagesState)
90
  builder.add_node("retriever", retriever)
app.py CHANGED
@@ -3,22 +3,27 @@ import os
3
  import gradio as gr
4
  import pandas as pd
5
  import requests
 
 
6
 
7
  DEFAULT_API_URL = "https://agents-course-unit4-scoring.hf.space"
8
 
9
- # --- Agent Definition ---
10
- class BasicAgent:
11
  def __init__(self):
12
- print("BasicAgent initialized.")
 
 
13
  def __call__(self, question: str) -> str:
14
  print(f"Agent received question (first 50 chars): {question[:50]}...")
15
- fixed_answer = "This is a default answer."
16
- print(f"Agent returning fixed answer: {fixed_answer}")
17
- return fixed_answer
 
18
 
19
  def run_and_submit_all( profile: gr.OAuthProfile | None):
20
  """
21
- Fetches all questions, runs the BasicAgent on them, submits all answers,
22
  and displays the results.
23
  """
24
  # --- Determine HF Space Runtime URL and Repo URL ---
@@ -35,9 +40,9 @@ def run_and_submit_all( profile: gr.OAuthProfile | None):
35
  questions_url = f"{api_url}/questions"
36
  submit_url = f"{api_url}/submit"
37
 
38
- # 1. Instantiate Agent ( modify this part to create your agent)
39
  try:
40
- agent = BasicAgent()
41
  except Exception as e:
42
  print(f"Error instantiating agent: {e}")
43
  return f"Error initializing agent: {e}", None
@@ -66,7 +71,7 @@ def run_and_submit_all( profile: gr.OAuthProfile | None):
66
  print(f"An unexpected error occurred fetching questions: {e}")
67
  return f"An unexpected error occurred fetching questions: {e}", None
68
 
69
- # 3. Run your Agent
70
  results_log = []
71
  answers_payload = []
72
  print(f"Running agent on {len(questions_data)} questions...")
 
3
  import gradio as gr
4
  import pandas as pd
5
  import requests
6
+ from agent import build_graph
7
+ from langchain_core.messages import HumanMessage
8
 
9
  DEFAULT_API_URL = "https://agents-course-unit4-scoring.hf.space"
10
 
11
+ class InfersenseAgent:
12
+ """A langgraph based agent."""
13
  def __init__(self):
14
+ print("InfersenseAgent initialized.")
15
+ self.graph = build_graph()
16
+
17
  def __call__(self, question: str) -> str:
18
  print(f"Agent received question (first 50 chars): {question[:50]}...")
19
+ messages = [HumanMessage(content=question)]
20
+ messages = self.graph.invoke({"messages": messages})
21
+ answer = messages['messages'][-1].content
22
+ return answer[14:]
23
 
24
  def run_and_submit_all( profile: gr.OAuthProfile | None):
25
  """
26
+ Fetches all questions, runs the Infersense agent on them, submits all answers,
27
  and displays the results.
28
  """
29
  # --- Determine HF Space Runtime URL and Repo URL ---
 
40
  questions_url = f"{api_url}/questions"
41
  submit_url = f"{api_url}/submit"
42
 
43
+ # 1. Instantiate Agent
44
  try:
45
+ agent = InfersenseAgent()
46
  except Exception as e:
47
  print(f"Error instantiating agent: {e}")
48
  return f"Error initializing agent: {e}", None
 
71
  print(f"An unexpected error occurred fetching questions: {e}")
72
  return f"An unexpected error occurred fetching questions: {e}", None
73
 
74
+ # 3. Run
75
  results_log = []
76
  answers_payload = []
77
  print(f"Running agent on {len(questions_data)} questions...")