Freddolin commited on
Commit
e0fa687
·
verified ·
1 Parent(s): 5c47ee8

Update agent.py

Browse files
Files changed (1) hide show
  1. agent.py +221 -114
agent.py CHANGED
@@ -1,118 +1,225 @@
 
1
  import os
2
- import torch
3
- from huggingface_hub import InferenceClient
4
-
5
- # Importera LangChain-komponenter
6
- from langchain_community.llms import HuggingFaceHub # För att använda HF Inference API som LLM
7
- from langchain.agents import AgentExecutor, create_react_agent # Agentens exekverare och ReAct-agent konstruktorn
8
- from langchain.tools import Tool # Verktygsklassen i LangChain
9
- from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
10
- from langchain_core.messages import HumanMessage, AIMessage
11
-
12
- # Importera dina befintliga, anpassade verktygsfunktioner
13
- from tavily_search import search_tavily
14
- from asr_tool import transcribe_audio
15
- from excel_tool import analyze_excel
16
- from math_tool import calculate_math
17
-
18
- class GaiaAgent:
19
- def __init__(self, model_id: str = "google/gemma-2b-it"):
20
- """
21
- Initialiserar GaiaAgent, nu med LangChain.
22
- """
23
- print(f"Initialiserar GaiaAgent med modell: {model_id}")
24
-
25
- hf_token = os.getenv("HF_TOKEN") or os.getenv("HUGGING_FACE_HUB_TOKEN")
26
- if not hf_token:
27
- raise ValueError(
28
- "Hugging Face token (HF_TOKEN eller HUGGING_FACE_HUB_TOKEN) är inte konfigurerad i miljövariabler."
29
- "Vänligen lägg till din token som en 'Repository secret' i dina Space-inställningar."
30
- )
31
-
32
- # 1. Initialisera LLM med LangChain's HuggingFaceHub
33
- try:
34
- # HuggingFaceHub ansluter till en fjärrmodell via HF Inference API
35
- self.llm = HuggingFaceHub(
36
- repo_id=model_id,
37
- huggingfacehub_api_token=hf_token,
38
- task="text-generation", # Specifiera task
39
- # model_kwargs={"temperature": 0.1, "max_new_tokens": 512} # Exempel på modell-kwargs
40
- )
41
- print("LangChain HuggingFaceHub LLM laddad framgångsrikt.")
42
- except Exception as e:
43
- raise RuntimeError(
44
- f"Misslyckades att initialisera HuggingFaceHub LLM: {e}."
45
- "Vänligen kontrollera din HF token och att modellen är tillgänglig/laddningsbar."
46
- )
47
-
48
- # 2. Definiera dina anpassade verktyg som LangChain Tool-objekt
49
- tools_list = [
50
- Tool.from_function(
51
- func=search_tavily,
52
- name="search_tavily",
53
- description="Användbart för att söka information online med Tavily Search. Returnerar en sammanfattning av de mest relevanta resultaten från webben. Kräver en fråga som input.",
54
- ),
55
- Tool.from_function(
56
- func=transcribe_audio,
57
- name="transcribe_audio",
58
- description="Transkriberar ljudfil till text. Användbart för att omvandla tal till text från en angiven ljudfilsväg. Kräver en filsökväg till ljudfilen som input.",
59
- ),
60
- Tool.from_function(
61
- func=analyze_excel,
62
- name="analyze_excel",
63
- description="Analysera Excel-filer och returnera detaljerad information om rader, kolumner, datatyper och statistik (summa, medelvärde, max, min för numeriska kolumner). Kan ta både en lokal filväg eller en URL till Excel-filen som input.",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
64
  ),
65
- Tool.from_function(
66
- func=calculate_math,
67
- name="calculate_math",
68
- description="Beräkna matematiska uttryck. Användbart för att utföra aritmetiska operationer som addition, subtraktion, multiplikation, division och potenser. Tar ett matematiskt uttryck som en sträng som input.",
69
- )
70
- ]
71
- print(f"Laddade {len(tools_list)} anpassade verktyg för LangChain.")
72
-
73
- # 3. Skapa en prompt för ReAct-agenten
74
- # Detta prompt-format är viktigt för hur LLM:en förstår att använda verktyg.
75
- # MessagesPlaceholder används för att injicera verktyg och meddelandehistorik dynamiskt.
76
- prompt = ChatPromptTemplate.from_messages(
77
- [
78
- ("system", "Du är en hjälpsam AI-assistent. Använd tillgängliga verktyg för att svara på frågor."),
79
- MessagesPlaceholder("chat_history", optional=True),
80
- ("human", "{input}"),
81
- MessagesPlaceholder("agent_scratchpad"), # Detta är där agentens tankar och verktygskall kommer att finnas
82
- ]
83
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
84
 
85
- # 4. Initialisera LangChain ReAct-agenten
86
- # create_react_agent är en konstruktorfunktion för en ReAct-baserad agent
87
- agent = create_react_agent(self.llm, tools_list, prompt)
88
-
89
- # 5. Skapa AgentExecutor för att köra agenten
90
- # AgentExecutor är den körbara delen som hanterar agentens "tankeloop" och verktygskall
91
- self.agent_executor = AgentExecutor(
92
- agent=agent,
93
- tools=tools_list,
94
- verbose=True, # Sätt till True för att se agentens tankeprocess i loggarna
95
- handle_parsing_errors=True # Hantera parsningsfel graciöst
96
- )
97
- print("LangChain AgentExecutor initialiserad.")
98
-
99
- def process_task(self, task_prompt: str) -> str:
100
- """
101
- Bearbetar en uppgift med den interna LangChain AgentExecutor.
102
- """
103
- print(f"\nBearbetar uppgift med LangChain AgentExecutor: '{task_prompt}'")
104
- try:
105
- # Anropa agenten med invoke. Den returnerar ett dictionary.
106
- # "input" är användarens prompt.
107
- # "chat_history" kan skickas in om du har kontext från tidigare konversationer.
108
- result = self.agent_executor.invoke({"input": task_prompt})
109
-
110
- # Det slutgiltiga svaret finns vanligtvis under nyckeln "output"
111
- final_answer = result.get("output", "Agenten kunde inte generera ett slutgiltigt svar.")
112
-
113
- print(f"\nLangChain AgentExecutor avslutad. Slutgiltigt svar: {final_answer}")
114
- return final_answer
115
- except Exception as e:
116
- error_message = f"Ett fel uppstod under agentens bearbetning: {e}"
117
- print(error_message)
118
- return f"Agenten kunde inte slutföra uppgiften på grund av ett fel: {error_message}"
 
1
+ """LangGraph Agent"""
2
  import os
3
+ from dotenv import load_dotenv
4
+ from langgraph.graph import START, StateGraph, MessagesState
5
+ from langgraph.prebuilt import tools_condition
6
+ from langgraph.prebuilt import ToolNode
7
+ from langchain_google_genai import ChatGoogleGenerativeAI
8
+ from langchain_groq import ChatGroq
9
+ from langchain_huggingface import ChatHuggingFace, HuggingFaceEndpoint, HuggingFaceEmbeddings
10
+ from langchain_community.tools.tavily_search import TavilySearchResults
11
+ from langchain_community.document_loaders import WikipediaLoader
12
+ from langchain_community.document_loaders import ArxivLoader
13
+ from langchain_community.vectorstores import SupabaseVectorStore
14
+ from langchain_core.messages import SystemMessage, HumanMessage
15
+ from langchain_core.tools import tool
16
+ from langchain.tools.retriever import create_retriever_tool
17
+ from supabase.client import Client, create_client
18
+
19
+ load_dotenv()
20
+
21
+ @tool
22
+ def multiply(a: int, b: int) -> int:
23
+ """Multiply two numbers.
24
+ Args:
25
+ a: first int
26
+ b: second int
27
+ """
28
+ return a * b
29
+
30
+ @tool
31
+ def add(a: int, b: int) -> int:
32
+ """Add two numbers.
33
+
34
+ Args:
35
+ a: first int
36
+ b: second int
37
+ """
38
+ return a + b
39
+
40
+ @tool
41
+ def subtract(a: int, b: int) -> int:
42
+ """Subtract two numbers.
43
+
44
+ Args:
45
+ a: first int
46
+ b: second int
47
+ """
48
+ return a - b
49
+
50
+ @tool
51
+ def divide(a: int, b: int) -> int:
52
+ """Divide two numbers.
53
+
54
+ Args:
55
+ a: first int
56
+ b: second int
57
+ """
58
+ if b == 0:
59
+ raise ValueError("Cannot divide by zero.")
60
+ return a / b
61
+
62
+ @tool
63
+ def modulus(a: int, b: int) -> int:
64
+ """Get the modulus of two numbers.
65
+
66
+ Args:
67
+ a: first int
68
+ b: second int
69
+ """
70
+ return a % b
71
+
72
+ @tool
73
+ def wiki_search(query: str) -> str:
74
+ """Search Wikipedia for a query and return maximum 2 results.
75
+
76
+ Args:
77
+ query: The search query."""
78
+ search_docs = WikipediaLoader(query=query, load_max_docs=2).load()
79
+ formatted_search_docs = "\n\n---\n\n".join(
80
+ [
81
+ f'<Document source="{doc.metadata["source"]}" page="{doc.metadata.get("page", "")}"/>\n{doc.page_content}\n</Document>'
82
+ for doc in search_docs
83
+ ])
84
+ return {"wiki_results": formatted_search_docs}
85
+
86
+ @tool
87
+ def web_search(query: str) -> str:
88
+ """Search Tavily for a query and return maximum 3 results.
89
+
90
+ Args:
91
+ query: The search query."""
92
+ search_docs = TavilySearchResults(max_results=3).invoke(query=query)
93
+ formatted_search_docs = "\n\n---\n\n".join(
94
+ [
95
+ f'<Document source="{doc.metadata["source"]}" page="{doc.metadata.get("page", "")}"/>\n{doc.page_content}\n</Document>'
96
+ for doc in search_docs
97
+ ])
98
+ return {"web_results": formatted_search_docs}
99
+
100
+ @tool
101
+ def arvix_search(query: str) -> str:
102
+ """Search Arxiv for a query and return maximum 3 result.
103
+
104
+ Args:
105
+ query: The search query."""
106
+ search_docs = ArxivLoader(query=query, load_max_docs=3).load()
107
+ formatted_search_docs = "\n\n---\n\n".join(
108
+ [
109
+ f'<Document source="{doc.metadata["source"]}" page="{doc.metadata.get("page", "")}"/>\n{doc.page_content[:1000]}\n</Document>'
110
+ for doc in search_docs
111
+ ])
112
+ return {"arvix_results": formatted_search_docs}
113
+
114
+
115
+
116
+ # load the system prompt from the file
117
+ with open("system_prompt.txt", "r", encoding="utf-8") as f:
118
+ system_prompt = f.read()
119
+
120
+ # System message
121
+ sys_msg = SystemMessage(content=system_prompt)
122
+
123
+ # build a retriever
124
+ embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-mpnet-base-v2") # dim=768
125
+ supabase: Client = create_client(
126
+ os.environ.get("SUPABASE_URL"),
127
+ os.environ.get("SUPABASE_SERVICE_KEY"))
128
+ vector_store = SupabaseVectorStore(
129
+ client=supabase,
130
+ embedding= embeddings,
131
+ table_name="documents",
132
+ query_name="match_documents_langchain",
133
+ )
134
+ create_retriever_tool = create_retriever_tool(
135
+ retriever=vector_store.as_retriever(),
136
+ name="Question Search",
137
+ description="A tool to retrieve similar questions from a vector store.",
138
+ )
139
+
140
+
141
+
142
+ tools = [
143
+ multiply,
144
+ add,
145
+ subtract,
146
+ divide,
147
+ modulus,
148
+ wiki_search,
149
+ web_search,
150
+ arvix_search,
151
+ ]
152
+
153
+ # Build graph function
154
+ def build_graph(provider: str = "google"):
155
+ """Build the graph"""
156
+ # Load environment variables from .env file
157
+ if provider == "google":
158
+ # Google Gemini
159
+ llm = ChatGoogleGenerativeAI(model="gemini-2.0-flash", temperature=0)
160
+ elif provider == "groq":
161
+ # Groq https://console.groq.com/docs/models
162
+ llm = ChatGroq(model="qwen-qwq-32b", temperature=0) # optional : qwen-qwq-32b gemma2-9b-it
163
+ elif provider == "huggingface":
164
+ # TODO: Add huggingface endpoint
165
+ llm = ChatHuggingFace(
166
+ llm=HuggingFaceEndpoint(
167
+ url="https://api-inference.huggingface.co/models/Meta-DeepLearning/llama-2-7b-chat-hf",
168
+ temperature=0,
169
  ),
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
170
  )
171
+ else:
172
+ raise ValueError("Invalid provider. Choose 'google', 'groq' or 'huggingface'.")
173
+ # Bind tools to LLM
174
+ llm_with_tools = llm.bind_tools(tools)
175
+
176
+ # Node
177
+ def assistant(state: MessagesState):
178
+ """Assistant node"""
179
+ return {"messages": [llm_with_tools.invoke(state["messages"])]}
180
+
181
+ # def retriever(state: MessagesState):
182
+ # """Retriever node"""
183
+ # similar_question = vector_store.similarity_search(state["messages"][0].content)
184
+ #example_msg = HumanMessage(
185
+ # content=f"Here I provide a similar question and answer for reference: \n\n{similar_question[0].page_content}",
186
+ # )
187
+ # return {"messages": [sys_msg] + state["messages"] + [example_msg]}
188
+
189
+ from langchain_core.messages import AIMessage
190
+
191
+ def retriever(state: MessagesState):
192
+ query = state["messages"][-1].content
193
+ similar_doc = vector_store.similarity_search(query, k=1)[0]
194
+
195
+ content = similar_doc.page_content
196
+ if "Final answer :" in content:
197
+ answer = content.split("Final answer :")[-1].strip()
198
+ else:
199
+ answer = content.strip()
200
+
201
+ return {"messages": [AIMessage(content=answer)]}
202
+
203
+ # builder = StateGraph(MessagesState)
204
+ #builder.add_node("retriever", retriever)
205
+ #builder.add_node("assistant", assistant)
206
+ #builder.add_node("tools", ToolNode(tools))
207
+ #builder.add_edge(START, "retriever")
208
+ #builder.add_edge("retriever", "assistant")
209
+ #builder.add_conditional_edges(
210
+ # "assistant",
211
+ # tools_condition,
212
+ #)
213
+ #builder.add_edge("tools", "assistant")
214
+
215
+ builder = StateGraph(MessagesState)
216
+ builder.add_node("retriever", retriever)
217
+
218
+ # Retriever ist Start und Endpunkt
219
+ builder.set_entry_point("retriever")
220
+ builder.set_finish_point("retriever")
221
+
222
+ # Compile graph
223
+ return builder.compile()
224
+
225