Update app.py
Browse files
app.py
CHANGED
@@ -1,8 +1,3 @@
|
|
1 |
-
# app.py
|
2 |
-
# =============
|
3 |
-
# This is a complete app.py file for an Arkanoid game that a neural network will play and learn using reinforcement learning.
|
4 |
-
# The game is built using pygame, and the neural network is trained using stable-baselines3. Gradio is used for the interface.
|
5 |
-
|
6 |
import os
|
7 |
import numpy as np
|
8 |
import pygame
|
@@ -12,7 +7,6 @@ from stable_baselines3 import DQN
|
|
12 |
from stable_baselines3.common.evaluation import evaluate_policy
|
13 |
import gradio as gr
|
14 |
import cv2
|
15 |
-
import imageio
|
16 |
|
17 |
# Constants
|
18 |
SCREEN_WIDTH = 640
|
@@ -68,7 +62,7 @@ class Ball:
|
|
68 |
|
69 |
class Brick:
|
70 |
def __init__(self, x, y):
|
71 |
-
self.rect = pygame.Rect(x, y, BRICK_WIDTH, BRICK_HEIGHT)
|
72 |
|
73 |
class ArkanoidEnv(gym.Env):
|
74 |
def __init__(self):
|
@@ -85,7 +79,8 @@ class ArkanoidEnv(gym.Env):
|
|
85 |
self.seed_value = seed
|
86 |
self.paddle = Paddle()
|
87 |
self.ball = Ball()
|
88 |
-
self.bricks = [Brick(x, y) for y in range(BRICK_HEIGHT, BRICK_HEIGHT * (BRICK_ROWS + 1), BRICK_HEIGHT)
|
|
|
89 |
self.done = False
|
90 |
self.score = 0
|
91 |
return self._get_state(), {}
|
@@ -172,27 +167,27 @@ def train_and_play():
|
|
172 |
|
173 |
for i in range(0, total_timesteps, timesteps_per_update):
|
174 |
model.learn(total_timesteps=timesteps_per_update)
|
175 |
-
obs = env.reset()
|
176 |
done = False
|
177 |
truncated = False
|
178 |
while not done and not truncated:
|
179 |
action, _states = model.predict(obs, deterministic=True)
|
180 |
-
obs, reward, done, truncated,
|
181 |
env.render()
|
182 |
# Capture the current frame
|
183 |
frame = pygame.surfarray.array3d(pygame.display.get_surface())
|
|
|
184 |
frame = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR)
|
185 |
video_frames.append(frame)
|
186 |
|
187 |
# Save the video
|
188 |
video_path = "arkanoid_training.mp4"
|
189 |
-
|
190 |
-
video_writer = cv2.VideoWriter(video_path, fourcc, FPS, (SCREEN_WIDTH, SCREEN_HEIGHT))
|
191 |
for frame in video_frames:
|
192 |
video_writer.write(frame)
|
193 |
video_writer.release()
|
194 |
|
195 |
-
#
|
196 |
return video_path
|
197 |
|
198 |
# Main function
|
@@ -208,17 +203,3 @@ def main():
|
|
208 |
|
209 |
if __name__ == "__main__":
|
210 |
main()
|
211 |
-
|
212 |
-
# Dependencies
|
213 |
-
# =============
|
214 |
-
# The following dependencies are required to run this app:
|
215 |
-
# - pygame
|
216 |
-
# - stable-baselines3
|
217 |
-
# - torch
|
218 |
-
# - gradio
|
219 |
-
# - gymnasium
|
220 |
-
# - opencv-python
|
221 |
-
# - imageio
|
222 |
-
#
|
223 |
-
# You can install these dependencies using pip:
|
224 |
-
# pip install pygame stable-baselines3 torch gradio gymnasium opencv-python imageio
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
import os
|
2 |
import numpy as np
|
3 |
import pygame
|
|
|
7 |
from stable_baselines3.common.evaluation import evaluate_policy
|
8 |
import gradio as gr
|
9 |
import cv2
|
|
|
10 |
|
11 |
# Constants
|
12 |
SCREEN_WIDTH = 640
|
|
|
62 |
|
63 |
class Brick:
|
64 |
def __init__(self, x, y):
|
65 |
+
self.rect = pygame.Rect(x, y, BRICK_WIDTH - 5, BRICK_HEIGHT - 5)
|
66 |
|
67 |
class ArkanoidEnv(gym.Env):
|
68 |
def __init__(self):
|
|
|
79 |
self.seed_value = seed
|
80 |
self.paddle = Paddle()
|
81 |
self.ball = Ball()
|
82 |
+
self.bricks = [Brick(x, y) for y in range(BRICK_HEIGHT, BRICK_HEIGHT * (BRICK_ROWS + 1), BRICK_HEIGHT)
|
83 |
+
for x in range(BRICK_WIDTH, SCREEN_WIDTH - BRICK_WIDTH, BRICK_WIDTH)]
|
84 |
self.done = False
|
85 |
self.score = 0
|
86 |
return self._get_state(), {}
|
|
|
167 |
|
168 |
for i in range(0, total_timesteps, timesteps_per_update):
|
169 |
model.learn(total_timesteps=timesteps_per_update)
|
170 |
+
obs, _ = env.reset()
|
171 |
done = False
|
172 |
truncated = False
|
173 |
while not done and not truncated:
|
174 |
action, _states = model.predict(obs, deterministic=True)
|
175 |
+
obs, reward, done, truncated, _ = env.step(action)
|
176 |
env.render()
|
177 |
# Capture the current frame
|
178 |
frame = pygame.surfarray.array3d(pygame.display.get_surface())
|
179 |
+
frame = np.rot90(frame) # Fix orientation
|
180 |
frame = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR)
|
181 |
video_frames.append(frame)
|
182 |
|
183 |
# Save the video
|
184 |
video_path = "arkanoid_training.mp4"
|
185 |
+
video_writer = cv2.VideoWriter(video_path, cv2.VideoWriter_fourcc(*'mp4v'), FPS, (SCREEN_WIDTH, SCREEN_HEIGHT))
|
|
|
186 |
for frame in video_frames:
|
187 |
video_writer.write(frame)
|
188 |
video_writer.release()
|
189 |
|
190 |
+
env.close() # Ensure the environment is properly closed
|
191 |
return video_path
|
192 |
|
193 |
# Main function
|
|
|
203 |
|
204 |
if __name__ == "__main__":
|
205 |
main()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|