Spaces:
Sleeping
Sleeping
import os | |
import json | |
from typing import Union, BinaryIO, Optional | |
from openai import OpenAI | |
from google import genai | |
from google.genai import types | |
from application.utils import logger | |
from application.schemas.response_schema import RESPONSE_FORMAT,GEMINI_RESPONSE_FORMAT | |
logger = logger.get_logger() | |
client = OpenAI(api_key=os.getenv("OPENAI_API_KEY")) | |
# --- Constants --- | |
PROMPT = ( | |
"You are a PDF parsing agent. " | |
"Your job is to extract GHG Protocol Parameters and ESG (Environmental, Social, Governance) Data " | |
"from a company’s sustainability or ESG report in PDF format." | |
) | |
# --- OpenAI Helpers --- | |
def get_files() -> list: | |
"""Retrieve all files from OpenAI client.""" | |
try: | |
files = client.files.list() | |
logger.info(f"Retrieved {len(files.data)} files.") | |
return files.data | |
except Exception as e: | |
logger.error(f"Failed to retrieve files: {e}") | |
raise | |
def get_or_create_file(file_input: BinaryIO, client) -> object: | |
""" | |
Retrieve a file from OpenAI by name or upload it if not present. | |
Args: | |
file_input: File-like object with `.name` attribute. | |
client: OpenAI client instance. | |
Returns: | |
File object. | |
""" | |
file_name = getattr(file_input, 'name', None) | |
if not file_name: | |
raise ValueError("File input must have a 'name' attribute.") | |
try: | |
for file in get_files(): | |
if file.filename == file_name: | |
logger.info(f"File '{file_name}' already exists with ID: {file.id}") | |
return client.files.retrieve(file.id) | |
logger.info(f"Uploading new file '{file_name}'...") | |
new_file = client.files.create(file=(file_name, file_input), purpose="assistants") | |
logger.info(f"File uploaded successfully with ID: {new_file.id}") | |
return new_file | |
except Exception as e: | |
logger.error(f"Error during get_or_create_file: {e}") | |
raise | |
def delete_file_by_size(size: int, client): | |
""" | |
Deletes files from OpenAI that match a given byte size. | |
Args: | |
size: File size in bytes to match for deletion. | |
client: OpenAI client instance. | |
""" | |
try: | |
files = get_files() | |
for file in files: | |
if file.bytes == size: | |
client.files.delete(file.id) | |
logger.info(f"File {file.filename} deleted (size matched: {size} bytes).") | |
else: | |
logger.info(f"File {file.filename} skipped (size mismatch).") | |
except Exception as e: | |
logger.error(f"Failed to delete files: {e}") | |
raise | |
# --- Main Function --- | |
def extract_emissions_data_as_json( | |
api: str, | |
model: str, | |
file_input: Union[BinaryIO, bytes] | |
) -> Optional[dict]: | |
""" | |
Extract ESG data from PDF using OpenAI or Gemini APIs. | |
Args: | |
api: 'openai' or 'gemini' | |
model: Model name (e.g. gpt-4o, gemini-pro) | |
file_input: File-like object or bytes of the PDF. | |
Returns: | |
Parsed ESG data as dict or None if failed. | |
""" | |
try: | |
if api.lower() == "openai": | |
file = get_or_create_file(file_input, client) | |
logger.info("[OpenAI] Sending content for generation...") | |
response = client.chat.completions.create( | |
model=model, | |
messages=[{ | |
"role": "user", | |
"content": [ | |
{"type": "file", "file": {"file_id": file.id}}, | |
{"type": "text", "text": PROMPT} | |
] | |
}], | |
response_format=RESPONSE_FORMAT | |
) | |
result = response.choices[0].message.content | |
logger.info("ESG data extraction successful.") | |
return result | |
elif api.lower() == "gemini": | |
client = genai.Client(api_key=os.getenv("gemini_api_key")) | |
file_bytes = file_input.read() | |
logger.info("[Gemini] Sending content for generation...") | |
response = client.models.generate_content( | |
model=model, | |
contents=[ | |
types.Part.from_bytes(data=file_bytes, mime_type="application/pdf"), | |
PROMPT | |
], | |
config={ | |
'response_mime_type': 'application/json', | |
'response_schema': GEMINI_RESPONSE_FORMAT, | |
} | |
) | |
logger.info("[Gemini] Response received.") | |
try: | |
return json.loads(response.text) | |
except json.JSONDecodeError: | |
logger.warning("Failed to parse JSON, returning raw response.") | |
return {"raw_response": response.text} | |
else: | |
logger.error(f"Unsupported API: {api}") | |
return None | |
except Exception as e: | |
logger.exception("Error during ESG data extraction.") | |
return None | |
def list_all_files(): | |
"""Lists all files currently uploaded to OpenAI.""" | |
try: | |
files = get_files() | |
for file in files: | |
logger.info(f"File ID: {file.id}, Name: {file.filename}, Size: {file.bytes} bytes") | |
except Exception as e: | |
logger.error(f"Failed to list files: {e}") |