Ivan000 commited on
Commit
995402e
·
verified ·
1 Parent(s): 965d906

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +8 -27
app.py CHANGED
@@ -1,8 +1,3 @@
1
- # app.py
2
- # =============
3
- # This is a complete app.py file for an Arkanoid game that a neural network will play and learn using reinforcement learning.
4
- # The game is built using pygame, and the neural network is trained using stable-baselines3. Gradio is used for the interface.
5
-
6
  import os
7
  import numpy as np
8
  import pygame
@@ -12,7 +7,6 @@ from stable_baselines3 import DQN
12
  from stable_baselines3.common.evaluation import evaluate_policy
13
  import gradio as gr
14
  import cv2
15
- import imageio
16
 
17
  # Constants
18
  SCREEN_WIDTH = 640
@@ -68,7 +62,7 @@ class Ball:
68
 
69
  class Brick:
70
  def __init__(self, x, y):
71
- self.rect = pygame.Rect(x, y, BRICK_WIDTH, BRICK_HEIGHT)
72
 
73
  class ArkanoidEnv(gym.Env):
74
  def __init__(self):
@@ -85,7 +79,8 @@ class ArkanoidEnv(gym.Env):
85
  self.seed_value = seed
86
  self.paddle = Paddle()
87
  self.ball = Ball()
88
- self.bricks = [Brick(x, y) for y in range(BRICK_HEIGHT, BRICK_HEIGHT * (BRICK_ROWS + 1), BRICK_HEIGHT) for x in range(BRICK_WIDTH, SCREEN_WIDTH - BRICK_WIDTH, BRICK_WIDTH)]
 
89
  self.done = False
90
  self.score = 0
91
  return self._get_state(), {}
@@ -172,27 +167,27 @@ def train_and_play():
172
 
173
  for i in range(0, total_timesteps, timesteps_per_update):
174
  model.learn(total_timesteps=timesteps_per_update)
175
- obs = env.reset()[0]
176
  done = False
177
  truncated = False
178
  while not done and not truncated:
179
  action, _states = model.predict(obs, deterministic=True)
180
- obs, reward, done, truncated, info = env.step(action)
181
  env.render()
182
  # Capture the current frame
183
  frame = pygame.surfarray.array3d(pygame.display.get_surface())
 
184
  frame = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR)
185
  video_frames.append(frame)
186
 
187
  # Save the video
188
  video_path = "arkanoid_training.mp4"
189
- fourcc = cv2.VideoWriter_fourcc(*'mp4v')
190
- video_writer = cv2.VideoWriter(video_path, fourcc, FPS, (SCREEN_WIDTH, SCREEN_HEIGHT))
191
  for frame in video_frames:
192
  video_writer.write(frame)
193
  video_writer.release()
194
 
195
- # Return the video path
196
  return video_path
197
 
198
  # Main function
@@ -208,17 +203,3 @@ def main():
208
 
209
  if __name__ == "__main__":
210
  main()
211
-
212
- # Dependencies
213
- # =============
214
- # The following dependencies are required to run this app:
215
- # - pygame
216
- # - stable-baselines3
217
- # - torch
218
- # - gradio
219
- # - gymnasium
220
- # - opencv-python
221
- # - imageio
222
- #
223
- # You can install these dependencies using pip:
224
- # pip install pygame stable-baselines3 torch gradio gymnasium opencv-python imageio
 
 
 
 
 
 
1
  import os
2
  import numpy as np
3
  import pygame
 
7
  from stable_baselines3.common.evaluation import evaluate_policy
8
  import gradio as gr
9
  import cv2
 
10
 
11
  # Constants
12
  SCREEN_WIDTH = 640
 
62
 
63
  class Brick:
64
  def __init__(self, x, y):
65
+ self.rect = pygame.Rect(x, y, BRICK_WIDTH - 5, BRICK_HEIGHT - 5)
66
 
67
  class ArkanoidEnv(gym.Env):
68
  def __init__(self):
 
79
  self.seed_value = seed
80
  self.paddle = Paddle()
81
  self.ball = Ball()
82
+ self.bricks = [Brick(x, y) for y in range(BRICK_HEIGHT, BRICK_HEIGHT * (BRICK_ROWS + 1), BRICK_HEIGHT)
83
+ for x in range(BRICK_WIDTH, SCREEN_WIDTH - BRICK_WIDTH, BRICK_WIDTH)]
84
  self.done = False
85
  self.score = 0
86
  return self._get_state(), {}
 
167
 
168
  for i in range(0, total_timesteps, timesteps_per_update):
169
  model.learn(total_timesteps=timesteps_per_update)
170
+ obs, _ = env.reset()
171
  done = False
172
  truncated = False
173
  while not done and not truncated:
174
  action, _states = model.predict(obs, deterministic=True)
175
+ obs, reward, done, truncated, _ = env.step(action)
176
  env.render()
177
  # Capture the current frame
178
  frame = pygame.surfarray.array3d(pygame.display.get_surface())
179
+ frame = np.rot90(frame) # Fix orientation
180
  frame = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR)
181
  video_frames.append(frame)
182
 
183
  # Save the video
184
  video_path = "arkanoid_training.mp4"
185
+ video_writer = cv2.VideoWriter(video_path, cv2.VideoWriter_fourcc(*'mp4v'), FPS, (SCREEN_WIDTH, SCREEN_HEIGHT))
 
186
  for frame in video_frames:
187
  video_writer.write(frame)
188
  video_writer.release()
189
 
190
+ env.close() # Ensure the environment is properly closed
191
  return video_path
192
 
193
  # Main function
 
203
 
204
  if __name__ == "__main__":
205
  main()