Update app.py
Browse files
app.py
CHANGED
@@ -162,28 +162,12 @@ def evaluate_model(model, env):
|
|
162 |
mean_reward, _ = evaluate_policy(model, env, n_eval_episodes=10, render=False)
|
163 |
return mean_reward
|
164 |
|
165 |
-
# Gradio interface
|
166 |
-
def play_game():
|
167 |
-
env = ArkanoidEnv()
|
168 |
-
model = DQN.load("arkanoid_model")
|
169 |
-
obs = env.reset()[0]
|
170 |
-
done = False
|
171 |
-
frames = []
|
172 |
-
while not done:
|
173 |
-
action, _states = model.predict(obs, deterministic=True)
|
174 |
-
obs, reward, done, truncated, info = env.step(action)
|
175 |
-
env.render()
|
176 |
-
pygame.image.save(screen, "frame.png")
|
177 |
-
frames.append(gr.Image(value="frame.png"))
|
178 |
-
return frames
|
179 |
-
|
180 |
# Real-time training function
|
181 |
def train_and_play():
|
182 |
env = ArkanoidEnv()
|
183 |
model = DQN('MlpPolicy', env, verbose=1)
|
184 |
total_timesteps = 10000
|
185 |
timesteps_per_update = 1000
|
186 |
-
frames = []
|
187 |
video_frames = []
|
188 |
|
189 |
for i in range(0, total_timesteps, timesteps_per_update):
|
@@ -191,7 +175,6 @@ def train_and_play():
|
|
191 |
obs = env.reset()[0]
|
192 |
done = False
|
193 |
truncated = False
|
194 |
-
episode_frames = []
|
195 |
while not done and not truncated:
|
196 |
action, _states = model.predict(obs, deterministic=True)
|
197 |
obs, reward, done, truncated, info = env.step(action)
|
@@ -200,9 +183,6 @@ def train_and_play():
|
|
200 |
frame = pygame.surfarray.array3d(pygame.display.get_surface())
|
201 |
frame = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR)
|
202 |
video_frames.append(frame)
|
203 |
-
episode_frames.append(gr.Image(value="frame.png"))
|
204 |
-
frames.extend(episode_frames)
|
205 |
-
yield frames
|
206 |
|
207 |
# Save the video
|
208 |
video_path = "arkanoid_training.mp4"
|
@@ -213,7 +193,7 @@ def train_and_play():
|
|
213 |
video_writer.release()
|
214 |
|
215 |
# Return the video path
|
216 |
-
return
|
217 |
|
218 |
# Main function
|
219 |
def main():
|
|
|
162 |
mean_reward, _ = evaluate_policy(model, env, n_eval_episodes=10, render=False)
|
163 |
return mean_reward
|
164 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
165 |
# Real-time training function
|
166 |
def train_and_play():
|
167 |
env = ArkanoidEnv()
|
168 |
model = DQN('MlpPolicy', env, verbose=1)
|
169 |
total_timesteps = 10000
|
170 |
timesteps_per_update = 1000
|
|
|
171 |
video_frames = []
|
172 |
|
173 |
for i in range(0, total_timesteps, timesteps_per_update):
|
|
|
175 |
obs = env.reset()[0]
|
176 |
done = False
|
177 |
truncated = False
|
|
|
178 |
while not done and not truncated:
|
179 |
action, _states = model.predict(obs, deterministic=True)
|
180 |
obs, reward, done, truncated, info = env.step(action)
|
|
|
183 |
frame = pygame.surfarray.array3d(pygame.display.get_surface())
|
184 |
frame = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR)
|
185 |
video_frames.append(frame)
|
|
|
|
|
|
|
186 |
|
187 |
# Save the video
|
188 |
video_path = "arkanoid_training.mp4"
|
|
|
193 |
video_writer.release()
|
194 |
|
195 |
# Return the video path
|
196 |
+
return video_path
|
197 |
|
198 |
# Main function
|
199 |
def main():
|