Spaces:
Sleeping
Sleeping
Vela
commited on
Commit
·
75115cd
1
Parent(s):
00f1bc6
enhanced graph
Browse files- .gitignore +2 -1
- app.py +4 -4
- application/services/{gemini_model.py → gemini_api_service.py} +29 -6
- application/services/mongo_db_service.py +2 -1
- application/tools/emission_data_extractor.py +1 -1
- application/tools/web_search_tools.py +3 -1
- main.py +161 -0
- pages/chatbot.py +10 -16
- pages/multiple_pdf_extractor.py +2 -2
.gitignore
CHANGED
@@ -3,4 +3,5 @@
|
|
3 |
data
|
4 |
__pycache__/
|
5 |
logs/
|
6 |
-
test.py
|
|
|
|
3 |
data
|
4 |
__pycache__/
|
5 |
logs/
|
6 |
+
test.py
|
7 |
+
reports/
|
app.py
CHANGED
@@ -1,6 +1,6 @@
|
|
1 |
import streamlit as st
|
2 |
import os
|
3 |
-
from application.services import
|
4 |
from google.genai.errors import ClientError
|
5 |
from application.utils import logger
|
6 |
from application.schemas.response_schema import (
|
@@ -44,7 +44,7 @@ if st.session_state.pdf_file:
|
|
44 |
with col1:
|
45 |
if st.button(f"Generate {MODEL_1} Response"):
|
46 |
with st.spinner(f"Calling {MODEL_1}..."):
|
47 |
-
result =
|
48 |
excel_file = streamlit_function.export_results_to_excel(result, MODEL_1, file_name)
|
49 |
st.session_state[f"{MODEL_1}_result"] = result
|
50 |
if st.session_state[f"{MODEL_1}_result"]:
|
@@ -54,7 +54,7 @@ if st.session_state.pdf_file:
|
|
54 |
with col2:
|
55 |
if st.button(f"Generate {MODEL_2} Response"):
|
56 |
with st.spinner(f"Calling {MODEL_2}..."):
|
57 |
-
result =
|
58 |
excel_file = streamlit_function.export_results_to_excel(result, MODEL_2, file_name)
|
59 |
st.session_state[f"{MODEL_2}_result"] = result
|
60 |
if st.session_state[f"{MODEL_2}_result"]:
|
@@ -65,7 +65,7 @@ if st.session_state.pdf_file:
|
|
65 |
try:
|
66 |
if st.button(f"Generate {MODEL_3} Response"):
|
67 |
with st.spinner(f"Calling {MODEL_3}..."):
|
68 |
-
result =
|
69 |
excel_file = streamlit_function.export_results_to_excel(result, MODEL_3, file_name)
|
70 |
st.session_state[f"{MODEL_3}_result"] = result
|
71 |
except ClientError as e:
|
|
|
1 |
import streamlit as st
|
2 |
import os
|
3 |
+
from application.services import gemini_api_service, streamlit_function
|
4 |
from google.genai.errors import ClientError
|
5 |
from application.utils import logger
|
6 |
from application.schemas.response_schema import (
|
|
|
44 |
with col1:
|
45 |
if st.button(f"Generate {MODEL_1} Response"):
|
46 |
with st.spinner(f"Calling {MODEL_1}..."):
|
47 |
+
result = gemini_api_service.extract_emissions_data_as_json(API_1 , MODEL_1, st.session_state.pdf_file[0],FULL_RESPONSE_SCHEMA)
|
48 |
excel_file = streamlit_function.export_results_to_excel(result, MODEL_1, file_name)
|
49 |
st.session_state[f"{MODEL_1}_result"] = result
|
50 |
if st.session_state[f"{MODEL_1}_result"]:
|
|
|
54 |
with col2:
|
55 |
if st.button(f"Generate {MODEL_2} Response"):
|
56 |
with st.spinner(f"Calling {MODEL_2}..."):
|
57 |
+
result = gemini_api_service.extract_emissions_data_as_json(API_2, MODEL_2, st.session_state.pdf_file[0],FULL_RESPONSE_SCHEMA)
|
58 |
excel_file = streamlit_function.export_results_to_excel(result, MODEL_2, file_name)
|
59 |
st.session_state[f"{MODEL_2}_result"] = result
|
60 |
if st.session_state[f"{MODEL_2}_result"]:
|
|
|
65 |
try:
|
66 |
if st.button(f"Generate {MODEL_3} Response"):
|
67 |
with st.spinner(f"Calling {MODEL_3}..."):
|
68 |
+
result = gemini_api_service.extract_emissions_data_as_json(API_3, MODEL_3, st.session_state.pdf_file[0], FULL_RESPONSE_SCHEMA)
|
69 |
excel_file = streamlit_function.export_results_to_excel(result, MODEL_3, file_name)
|
70 |
st.session_state[f"{MODEL_3}_result"] = result
|
71 |
except ClientError as e:
|
application/services/{gemini_model.py → gemini_api_service.py}
RENAMED
@@ -5,6 +5,8 @@ from typing import Optional, Dict, Union, IO, List, BinaryIO
|
|
5 |
from google import genai
|
6 |
from google.genai import types
|
7 |
from application.utils import logger
|
|
|
|
|
8 |
|
9 |
logger=logger.get_logger()
|
10 |
|
@@ -136,11 +138,11 @@ def upload_file(
|
|
136 |
config: Optional[Dict[str, str]] = None
|
137 |
) -> Optional[types.File]:
|
138 |
"""
|
139 |
-
Uploads a file to the Gemini API, handling
|
140 |
|
141 |
Args:
|
142 |
-
file (Union[str, IO[bytes]]):
|
143 |
-
file_name (Optional[str]): Name for the file. If None,
|
144 |
config (Optional[Dict[str, str]]): Extra config like 'mime_type'.
|
145 |
|
146 |
Returns:
|
@@ -150,8 +152,14 @@ def upload_file(
|
|
150 |
Exception: If upload fails.
|
151 |
"""
|
152 |
try:
|
|
|
|
|
|
|
|
|
153 |
if not file_name:
|
154 |
-
if
|
|
|
|
|
155 |
file_name = os.path.basename(file)
|
156 |
elif hasattr(file, "name"):
|
157 |
file_name = os.path.basename(file.name)
|
@@ -164,17 +172,32 @@ def upload_file(
|
|
164 |
config.update({"name": sanitized_name, "mime_type": mime_type})
|
165 |
gemini_file_key = f"files/{sanitized_name}"
|
166 |
|
|
|
167 |
if gemini_file_key in get_files():
|
168 |
logger.info(f"File already exists on Gemini: {gemini_file_key}")
|
169 |
return client.files.get(name=gemini_file_key)
|
170 |
|
171 |
logger.info(f"Uploading file to Gemini: {gemini_file_key}")
|
172 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
173 |
if isinstance(file, str):
|
|
|
|
|
174 |
with open(file, "rb") as f:
|
175 |
return client.files.upload(file=f, config=config)
|
176 |
-
|
177 |
-
|
|
|
178 |
|
179 |
except Exception as e:
|
180 |
logger.error(f"Failed to upload file '{file_name}': {e}")
|
|
|
5 |
from google import genai
|
6 |
from google.genai import types
|
7 |
from application.utils import logger
|
8 |
+
import requests
|
9 |
+
import io
|
10 |
|
11 |
logger=logger.get_logger()
|
12 |
|
|
|
138 |
config: Optional[Dict[str, str]] = None
|
139 |
) -> Optional[types.File]:
|
140 |
"""
|
141 |
+
Uploads a file to the Gemini API, handling local file paths, binary streams, and URLs.
|
142 |
|
143 |
Args:
|
144 |
+
file (Union[str, IO[bytes]]): Local file path, URL, or binary file object.
|
145 |
+
file_name (Optional[str]): Name for the file. If None, tries to infer it from the source.
|
146 |
config (Optional[Dict[str, str]]): Extra config like 'mime_type'.
|
147 |
|
148 |
Returns:
|
|
|
152 |
Exception: If upload fails.
|
153 |
"""
|
154 |
try:
|
155 |
+
# Determine if input is a URL
|
156 |
+
is_url = isinstance(file, str) and file.startswith(('http://', 'https://'))
|
157 |
+
|
158 |
+
# Determine file name if not provided
|
159 |
if not file_name:
|
160 |
+
if is_url:
|
161 |
+
file_name = os.path.basename(file.split("?")[0]) # Remove query params
|
162 |
+
elif isinstance(file, str):
|
163 |
file_name = os.path.basename(file)
|
164 |
elif hasattr(file, "name"):
|
165 |
file_name = os.path.basename(file.name)
|
|
|
172 |
config.update({"name": sanitized_name, "mime_type": mime_type})
|
173 |
gemini_file_key = f"files/{sanitized_name}"
|
174 |
|
175 |
+
# Check if file already exists
|
176 |
if gemini_file_key in get_files():
|
177 |
logger.info(f"File already exists on Gemini: {gemini_file_key}")
|
178 |
return client.files.get(name=gemini_file_key)
|
179 |
|
180 |
logger.info(f"Uploading file to Gemini: {gemini_file_key}")
|
181 |
|
182 |
+
# Handle URL
|
183 |
+
if is_url:
|
184 |
+
headers = {
|
185 |
+
"User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/120.0.0.0 Safari/537.36"
|
186 |
+
}
|
187 |
+
response = requests.get(file, headers=headers)
|
188 |
+
response.raise_for_status()
|
189 |
+
file_content = io.BytesIO(response.content)
|
190 |
+
return client.files.upload(file=file_content, config=config)
|
191 |
+
|
192 |
+
# Handle local file path
|
193 |
if isinstance(file, str):
|
194 |
+
if not os.path.isfile(file):
|
195 |
+
raise FileNotFoundError(f"Local file '{file}' does not exist.")
|
196 |
with open(file, "rb") as f:
|
197 |
return client.files.upload(file=f, config=config)
|
198 |
+
|
199 |
+
# Handle already opened binary file object
|
200 |
+
return client.files.upload(file=file, config=config)
|
201 |
|
202 |
except Exception as e:
|
203 |
logger.error(f"Failed to upload file '{file_name}': {e}")
|
application/services/mongo_db_service.py
CHANGED
@@ -84,4 +84,5 @@ def retrieve_documents(collection_name: str, query: Optional[Dict] = None) -> Li
|
|
84 |
logger.exception(f"An error occurred while retrieving documents: {str(e)}")
|
85 |
return []
|
86 |
|
87 |
-
# all_docs = retrieve_documents("Zalando")
|
|
|
|
84 |
logger.exception(f"An error occurred while retrieving documents: {str(e)}")
|
85 |
return []
|
86 |
|
87 |
+
# all_docs = retrieve_documents("Zalando")
|
88 |
+
# print(all_docs)
|
application/tools/emission_data_extractor.py
CHANGED
@@ -6,7 +6,7 @@ import requests
|
|
6 |
from google import genai
|
7 |
from google.genai import types
|
8 |
from application.utils.logger import get_logger
|
9 |
-
from application.services.
|
10 |
from application.services.mongo_db_service import store_document
|
11 |
from application.schemas.response_schema import GEMINI_GHG_PARAMETERS
|
12 |
from langchain_core.tools import tool
|
|
|
6 |
from google import genai
|
7 |
from google.genai import types
|
8 |
from application.utils.logger import get_logger
|
9 |
+
from application.services.gemini_api_service import upload_file
|
10 |
from application.services.mongo_db_service import store_document
|
11 |
from application.schemas.response_schema import GEMINI_GHG_PARAMETERS
|
12 |
from langchain_core.tools import tool
|
application/tools/web_search_tools.py
CHANGED
@@ -8,6 +8,7 @@ from typing import Literal
|
|
8 |
from duckduckgo_search import DDGS
|
9 |
from tavily import TavilyClient
|
10 |
from langchain_core.tools import tool
|
|
|
11 |
|
12 |
logger = get_logger()
|
13 |
load_dotenv()
|
@@ -54,7 +55,8 @@ def get_top_companies_from_web(query: str):
|
|
54 |
|
55 |
output = response.output_text
|
56 |
# logger.info(f"Raw Output: {output}")
|
57 |
-
parsed_list =
|
|
|
58 |
logger.info(f"Parsed List: {parsed_list}")
|
59 |
result = CompanyListResponse(companies=parsed_list)
|
60 |
return result
|
|
|
8 |
from duckduckgo_search import DDGS
|
9 |
from tavily import TavilyClient
|
10 |
from langchain_core.tools import tool
|
11 |
+
import ast
|
12 |
|
13 |
logger = get_logger()
|
14 |
load_dotenv()
|
|
|
55 |
|
56 |
output = response.output_text
|
57 |
# logger.info(f"Raw Output: {output}")
|
58 |
+
parsed_list = ast.literal_eval(output.strip())
|
59 |
+
# parsed_list = eval(output.strip())
|
60 |
logger.info(f"Parsed List: {parsed_list}")
|
61 |
result = CompanyListResponse(companies=parsed_list)
|
62 |
return result
|
main.py
ADDED
@@ -0,0 +1,161 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import operator
|
3 |
+
import functools
|
4 |
+
from typing import Annotated, Sequence, TypedDict, Union, Optional
|
5 |
+
|
6 |
+
from dotenv import load_dotenv
|
7 |
+
from langchain_openai import ChatOpenAI
|
8 |
+
from langchain_core.messages import BaseMessage, HumanMessage, AIMessage
|
9 |
+
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
|
10 |
+
from langchain_core.runnables import Runnable
|
11 |
+
from langchain.output_parsers.openai_tools import JsonOutputKeyToolsParser
|
12 |
+
from langgraph.graph import StateGraph, END
|
13 |
+
|
14 |
+
from application.agents.scraper_agent import scraper_agent
|
15 |
+
from application.agents.extractor_agent import extractor_agent
|
16 |
+
from application.utils.logger import get_logger
|
17 |
+
|
18 |
+
load_dotenv()
|
19 |
+
logger = get_logger()
|
20 |
+
|
21 |
+
OPENAI_API_KEY = os.environ.get("OPENAI_API_KEY")
|
22 |
+
if not OPENAI_API_KEY:
|
23 |
+
logger.error("OPENAI_API_KEY is missing. Please set it in your environment variables.")
|
24 |
+
raise EnvironmentError("OPENAI_API_KEY not found in environment variables.")
|
25 |
+
|
26 |
+
MEMBERS = ["Scraper", "Extractor"]
|
27 |
+
OPTIONS = ["FINISH"] + MEMBERS
|
28 |
+
|
29 |
+
SUPERVISOR_SYSTEM_PROMPT = (
|
30 |
+
"You are a supervisor tasked with managing a conversation between the following workers: {members}. "
|
31 |
+
"Given the user's request and the previous messages, determine what to do next:\n"
|
32 |
+
"- If the user asks to search, find, or scrape data from the web, choose 'Scraper'.\n"
|
33 |
+
"- If the user asks to extract ESG emissions data from a file or PDF, choose 'Extractor'.\n"
|
34 |
+
"- If the task is complete, choose 'FINISH'.\n"
|
35 |
+
"- If the message is general conversation (like greetings, questions, thanks, chatting), directly respond with a message.\n"
|
36 |
+
"Each worker will perform its task and report back.\n"
|
37 |
+
"When you respond directly, make sure your message is friendly and helpful."
|
38 |
+
)
|
39 |
+
|
40 |
+
FUNCTION_DEF = {
|
41 |
+
"name": "route_or_respond",
|
42 |
+
"description": "Select the next role OR respond directly.",
|
43 |
+
"parameters": {
|
44 |
+
"title": "RouteOrRespondSchema",
|
45 |
+
"type": "object",
|
46 |
+
"properties": {
|
47 |
+
"next": {
|
48 |
+
"title": "Next Worker",
|
49 |
+
"anyOf": [{"enum": OPTIONS}],
|
50 |
+
"description": "Choose next worker if needed."
|
51 |
+
},
|
52 |
+
"response": {
|
53 |
+
"title": "Supervisor Response",
|
54 |
+
"type": "string",
|
55 |
+
"description": "Respond directly if no worker action is needed."
|
56 |
+
}
|
57 |
+
},
|
58 |
+
"required": [],
|
59 |
+
},
|
60 |
+
}
|
61 |
+
|
62 |
+
class AgentState(TypedDict):
|
63 |
+
messages: Annotated[Sequence[BaseMessage], operator.add]
|
64 |
+
next: Optional[str]
|
65 |
+
response: Optional[str]
|
66 |
+
|
67 |
+
llm = ChatOpenAI(model="gpt-4o-mini", temperature=0)
|
68 |
+
|
69 |
+
def agent_node(state: AgentState, agent: Runnable, name: str) -> dict:
|
70 |
+
logger.info(f"Agent {name} invoked.")
|
71 |
+
try:
|
72 |
+
result = agent.invoke(state)
|
73 |
+
logger.info(f"Agent {name} completed successfully.")
|
74 |
+
return {"messages": [HumanMessage(content=result["output"], name=name)]}
|
75 |
+
except Exception as e:
|
76 |
+
logger.exception(f"Agent {name} failed with error: {str(e)}")
|
77 |
+
raise
|
78 |
+
|
79 |
+
prompt = ChatPromptTemplate.from_messages(
|
80 |
+
[
|
81 |
+
("system", SUPERVISOR_SYSTEM_PROMPT),
|
82 |
+
MessagesPlaceholder(variable_name="messages"),
|
83 |
+
(
|
84 |
+
"system",
|
85 |
+
"Based on the conversation, either select next worker (one of: {options}) or respond directly with a message.",
|
86 |
+
),
|
87 |
+
]
|
88 |
+
).partial(options=str(OPTIONS), members=", ".join(MEMBERS))
|
89 |
+
|
90 |
+
# supervisor_chain = (
|
91 |
+
# prompt
|
92 |
+
# | llm.bind_functions(functions=[FUNCTION_DEF], function_call="route_or_respond")
|
93 |
+
# | JsonOutputFunctionsParser()
|
94 |
+
# )
|
95 |
+
|
96 |
+
supervisor_chain = (
|
97 |
+
prompt
|
98 |
+
| llm.bind_tools(tools=[FUNCTION_DEF], tool_choice="route_or_respond")
|
99 |
+
| JsonOutputKeyToolsParser(key_name="route_or_respond")
|
100 |
+
)
|
101 |
+
|
102 |
+
def supervisor_node(state: AgentState) -> AgentState:
|
103 |
+
logger.info("Supervisor invoked.")
|
104 |
+
output = supervisor_chain.invoke(state)
|
105 |
+
logger.info(f"Supervisor output: {output}")
|
106 |
+
|
107 |
+
if isinstance(output, list) and len(output) > 0:
|
108 |
+
output = output[0]
|
109 |
+
|
110 |
+
next_step = output.get("next")
|
111 |
+
response = output.get("response")
|
112 |
+
|
113 |
+
if not next_step and not response:
|
114 |
+
raise ValueError(f"Supervisor produced invalid output: {output}")
|
115 |
+
|
116 |
+
return {
|
117 |
+
"messages": state["messages"],
|
118 |
+
"next": next_step,
|
119 |
+
"response": response,
|
120 |
+
}
|
121 |
+
|
122 |
+
workflow = StateGraph(AgentState)
|
123 |
+
|
124 |
+
workflow.add_node("Scraper", functools.partial(agent_node, agent=scraper_agent, name="Scraper"))
|
125 |
+
workflow.add_node("Extractor", functools.partial(agent_node, agent=extractor_agent, name="Extractor"))
|
126 |
+
workflow.add_node("supervisor", supervisor_node)
|
127 |
+
# workflow.add_node("supervisor", supervisor_chain)
|
128 |
+
workflow.add_node("supervisor_response", lambda state: {"messages": [AIMessage(content=state["response"], name="Supervisor")]})
|
129 |
+
|
130 |
+
for member in MEMBERS:
|
131 |
+
workflow.add_edge(member, "supervisor")
|
132 |
+
|
133 |
+
def router(state: AgentState):
|
134 |
+
if state.get("response"):
|
135 |
+
return "supervisor_response"
|
136 |
+
return state.get("next")
|
137 |
+
|
138 |
+
conditional_map = {member: member for member in MEMBERS}
|
139 |
+
conditional_map["FINISH"] = END
|
140 |
+
conditional_map["supervisor_response"] = "supervisor_response"
|
141 |
+
|
142 |
+
|
143 |
+
workflow.add_conditional_edges("supervisor", router, conditional_map)
|
144 |
+
|
145 |
+
workflow.set_entry_point("supervisor")
|
146 |
+
|
147 |
+
graph = workflow.compile()
|
148 |
+
|
149 |
+
# # === Example Run ===
|
150 |
+
if __name__ == "__main__":
|
151 |
+
logger.info("Starting the graph execution...")
|
152 |
+
initial_message = HumanMessage(content="Can you get zalando pdf link")
|
153 |
+
input_state = {"messages": [initial_message]}
|
154 |
+
|
155 |
+
for step in graph.stream(input_state):
|
156 |
+
if "__end__" not in step:
|
157 |
+
logger.info(f"Graph Step Output: {step}")
|
158 |
+
print(step)
|
159 |
+
print("----")
|
160 |
+
|
161 |
+
logger.info("Graph execution completed.")
|
pages/chatbot.py
CHANGED
@@ -2,16 +2,10 @@ import streamlit as st
|
|
2 |
from dotenv import load_dotenv
|
3 |
|
4 |
from langchain_core.messages import HumanMessage, AIMessage, SystemMessage
|
5 |
-
# from application.agents.scraper_agent import app
|
6 |
-
# from application.utils.logger import get_logger
|
7 |
|
8 |
-
|
9 |
-
|
10 |
-
|
11 |
-
from application.utils.logger import get_logger
|
12 |
-
except ImportError as e:
|
13 |
-
st.error(f"Import Error: Ensure backend modules are accessible. Details: {e}")
|
14 |
-
st.stop()
|
15 |
|
16 |
logger = get_logger()
|
17 |
|
@@ -19,8 +13,8 @@ st.set_page_config(page_title="Sustainability AI Assistant", layout="wide")
|
|
19 |
st.title("♻️ Sustainability Report AI Assistant")
|
20 |
st.caption(
|
21 |
"Ask about sustainability reports by company or industry! "
|
22 |
-
"(e.g., 'Get report for Apple', 'Download report for Microsoft 2023', "
|
23 |
-
"'Find reports for top 3 airline companies', 'Download this pdf <link>')"
|
24 |
)
|
25 |
|
26 |
load_dotenv()
|
@@ -34,10 +28,10 @@ def initialize_chat_history():
|
|
34 |
def display_chat_history():
|
35 |
"""Render previous chat messages."""
|
36 |
for message in st.session_state.messages:
|
37 |
-
if isinstance(message, SystemMessage):
|
38 |
-
|
39 |
-
|
40 |
-
|
41 |
with st.chat_message("user"):
|
42 |
st.markdown(message.content)
|
43 |
elif isinstance(message, AIMessage):
|
@@ -77,10 +71,10 @@ def display_last_ai_response():
|
|
77 |
logger.warning("No AI message found in the final output.")
|
78 |
|
79 |
initialize_chat_history()
|
80 |
-
display_chat_history()
|
81 |
|
82 |
if user_query := st.chat_input("Your question about sustainability reports..."):
|
83 |
logger.info(f"User input received: {user_query}")
|
|
|
84 |
|
85 |
st.session_state.messages.append(HumanMessage(content=user_query))
|
86 |
|
|
|
2 |
from dotenv import load_dotenv
|
3 |
|
4 |
from langchain_core.messages import HumanMessage, AIMessage, SystemMessage
|
|
|
|
|
5 |
|
6 |
+
from application.agents.scraper_agent import app
|
7 |
+
from main import graph
|
8 |
+
from application.utils.logger import get_logger
|
|
|
|
|
|
|
|
|
9 |
|
10 |
logger = get_logger()
|
11 |
|
|
|
13 |
st.title("♻️ Sustainability Report AI Assistant")
|
14 |
st.caption(
|
15 |
"Ask about sustainability reports by company or industry! "
|
16 |
+
"(e.g., 'Get sustainability report for Apple', 'Download sustainability report for Microsoft 2023', "
|
17 |
+
"'Find sustainability reports for top 3 airline companies', 'Download this pdf <link>')"
|
18 |
)
|
19 |
|
20 |
load_dotenv()
|
|
|
28 |
def display_chat_history():
|
29 |
"""Render previous chat messages."""
|
30 |
for message in st.session_state.messages:
|
31 |
+
# if isinstance(message, SystemMessage):
|
32 |
+
# # st.info(f"System: {message.content}")
|
33 |
+
# pass
|
34 |
+
if isinstance(message, HumanMessage):
|
35 |
with st.chat_message("user"):
|
36 |
st.markdown(message.content)
|
37 |
elif isinstance(message, AIMessage):
|
|
|
71 |
logger.warning("No AI message found in the final output.")
|
72 |
|
73 |
initialize_chat_history()
|
|
|
74 |
|
75 |
if user_query := st.chat_input("Your question about sustainability reports..."):
|
76 |
logger.info(f"User input received: {user_query}")
|
77 |
+
display_chat_history()
|
78 |
|
79 |
st.session_state.messages.append(HumanMessage(content=user_query))
|
80 |
|
pages/multiple_pdf_extractor.py
CHANGED
@@ -6,7 +6,7 @@ from application.schemas.response_schema import (
|
|
6 |
GEMINI_GOVERNANCE_PARAMETERS, GEMINI_MATERIALITY_PARAMETERS,
|
7 |
GEMINI_NET_ZERO_INTERVENTION_PARAMETERS
|
8 |
)
|
9 |
-
from application.services import
|
10 |
from application.utils import logger
|
11 |
|
12 |
logger = logger.get_logger()
|
@@ -58,7 +58,7 @@ if st.session_state.uploaded_files:
|
|
58 |
all_results = {}
|
59 |
|
60 |
for label, schema in RESPONSE_SCHEMAS.items():
|
61 |
-
result =
|
62 |
streamlit_function.export_results_to_excel(result, sheet_name=selected_model, filename=file_name, column=label)
|
63 |
all_results[label] = result
|
64 |
st.session_state[result_key] = all_results
|
|
|
6 |
GEMINI_GOVERNANCE_PARAMETERS, GEMINI_MATERIALITY_PARAMETERS,
|
7 |
GEMINI_NET_ZERO_INTERVENTION_PARAMETERS
|
8 |
)
|
9 |
+
from application.services import gemini_api_service, streamlit_function
|
10 |
from application.utils import logger
|
11 |
|
12 |
logger = logger.get_logger()
|
|
|
58 |
all_results = {}
|
59 |
|
60 |
for label, schema in RESPONSE_SCHEMAS.items():
|
61 |
+
result = gemini_api_service.extract_emissions_data_as_json("gemini", selected_model, pdf_file, schema)
|
62 |
streamlit_function.export_results_to_excel(result, sheet_name=selected_model, filename=file_name, column=label)
|
63 |
all_results[label] = result
|
64 |
st.session_state[result_key] = all_results
|