Ali-Developments commited on
Commit
63de395
ยท
verified ยท
1 Parent(s): 051c1ef

Update agent.py

Browse files
Files changed (1) hide show
  1. agent.py +21 -53
agent.py CHANGED
@@ -1,6 +1,4 @@
1
- # app.py
2
  import os
3
- import streamlit as st
4
  from dotenv import load_dotenv
5
  from langchain.docstore.document import Document
6
  from langchain_community.retrievers import BM25Retriever
@@ -16,11 +14,11 @@ import fitz # PyMuPDF
16
 
17
  # Load environment variables
18
  load_dotenv()
19
- os.environ["GROQ_API_KEY"] = os.getenv("GROQ_API_KEY")
20
  groq_api_key = os.getenv("GROQ_API_KEY")
21
  serpapi_api_key = os.getenv("SERPAPI_API_KEY")
22
 
23
- # --- PDF uploader and parser ---
 
24
  def parse_pdfs(uploaded_files):
25
  pdf_docs = []
26
  for uploaded_file in uploaded_files:
@@ -31,10 +29,12 @@ def parse_pdfs(uploaded_files):
31
  pdf_docs.append(Document(page_content=text, metadata={"source": uploaded_file.name}))
32
  return pdf_docs
33
 
34
- # --- Guest info retrieval ---
 
35
  def build_retriever(all_docs):
36
  return BM25Retriever.from_documents(all_docs)
37
 
 
38
  def extract_text(query: str, retriever):
39
  results = retriever.invoke(query)
40
  if results:
@@ -42,25 +42,11 @@ def extract_text(query: str, retriever):
42
  else:
43
  return "ู„ู… ูŠุชู… ุงู„ุนุซูˆุฑ ุนู„ู‰ ู…ุนู„ูˆู…ุงุช ู…ุทุงุจู‚ุฉ ููŠ ุงู„ู…ู„ูุงุช."
44
 
45
- # --- Streamlit UI ---
46
- st.set_page_config(page_title="NINU Agent", page_icon="๐Ÿ›๏ธ")
47
- st.title("๐Ÿ›๏ธ NINU - Guest & PDF & Web Assistant")
48
-
49
- st.markdown("** Hint:** NINU can help summarize lectures, answer questions from PDFs, and search the web interactively.")
50
-
51
- if "conversation_history" not in st.session_state:
52
- st.session_state.conversation_history = []
53
-
54
- query = st.text_area("๐Ÿ“ ุงูƒุชุจ ุณุคุงู„ูƒ ุฃูˆ ูƒู…ู„ ู…ุฐุงูƒุฑุชูƒ ู‡ู†ุง:")
55
-
56
- uploaded_files = st.file_uploader("๐Ÿ“„ ุงุฑูุน ู…ู„ูุงุช PDF ู„ู„ู…ุญุงุถุฑุงุช", type=["pdf"], accept_multiple_files=True)
57
 
58
- if st.button("Ask NINU") and query:
59
- # Parse PDFs if uploaded
60
- user_docs = parse_pdfs(uploaded_files) if uploaded_files else []
61
  bm25_retriever = build_retriever(user_docs) if user_docs else None
62
 
63
- # Tool for PDF retrieval (if PDFs uploaded)
64
  def pdf_tool_func(q):
65
  if bm25_retriever:
66
  return extract_text(q, bm25_retriever)
@@ -73,7 +59,6 @@ if st.button("Ask NINU") and query:
73
  description="Retrieves content from uploaded PDFs based on a query."
74
  )
75
 
76
- # Tool for Web search using SerpAPI
77
  serpapi = SerpAPIWrapper(serpapi_api_key=serpapi_api_key)
78
  SerpAPI_tool = Tool(
79
  name="WebSearch",
@@ -81,34 +66,33 @@ if st.button("Ask NINU") and query:
81
  description="Searches the web for recent information."
82
  )
83
 
84
- # Combine tools
85
  tools = [NINU_tool, SerpAPI_tool]
86
 
87
- # Create LLM and bind tools
88
  llm = ChatGroq(model="deepseek-r1-distill-llama-70b", groq_api_key=groq_api_key)
89
  llm_with_tools = llm.bind_tools(tools)
90
 
91
- # Define Agent state and assistant function
92
  class AgentState(TypedDict):
93
  messages: Annotated[list[AnyMessage], add_messages]
94
 
95
  def assistant(state: AgentState):
96
- return {
97
- "messages": [llm_with_tools.invoke(state["messages"])]
98
- }
99
 
100
- # Build the StateGraph agent
101
  builder = StateGraph(AgentState)
102
  builder.add_node("assistant", assistant)
103
  builder.add_node("tools", ToolNode(tools))
104
  builder.add_edge(START, "assistant")
105
  builder.add_conditional_edges("assistant", tools_condition)
106
  builder.add_edge("tools", "assistant")
107
- NINU = builder.compile()
108
 
109
- # Add intro prompt if first message
110
- if len(st.session_state.conversation_history) == 0:
111
- intro_prompt = """
 
 
 
 
 
112
  You are a general AI assistant with access to two tools:
113
 
114
  1. NINU_Lec_retriever: retrieves content from uploaded PDFs based on a query.
@@ -125,24 +109,8 @@ If you are asked for a number, don't use commas or units (like $, %, etc.) unles
125
 
126
  If you are asked for a string, avoid articles, abbreviations, and write digits in plain text unless specified.
127
  """
128
- st.session_state.conversation_history.append(HumanMessage(content=intro_prompt))
129
-
130
- # Add user query
131
- st.session_state.conversation_history.append(HumanMessage(content=query))
132
-
133
- # Invoke the agent
134
- response = NINU.invoke({"messages": st.session_state.conversation_history})
135
-
136
- # Append assistant reply to conversation history
137
- assistant_reply = response["messages"][-1]
138
- st.session_state.conversation_history.append(assistant_reply)
139
-
140
- # Show assistant reply
141
- st.markdown("### NINU's Response:")
142
- st.write(assistant_reply.content)
143
 
144
- # Show full conversation history (optional)
145
- with st.expander("๐Ÿงพ Show full conversation history"):
146
- for msg in st.session_state.conversation_history:
147
- role = "You" if msg.type == "human" else "NINU"
148
- st.markdown(f"**{role}:** {msg.content}")
 
 
1
  import os
 
2
  from dotenv import load_dotenv
3
  from langchain.docstore.document import Document
4
  from langchain_community.retrievers import BM25Retriever
 
14
 
15
  # Load environment variables
16
  load_dotenv()
 
17
  groq_api_key = os.getenv("GROQ_API_KEY")
18
  serpapi_api_key = os.getenv("SERPAPI_API_KEY")
19
 
20
+
21
+ # --- PDF parsing ---
22
  def parse_pdfs(uploaded_files):
23
  pdf_docs = []
24
  for uploaded_file in uploaded_files:
 
29
  pdf_docs.append(Document(page_content=text, metadata={"source": uploaded_file.name}))
30
  return pdf_docs
31
 
32
+
33
+ # --- BM25 Retrieval ---
34
  def build_retriever(all_docs):
35
  return BM25Retriever.from_documents(all_docs)
36
 
37
+
38
  def extract_text(query: str, retriever):
39
  results = retriever.invoke(query)
40
  if results:
 
42
  else:
43
  return "ู„ู… ูŠุชู… ุงู„ุนุซูˆุฑ ุนู„ู‰ ู…ุนู„ูˆู…ุงุช ู…ุทุงุจู‚ุฉ ููŠ ุงู„ู…ู„ูุงุช."
44
 
 
 
 
 
 
 
 
 
 
 
 
 
45
 
46
+ # --- Create NINU Agent ---
47
+ def create_ninu_agent(user_docs=None):
 
48
  bm25_retriever = build_retriever(user_docs) if user_docs else None
49
 
 
50
  def pdf_tool_func(q):
51
  if bm25_retriever:
52
  return extract_text(q, bm25_retriever)
 
59
  description="Retrieves content from uploaded PDFs based on a query."
60
  )
61
 
 
62
  serpapi = SerpAPIWrapper(serpapi_api_key=serpapi_api_key)
63
  SerpAPI_tool = Tool(
64
  name="WebSearch",
 
66
  description="Searches the web for recent information."
67
  )
68
 
 
69
  tools = [NINU_tool, SerpAPI_tool]
70
 
 
71
  llm = ChatGroq(model="deepseek-r1-distill-llama-70b", groq_api_key=groq_api_key)
72
  llm_with_tools = llm.bind_tools(tools)
73
 
 
74
  class AgentState(TypedDict):
75
  messages: Annotated[list[AnyMessage], add_messages]
76
 
77
  def assistant(state: AgentState):
78
+ return {"messages": [llm_with_tools.invoke(state["messages"])]}
 
 
79
 
 
80
  builder = StateGraph(AgentState)
81
  builder.add_node("assistant", assistant)
82
  builder.add_node("tools", ToolNode(tools))
83
  builder.add_edge(START, "assistant")
84
  builder.add_conditional_edges("assistant", tools_condition)
85
  builder.add_edge("tools", "assistant")
86
+ return builder.compile()
87
 
88
+
89
+ # --- Main interaction function ---
90
+ def run_ninu(query, user_docs=None):
91
+ agent = create_ninu_agent(user_docs)
92
+
93
+ conversation = []
94
+
95
+ intro_prompt = """
96
  You are a general AI assistant with access to two tools:
97
 
98
  1. NINU_Lec_retriever: retrieves content from uploaded PDFs based on a query.
 
109
 
110
  If you are asked for a string, avoid articles, abbreviations, and write digits in plain text unless specified.
111
  """
112
+ conversation.append(HumanMessage(content=intro_prompt))
113
+ conversation.append(HumanMessage(content=query))
 
 
 
 
 
 
 
 
 
 
 
 
 
114
 
115
+ response = agent.invoke({"messages": conversation})
116
+ return response["messages"][-1].content