Vela commited on
Commit
5d4ad83
Β·
0 Parent(s):

Created a frontend dashboard

Browse files
.gitignore ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ .venv
2
+ logs
3
+ .env
README.md ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: PDFExtractor
3
+ emoji: 🌍
4
+ colorFrom: gray
5
+ colorTo: pink
6
+ sdk: streamlit
7
+ sdk_version: 1.44.1
8
+ app_file: app.py
9
+ pinned: false
10
+ short_description: An AI-powered tool that extracts sustainability data
11
+ ---
12
+ Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
app.py ADDED
@@ -0,0 +1,125 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import pandas as pd
3
+ import os
4
+ from src.utils import streamlit_function
5
+ from src.utils import logger
6
+
7
+ logger = logger.get_logger()
8
+ streamlit_function.config_homepage()
9
+
10
+ st.title("Sustainability Report Analyzer")
11
+ st.write("Upload your sustainability report PDF and generate insights using Gemini models.")
12
+
13
+ uploaded_files = streamlit_function.upload_file("pdf", label="πŸ“€ Upload Sustainability Report PDF")
14
+ if uploaded_files:
15
+ st.session_state.uploaded_files = uploaded_files
16
+
17
+ if "uploaded_files" not in st.session_state:
18
+ st.session_state.uploaded_files = []
19
+
20
+ if st.session_state.uploaded_files:
21
+ columns = st.columns(1)
22
+
23
+
24
+
25
+
26
+
27
+
28
+
29
+
30
+
31
+
32
+
33
+
34
+
35
+
36
+
37
+
38
+
39
+
40
+
41
+
42
+
43
+
44
+
45
+
46
+
47
+
48
+ # # import streamlit as st
49
+ # # from application.schemas.response_schema import GEMINI_GHG_PARAMETERS, GEMINI_ENVIRONMENTAL_PARAMETERS_CSRD,GEMINI_ENVIRONMENT_PARAMETERS,GEMINI_SOCIAL_PARAMETERS, GEMINI_GOVERNANCE_PARAMETERS, GEMINI_MATERIALITY_PARAMETERS, GEMINI_NET_ZERO_INTERVENTION_PARAMETERS
50
+ # # from application.services import streamlit_function, gemini_model
51
+ # # from application.utils import logger
52
+ # # import test
53
+
54
+ # # logger = logger.get_logger()
55
+ # # streamlit_function.config_homepage()
56
+ # # st.title("Sustainability Report Analyzer")
57
+ # # st.write("Upload your sustainability report PDF and generate insights using different models.")
58
+
59
+ # # MODEL = ["gemini-1.5-pro-latest", "gemini-2.0-flash", "gemini-1.5-flash", "gemini-2.5-pro-exp-03-25"]
60
+
61
+ # # MODEL_1 = "gemini-1.5-pro-latest"
62
+ # # MODEL_2 = "gemini-2.0-flash"
63
+ # # MODEL_3 = "gemini-1.5-flash"
64
+
65
+ # # API_1 = "gemini"
66
+ # # API_2 = "gemini"
67
+ # # API_3 = "gemini"
68
+
69
+ # # response_schema = [ GEMINI_GHG_PARAMETERS, GEMINI_ENVIRONMENTAL_PARAMETERS_CSRD,
70
+ # # GEMINI_ENVIRONMENT_PARAMETERS,GEMINI_SOCIAL_PARAMETERS,
71
+ # # GEMINI_GOVERNANCE_PARAMETERS, GEMINI_MATERIALITY_PARAMETERS,
72
+ # # GEMINI_NET_ZERO_INTERVENTION_PARAMETERS]
73
+
74
+ # # if "uploaded_files" not in st.session_state:
75
+ # # st.session_state.uploaded_files = []
76
+
77
+ # # MODEL = st.selectbox(
78
+ # # "Select Model",
79
+ # # options=MODEL,
80
+ # # index=0,
81
+ # # )
82
+
83
+ # # uploaded_files = streamlit_function.upload_file("pdf", label="Upload Sustainability Report PDF")
84
+
85
+ # # if uploaded_files:
86
+ # # st.session_state.uploaded_files = uploaded_files
87
+
88
+ # # if st.session_state.uploaded_files:
89
+ # # columns = st.columns([5, 5, 5], gap="small")
90
+
91
+ # # for i, col in enumerate(columns):
92
+ # # if i < len(st.session_state.uploaded_files):
93
+ # # pdf_file = st.session_state.uploaded_files[i]
94
+ # # file_name = pdf_file.name.removesuffix(".pdf")
95
+ # # result_key = f"{MODEL}_result_file_{i+1}"
96
+
97
+ # # with col:
98
+ # # st.write(f"**File {i+1}:** `{pdf_file.name}`")
99
+ # # if st.button(f"Extract Data from File {i+1}", key=f"extract_btn_{i}"):
100
+ # # with st.spinner(f"Extracting data from File {i+1} using {MODEL}..."):
101
+ # # for schema in response_schema:
102
+ # # result = gemini_model.extract_emissions_data_as_json(API_1, MODEL, pdf_file, schema)
103
+ # # if schema == GEMINI_GHG_PARAMETERS:
104
+ # # column = "Greenhouse Gas (GHG) Protocol Parameters"
105
+ # # elif schema == GEMINI_ENVIRONMENTAL_PARAMETERS_CSRD:
106
+ # # column = "Environmental Parameters (CSRD)"
107
+ # # elif schema == GEMINI_ENVIRONMENT_PARAMETERS:
108
+ # # column = "Environmental Parameters"
109
+ # # elif schema == GEMINI_SOCIAL_PARAMETERS:
110
+ # # column = "Social Parameters"
111
+ # # elif schema == GEMINI_GOVERNANCE_PARAMETERS:
112
+ # # column = "Governance Parameters"
113
+ # # elif schema == GEMINI_MATERIALITY_PARAMETERS:
114
+ # # column = "Materiality Parameters"
115
+ # # elif schema == GEMINI_NET_ZERO_INTERVENTION_PARAMETERS:
116
+ # # column = "Net Zero Intervention Parameters"
117
+ # # else:
118
+ # # column = None
119
+
120
+ # # test.export_results_to_excel(result, sheet_name=MODEL, filename=file_name, column=column )
121
+ # # st.session_state[result_key] = result
122
+
123
+ # # if st.session_state.get(result_key):
124
+ # # st.write(f"**Extracted Metrics for File {i+1}:**")
125
+ # # st.json(st.session_state[result_key])
pages/database.py ADDED
@@ -0,0 +1,92 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import pandas as pd
3
+
4
+ from src.utils import streamlit_function
5
+ from src.utils.logger import get_logger
6
+ from src.services.mongo_db_service import retrieve_documents
7
+ from src.utils.common_functions import prepare_comparison_df
8
+
9
+ logger = get_logger()
10
+ streamlit_function.config_homepage()
11
+
12
+ st.title("πŸ“Š ESG Report Comparison Dashboard")
13
+
14
+ METRIC_OPTIONS = {
15
+ "Report Metadata": ["report_metadata"],
16
+ "Environmental Parameters": [
17
+ "Emissions", "Energy Consumption", "Water Withdrawal", "Water Discharge",
18
+ "Waste Generation", "Waste Disposal", "Waste Recovery"
19
+ ],
20
+ "Social Parameters": [
21
+ "Human Rights Training Coverage", "LTIFR", "Other Safety Incidents",
22
+ "Health & Safety Training Coverage", "Grievances Reported",
23
+ "Third-party Assessment Coverage", "CSR Beneficiaries", "Female Wage Share",
24
+ "Wages by Location", "Well-being Cost", "Worker Well-being Coverage",
25
+ "Employee Well-being Coverage", "Turnover Count", "Workforce Gender Diversity"
26
+ ],
27
+ "Governance Parameters": [
28
+ "Non-compliance Instances", "Disciplinary Actions", "Consumer Complaints",
29
+ "Customer Data Breaches", "Governance Diversity", "Purchase Concentration",
30
+ "Sales Concentration", "Related Party Transactions"
31
+ ],
32
+ "Materiality": ["material_topics"]
33
+ }
34
+
35
+ ESG_EXTRACTOR_COLLECTION = "esg_report_extracts"
36
+
37
+ company_docs = retrieve_documents(collection_name=ESG_EXTRACTOR_COLLECTION)
38
+ available_company_data = [doc["_id"] for doc in company_docs if "_id" in doc]
39
+
40
+ selected_companies = st.multiselect(
41
+ "Select up to 3 companies",
42
+ options=available_company_data,
43
+ max_selections=3
44
+ )
45
+
46
+ def get_all_years(docs) -> list:
47
+ years = set()
48
+ for doc in docs:
49
+ if "esg_reports" in doc and isinstance(doc["esg_reports"], dict):
50
+ years.update(doc["esg_reports"].keys())
51
+ return sorted(years, reverse=True)
52
+
53
+ def highlight_missing_values(df):
54
+ return df.style.map(lambda v: "background-color: #ffe6e6" if pd.isna(v) or str(v).strip() in ["", "nan", "None", "Not Available","N/A"] else "background-color: #e6ffe6")
55
+
56
+ def extract_company_name_from_doc(doc, default_name):
57
+ return doc.get("report_metadata", {}).get("company_legal_name", default_name)
58
+
59
+ if selected_companies:
60
+ all_years = get_all_years(company_docs)
61
+
62
+ selected_year = st.selectbox(
63
+ "Select a report year (applies to all selected companies)",
64
+ options=["-- Select Year --"] + all_years,
65
+ key="common_year"
66
+ )
67
+
68
+ if selected_year != "-- Select Year --":
69
+ tabs = st.tabs(list(METRIC_OPTIONS.keys()))
70
+ metric_categories = list(METRIC_OPTIONS.keys())
71
+ for i, tab in enumerate(tabs):
72
+ with tab:
73
+ st.subheader(metric_categories[i])
74
+ metric_keys = METRIC_OPTIONS[metric_categories[i]]
75
+ for metric in metric_keys:
76
+ st.markdown(f"### {metric}")
77
+
78
+ comparison_df = prepare_comparison_df(
79
+ selected_companies,
80
+ selected_year,
81
+ metric,
82
+ company_docs
83
+ )
84
+
85
+ if comparison_df is not None:
86
+ st.dataframe(highlight_missing_values(comparison_df), use_container_width=True)
87
+ else:
88
+ st.warning(f"No data found for **{metric}** in {selected_year}")
89
+ else:
90
+ st.info("Please select a year to view report comparisons.")
91
+ else:
92
+ st.info("Please select at least one company to continue.")
requirements.txt ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ streamlit
2
+ pymongo
3
+ openpyxl
4
+ dotenv
5
+ unidecode
src/services/__pycache__/mongo_db_service.cpython-313.pyc ADDED
Binary file (5.83 kB). View file
 
src/services/mongo_db_service.py ADDED
@@ -0,0 +1,139 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from typing import List, Dict, Optional, Union
3
+ from dotenv import load_dotenv
4
+ from pymongo.errors import ConnectionFailure
5
+ from pymongo import MongoClient, errors
6
+ from bson import ObjectId
7
+
8
+ from src.utils.logger import get_logger
9
+
10
+ logger = get_logger()
11
+ load_dotenv()
12
+
13
+ MONGODB_URI = os.getenv("MONGODB_URI")
14
+ MONGODB_DB_NAME = os.getenv("MONGODB_DB_NAME")
15
+ client = MongoClient(MONGODB_URI)
16
+ db = client[MONGODB_DB_NAME]
17
+ ESG_REPORT_EXTRACTS_COLLECTION = "esg_report_extracts"
18
+
19
+ def get_mongo_client() -> Optional[MongoClient]:
20
+ """
21
+ Establishes and returns a MongoDB client using credentials from the environment.
22
+ """
23
+ try:
24
+ client = MongoClient(os.getenv("MONGODB_URI"))
25
+ return client
26
+ except ConnectionFailure:
27
+ logger.error("MongoDB connection failed. Please check MONGODB_URI.")
28
+ except Exception as e:
29
+ logger.exception(f"Unexpected error while connecting to MongoDB: {str(e)}")
30
+ return None
31
+
32
+
33
+ def retrieve_documents(
34
+ collection_name: str,
35
+ query: Optional[Dict] = None,
36
+ only_ids: bool = False,
37
+ single: bool = False,
38
+ company_legal_name: Optional[str] = None,
39
+ reporting_year: Optional[int] = None
40
+ ) -> Union[List[Dict], Dict, None]:
41
+ """
42
+ Retrieves documents from a specified MongoDB collection with optional filtering.
43
+
44
+ Args:
45
+ collection_name (str): MongoDB collection name.
46
+ query (Optional[Dict]): MongoDB query filter.
47
+ only_ids (bool): If True, return only _id field for all documents.
48
+ single (bool): If True, return only a single matching document.
49
+ company_legal_name (Optional[str]): Filter by company_legal_name.
50
+ reporting_year (Optional[int]): Filter by reporting_year inside 'esg_report'.
51
+
52
+ Returns:
53
+ Union[List[Dict], Dict, None]: A list of documents, a single document, or None.
54
+ """
55
+ try:
56
+ client = get_mongo_client()
57
+ if client is None:
58
+ logger.error("MongoDB client is not available.")
59
+ return [] if not single else None
60
+
61
+ db = client[MONGODB_DB_NAME]
62
+ collection = db[collection_name]
63
+
64
+ mongo_query = query or {}
65
+
66
+ if company_legal_name:
67
+ mongo_query["report_metadata.company_legal_name"] = company_legal_name
68
+ if reporting_year is not None:
69
+ mongo_query["esg_report.year"] = reporting_year
70
+
71
+ projection = {"_id": 1} if only_ids else None
72
+
73
+ if single:
74
+ result = collection.find_one(mongo_query, projection)
75
+ logger.info(f"Retrieved single document from {collection_name} for query: {mongo_query}")
76
+ return result
77
+
78
+ documents_cursor = collection.find(mongo_query, projection)
79
+ documents = list(documents_cursor)
80
+ logger.info(f"Retrieved {len(documents)} documents from collection: {collection_name}")
81
+ return documents
82
+
83
+ except Exception as e:
84
+ logger.exception(f"An error occurred while retrieving documents: {str(e)}")
85
+ return [] if not single else None
86
+
87
+ def retrieve_document_by_id(collection_name: str, document_id, convert_to_object_id: bool = False):
88
+ """
89
+ Retrieve a single document from a MongoDB collection by _id.
90
+
91
+ Args:
92
+ collection_name (str): The name of the MongoDB collection.
93
+ document_id (str or ObjectId): The value of the _id to retrieve.
94
+ convert_to_object_id (bool): Set to True if _id is an ObjectId, not a string.
95
+
96
+ Returns:
97
+ dict or None: The document if found, otherwise None.
98
+
99
+ Raises:
100
+ ValueError: If inputs are invalid.
101
+ Exception: For any unexpected database errors.
102
+ """
103
+ if not collection_name or not isinstance(collection_name, str):
104
+ raise ValueError("Invalid collection name.")
105
+
106
+ if document_id is None:
107
+ raise ValueError("document_id must not be None.")
108
+
109
+ try:
110
+ collection = db[collection_name]
111
+
112
+ if convert_to_object_id:
113
+ try:
114
+ document_id = ObjectId(document_id)
115
+ except Exception as e:
116
+ raise ValueError(f"Invalid ObjectId format: {document_id}") from e
117
+
118
+ document = collection.find_one({"_id": document_id})
119
+
120
+ if document:
121
+ logger.info(f"Document found with _id: {document_id}")
122
+ return document
123
+ else:
124
+ logger.error(f"No document found with _id: {document_id}")
125
+ return None
126
+
127
+ except errors.PyMongoError as e:
128
+ logger.error(f"Database error while retrieving document: {e}")
129
+ raise
130
+
131
+ except Exception as ex:
132
+ logger.error(f"Unexpected error: {ex}")
133
+ raise
134
+
135
+ # all_docs = retrieve_documents(collection_name=ESG_REPORT_EXTRACTS_COLLECTION)
136
+ # print(all_docs[0]["_id"])
137
+
138
+ # collection = list_collections()
139
+ # print(collection)
src/utils/__pycache__/common_functions.cpython-313.pyc ADDED
Binary file (3.92 kB). View file
 
src/utils/__pycache__/gics_schema.cpython-313.pyc ADDED
Binary file (14.7 kB). View file
 
src/utils/__pycache__/logger.cpython-313.pyc ADDED
Binary file (1.85 kB). View file
 
src/utils/__pycache__/streamlit_function.cpython-313.pyc ADDED
Binary file (8.48 kB). View file
 
src/utils/__pycache__/system_prompts.cpython-313.pyc ADDED
Binary file (20.5 kB). View file
 
src/utils/common_functions.py ADDED
@@ -0,0 +1,61 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+ import pandas as pd
3
+
4
+ def prepare_comparison_df(selected_companies, selected_year, metric_key, company_docs):
5
+ """
6
+ Prepares a wide-format comparison DataFrame for the selected companies and metric.
7
+ """
8
+ rows = {}
9
+ for company_id in selected_companies:
10
+ doc = next((d for d in company_docs if d["_id"] == company_id), None)
11
+ if not doc or "esg_reports" not in doc:
12
+ continue
13
+
14
+ report = doc["esg_reports"].get(selected_year, {})
15
+ metric_data = report.get(metric_key, {})
16
+
17
+ def extract_final_value(val):
18
+ if isinstance(val, dict):
19
+ numeric = val.get("numeric_value")
20
+ unit = val.get("measurement_unit")
21
+ if numeric is not None:
22
+ return f"{numeric} {unit}".strip() if unit else str(numeric)
23
+ return None
24
+
25
+ def recursively_flatten(data, parent_key=""):
26
+ flat = {}
27
+
28
+ if isinstance(data, dict):
29
+ for key, val in data.items():
30
+ full_key = f"{parent_key} - {key.replace('_', ' ').title()}" if parent_key else key.replace('_', ' ').title()
31
+
32
+ if isinstance(val, dict):
33
+ extracted = extract_final_value(val)
34
+ if extracted is not None:
35
+ flat[full_key] = extracted
36
+ else:
37
+ flat.update(recursively_flatten(val, full_key))
38
+ else:
39
+ flat[full_key] = str(val) if val is not None else "Not Available"
40
+ elif parent_key:
41
+ flat[parent_key] = str(data) if data is not None else "Not Available"
42
+
43
+ return flat
44
+
45
+ # def extract_readable(data):
46
+ # if isinstance(data, dict):
47
+ # return {k.replace("_", " ").title(): (str(v) if v is not None else "Not Available") for k, v in data.items()}
48
+ # return {metric_key.replace("_", " ").title(): str(data)}
49
+
50
+ flattened = recursively_flatten(metric_data)
51
+
52
+ for key, val in flattened.items():
53
+ rows.setdefault(key, {})[company_id] = val
54
+
55
+ if not rows:
56
+ return None
57
+
58
+ df = pd.DataFrame(rows).T
59
+ df.index.name = "Metric"
60
+ df = df.fillna("Not Available")
61
+ return df
src/utils/logger.py ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import logging
3
+ from logging.handlers import RotatingFileHandler
4
+
5
+ log_file = 'sustainability_report_extractor.log'
6
+ log_dir = 'logs/app'
7
+ log_level=logging.INFO
8
+
9
+ def get_logger():
10
+
11
+ if not os.path.exists(log_dir):
12
+ os.makedirs(log_dir)
13
+
14
+ log_file_path = os.path.join(log_dir, log_file)
15
+ logger = logging.getLogger(__name__)
16
+
17
+ if not logger.hasHandlers():
18
+ logger.setLevel(log_level)
19
+
20
+ console_handler = logging.StreamHandler()
21
+ console_handler.setLevel(logging.DEBUG)
22
+
23
+ file_handler = RotatingFileHandler(log_file_path, maxBytes=5*1024*1024, backupCount=3)
24
+ file_handler.setLevel(logging.INFO)
25
+
26
+ log_format = '%(asctime)s - %(levelname)s - %(message)s'
27
+ formatter = logging.Formatter(log_format, datefmt='%Y-%m-%d %H:%M')
28
+ console_handler.setFormatter(formatter)
29
+ file_handler.setFormatter(formatter)
30
+
31
+ logger.addHandler(console_handler)
32
+ logger.addHandler(file_handler)
33
+
34
+ return logger
src/utils/streamlit_function.py ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ from typing import Union, List
3
+ from src.utils import logger
4
+
5
+ logger = logger.get_logger()
6
+
7
+ PAGE_TITLE = "PDF Extractor"
8
+ PAGE_LAYOUT = "wide"
9
+
10
+
11
+ def config_homepage(page_title=PAGE_TITLE):
12
+ """
13
+ Configures the Streamlit homepage with essential settings.
14
+
15
+ This function sets up the page title, icon, layout, and sidebar state.
16
+ It also defines custom menu items for better navigation.
17
+
18
+ Args:
19
+ page_title (str): The title displayed on the browser tab (default is PAGE_TITLE).
20
+
21
+ Key Features:
22
+ - Ensures `st.set_page_config()` is called only once to avoid errors.
23
+ - Uses constants for improved maintainability and consistency.
24
+ - Provides links for help, bug reporting, and an 'About' section.
25
+
26
+ Example:
27
+ >>> config_homepage("My Custom App")
28
+ """
29
+ if "page_config_set" not in st.session_state:
30
+ st.set_page_config(
31
+ page_title=page_title,
32
+ layout=PAGE_LAYOUT,
33
+ initial_sidebar_state="collapsed",
34
+ )
35
+ # st.session_state.page_config_set = True
36
+
37
+ def upload_file(
38
+ file_types: Union[str, List[str]] = "pdf",
39
+ label: str = "πŸ“€ Upload a file",
40
+ help_text: str = "Upload your file for processing.",
41
+ allow_multiple: bool = True,
42
+ ):
43
+ """
44
+ Streamlit file uploader widget with options.
45
+
46
+ Args:
47
+ file_types (str or list): Allowed file type(s), e.g., "pdf" or ["pdf", "docx"].
48
+ label (str): Label displayed above the uploader.
49
+ help_text (str): Tooltip help text.
50
+ allow_multiple (bool): Allow multiple file uploads.
51
+
52
+ Returns:
53
+ Uploaded file(s): A single file object or a list of file objects.
54
+ """
55
+ if isinstance(file_types, str):
56
+ file_types = [file_types]
57
+
58
+ uploaded_files = st.file_uploader(
59
+ label=label,
60
+ type=file_types,
61
+ help=help_text,
62
+ accept_multiple_files=allow_multiple
63
+ )
64
+
65
+ if st.button("Submit"):
66
+ st.session_state.pdf_file = uploaded_files
67
+ return uploaded_files