Ivan000 commited on
Commit
9d7767b
·
verified ·
1 Parent(s): ce94199

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +1 -21
app.py CHANGED
@@ -162,28 +162,12 @@ def evaluate_model(model, env):
162
  mean_reward, _ = evaluate_policy(model, env, n_eval_episodes=10, render=False)
163
  return mean_reward
164
 
165
- # Gradio interface
166
- def play_game():
167
- env = ArkanoidEnv()
168
- model = DQN.load("arkanoid_model")
169
- obs = env.reset()[0]
170
- done = False
171
- frames = []
172
- while not done:
173
- action, _states = model.predict(obs, deterministic=True)
174
- obs, reward, done, truncated, info = env.step(action)
175
- env.render()
176
- pygame.image.save(screen, "frame.png")
177
- frames.append(gr.Image(value="frame.png"))
178
- return frames
179
-
180
  # Real-time training function
181
  def train_and_play():
182
  env = ArkanoidEnv()
183
  model = DQN('MlpPolicy', env, verbose=1)
184
  total_timesteps = 10000
185
  timesteps_per_update = 1000
186
- frames = []
187
  video_frames = []
188
 
189
  for i in range(0, total_timesteps, timesteps_per_update):
@@ -191,7 +175,6 @@ def train_and_play():
191
  obs = env.reset()[0]
192
  done = False
193
  truncated = False
194
- episode_frames = []
195
  while not done and not truncated:
196
  action, _states = model.predict(obs, deterministic=True)
197
  obs, reward, done, truncated, info = env.step(action)
@@ -200,9 +183,6 @@ def train_and_play():
200
  frame = pygame.surfarray.array3d(pygame.display.get_surface())
201
  frame = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR)
202
  video_frames.append(frame)
203
- episode_frames.append(gr.Image(value="frame.png"))
204
- frames.extend(episode_frames)
205
- yield frames
206
 
207
  # Save the video
208
  video_path = "arkanoid_training.mp4"
@@ -213,7 +193,7 @@ def train_and_play():
213
  video_writer.release()
214
 
215
  # Return the video path
216
- return gr.Video(video_path)
217
 
218
  # Main function
219
  def main():
 
162
  mean_reward, _ = evaluate_policy(model, env, n_eval_episodes=10, render=False)
163
  return mean_reward
164
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
165
  # Real-time training function
166
  def train_and_play():
167
  env = ArkanoidEnv()
168
  model = DQN('MlpPolicy', env, verbose=1)
169
  total_timesteps = 10000
170
  timesteps_per_update = 1000
 
171
  video_frames = []
172
 
173
  for i in range(0, total_timesteps, timesteps_per_update):
 
175
  obs = env.reset()[0]
176
  done = False
177
  truncated = False
 
178
  while not done and not truncated:
179
  action, _states = model.predict(obs, deterministic=True)
180
  obs, reward, done, truncated, info = env.step(action)
 
183
  frame = pygame.surfarray.array3d(pygame.display.get_surface())
184
  frame = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR)
185
  video_frames.append(frame)
 
 
 
186
 
187
  # Save the video
188
  video_path = "arkanoid_training.mp4"
 
193
  video_writer.release()
194
 
195
  # Return the video path
196
+ return video_path
197
 
198
  # Main function
199
  def main():