Quentin Gallouédec commited on
Commit
3922a8b
·
1 Parent(s): 041b899

add lot of envs

Browse files
Files changed (1) hide show
  1. app.py +114 -14
app.py CHANGED
@@ -9,9 +9,10 @@ from apscheduler.schedulers.background import BackgroundScheduler
9
  from huggingface_hub import HfApi, hf_hub_download
10
 
11
  from src.backend import backend_routine
12
- from src.css_html_js import dark_mode_gradio_js
13
  from src.logging import configure_root_logger, setup_logger
14
 
 
15
  configure_root_logger()
16
  logger = setup_logger(__name__)
17
 
@@ -19,20 +20,101 @@ API = HfApi(token=os.environ.get("TOKEN"))
19
  RESULTS_REPO = f"open-rl-leaderboard/results"
20
  ALL_ENV_IDS = {
21
  "Atari": [
22
- "BeamRiderNoFrameskip-v4",
23
- "BreakoutNoFrameskip-v4",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
24
  ],
25
  "Box2D": [
26
  "LunarLander-v2",
 
27
  "BipedalWalker-v3",
 
 
 
 
 
 
 
 
28
  ],
29
  "Classic control": [
 
30
  "CartPole-v1",
31
  "MountainCar-v0",
 
 
32
  ],
33
  "MuJoCo": [
34
- "Hopper-v4",
 
 
 
35
  "HalfCheetah-v4",
 
 
 
 
 
 
36
  ],
37
  }
38
 
@@ -77,12 +159,14 @@ The Open RL Leaderboard is a community-driven benchmark for reinforcement learni
77
 
78
  def select_env(df: pd.DataFrame, env_id: str):
79
  df = df[df["env_id"] == env_id]
80
-
81
- # Add the ranking
82
  df = df.sort_values("mean_episodic_return", ascending=False)
83
  df["ranking"] = np.arange(1, len(df) + 1)
 
84
 
 
 
85
  # Add hyperlinks
 
86
  for index, row in df.iterrows():
87
  user_id = row["user_id"]
88
  model_id = row["model_id"]
@@ -105,26 +189,42 @@ with gr.Blocks() as demo:
105
  for env_id in env_ids:
106
  with gr.TabItem(env_id):
107
  with gr.Row(equal_height=False):
 
 
 
108
  gr.components.Dataframe(
109
- value=select_env(df, env_id),
110
  headers=["🏆 Ranking", "🧑 User", "🤖 Model id", "📊 Mean episodic return"],
111
  datatype=["number", "markdown", "markdown", "number"],
112
  row_count=(10, "fixed"),
113
  scale=3,
114
  )
115
- gr.Video(
116
- "https://huggingface.co/qgallouedec/MsPacmanNoFrameskip-v4-dqn_atari-seed1/resolve/main/replay.mp4",
117
- autoplay=True,
118
- scale=1,
119
- min_width=50,
120
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
121
 
122
  with gr.TabItem("📝 About", elem_id="llm-benchmark-tab-table", id=2):
123
  gr.Markdown(ABOUT_TEXT)
124
 
125
 
126
  scheduler = BackgroundScheduler()
127
- scheduler.add_job(func=backend_routine, trigger="interval", seconds=0.5 * 60, max_instances=1)
128
  scheduler.start()
129
 
130
 
 
9
  from huggingface_hub import HfApi, hf_hub_download
10
 
11
  from src.backend import backend_routine
12
+
13
  from src.logging import configure_root_logger, setup_logger
14
 
15
+
16
  configure_root_logger()
17
  logger = setup_logger(__name__)
18
 
 
20
  RESULTS_REPO = f"open-rl-leaderboard/results"
21
  ALL_ENV_IDS = {
22
  "Atari": [
23
+ "Adventure",
24
+ "AirRaid",
25
+ "Alien",
26
+ "Amidar",
27
+ "Assault",
28
+ "Asterix",
29
+ "Asteroids",
30
+ "Atlantis",
31
+ "BankHeist",
32
+ "BattleZone",
33
+ "BeamRider",
34
+ "Berzerk",
35
+ "Bowling",
36
+ "Boxing",
37
+ "Breakout",
38
+ "Carnival",
39
+ "Centipede",
40
+ "ChopperCommand",
41
+ "CrazyClimber",
42
+ "Defender",
43
+ "DemonAttack",
44
+ "DoubleDunk",
45
+ "ElevatorAction",
46
+ "Enduro",
47
+ "FishingDerby",
48
+ "Freeway",
49
+ "Frostbite",
50
+ "Gopher",
51
+ "Gravitar",
52
+ "Hero",
53
+ "IceHockey",
54
+ "Jamesbond",
55
+ "JourneyEscape",
56
+ "Kangaroo",
57
+ "Krull",
58
+ "KungFuMaster",
59
+ "MontezumaRevenge",
60
+ "MsPacman",
61
+ "NameThisGame",
62
+ "Phoenix",
63
+ "Pitfall",
64
+ "Pong",
65
+ "Pooyan",
66
+ "PrivateEye",
67
+ "Qbert",
68
+ "Riverraid",
69
+ "RoadRunner",
70
+ "Robotank",
71
+ "Seaquest",
72
+ "Skiing",
73
+ "Solaris",
74
+ "SpaceInvaders",
75
+ "StarGunner",
76
+ "Tennis",
77
+ "TimePilot",
78
+ "Tutankham",
79
+ "UpNDown",
80
+ "Venture",
81
+ "VideoPinball",
82
+ "WizardOfWor",
83
+ "YarsRevenge",
84
+ "Zaxxon",
85
  ],
86
  "Box2D": [
87
  "LunarLander-v2",
88
+ "LunarLanderContinuous-v2",
89
  "BipedalWalker-v3",
90
+ "BipedalWalkerHardcore-v3",
91
+ "CarRacing-v2",
92
+ ],
93
+ "Toy text": [
94
+ "Blackjack-v1",
95
+ "FrozenLake-v1",
96
+ "FrozenLake8x8-v1",
97
+ "CliffWalking-v0",
98
  ],
99
  "Classic control": [
100
+ "Acrobot-v1",
101
  "CartPole-v1",
102
  "MountainCar-v0",
103
+ "MountainCarContinuous-v0",
104
+ "Pendulum-v1",
105
  ],
106
  "MuJoCo": [
107
+ "Reacher-v4",
108
+ "Pusher-v4",
109
+ "InvertedPendulum-v4",
110
+ "InvertedDoublePendulum-v4",
111
  "HalfCheetah-v4",
112
+ "Hopper-v4",
113
+ "Swimmer-v4",
114
+ "Walker2d-v4",
115
+ "Ant-v4",
116
+ "Humanoid-v4",
117
+ "HumanoidStandup-v4",
118
  ],
119
  }
120
 
 
159
 
160
  def select_env(df: pd.DataFrame, env_id: str):
161
  df = df[df["env_id"] == env_id]
 
 
162
  df = df.sort_values("mean_episodic_return", ascending=False)
163
  df["ranking"] = np.arange(1, len(df) + 1)
164
+ return df
165
 
166
+
167
+ def format_df(df: pd.DataFrame):
168
  # Add hyperlinks
169
+ df = df.copy()
170
  for index, row in df.iterrows():
171
  user_id = row["user_id"]
172
  model_id = row["model_id"]
 
189
  for env_id in env_ids:
190
  with gr.TabItem(env_id):
191
  with gr.Row(equal_height=False):
192
+ if env_domain == "Atari":
193
+ env_id = f"{env_id}NoFrameskip-v4"
194
+ env_df = select_env(df, env_id)
195
  gr.components.Dataframe(
196
+ value=format_df(env_df),
197
  headers=["🏆 Ranking", "🧑 User", "🤖 Model id", "📊 Mean episodic return"],
198
  datatype=["number", "markdown", "markdown", "number"],
199
  row_count=(10, "fixed"),
200
  scale=3,
201
  )
202
+ # Get the best model and
203
+ if not env_df.empty:
204
+ user_id = env_df.iloc[0]["user_id"]
205
+ model_id = env_df.iloc[0]["model_id"]
206
+ video_path = hf_hub_download(
207
+ repo_id=f"{user_id}/{model_id}",
208
+ filename="replay.mp4",
209
+ revision="main",
210
+ repo_type="model",
211
+ )
212
+ video = gr.PlayableVideo(
213
+ video_path,
214
+ autoplay=True,
215
+ scale=1,
216
+ min_width=50,
217
+ show_download_button=False,
218
+ label=model_id,
219
+ )
220
+ # Doesn't loop for the moment, see https://github.com/gradio-app/gradio/issues/7689
221
 
222
  with gr.TabItem("📝 About", elem_id="llm-benchmark-tab-table", id=2):
223
  gr.Markdown(ABOUT_TEXT)
224
 
225
 
226
  scheduler = BackgroundScheduler()
227
+ scheduler.add_job(func=backend_routine, trigger="interval", seconds=10 * 60, max_instances=1)
228
  scheduler.start()
229
 
230