Vela commited on
Commit
75115cd
·
1 Parent(s): 00f1bc6

enhanced graph

Browse files
.gitignore CHANGED
@@ -3,4 +3,5 @@
3
  data
4
  __pycache__/
5
  logs/
6
- test.py
 
 
3
  data
4
  __pycache__/
5
  logs/
6
+ test.py
7
+ reports/
app.py CHANGED
@@ -1,6 +1,6 @@
1
  import streamlit as st
2
  import os
3
- from application.services import streamlit_function, gemini_model
4
  from google.genai.errors import ClientError
5
  from application.utils import logger
6
  from application.schemas.response_schema import (
@@ -44,7 +44,7 @@ if st.session_state.pdf_file:
44
  with col1:
45
  if st.button(f"Generate {MODEL_1} Response"):
46
  with st.spinner(f"Calling {MODEL_1}..."):
47
- result = gemini_model.extract_emissions_data_as_json(API_1 , MODEL_1, st.session_state.pdf_file[0],FULL_RESPONSE_SCHEMA)
48
  excel_file = streamlit_function.export_results_to_excel(result, MODEL_1, file_name)
49
  st.session_state[f"{MODEL_1}_result"] = result
50
  if st.session_state[f"{MODEL_1}_result"]:
@@ -54,7 +54,7 @@ if st.session_state.pdf_file:
54
  with col2:
55
  if st.button(f"Generate {MODEL_2} Response"):
56
  with st.spinner(f"Calling {MODEL_2}..."):
57
- result = gemini_model.extract_emissions_data_as_json(API_2, MODEL_2, st.session_state.pdf_file[0],FULL_RESPONSE_SCHEMA)
58
  excel_file = streamlit_function.export_results_to_excel(result, MODEL_2, file_name)
59
  st.session_state[f"{MODEL_2}_result"] = result
60
  if st.session_state[f"{MODEL_2}_result"]:
@@ -65,7 +65,7 @@ if st.session_state.pdf_file:
65
  try:
66
  if st.button(f"Generate {MODEL_3} Response"):
67
  with st.spinner(f"Calling {MODEL_3}..."):
68
- result = gemini_model.extract_emissions_data_as_json(API_3, MODEL_3, st.session_state.pdf_file[0], FULL_RESPONSE_SCHEMA)
69
  excel_file = streamlit_function.export_results_to_excel(result, MODEL_3, file_name)
70
  st.session_state[f"{MODEL_3}_result"] = result
71
  except ClientError as e:
 
1
  import streamlit as st
2
  import os
3
+ from application.services import gemini_api_service, streamlit_function
4
  from google.genai.errors import ClientError
5
  from application.utils import logger
6
  from application.schemas.response_schema import (
 
44
  with col1:
45
  if st.button(f"Generate {MODEL_1} Response"):
46
  with st.spinner(f"Calling {MODEL_1}..."):
47
+ result = gemini_api_service.extract_emissions_data_as_json(API_1 , MODEL_1, st.session_state.pdf_file[0],FULL_RESPONSE_SCHEMA)
48
  excel_file = streamlit_function.export_results_to_excel(result, MODEL_1, file_name)
49
  st.session_state[f"{MODEL_1}_result"] = result
50
  if st.session_state[f"{MODEL_1}_result"]:
 
54
  with col2:
55
  if st.button(f"Generate {MODEL_2} Response"):
56
  with st.spinner(f"Calling {MODEL_2}..."):
57
+ result = gemini_api_service.extract_emissions_data_as_json(API_2, MODEL_2, st.session_state.pdf_file[0],FULL_RESPONSE_SCHEMA)
58
  excel_file = streamlit_function.export_results_to_excel(result, MODEL_2, file_name)
59
  st.session_state[f"{MODEL_2}_result"] = result
60
  if st.session_state[f"{MODEL_2}_result"]:
 
65
  try:
66
  if st.button(f"Generate {MODEL_3} Response"):
67
  with st.spinner(f"Calling {MODEL_3}..."):
68
+ result = gemini_api_service.extract_emissions_data_as_json(API_3, MODEL_3, st.session_state.pdf_file[0], FULL_RESPONSE_SCHEMA)
69
  excel_file = streamlit_function.export_results_to_excel(result, MODEL_3, file_name)
70
  st.session_state[f"{MODEL_3}_result"] = result
71
  except ClientError as e:
application/services/{gemini_model.py → gemini_api_service.py} RENAMED
@@ -5,6 +5,8 @@ from typing import Optional, Dict, Union, IO, List, BinaryIO
5
  from google import genai
6
  from google.genai import types
7
  from application.utils import logger
 
 
8
 
9
  logger=logger.get_logger()
10
 
@@ -136,11 +138,11 @@ def upload_file(
136
  config: Optional[Dict[str, str]] = None
137
  ) -> Optional[types.File]:
138
  """
139
- Uploads a file to the Gemini API, handling both file paths and binary streams.
140
 
141
  Args:
142
- file (Union[str, IO[bytes]]): File path or binary file object (e.g., from Streamlit).
143
- file_name (Optional[str]): Name for the file. If None, attempts to use file.name.
144
  config (Optional[Dict[str, str]]): Extra config like 'mime_type'.
145
 
146
  Returns:
@@ -150,8 +152,14 @@ def upload_file(
150
  Exception: If upload fails.
151
  """
152
  try:
 
 
 
 
153
  if not file_name:
154
- if isinstance(file, str):
 
 
155
  file_name = os.path.basename(file)
156
  elif hasattr(file, "name"):
157
  file_name = os.path.basename(file.name)
@@ -164,17 +172,32 @@ def upload_file(
164
  config.update({"name": sanitized_name, "mime_type": mime_type})
165
  gemini_file_key = f"files/{sanitized_name}"
166
 
 
167
  if gemini_file_key in get_files():
168
  logger.info(f"File already exists on Gemini: {gemini_file_key}")
169
  return client.files.get(name=gemini_file_key)
170
 
171
  logger.info(f"Uploading file to Gemini: {gemini_file_key}")
172
 
 
 
 
 
 
 
 
 
 
 
 
173
  if isinstance(file, str):
 
 
174
  with open(file, "rb") as f:
175
  return client.files.upload(file=f, config=config)
176
- else:
177
- return client.files.upload(file=file, config=config)
 
178
 
179
  except Exception as e:
180
  logger.error(f"Failed to upload file '{file_name}': {e}")
 
5
  from google import genai
6
  from google.genai import types
7
  from application.utils import logger
8
+ import requests
9
+ import io
10
 
11
  logger=logger.get_logger()
12
 
 
138
  config: Optional[Dict[str, str]] = None
139
  ) -> Optional[types.File]:
140
  """
141
+ Uploads a file to the Gemini API, handling local file paths, binary streams, and URLs.
142
 
143
  Args:
144
+ file (Union[str, IO[bytes]]): Local file path, URL, or binary file object.
145
+ file_name (Optional[str]): Name for the file. If None, tries to infer it from the source.
146
  config (Optional[Dict[str, str]]): Extra config like 'mime_type'.
147
 
148
  Returns:
 
152
  Exception: If upload fails.
153
  """
154
  try:
155
+ # Determine if input is a URL
156
+ is_url = isinstance(file, str) and file.startswith(('http://', 'https://'))
157
+
158
+ # Determine file name if not provided
159
  if not file_name:
160
+ if is_url:
161
+ file_name = os.path.basename(file.split("?")[0]) # Remove query params
162
+ elif isinstance(file, str):
163
  file_name = os.path.basename(file)
164
  elif hasattr(file, "name"):
165
  file_name = os.path.basename(file.name)
 
172
  config.update({"name": sanitized_name, "mime_type": mime_type})
173
  gemini_file_key = f"files/{sanitized_name}"
174
 
175
+ # Check if file already exists
176
  if gemini_file_key in get_files():
177
  logger.info(f"File already exists on Gemini: {gemini_file_key}")
178
  return client.files.get(name=gemini_file_key)
179
 
180
  logger.info(f"Uploading file to Gemini: {gemini_file_key}")
181
 
182
+ # Handle URL
183
+ if is_url:
184
+ headers = {
185
+ "User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/120.0.0.0 Safari/537.36"
186
+ }
187
+ response = requests.get(file, headers=headers)
188
+ response.raise_for_status()
189
+ file_content = io.BytesIO(response.content)
190
+ return client.files.upload(file=file_content, config=config)
191
+
192
+ # Handle local file path
193
  if isinstance(file, str):
194
+ if not os.path.isfile(file):
195
+ raise FileNotFoundError(f"Local file '{file}' does not exist.")
196
  with open(file, "rb") as f:
197
  return client.files.upload(file=f, config=config)
198
+
199
+ # Handle already opened binary file object
200
+ return client.files.upload(file=file, config=config)
201
 
202
  except Exception as e:
203
  logger.error(f"Failed to upload file '{file_name}': {e}")
application/services/mongo_db_service.py CHANGED
@@ -84,4 +84,5 @@ def retrieve_documents(collection_name: str, query: Optional[Dict] = None) -> Li
84
  logger.exception(f"An error occurred while retrieving documents: {str(e)}")
85
  return []
86
 
87
- # all_docs = retrieve_documents("Zalando")
 
 
84
  logger.exception(f"An error occurred while retrieving documents: {str(e)}")
85
  return []
86
 
87
+ # all_docs = retrieve_documents("Zalando")
88
+ # print(all_docs)
application/tools/emission_data_extractor.py CHANGED
@@ -6,7 +6,7 @@ import requests
6
  from google import genai
7
  from google.genai import types
8
  from application.utils.logger import get_logger
9
- from application.services.gemini_model import upload_file
10
  from application.services.mongo_db_service import store_document
11
  from application.schemas.response_schema import GEMINI_GHG_PARAMETERS
12
  from langchain_core.tools import tool
 
6
  from google import genai
7
  from google.genai import types
8
  from application.utils.logger import get_logger
9
+ from application.services.gemini_api_service import upload_file
10
  from application.services.mongo_db_service import store_document
11
  from application.schemas.response_schema import GEMINI_GHG_PARAMETERS
12
  from langchain_core.tools import tool
application/tools/web_search_tools.py CHANGED
@@ -8,6 +8,7 @@ from typing import Literal
8
  from duckduckgo_search import DDGS
9
  from tavily import TavilyClient
10
  from langchain_core.tools import tool
 
11
 
12
  logger = get_logger()
13
  load_dotenv()
@@ -54,7 +55,8 @@ def get_top_companies_from_web(query: str):
54
 
55
  output = response.output_text
56
  # logger.info(f"Raw Output: {output}")
57
- parsed_list = eval(output.strip())
 
58
  logger.info(f"Parsed List: {parsed_list}")
59
  result = CompanyListResponse(companies=parsed_list)
60
  return result
 
8
  from duckduckgo_search import DDGS
9
  from tavily import TavilyClient
10
  from langchain_core.tools import tool
11
+ import ast
12
 
13
  logger = get_logger()
14
  load_dotenv()
 
55
 
56
  output = response.output_text
57
  # logger.info(f"Raw Output: {output}")
58
+ parsed_list = ast.literal_eval(output.strip())
59
+ # parsed_list = eval(output.strip())
60
  logger.info(f"Parsed List: {parsed_list}")
61
  result = CompanyListResponse(companies=parsed_list)
62
  return result
main.py ADDED
@@ -0,0 +1,161 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import operator
3
+ import functools
4
+ from typing import Annotated, Sequence, TypedDict, Union, Optional
5
+
6
+ from dotenv import load_dotenv
7
+ from langchain_openai import ChatOpenAI
8
+ from langchain_core.messages import BaseMessage, HumanMessage, AIMessage
9
+ from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
10
+ from langchain_core.runnables import Runnable
11
+ from langchain.output_parsers.openai_tools import JsonOutputKeyToolsParser
12
+ from langgraph.graph import StateGraph, END
13
+
14
+ from application.agents.scraper_agent import scraper_agent
15
+ from application.agents.extractor_agent import extractor_agent
16
+ from application.utils.logger import get_logger
17
+
18
+ load_dotenv()
19
+ logger = get_logger()
20
+
21
+ OPENAI_API_KEY = os.environ.get("OPENAI_API_KEY")
22
+ if not OPENAI_API_KEY:
23
+ logger.error("OPENAI_API_KEY is missing. Please set it in your environment variables.")
24
+ raise EnvironmentError("OPENAI_API_KEY not found in environment variables.")
25
+
26
+ MEMBERS = ["Scraper", "Extractor"]
27
+ OPTIONS = ["FINISH"] + MEMBERS
28
+
29
+ SUPERVISOR_SYSTEM_PROMPT = (
30
+ "You are a supervisor tasked with managing a conversation between the following workers: {members}. "
31
+ "Given the user's request and the previous messages, determine what to do next:\n"
32
+ "- If the user asks to search, find, or scrape data from the web, choose 'Scraper'.\n"
33
+ "- If the user asks to extract ESG emissions data from a file or PDF, choose 'Extractor'.\n"
34
+ "- If the task is complete, choose 'FINISH'.\n"
35
+ "- If the message is general conversation (like greetings, questions, thanks, chatting), directly respond with a message.\n"
36
+ "Each worker will perform its task and report back.\n"
37
+ "When you respond directly, make sure your message is friendly and helpful."
38
+ )
39
+
40
+ FUNCTION_DEF = {
41
+ "name": "route_or_respond",
42
+ "description": "Select the next role OR respond directly.",
43
+ "parameters": {
44
+ "title": "RouteOrRespondSchema",
45
+ "type": "object",
46
+ "properties": {
47
+ "next": {
48
+ "title": "Next Worker",
49
+ "anyOf": [{"enum": OPTIONS}],
50
+ "description": "Choose next worker if needed."
51
+ },
52
+ "response": {
53
+ "title": "Supervisor Response",
54
+ "type": "string",
55
+ "description": "Respond directly if no worker action is needed."
56
+ }
57
+ },
58
+ "required": [],
59
+ },
60
+ }
61
+
62
+ class AgentState(TypedDict):
63
+ messages: Annotated[Sequence[BaseMessage], operator.add]
64
+ next: Optional[str]
65
+ response: Optional[str]
66
+
67
+ llm = ChatOpenAI(model="gpt-4o-mini", temperature=0)
68
+
69
+ def agent_node(state: AgentState, agent: Runnable, name: str) -> dict:
70
+ logger.info(f"Agent {name} invoked.")
71
+ try:
72
+ result = agent.invoke(state)
73
+ logger.info(f"Agent {name} completed successfully.")
74
+ return {"messages": [HumanMessage(content=result["output"], name=name)]}
75
+ except Exception as e:
76
+ logger.exception(f"Agent {name} failed with error: {str(e)}")
77
+ raise
78
+
79
+ prompt = ChatPromptTemplate.from_messages(
80
+ [
81
+ ("system", SUPERVISOR_SYSTEM_PROMPT),
82
+ MessagesPlaceholder(variable_name="messages"),
83
+ (
84
+ "system",
85
+ "Based on the conversation, either select next worker (one of: {options}) or respond directly with a message.",
86
+ ),
87
+ ]
88
+ ).partial(options=str(OPTIONS), members=", ".join(MEMBERS))
89
+
90
+ # supervisor_chain = (
91
+ # prompt
92
+ # | llm.bind_functions(functions=[FUNCTION_DEF], function_call="route_or_respond")
93
+ # | JsonOutputFunctionsParser()
94
+ # )
95
+
96
+ supervisor_chain = (
97
+ prompt
98
+ | llm.bind_tools(tools=[FUNCTION_DEF], tool_choice="route_or_respond")
99
+ | JsonOutputKeyToolsParser(key_name="route_or_respond")
100
+ )
101
+
102
+ def supervisor_node(state: AgentState) -> AgentState:
103
+ logger.info("Supervisor invoked.")
104
+ output = supervisor_chain.invoke(state)
105
+ logger.info(f"Supervisor output: {output}")
106
+
107
+ if isinstance(output, list) and len(output) > 0:
108
+ output = output[0]
109
+
110
+ next_step = output.get("next")
111
+ response = output.get("response")
112
+
113
+ if not next_step and not response:
114
+ raise ValueError(f"Supervisor produced invalid output: {output}")
115
+
116
+ return {
117
+ "messages": state["messages"],
118
+ "next": next_step,
119
+ "response": response,
120
+ }
121
+
122
+ workflow = StateGraph(AgentState)
123
+
124
+ workflow.add_node("Scraper", functools.partial(agent_node, agent=scraper_agent, name="Scraper"))
125
+ workflow.add_node("Extractor", functools.partial(agent_node, agent=extractor_agent, name="Extractor"))
126
+ workflow.add_node("supervisor", supervisor_node)
127
+ # workflow.add_node("supervisor", supervisor_chain)
128
+ workflow.add_node("supervisor_response", lambda state: {"messages": [AIMessage(content=state["response"], name="Supervisor")]})
129
+
130
+ for member in MEMBERS:
131
+ workflow.add_edge(member, "supervisor")
132
+
133
+ def router(state: AgentState):
134
+ if state.get("response"):
135
+ return "supervisor_response"
136
+ return state.get("next")
137
+
138
+ conditional_map = {member: member for member in MEMBERS}
139
+ conditional_map["FINISH"] = END
140
+ conditional_map["supervisor_response"] = "supervisor_response"
141
+
142
+
143
+ workflow.add_conditional_edges("supervisor", router, conditional_map)
144
+
145
+ workflow.set_entry_point("supervisor")
146
+
147
+ graph = workflow.compile()
148
+
149
+ # # === Example Run ===
150
+ if __name__ == "__main__":
151
+ logger.info("Starting the graph execution...")
152
+ initial_message = HumanMessage(content="Can you get zalando pdf link")
153
+ input_state = {"messages": [initial_message]}
154
+
155
+ for step in graph.stream(input_state):
156
+ if "__end__" not in step:
157
+ logger.info(f"Graph Step Output: {step}")
158
+ print(step)
159
+ print("----")
160
+
161
+ logger.info("Graph execution completed.")
pages/chatbot.py CHANGED
@@ -2,16 +2,10 @@ import streamlit as st
2
  from dotenv import load_dotenv
3
 
4
  from langchain_core.messages import HumanMessage, AIMessage, SystemMessage
5
- # from application.agents.scraper_agent import app
6
- # from application.utils.logger import get_logger
7
 
8
- try:
9
- from application.agents.scraper_agent import app
10
- # from application.main import graph
11
- from application.utils.logger import get_logger
12
- except ImportError as e:
13
- st.error(f"Import Error: Ensure backend modules are accessible. Details: {e}")
14
- st.stop()
15
 
16
  logger = get_logger()
17
 
@@ -19,8 +13,8 @@ st.set_page_config(page_title="Sustainability AI Assistant", layout="wide")
19
  st.title("♻️ Sustainability Report AI Assistant")
20
  st.caption(
21
  "Ask about sustainability reports by company or industry! "
22
- "(e.g., 'Get report for Apple', 'Download report for Microsoft 2023', "
23
- "'Find reports for top 3 airline companies', 'Download this pdf <link>')"
24
  )
25
 
26
  load_dotenv()
@@ -34,10 +28,10 @@ def initialize_chat_history():
34
  def display_chat_history():
35
  """Render previous chat messages."""
36
  for message in st.session_state.messages:
37
- if isinstance(message, SystemMessage):
38
- # st.info(f"System: {message.content}")
39
- pass
40
- elif isinstance(message, HumanMessage):
41
  with st.chat_message("user"):
42
  st.markdown(message.content)
43
  elif isinstance(message, AIMessage):
@@ -77,10 +71,10 @@ def display_last_ai_response():
77
  logger.warning("No AI message found in the final output.")
78
 
79
  initialize_chat_history()
80
- display_chat_history()
81
 
82
  if user_query := st.chat_input("Your question about sustainability reports..."):
83
  logger.info(f"User input received: {user_query}")
 
84
 
85
  st.session_state.messages.append(HumanMessage(content=user_query))
86
 
 
2
  from dotenv import load_dotenv
3
 
4
  from langchain_core.messages import HumanMessage, AIMessage, SystemMessage
 
 
5
 
6
+ from application.agents.scraper_agent import app
7
+ from main import graph
8
+ from application.utils.logger import get_logger
 
 
 
 
9
 
10
  logger = get_logger()
11
 
 
13
  st.title("♻️ Sustainability Report AI Assistant")
14
  st.caption(
15
  "Ask about sustainability reports by company or industry! "
16
+ "(e.g., 'Get sustainability report for Apple', 'Download sustainability report for Microsoft 2023', "
17
+ "'Find sustainability reports for top 3 airline companies', 'Download this pdf <link>')"
18
  )
19
 
20
  load_dotenv()
 
28
  def display_chat_history():
29
  """Render previous chat messages."""
30
  for message in st.session_state.messages:
31
+ # if isinstance(message, SystemMessage):
32
+ # # st.info(f"System: {message.content}")
33
+ # pass
34
+ if isinstance(message, HumanMessage):
35
  with st.chat_message("user"):
36
  st.markdown(message.content)
37
  elif isinstance(message, AIMessage):
 
71
  logger.warning("No AI message found in the final output.")
72
 
73
  initialize_chat_history()
 
74
 
75
  if user_query := st.chat_input("Your question about sustainability reports..."):
76
  logger.info(f"User input received: {user_query}")
77
+ display_chat_history()
78
 
79
  st.session_state.messages.append(HumanMessage(content=user_query))
80
 
pages/multiple_pdf_extractor.py CHANGED
@@ -6,7 +6,7 @@ from application.schemas.response_schema import (
6
  GEMINI_GOVERNANCE_PARAMETERS, GEMINI_MATERIALITY_PARAMETERS,
7
  GEMINI_NET_ZERO_INTERVENTION_PARAMETERS
8
  )
9
- from application.services import streamlit_function, gemini_model
10
  from application.utils import logger
11
 
12
  logger = logger.get_logger()
@@ -58,7 +58,7 @@ if st.session_state.uploaded_files:
58
  all_results = {}
59
 
60
  for label, schema in RESPONSE_SCHEMAS.items():
61
- result = gemini_model.extract_emissions_data_as_json("gemini", selected_model, pdf_file, schema)
62
  streamlit_function.export_results_to_excel(result, sheet_name=selected_model, filename=file_name, column=label)
63
  all_results[label] = result
64
  st.session_state[result_key] = all_results
 
6
  GEMINI_GOVERNANCE_PARAMETERS, GEMINI_MATERIALITY_PARAMETERS,
7
  GEMINI_NET_ZERO_INTERVENTION_PARAMETERS
8
  )
9
+ from application.services import gemini_api_service, streamlit_function
10
  from application.utils import logger
11
 
12
  logger = logger.get_logger()
 
58
  all_results = {}
59
 
60
  for label, schema in RESPONSE_SCHEMAS.items():
61
+ result = gemini_api_service.extract_emissions_data_as_json("gemini", selected_model, pdf_file, schema)
62
  streamlit_function.export_results_to_excel(result, sheet_name=selected_model, filename=file_name, column=label)
63
  all_results[label] = result
64
  st.session_state[result_key] = all_results