File size: 6,461 Bytes
3e8c06e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
import abc
import asyncio
from abc import abstractmethod
from dataclasses import field, dataclass
from typing import AsyncIterator, Any, Union, Iterator

from aworld.logs.util import logger
from aworld.output import Output
from aworld.output.base import RUN_FINISHED_SIGNAL


@dataclass
class Outputs(abc.ABC):
    """Base class for managing output streams in the AWorld framework.
    Provides abstract methods for adding and streaming outputs both synchronously and asynchronously.
    reference: https://github.com/openai/openai-agents-python/blob/main/src/agents/result.py
    """
    _metadata: dict = field(default_factory=dict)

    @abstractmethod
    async def add_output(self, output: Output):
        """Add an output asynchronously to the output stream.
        
        Args:
            output (Output): The output to be added
        """
        pass

    @abstractmethod
    def sync_add_output(self, output: Output):
        """Add an output synchronously to the output stream.
        
        Args:
            output (Output): The output to be added
        """
        pass

    @abstractmethod
    async def stream_events(self) -> Union[AsyncIterator[Output], list]:
        """Stream outputs asynchronously.
        
        Returns:
            AsyncIterator[Output]: An async iterator of outputs
        """
        pass

    @abstractmethod
    def sync_stream_events(self) -> Union[Iterator[Output], list]:
        """Stream outputs synchronously.
        
        Returns:
            Iterator[Output]: An iterator of outputs
        """
        pass

    @abstractmethod
    async def mark_completed(self):
        pass

    async def get_metadata(self) -> dict:
        return self._metadata

    async def set_metadata(self, metadata: dict):
        self._metadata = metadata

@dataclass
class AsyncOutputs(Outputs):
    """Intermediate class that implements the Outputs interface with async support.
    This class serves as a base for more specific async output implementations."""

    async def add_output(self, output: Output):
        pass

    def sync_add_output(self, output: Output):
        pass

    async def stream_events(self) -> Union[AsyncIterator[Output], list]:
        pass

    def sync_stream_events(self) -> Union[Iterator[Output]]:
        pass


@dataclass
class DefaultOutputs(Outputs):
    """DefaultAsyncOutputs """

    _outputs: list = field(default_factory=list)

    async def add_output(self, output: Output):
        self._outputs.append(output)

    def sync_add_output(self, output: Output):
        self._outputs.append(output)

    async def stream_events(self) -> Union[AsyncIterator[Output], list]:
        return self._outputs

    def sync_stream_events(self) -> Union[Iterator[Output], list]:
        return self._outputs

    async def mark_completed(self):
        pass


@dataclass
class StreamingOutputs(AsyncOutputs):
    """Concrete implementation of AsyncOutputs that provides streaming functionality.
    Manages a queue of outputs and handles streaming with error checking and task management."""

    # Task and input related fields
    # task: Task = Field(default=None)  # The task associated with these outputs
    input: Any = field(default=None)  # Input data for the task
    usage: dict = field(default=None)  # Usage statistics

    # State tracking
    is_complete: bool = field(default=False)  # Flag indicating if streaming is complete

    # Queue for managing outputs
    _output_queue: asyncio.Queue[Output] = field(
        default_factory=asyncio.Queue, repr=False
    )

    # Internal state management
    _visited_outputs: list[Output] = field(default_factory=list)
    _stored_exception: Exception | None = field(default=None, repr=False)  # Stores any exceptions that occur
    _run_impl_task: asyncio.Task[Any] | None = field(default=None, repr=False)  # The running task

    async def add_output(self, output: Output):
        """Add an output to the queue asynchronously.
        
        Args:
            output (Output): The output to be added to the queue
        """
        self._output_queue.put_nowait(output)

    async def stream_events(self) -> AsyncIterator[Output]:
        """Stream outputs asynchronously, handling cached outputs and new outputs from the queue.
        Includes error checking and task cleanup.
        
        Yields:
            Output: The next output in the stream
            
        Raises:
            Exception: Any stored exception that occurred during streaming
        """
        # First yield any cached outputs
        for output in self._visited_outputs:
            if output == RUN_FINISHED_SIGNAL:
                self._output_queue.task_done()
                return
            yield output

        # Main streaming loop
        while True:
            self._check_errors()
            if self._stored_exception:
                logger.debug("Breaking due to stored exception")
                self.is_complete = True
                break

            if self.is_complete and self._output_queue.empty():
                break

            try:
                output = await self._output_queue.get()
                self._visited_outputs.append(output)

            except asyncio.CancelledError:
                break

            if output == RUN_FINISHED_SIGNAL:
                self._output_queue.task_done()
                self._check_errors()
                break

            yield output
            self._output_queue.task_done()

        self._cleanup_tasks()

        if self._stored_exception:
            raise self._stored_exception

    def _check_errors(self):
        """Check for errors in the streaming process.
        Verifies step count and checks for exceptions in the running task.
        """
        # Check the task for any exceptions
        if self._run_impl_task and self._run_impl_task.done():
            exc = self._run_impl_task.exception()
            if exc and isinstance(exc, Exception):
                self._stored_exception = exc

    def _cleanup_tasks(self):
        """Clean up any running tasks by cancelling them if they're not done."""
        if self._run_impl_task and not self._run_impl_task.done():
            self._run_impl_task.cancel()

    async def mark_completed(self) -> None:
        """Mark the streaming process as completed by adding a RUN_FINISHED_SIGNAL to the queue."""
        await self._output_queue.put(RUN_FINISHED_SIGNAL)