|
import os
|
|
from typing import List, Dict, Optional, Union
|
|
from dotenv import load_dotenv
|
|
from pymongo.errors import ConnectionFailure
|
|
from pymongo import MongoClient, errors
|
|
from bson import ObjectId
|
|
|
|
from src.utils.logger import get_logger
|
|
|
|
logger = get_logger()
|
|
load_dotenv()
|
|
|
|
MONGODB_URI = os.getenv("MONGODB_URI")
|
|
MONGODB_DB_NAME = os.getenv("MONGODB_DB_NAME")
|
|
client = MongoClient(MONGODB_URI)
|
|
db = client[MONGODB_DB_NAME]
|
|
ESG_REPORT_EXTRACTS_COLLECTION = "esg_report_extracts"
|
|
|
|
def get_mongo_client() -> Optional[MongoClient]:
|
|
"""
|
|
Establishes and returns a MongoDB client using credentials from the environment.
|
|
"""
|
|
try:
|
|
client = MongoClient(os.getenv("MONGODB_URI"))
|
|
return client
|
|
except ConnectionFailure:
|
|
logger.error("MongoDB connection failed. Please check MONGODB_URI.")
|
|
except Exception as e:
|
|
logger.exception(f"Unexpected error while connecting to MongoDB: {str(e)}")
|
|
return None
|
|
|
|
|
|
def retrieve_documents(
|
|
collection_name: str,
|
|
query: Optional[Dict] = None,
|
|
only_ids: bool = False,
|
|
single: bool = False,
|
|
company_legal_name: Optional[str] = None,
|
|
reporting_year: Optional[int] = None
|
|
) -> Union[List[Dict], Dict, None]:
|
|
"""
|
|
Retrieves documents from a specified MongoDB collection with optional filtering.
|
|
|
|
Args:
|
|
collection_name (str): MongoDB collection name.
|
|
query (Optional[Dict]): MongoDB query filter.
|
|
only_ids (bool): If True, return only _id field for all documents.
|
|
single (bool): If True, return only a single matching document.
|
|
company_legal_name (Optional[str]): Filter by company_legal_name.
|
|
reporting_year (Optional[int]): Filter by reporting_year inside 'esg_report'.
|
|
|
|
Returns:
|
|
Union[List[Dict], Dict, None]: A list of documents, a single document, or None.
|
|
"""
|
|
try:
|
|
client = get_mongo_client()
|
|
if client is None:
|
|
logger.error("MongoDB client is not available.")
|
|
return [] if not single else None
|
|
|
|
db = client[MONGODB_DB_NAME]
|
|
collection = db[collection_name]
|
|
|
|
mongo_query = query or {}
|
|
|
|
if company_legal_name:
|
|
mongo_query["report_metadata.company_legal_name"] = company_legal_name
|
|
if reporting_year is not None:
|
|
mongo_query["esg_report.year"] = reporting_year
|
|
|
|
projection = {"_id": 1} if only_ids else None
|
|
|
|
if single:
|
|
result = collection.find_one(mongo_query, projection)
|
|
logger.info(f"Retrieved single document from {collection_name} for query: {mongo_query}")
|
|
return result
|
|
|
|
documents_cursor = collection.find(mongo_query, projection)
|
|
documents = list(documents_cursor)
|
|
logger.info(f"Retrieved {len(documents)} documents from collection: {collection_name}")
|
|
return documents
|
|
|
|
except Exception as e:
|
|
logger.exception(f"An error occurred while retrieving documents: {str(e)}")
|
|
return [] if not single else None
|
|
|
|
def retrieve_document_by_id(collection_name: str, document_id, convert_to_object_id: bool = False):
|
|
"""
|
|
Retrieve a single document from a MongoDB collection by _id.
|
|
|
|
Args:
|
|
collection_name (str): The name of the MongoDB collection.
|
|
document_id (str or ObjectId): The value of the _id to retrieve.
|
|
convert_to_object_id (bool): Set to True if _id is an ObjectId, not a string.
|
|
|
|
Returns:
|
|
dict or None: The document if found, otherwise None.
|
|
|
|
Raises:
|
|
ValueError: If inputs are invalid.
|
|
Exception: For any unexpected database errors.
|
|
"""
|
|
if not collection_name or not isinstance(collection_name, str):
|
|
raise ValueError("Invalid collection name.")
|
|
|
|
if document_id is None:
|
|
raise ValueError("document_id must not be None.")
|
|
|
|
try:
|
|
collection = db[collection_name]
|
|
|
|
if convert_to_object_id:
|
|
try:
|
|
document_id = ObjectId(document_id)
|
|
except Exception as e:
|
|
raise ValueError(f"Invalid ObjectId format: {document_id}") from e
|
|
|
|
document = collection.find_one({"_id": document_id})
|
|
|
|
if document:
|
|
logger.info(f"Document found with _id: {document_id}")
|
|
return document
|
|
else:
|
|
logger.error(f"No document found with _id: {document_id}")
|
|
return None
|
|
|
|
except errors.PyMongoError as e:
|
|
logger.error(f"Database error while retrieving document: {e}")
|
|
raise
|
|
|
|
except Exception as ex:
|
|
logger.error(f"Unexpected error: {ex}")
|
|
raise
|
|
|
|
|
|
|
|
|
|
|
|
|