|
|
|
|
|
|
|
|
import asyncio |
|
|
import time |
|
|
from typing import Optional, TypeVar |
|
|
|
|
|
from agents import Agent, ModelSettings, Runner, set_default_openai_client |
|
|
from openai import ( |
|
|
APIConnectionError, |
|
|
APIStatusError, |
|
|
AsyncOpenAI, |
|
|
OpenAIError, |
|
|
RateLimitError, |
|
|
) |
|
|
from pydantic import BaseModel |
|
|
|
|
|
from ankigen_core.logging import logger |
|
|
from ankigen_core.utils import ResponseCache |
|
|
|
|
|
T = TypeVar("T", bound=BaseModel) |
|
|
|
|
|
|
|
|
class OpenAIClientManager: |
|
|
"""Manages the AsyncOpenAI client instance.""" |
|
|
|
|
|
def __init__(self): |
|
|
self._client: Optional[AsyncOpenAI] = None |
|
|
self._api_key: Optional[str] = None |
|
|
|
|
|
async def initialize_client(self, api_key: str): |
|
|
"""Initializes the AsyncOpenAI client with the given API key.""" |
|
|
if not api_key or not api_key.startswith("sk-"): |
|
|
logger.error("Invalid OpenAI API key provided for client initialization.") |
|
|
raise ValueError("Invalid OpenAI API key format.") |
|
|
self._api_key = api_key |
|
|
try: |
|
|
self._client = AsyncOpenAI(api_key=self._api_key) |
|
|
logger.info("AsyncOpenAI client initialized successfully.") |
|
|
except OpenAIError as e: |
|
|
logger.error(f"Failed to initialize AsyncOpenAI client: {e}", exc_info=True) |
|
|
self._client = None |
|
|
raise |
|
|
except Exception as e: |
|
|
logger.error( |
|
|
f"An unexpected error occurred during AsyncOpenAI client initialization: {e}", |
|
|
exc_info=True, |
|
|
) |
|
|
self._client = None |
|
|
raise RuntimeError("Unexpected error initializing AsyncOpenAI client.") |
|
|
|
|
|
def get_client(self) -> AsyncOpenAI: |
|
|
"""Returns the initialized AsyncOpenAI client. Raises error if not initialized.""" |
|
|
if self._client is None: |
|
|
logger.error( |
|
|
"AsyncOpenAI client accessed before initialization or after a failed initialization." |
|
|
) |
|
|
raise RuntimeError( |
|
|
"AsyncOpenAI client is not initialized. Please provide a valid API key." |
|
|
) |
|
|
return self._client |
|
|
|
|
|
def __enter__(self): |
|
|
"""Context manager entry.""" |
|
|
return self |
|
|
|
|
|
def __exit__(self, exc_type, exc_val, exc_tb): |
|
|
"""Context manager exit - cleanup resources.""" |
|
|
self.close() |
|
|
return False |
|
|
|
|
|
async def __aenter__(self): |
|
|
"""Async context manager entry.""" |
|
|
return self |
|
|
|
|
|
async def __aexit__(self, exc_type, exc_val, exc_tb): |
|
|
"""Async context manager exit - cleanup resources.""" |
|
|
await self.aclose() |
|
|
return False |
|
|
|
|
|
def close(self) -> None: |
|
|
"""Close the OpenAI client synchronously.""" |
|
|
if self._client: |
|
|
try: |
|
|
|
|
|
if hasattr(self._client, "close"): |
|
|
self._client.close() |
|
|
logger.debug("OpenAI client closed") |
|
|
except Exception as e: |
|
|
logger.warning(f"Error closing OpenAI client: {e}") |
|
|
finally: |
|
|
self._client = None |
|
|
|
|
|
async def aclose(self) -> None: |
|
|
"""Close the OpenAI client asynchronously.""" |
|
|
if self._client: |
|
|
try: |
|
|
|
|
|
if hasattr(self._client, "aclose"): |
|
|
await self._client.aclose() |
|
|
elif hasattr(self._client, "close"): |
|
|
self._client.close() |
|
|
logger.debug("OpenAI client closed (async)") |
|
|
except Exception as e: |
|
|
logger.warning(f"Error closing OpenAI client: {e}") |
|
|
finally: |
|
|
self._client = None |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
async def structured_agent_call( |
|
|
openai_client: AsyncOpenAI, |
|
|
model: str, |
|
|
instructions: str, |
|
|
user_input: str, |
|
|
output_type: type[T], |
|
|
cache: Optional[ResponseCache] = None, |
|
|
cache_key: Optional[str] = None, |
|
|
temperature: float = 0.7, |
|
|
timeout: float = 120.0, |
|
|
retry_attempts: int = 3, |
|
|
) -> T: |
|
|
""" |
|
|
Make a single-turn structured output call using the agents SDK. |
|
|
|
|
|
This is a lightweight wrapper for simple structured output calls, |
|
|
not intended for complex multi-agent workflows. |
|
|
|
|
|
Args: |
|
|
openai_client: AsyncOpenAI client instance |
|
|
model: Model name (e.g., "gpt-5.2", "gpt-5.2-chat-latest") |
|
|
instructions: System instructions for the agent |
|
|
user_input: User prompt/input |
|
|
output_type: Pydantic model class for structured output |
|
|
cache: Optional ResponseCache instance |
|
|
cache_key: Cache key (required if cache is provided) |
|
|
temperature: Model temperature (default 0.7) |
|
|
timeout: Request timeout in seconds (default 120) |
|
|
retry_attempts: Number of retry attempts (default 3) |
|
|
|
|
|
Returns: |
|
|
Instance of output_type with the structured response |
|
|
""" |
|
|
|
|
|
if cache and cache_key: |
|
|
cached = cache.get(cache_key, model) |
|
|
if cached is not None: |
|
|
logger.info(f"Using cached response for model {model}") |
|
|
|
|
|
if isinstance(cached, dict): |
|
|
return output_type.model_validate(cached) |
|
|
return cached |
|
|
|
|
|
|
|
|
set_default_openai_client(openai_client, use_for_tracing=False) |
|
|
|
|
|
|
|
|
model_settings_kwargs: dict = {"temperature": temperature} |
|
|
|
|
|
|
|
|
if model.startswith("gpt-5") and "chat-latest" not in model: |
|
|
from openai.types.shared import Reasoning |
|
|
|
|
|
model_settings_kwargs["reasoning"] = Reasoning(effort="none") |
|
|
|
|
|
model_settings = ModelSettings(**model_settings_kwargs) |
|
|
|
|
|
|
|
|
agent = Agent( |
|
|
name="structured_output_agent", |
|
|
instructions=instructions, |
|
|
model=model, |
|
|
model_settings=model_settings, |
|
|
output_type=output_type, |
|
|
) |
|
|
|
|
|
|
|
|
last_error: Optional[Exception] = None |
|
|
for attempt in range(retry_attempts): |
|
|
try: |
|
|
result = await asyncio.wait_for( |
|
|
Runner.run(agent, user_input), |
|
|
timeout=timeout, |
|
|
) |
|
|
|
|
|
|
|
|
output = result.final_output |
|
|
|
|
|
|
|
|
if cache and cache_key and output is not None: |
|
|
if isinstance(output, BaseModel): |
|
|
cache.set(cache_key, model, output.model_dump()) |
|
|
else: |
|
|
cache.set(cache_key, model, output) |
|
|
|
|
|
logger.debug(f"Successfully received response from model {model}") |
|
|
return output |
|
|
|
|
|
except asyncio.TimeoutError as e: |
|
|
last_error = e |
|
|
if attempt < retry_attempts - 1: |
|
|
wait_time = 4 * (2**attempt) |
|
|
logger.warning( |
|
|
f"Agent timed out (attempt {attempt + 1}/{retry_attempts}), " |
|
|
f"retrying in {wait_time}s..." |
|
|
) |
|
|
await asyncio.sleep(wait_time) |
|
|
continue |
|
|
logger.error(f"Agent timed out after {retry_attempts} attempts") |
|
|
raise |
|
|
except Exception as e: |
|
|
last_error = e |
|
|
if attempt < retry_attempts - 1: |
|
|
wait_time = 4 * (2**attempt) |
|
|
logger.warning( |
|
|
f"Agent failed (attempt {attempt + 1}/{retry_attempts}): {e}, " |
|
|
f"retrying in {wait_time}s..." |
|
|
) |
|
|
await asyncio.sleep(wait_time) |
|
|
continue |
|
|
logger.error(f"Agent failed after {retry_attempts} attempts: {e}") |
|
|
raise |
|
|
|
|
|
raise RuntimeError(f"Retry loop exited without result: {last_error}") |
|
|
|
|
|
|
|
|
|
|
|
class GenericJsonOutput(BaseModel): |
|
|
"""Generic container for JSON output - allows any structure.""" |
|
|
|
|
|
model_config = {"extra": "allow"} |
|
|
|
|
|
|
|
|
async def structured_output_completion( |
|
|
openai_client: AsyncOpenAI, |
|
|
model: str, |
|
|
response_format: dict, |
|
|
system_prompt: str, |
|
|
user_prompt: str, |
|
|
cache: ResponseCache, |
|
|
) -> Optional[dict]: |
|
|
""" |
|
|
Makes an API call with structured output using agents SDK. |
|
|
|
|
|
Note: response_format parameter is ignored - the agents SDK handles |
|
|
JSON parsing automatically. For typed outputs, use structured_agent_call() directly. |
|
|
""" |
|
|
cache_key = f"{system_prompt}:{user_prompt}" |
|
|
|
|
|
|
|
|
effective_system_prompt = system_prompt |
|
|
if "JSON object matching the specified schema" not in system_prompt: |
|
|
effective_system_prompt = f"{system_prompt}\nProvide your response as a JSON object matching the specified schema." |
|
|
|
|
|
try: |
|
|
result = await structured_agent_call( |
|
|
openai_client=openai_client, |
|
|
model=model, |
|
|
instructions=effective_system_prompt.strip(), |
|
|
user_input=user_prompt.strip(), |
|
|
output_type=GenericJsonOutput, |
|
|
cache=cache, |
|
|
cache_key=cache_key, |
|
|
temperature=0.7, |
|
|
) |
|
|
|
|
|
|
|
|
if isinstance(result, BaseModel): |
|
|
return result.model_dump() |
|
|
return result |
|
|
|
|
|
except Exception as e: |
|
|
logger.error( |
|
|
f"structured_output_completion failed for model {model}: {e}", |
|
|
exc_info=True, |
|
|
) |
|
|
raise |
|
|
|
|
|
|
|
|
|
|
|
RETRYABLE_OPENAI_ERRORS = ( |
|
|
APIConnectionError, |
|
|
RateLimitError, |
|
|
APIStatusError, |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class OpenAIRateLimiter: |
|
|
"""Manages token usage to proactively stay within (estimated) OpenAI rate limits.""" |
|
|
|
|
|
def __init__(self, tokens_per_minute: int = 60000): |
|
|
self.tokens_per_minute_limit: int = tokens_per_minute |
|
|
self.tokens_used_current_window: int = 0 |
|
|
self.current_window_start_time: float = time.monotonic() |
|
|
|
|
|
async def wait_if_needed(self, estimated_tokens_for_request: int): |
|
|
"""Waits if adding the estimated tokens would exceed the rate limit for the current window.""" |
|
|
current_time = time.monotonic() |
|
|
|
|
|
|
|
|
if current_time - self.current_window_start_time >= 60.0: |
|
|
|
|
|
self.current_window_start_time = current_time |
|
|
self.tokens_used_current_window = 0 |
|
|
logger.debug("OpenAIRateLimiter: Window reset.") |
|
|
|
|
|
|
|
|
if ( |
|
|
self.tokens_used_current_window + estimated_tokens_for_request |
|
|
> self.tokens_per_minute_limit |
|
|
): |
|
|
time_to_wait = (self.current_window_start_time + 60.0) - current_time |
|
|
if time_to_wait > 0: |
|
|
logger.info( |
|
|
f"OpenAIRateLimiter: Approaching token limit. Waiting for {time_to_wait:.2f} seconds to reset window." |
|
|
) |
|
|
await asyncio.sleep(time_to_wait) |
|
|
|
|
|
self.current_window_start_time = time.monotonic() |
|
|
self.tokens_used_current_window = 0 |
|
|
logger.debug("OpenAIRateLimiter: Window reset after waiting.") |
|
|
|
|
|
|
|
|
|
|
|
self.tokens_used_current_window += estimated_tokens_for_request |
|
|
logger.debug( |
|
|
f"OpenAIRateLimiter: Tokens used in current window: {self.tokens_used_current_window}/{self.tokens_per_minute_limit}" |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
openai_rate_limiter = OpenAIRateLimiter() |
|
|
|