import os from typing import List, Dict, Optional, Union from dotenv import load_dotenv from pymongo.errors import ConnectionFailure from pymongo import MongoClient, errors from bson import ObjectId from src.utils.logger import get_logger logger = get_logger() load_dotenv() MONGODB_URI = os.getenv("MONGODB_URI") MONGODB_DB_NAME = os.getenv("MONGODB_DB_NAME") client = MongoClient(MONGODB_URI) db = client[MONGODB_DB_NAME] ESG_REPORT_EXTRACTS_COLLECTION = "esg_report_extracts" def get_mongo_client() -> Optional[MongoClient]: """ Establishes and returns a MongoDB client using credentials from the environment. """ try: client = MongoClient(os.getenv("MONGODB_URI")) return client except ConnectionFailure: logger.error("MongoDB connection failed. Please check MONGODB_URI.") except Exception as e: logger.exception(f"Unexpected error while connecting to MongoDB: {str(e)}") return None def retrieve_documents( collection_name: str, query: Optional[Dict] = None, only_ids: bool = False, single: bool = False, company_legal_name: Optional[str] = None, reporting_year: Optional[int] = None ) -> Union[List[Dict], Dict, None]: """ Retrieves documents from a specified MongoDB collection with optional filtering. Args: collection_name (str): MongoDB collection name. query (Optional[Dict]): MongoDB query filter. only_ids (bool): If True, return only _id field for all documents. single (bool): If True, return only a single matching document. company_legal_name (Optional[str]): Filter by company_legal_name. reporting_year (Optional[int]): Filter by reporting_year inside 'esg_report'. Returns: Union[List[Dict], Dict, None]: A list of documents, a single document, or None. """ try: client = get_mongo_client() if client is None: logger.error("MongoDB client is not available.") return [] if not single else None db = client[MONGODB_DB_NAME] collection = db[collection_name] mongo_query = query or {} if company_legal_name: mongo_query["report_metadata.company_legal_name"] = company_legal_name if reporting_year is not None: mongo_query["esg_report.year"] = reporting_year projection = {"_id": 1} if only_ids else None if single: result = collection.find_one(mongo_query, projection) logger.info(f"Retrieved single document from {collection_name} for query: {mongo_query}") return result documents_cursor = collection.find(mongo_query, projection) documents = list(documents_cursor) logger.info(f"Retrieved {len(documents)} documents from collection: {collection_name}") return documents except Exception as e: logger.exception(f"An error occurred while retrieving documents: {str(e)}") return [] if not single else None def retrieve_document_by_id(collection_name: str, document_id, convert_to_object_id: bool = False): """ Retrieve a single document from a MongoDB collection by _id. Args: collection_name (str): The name of the MongoDB collection. document_id (str or ObjectId): The value of the _id to retrieve. convert_to_object_id (bool): Set to True if _id is an ObjectId, not a string. Returns: dict or None: The document if found, otherwise None. Raises: ValueError: If inputs are invalid. Exception: For any unexpected database errors. """ if not collection_name or not isinstance(collection_name, str): raise ValueError("Invalid collection name.") if document_id is None: raise ValueError("document_id must not be None.") try: collection = db[collection_name] if convert_to_object_id: try: document_id = ObjectId(document_id) except Exception as e: raise ValueError(f"Invalid ObjectId format: {document_id}") from e document = collection.find_one({"_id": document_id}) if document: logger.info(f"Document found with _id: {document_id}") return document else: logger.error(f"No document found with _id: {document_id}") return None except errors.PyMongoError as e: logger.error(f"Database error while retrieving document: {e}") raise except Exception as ex: logger.error(f"Unexpected error: {ex}") raise # all_docs = retrieve_documents(collection_name=ESG_REPORT_EXTRACTS_COLLECTION) # print(all_docs[0]["_id"]) # collection = list_collections() # print(collection)