Spaces:
Sleeping
Sleeping
Upload 2 files
Browse files- examples/gym_demo/agent.py +35 -0
- examples/gym_demo/run.py +30 -0
examples/gym_demo/agent.py
ADDED
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding: utf-8
|
2 |
+
# Copyright (c) 2025 inclusionAI.
|
3 |
+
|
4 |
+
from typing import Any, Dict, Union, List
|
5 |
+
|
6 |
+
from examples.tools.common import Agents, Tools
|
7 |
+
from aworld.core.agent.base import AgentFactory
|
8 |
+
from aworld.config.conf import AgentConfig, ConfigDict
|
9 |
+
from aworld.agents.llm_agent import Agent
|
10 |
+
from aworld.core.common import Observation, ActionModel
|
11 |
+
|
12 |
+
|
13 |
+
class GymDemoAgent(Agent):
|
14 |
+
"""Example agent"""
|
15 |
+
|
16 |
+
def __init__(self, conf: Union[Dict[str, Any], ConfigDict, AgentConfig], **kwargs):
|
17 |
+
super(GymDemoAgent, self).__init__(conf, **kwargs)
|
18 |
+
|
19 |
+
def policy(self, observation: Observation, info: Dict[str, Any] = {}, **kwargs) -> Union[
|
20 |
+
List[ActionModel], None]:
|
21 |
+
import numpy as np
|
22 |
+
|
23 |
+
env_id = observation.info.get('env_id')
|
24 |
+
if env_id and env_id != 'CartPole-v1':
|
25 |
+
raise ValueError("Unsupported env")
|
26 |
+
|
27 |
+
res = np.random.randint(2)
|
28 |
+
action = [ActionModel(agent_name=self.id(), tool_name=Tools.GYM.value, action_name="play", params={"result": res})]
|
29 |
+
if observation.info.get("done"):
|
30 |
+
self._finished = True
|
31 |
+
return action
|
32 |
+
|
33 |
+
async def async_policy(self, observation: Observation, info: Dict[str, Any] = {}, **kwargs) -> Union[
|
34 |
+
List[ActionModel], None]:
|
35 |
+
return self.policy(observation, info, **kwargs)
|
examples/gym_demo/run.py
ADDED
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding: utf-8
|
2 |
+
# Copyright (c) 2025 inclusionAI.
|
3 |
+
import asyncio
|
4 |
+
|
5 |
+
from aworld.agents.parallel_llm_agent import ParallelizableAgent
|
6 |
+
from aworld.config import RunConfig
|
7 |
+
|
8 |
+
from aworld.config.conf import AgentConfig
|
9 |
+
from aworld.core.task import Task
|
10 |
+
from aworld.runner import Runners
|
11 |
+
from examples.tools.common import Tools, Agents
|
12 |
+
from examples.gym_demo.agent import GymDemoAgent as GymAgent
|
13 |
+
from examples.tools.gym_tool.async_openai_gym import OpenAIGym
|
14 |
+
|
15 |
+
|
16 |
+
async def main():
|
17 |
+
agent = GymAgent(name=Agents.GYM.value, conf=AgentConfig(), tool_names=[Tools.GYM.value])
|
18 |
+
agent = ParallelizableAgent(name=agent.name(), conf=agent.conf,
|
19 |
+
tool_names=[Tools.GYM.value],
|
20 |
+
agents=[agent])
|
21 |
+
# It can also be used `ToolFactory` for simplification.
|
22 |
+
task = Task(agent=agent,
|
23 |
+
tools_conf={Tools.GYM.value: {"env_id": "CartPole-v1", "render_mode": "human", "render": True}})
|
24 |
+
res = await Runners.run_task(task=task, run_conf=RunConfig())
|
25 |
+
return res
|
26 |
+
|
27 |
+
|
28 |
+
if __name__ == "__main__":
|
29 |
+
# We use it as a showcase to demonstrate the framework's scalability.
|
30 |
+
asyncio.run(main())
|