File size: 4,898 Bytes
5d4ad83 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 |
import os
from typing import List, Dict, Optional, Union
from dotenv import load_dotenv
from pymongo.errors import ConnectionFailure
from pymongo import MongoClient, errors
from bson import ObjectId
from src.utils.logger import get_logger
logger = get_logger()
load_dotenv()
MONGODB_URI = os.getenv("MONGODB_URI")
MONGODB_DB_NAME = os.getenv("MONGODB_DB_NAME")
client = MongoClient(MONGODB_URI)
db = client[MONGODB_DB_NAME]
ESG_REPORT_EXTRACTS_COLLECTION = "esg_report_extracts"
def get_mongo_client() -> Optional[MongoClient]:
"""
Establishes and returns a MongoDB client using credentials from the environment.
"""
try:
client = MongoClient(os.getenv("MONGODB_URI"))
return client
except ConnectionFailure:
logger.error("MongoDB connection failed. Please check MONGODB_URI.")
except Exception as e:
logger.exception(f"Unexpected error while connecting to MongoDB: {str(e)}")
return None
def retrieve_documents(
collection_name: str,
query: Optional[Dict] = None,
only_ids: bool = False,
single: bool = False,
company_legal_name: Optional[str] = None,
reporting_year: Optional[int] = None
) -> Union[List[Dict], Dict, None]:
"""
Retrieves documents from a specified MongoDB collection with optional filtering.
Args:
collection_name (str): MongoDB collection name.
query (Optional[Dict]): MongoDB query filter.
only_ids (bool): If True, return only _id field for all documents.
single (bool): If True, return only a single matching document.
company_legal_name (Optional[str]): Filter by company_legal_name.
reporting_year (Optional[int]): Filter by reporting_year inside 'esg_report'.
Returns:
Union[List[Dict], Dict, None]: A list of documents, a single document, or None.
"""
try:
client = get_mongo_client()
if client is None:
logger.error("MongoDB client is not available.")
return [] if not single else None
db = client[MONGODB_DB_NAME]
collection = db[collection_name]
mongo_query = query or {}
if company_legal_name:
mongo_query["report_metadata.company_legal_name"] = company_legal_name
if reporting_year is not None:
mongo_query["esg_report.year"] = reporting_year
projection = {"_id": 1} if only_ids else None
if single:
result = collection.find_one(mongo_query, projection)
logger.info(f"Retrieved single document from {collection_name} for query: {mongo_query}")
return result
documents_cursor = collection.find(mongo_query, projection)
documents = list(documents_cursor)
logger.info(f"Retrieved {len(documents)} documents from collection: {collection_name}")
return documents
except Exception as e:
logger.exception(f"An error occurred while retrieving documents: {str(e)}")
return [] if not single else None
def retrieve_document_by_id(collection_name: str, document_id, convert_to_object_id: bool = False):
"""
Retrieve a single document from a MongoDB collection by _id.
Args:
collection_name (str): The name of the MongoDB collection.
document_id (str or ObjectId): The value of the _id to retrieve.
convert_to_object_id (bool): Set to True if _id is an ObjectId, not a string.
Returns:
dict or None: The document if found, otherwise None.
Raises:
ValueError: If inputs are invalid.
Exception: For any unexpected database errors.
"""
if not collection_name or not isinstance(collection_name, str):
raise ValueError("Invalid collection name.")
if document_id is None:
raise ValueError("document_id must not be None.")
try:
collection = db[collection_name]
if convert_to_object_id:
try:
document_id = ObjectId(document_id)
except Exception as e:
raise ValueError(f"Invalid ObjectId format: {document_id}") from e
document = collection.find_one({"_id": document_id})
if document:
logger.info(f"Document found with _id: {document_id}")
return document
else:
logger.error(f"No document found with _id: {document_id}")
return None
except errors.PyMongoError as e:
logger.error(f"Database error while retrieving document: {e}")
raise
except Exception as ex:
logger.error(f"Unexpected error: {ex}")
raise
# all_docs = retrieve_documents(collection_name=ESG_REPORT_EXTRACTS_COLLECTION)
# print(all_docs[0]["_id"])
# collection = list_collections()
# print(collection) |