Spaces:
Runtime error
Runtime error
Quentin Gallouédec
commited on
Commit
·
3922a8b
1
Parent(s):
041b899
add lot of envs
Browse files
app.py
CHANGED
@@ -9,9 +9,10 @@ from apscheduler.schedulers.background import BackgroundScheduler
|
|
9 |
from huggingface_hub import HfApi, hf_hub_download
|
10 |
|
11 |
from src.backend import backend_routine
|
12 |
-
|
13 |
from src.logging import configure_root_logger, setup_logger
|
14 |
|
|
|
15 |
configure_root_logger()
|
16 |
logger = setup_logger(__name__)
|
17 |
|
@@ -19,20 +20,101 @@ API = HfApi(token=os.environ.get("TOKEN"))
|
|
19 |
RESULTS_REPO = f"open-rl-leaderboard/results"
|
20 |
ALL_ENV_IDS = {
|
21 |
"Atari": [
|
22 |
-
"
|
23 |
-
"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
24 |
],
|
25 |
"Box2D": [
|
26 |
"LunarLander-v2",
|
|
|
27 |
"BipedalWalker-v3",
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
28 |
],
|
29 |
"Classic control": [
|
|
|
30 |
"CartPole-v1",
|
31 |
"MountainCar-v0",
|
|
|
|
|
32 |
],
|
33 |
"MuJoCo": [
|
34 |
-
"
|
|
|
|
|
|
|
35 |
"HalfCheetah-v4",
|
|
|
|
|
|
|
|
|
|
|
|
|
36 |
],
|
37 |
}
|
38 |
|
@@ -77,12 +159,14 @@ The Open RL Leaderboard is a community-driven benchmark for reinforcement learni
|
|
77 |
|
78 |
def select_env(df: pd.DataFrame, env_id: str):
|
79 |
df = df[df["env_id"] == env_id]
|
80 |
-
|
81 |
-
# Add the ranking
|
82 |
df = df.sort_values("mean_episodic_return", ascending=False)
|
83 |
df["ranking"] = np.arange(1, len(df) + 1)
|
|
|
84 |
|
|
|
|
|
85 |
# Add hyperlinks
|
|
|
86 |
for index, row in df.iterrows():
|
87 |
user_id = row["user_id"]
|
88 |
model_id = row["model_id"]
|
@@ -105,26 +189,42 @@ with gr.Blocks() as demo:
|
|
105 |
for env_id in env_ids:
|
106 |
with gr.TabItem(env_id):
|
107 |
with gr.Row(equal_height=False):
|
|
|
|
|
|
|
108 |
gr.components.Dataframe(
|
109 |
-
value=
|
110 |
headers=["🏆 Ranking", "🧑 User", "🤖 Model id", "📊 Mean episodic return"],
|
111 |
datatype=["number", "markdown", "markdown", "number"],
|
112 |
row_count=(10, "fixed"),
|
113 |
scale=3,
|
114 |
)
|
115 |
-
|
116 |
-
|
117 |
-
|
118 |
-
|
119 |
-
|
120 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
121 |
|
122 |
with gr.TabItem("📝 About", elem_id="llm-benchmark-tab-table", id=2):
|
123 |
gr.Markdown(ABOUT_TEXT)
|
124 |
|
125 |
|
126 |
scheduler = BackgroundScheduler()
|
127 |
-
scheduler.add_job(func=backend_routine, trigger="interval", seconds=
|
128 |
scheduler.start()
|
129 |
|
130 |
|
|
|
9 |
from huggingface_hub import HfApi, hf_hub_download
|
10 |
|
11 |
from src.backend import backend_routine
|
12 |
+
|
13 |
from src.logging import configure_root_logger, setup_logger
|
14 |
|
15 |
+
|
16 |
configure_root_logger()
|
17 |
logger = setup_logger(__name__)
|
18 |
|
|
|
20 |
RESULTS_REPO = f"open-rl-leaderboard/results"
|
21 |
ALL_ENV_IDS = {
|
22 |
"Atari": [
|
23 |
+
"Adventure",
|
24 |
+
"AirRaid",
|
25 |
+
"Alien",
|
26 |
+
"Amidar",
|
27 |
+
"Assault",
|
28 |
+
"Asterix",
|
29 |
+
"Asteroids",
|
30 |
+
"Atlantis",
|
31 |
+
"BankHeist",
|
32 |
+
"BattleZone",
|
33 |
+
"BeamRider",
|
34 |
+
"Berzerk",
|
35 |
+
"Bowling",
|
36 |
+
"Boxing",
|
37 |
+
"Breakout",
|
38 |
+
"Carnival",
|
39 |
+
"Centipede",
|
40 |
+
"ChopperCommand",
|
41 |
+
"CrazyClimber",
|
42 |
+
"Defender",
|
43 |
+
"DemonAttack",
|
44 |
+
"DoubleDunk",
|
45 |
+
"ElevatorAction",
|
46 |
+
"Enduro",
|
47 |
+
"FishingDerby",
|
48 |
+
"Freeway",
|
49 |
+
"Frostbite",
|
50 |
+
"Gopher",
|
51 |
+
"Gravitar",
|
52 |
+
"Hero",
|
53 |
+
"IceHockey",
|
54 |
+
"Jamesbond",
|
55 |
+
"JourneyEscape",
|
56 |
+
"Kangaroo",
|
57 |
+
"Krull",
|
58 |
+
"KungFuMaster",
|
59 |
+
"MontezumaRevenge",
|
60 |
+
"MsPacman",
|
61 |
+
"NameThisGame",
|
62 |
+
"Phoenix",
|
63 |
+
"Pitfall",
|
64 |
+
"Pong",
|
65 |
+
"Pooyan",
|
66 |
+
"PrivateEye",
|
67 |
+
"Qbert",
|
68 |
+
"Riverraid",
|
69 |
+
"RoadRunner",
|
70 |
+
"Robotank",
|
71 |
+
"Seaquest",
|
72 |
+
"Skiing",
|
73 |
+
"Solaris",
|
74 |
+
"SpaceInvaders",
|
75 |
+
"StarGunner",
|
76 |
+
"Tennis",
|
77 |
+
"TimePilot",
|
78 |
+
"Tutankham",
|
79 |
+
"UpNDown",
|
80 |
+
"Venture",
|
81 |
+
"VideoPinball",
|
82 |
+
"WizardOfWor",
|
83 |
+
"YarsRevenge",
|
84 |
+
"Zaxxon",
|
85 |
],
|
86 |
"Box2D": [
|
87 |
"LunarLander-v2",
|
88 |
+
"LunarLanderContinuous-v2",
|
89 |
"BipedalWalker-v3",
|
90 |
+
"BipedalWalkerHardcore-v3",
|
91 |
+
"CarRacing-v2",
|
92 |
+
],
|
93 |
+
"Toy text": [
|
94 |
+
"Blackjack-v1",
|
95 |
+
"FrozenLake-v1",
|
96 |
+
"FrozenLake8x8-v1",
|
97 |
+
"CliffWalking-v0",
|
98 |
],
|
99 |
"Classic control": [
|
100 |
+
"Acrobot-v1",
|
101 |
"CartPole-v1",
|
102 |
"MountainCar-v0",
|
103 |
+
"MountainCarContinuous-v0",
|
104 |
+
"Pendulum-v1",
|
105 |
],
|
106 |
"MuJoCo": [
|
107 |
+
"Reacher-v4",
|
108 |
+
"Pusher-v4",
|
109 |
+
"InvertedPendulum-v4",
|
110 |
+
"InvertedDoublePendulum-v4",
|
111 |
"HalfCheetah-v4",
|
112 |
+
"Hopper-v4",
|
113 |
+
"Swimmer-v4",
|
114 |
+
"Walker2d-v4",
|
115 |
+
"Ant-v4",
|
116 |
+
"Humanoid-v4",
|
117 |
+
"HumanoidStandup-v4",
|
118 |
],
|
119 |
}
|
120 |
|
|
|
159 |
|
160 |
def select_env(df: pd.DataFrame, env_id: str):
|
161 |
df = df[df["env_id"] == env_id]
|
|
|
|
|
162 |
df = df.sort_values("mean_episodic_return", ascending=False)
|
163 |
df["ranking"] = np.arange(1, len(df) + 1)
|
164 |
+
return df
|
165 |
|
166 |
+
|
167 |
+
def format_df(df: pd.DataFrame):
|
168 |
# Add hyperlinks
|
169 |
+
df = df.copy()
|
170 |
for index, row in df.iterrows():
|
171 |
user_id = row["user_id"]
|
172 |
model_id = row["model_id"]
|
|
|
189 |
for env_id in env_ids:
|
190 |
with gr.TabItem(env_id):
|
191 |
with gr.Row(equal_height=False):
|
192 |
+
if env_domain == "Atari":
|
193 |
+
env_id = f"{env_id}NoFrameskip-v4"
|
194 |
+
env_df = select_env(df, env_id)
|
195 |
gr.components.Dataframe(
|
196 |
+
value=format_df(env_df),
|
197 |
headers=["🏆 Ranking", "🧑 User", "🤖 Model id", "📊 Mean episodic return"],
|
198 |
datatype=["number", "markdown", "markdown", "number"],
|
199 |
row_count=(10, "fixed"),
|
200 |
scale=3,
|
201 |
)
|
202 |
+
# Get the best model and
|
203 |
+
if not env_df.empty:
|
204 |
+
user_id = env_df.iloc[0]["user_id"]
|
205 |
+
model_id = env_df.iloc[0]["model_id"]
|
206 |
+
video_path = hf_hub_download(
|
207 |
+
repo_id=f"{user_id}/{model_id}",
|
208 |
+
filename="replay.mp4",
|
209 |
+
revision="main",
|
210 |
+
repo_type="model",
|
211 |
+
)
|
212 |
+
video = gr.PlayableVideo(
|
213 |
+
video_path,
|
214 |
+
autoplay=True,
|
215 |
+
scale=1,
|
216 |
+
min_width=50,
|
217 |
+
show_download_button=False,
|
218 |
+
label=model_id,
|
219 |
+
)
|
220 |
+
# Doesn't loop for the moment, see https://github.com/gradio-app/gradio/issues/7689
|
221 |
|
222 |
with gr.TabItem("📝 About", elem_id="llm-benchmark-tab-table", id=2):
|
223 |
gr.Markdown(ABOUT_TEXT)
|
224 |
|
225 |
|
226 |
scheduler = BackgroundScheduler()
|
227 |
+
scheduler.add_job(func=backend_routine, trigger="interval", seconds=10 * 60, max_instances=1)
|
228 |
scheduler.start()
|
229 |
|
230 |
|