Spaces:
Sleeping
Sleeping
| import abc | |
| import asyncio | |
| from abc import abstractmethod | |
| from dataclasses import field, dataclass | |
| from typing import AsyncIterator, Any, Union, Iterator | |
| from aworld.logs.util import logger | |
| from aworld.output import Output | |
| from aworld.output.base import RUN_FINISHED_SIGNAL | |
| class Outputs(abc.ABC): | |
| """Base class for managing output streams in the AWorld framework. | |
| Provides abstract methods for adding and streaming outputs both synchronously and asynchronously. | |
| reference: https://github.com/openai/openai-agents-python/blob/main/src/agents/result.py | |
| """ | |
| _metadata: dict = field(default_factory=dict) | |
| async def add_output(self, output: Output): | |
| """Add an output asynchronously to the output stream. | |
| Args: | |
| output (Output): The output to be added | |
| """ | |
| pass | |
| def sync_add_output(self, output: Output): | |
| """Add an output synchronously to the output stream. | |
| Args: | |
| output (Output): The output to be added | |
| """ | |
| pass | |
| async def stream_events(self) -> Union[AsyncIterator[Output], list]: | |
| """Stream outputs asynchronously. | |
| Returns: | |
| AsyncIterator[Output]: An async iterator of outputs | |
| """ | |
| pass | |
| def sync_stream_events(self) -> Union[Iterator[Output], list]: | |
| """Stream outputs synchronously. | |
| Returns: | |
| Iterator[Output]: An iterator of outputs | |
| """ | |
| pass | |
| async def mark_completed(self): | |
| pass | |
| async def get_metadata(self) -> dict: | |
| return self._metadata | |
| async def set_metadata(self, metadata: dict): | |
| self._metadata = metadata | |
| class AsyncOutputs(Outputs): | |
| """Intermediate class that implements the Outputs interface with async support. | |
| This class serves as a base for more specific async output implementations.""" | |
| async def add_output(self, output: Output): | |
| pass | |
| def sync_add_output(self, output: Output): | |
| pass | |
| async def stream_events(self) -> Union[AsyncIterator[Output], list]: | |
| pass | |
| def sync_stream_events(self) -> Union[Iterator[Output]]: | |
| pass | |
| class DefaultOutputs(Outputs): | |
| """DefaultAsyncOutputs """ | |
| _outputs: list = field(default_factory=list) | |
| async def add_output(self, output: Output): | |
| self._outputs.append(output) | |
| def sync_add_output(self, output: Output): | |
| self._outputs.append(output) | |
| async def stream_events(self) -> Union[AsyncIterator[Output], list]: | |
| return self._outputs | |
| def sync_stream_events(self) -> Union[Iterator[Output], list]: | |
| return self._outputs | |
| async def mark_completed(self): | |
| pass | |
| class StreamingOutputs(AsyncOutputs): | |
| """Concrete implementation of AsyncOutputs that provides streaming functionality. | |
| Manages a queue of outputs and handles streaming with error checking and task management.""" | |
| # Task and input related fields | |
| # task: Task = Field(default=None) # The task associated with these outputs | |
| input: Any = field(default=None) # Input data for the task | |
| usage: dict = field(default=None) # Usage statistics | |
| # State tracking | |
| is_complete: bool = field(default=False) # Flag indicating if streaming is complete | |
| # Queue for managing outputs | |
| _output_queue: asyncio.Queue[Output] = field( | |
| default_factory=asyncio.Queue, repr=False | |
| ) | |
| # Internal state management | |
| _visited_outputs: list[Output] = field(default_factory=list) | |
| _stored_exception: Exception | None = field(default=None, repr=False) # Stores any exceptions that occur | |
| _run_impl_task: asyncio.Task[Any] | None = field(default=None, repr=False) # The running task | |
| async def add_output(self, output: Output): | |
| """Add an output to the queue asynchronously. | |
| Args: | |
| output (Output): The output to be added to the queue | |
| """ | |
| self._output_queue.put_nowait(output) | |
| async def stream_events(self) -> AsyncIterator[Output]: | |
| """Stream outputs asynchronously, handling cached outputs and new outputs from the queue. | |
| Includes error checking and task cleanup. | |
| Yields: | |
| Output: The next output in the stream | |
| Raises: | |
| Exception: Any stored exception that occurred during streaming | |
| """ | |
| # First yield any cached outputs | |
| for output in self._visited_outputs: | |
| if output == RUN_FINISHED_SIGNAL: | |
| self._output_queue.task_done() | |
| return | |
| yield output | |
| # Main streaming loop | |
| while True: | |
| self._check_errors() | |
| if self._stored_exception: | |
| logger.debug("Breaking due to stored exception") | |
| self.is_complete = True | |
| break | |
| if self.is_complete and self._output_queue.empty(): | |
| break | |
| try: | |
| output = await self._output_queue.get() | |
| self._visited_outputs.append(output) | |
| except asyncio.CancelledError: | |
| break | |
| if output == RUN_FINISHED_SIGNAL: | |
| self._output_queue.task_done() | |
| self._check_errors() | |
| break | |
| yield output | |
| self._output_queue.task_done() | |
| self._cleanup_tasks() | |
| if self._stored_exception: | |
| raise self._stored_exception | |
| def _check_errors(self): | |
| """Check for errors in the streaming process. | |
| Verifies step count and checks for exceptions in the running task. | |
| """ | |
| # Check the task for any exceptions | |
| if self._run_impl_task and self._run_impl_task.done(): | |
| exc = self._run_impl_task.exception() | |
| if exc and isinstance(exc, Exception): | |
| self._stored_exception = exc | |
| def _cleanup_tasks(self): | |
| """Clean up any running tasks by cancelling them if they're not done.""" | |
| if self._run_impl_task and not self._run_impl_task.done(): | |
| self._run_impl_task.cancel() | |
| async def mark_completed(self) -> None: | |
| """Mark the streaming process as completed by adding a RUN_FINISHED_SIGNAL to the queue.""" | |
| await self._output_queue.put(RUN_FINISHED_SIGNAL) | |