Update app.py
Browse files
app.py
CHANGED
@@ -4,6 +4,7 @@ import pygame
|
|
4 |
import random
|
5 |
import gymnasium as gym
|
6 |
from stable_baselines3 import DQN
|
|
|
7 |
import gradio as gr
|
8 |
import cv2
|
9 |
|
@@ -26,8 +27,6 @@ RED = (255, 0, 0)
|
|
26 |
|
27 |
# Initialize Pygame
|
28 |
pygame.init()
|
29 |
-
screen = pygame.display.set_mode((SCREEN_WIDTH, SCREEN_HEIGHT))
|
30 |
-
pygame.display.set_caption("Arkanoid")
|
31 |
|
32 |
# Game classes
|
33 |
class Paddle:
|
@@ -41,7 +40,6 @@ class Paddle:
|
|
41 |
self.rect.x += 10
|
42 |
self.rect.clamp_ip(pygame.Rect(0, 0, SCREEN_WIDTH, SCREEN_HEIGHT))
|
43 |
|
44 |
-
|
45 |
class Ball:
|
46 |
def __init__(self):
|
47 |
self.rect = pygame.Rect(SCREEN_WIDTH // 2 - BALL_RADIUS, SCREEN_HEIGHT // 2 - BALL_RADIUS, BALL_RADIUS * 2, BALL_RADIUS * 2)
|
@@ -60,14 +58,12 @@ class Ball:
|
|
60 |
self.rect = pygame.Rect(SCREEN_WIDTH // 2 - BALL_RADIUS, SCREEN_HEIGHT // 2 - BALL_RADIUS, BALL_RADIUS * 2, BALL_RADIUS * 2)
|
61 |
self.velocity = [random.choice([-5, 5]), -5]
|
62 |
|
63 |
-
|
64 |
class Brick:
|
65 |
def __init__(self, x, y):
|
66 |
self.rect = pygame.Rect(x, y, BRICK_WIDTH - 5, BRICK_HEIGHT - 5)
|
67 |
|
68 |
-
|
69 |
class ArkanoidEnv(gym.Env):
|
70 |
-
def __init__(self, reward_size=1, penalty_size=-1, platform_reward=
|
71 |
super(ArkanoidEnv, self).__init__()
|
72 |
self.action_space = gym.spaces.Discrete(3) # 0: stay, 1: move left, 2: move right
|
73 |
self.observation_space = gym.spaces.Box(low=0, high=SCREEN_WIDTH, shape=(5 + BRICK_ROWS * BRICK_COLS * 2,), dtype=np.float32)
|
@@ -77,9 +73,12 @@ class ArkanoidEnv(gym.Env):
|
|
77 |
self.reset()
|
78 |
|
79 |
def reset(self, seed=None, options=None):
|
|
|
|
|
|
|
80 |
self.paddle = Paddle()
|
81 |
self.ball = Ball()
|
82 |
-
self.bricks = [Brick(x, y) for y in range(BRICK_HEIGHT, BRICK_HEIGHT * (BRICK_ROWS + 1), BRICK_HEIGHT)
|
83 |
for x in range(BRICK_WIDTH, SCREEN_WIDTH - BRICK_WIDTH, BRICK_WIDTH)]
|
84 |
self.done = False
|
85 |
self.score = 0
|
@@ -95,28 +94,31 @@ class ArkanoidEnv(gym.Env):
|
|
95 |
|
96 |
self.ball.move()
|
97 |
|
98 |
-
reward = 0
|
99 |
if self.ball.rect.colliderect(self.paddle.rect):
|
100 |
self.ball.velocity[1] = -self.ball.velocity[1]
|
101 |
-
|
102 |
|
103 |
for brick in self.bricks[:]:
|
104 |
if self.ball.rect.colliderect(brick.rect):
|
105 |
self.bricks.remove(brick)
|
106 |
self.ball.velocity[1] = -self.ball.velocity[1]
|
107 |
self.score += 1
|
108 |
-
reward
|
109 |
-
|
110 |
if not self.bricks:
|
111 |
-
reward += self.reward_size * 10
|
112 |
self.done = True
|
113 |
-
|
|
|
114 |
|
115 |
if self.ball.rect.bottom >= SCREEN_HEIGHT:
|
116 |
self.done = True
|
117 |
-
reward
|
|
|
|
|
|
|
|
|
118 |
|
119 |
-
return self._get_state(), reward, self.done,
|
120 |
|
121 |
def _get_state(self):
|
122 |
state = [
|
@@ -128,26 +130,26 @@ class ArkanoidEnv(gym.Env):
|
|
128 |
]
|
129 |
for brick in self.bricks:
|
130 |
state.extend([brick.rect.x, brick.rect.y])
|
131 |
-
state.extend([0, 0] * (BRICK_ROWS * BRICK_COLS - len(self.bricks)))
|
132 |
return np.array(state, dtype=np.float32)
|
133 |
|
134 |
-
def render(self, mode='
|
135 |
-
|
136 |
-
|
137 |
-
|
138 |
-
|
139 |
-
screen.fill(BLACK)
|
140 |
-
pygame.draw.rect(screen, WHITE, self.paddle.rect)
|
141 |
-
pygame.draw.ellipse(screen, WHITE, self.ball.rect)
|
142 |
for brick in self.bricks:
|
143 |
-
pygame.draw.rect(
|
144 |
-
|
145 |
-
|
|
|
|
|
|
|
|
|
146 |
|
147 |
def close(self):
|
148 |
pygame.quit()
|
149 |
|
150 |
-
|
151 |
# Training and playing with custom parameters
|
152 |
def train_and_play(reward_size, penalty_size, platform_reward, iterations):
|
153 |
env = ArkanoidEnv(reward_size=reward_size, penalty_size=penalty_size, platform_reward=platform_reward)
|
@@ -163,18 +165,11 @@ def train_and_play(reward_size, penalty_size, platform_reward, iterations):
|
|
163 |
|
164 |
obs, _ = env.reset()
|
165 |
done = False
|
166 |
-
|
167 |
-
while not done and not truncated:
|
168 |
action, _states = model.predict(obs, deterministic=True)
|
169 |
obs, reward, done, truncated, _ = env.step(action)
|
170 |
|
171 |
-
|
172 |
-
env.render()
|
173 |
-
except pygame.error:
|
174 |
-
print("Pygame display was closed. Exiting render loop.")
|
175 |
-
return "Training interrupted."
|
176 |
-
|
177 |
-
frame = pygame.surfarray.array3d(pygame.display.get_surface())
|
178 |
frame = np.rot90(frame)
|
179 |
frame = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR)
|
180 |
video_frames.append(frame)
|
@@ -188,22 +183,20 @@ def train_and_play(reward_size, penalty_size, platform_reward, iterations):
|
|
188 |
env.close()
|
189 |
return video_path
|
190 |
|
191 |
-
|
192 |
-
# Main function
|
193 |
def main():
|
194 |
iface = gr.Interface(
|
195 |
fn=train_and_play,
|
196 |
inputs=[
|
197 |
gr.Number(label="Reward Size", value=1),
|
198 |
gr.Number(label="Penalty Size", value=-1),
|
199 |
-
gr.Number(label="Platform Reward", value=
|
200 |
gr.Slider(label="Iterations", minimum=10, maximum=100000, step=10, value=10000)
|
201 |
],
|
202 |
outputs="video",
|
203 |
-
live=False #
|
204 |
)
|
205 |
iface.launch()
|
206 |
|
207 |
-
|
208 |
if __name__ == "__main__":
|
209 |
main()
|
|
|
4 |
import random
|
5 |
import gymnasium as gym
|
6 |
from stable_baselines3 import DQN
|
7 |
+
from stable_baselines3.common.evaluation import evaluate_policy
|
8 |
import gradio as gr
|
9 |
import cv2
|
10 |
|
|
|
27 |
|
28 |
# Initialize Pygame
|
29 |
pygame.init()
|
|
|
|
|
30 |
|
31 |
# Game classes
|
32 |
class Paddle:
|
|
|
40 |
self.rect.x += 10
|
41 |
self.rect.clamp_ip(pygame.Rect(0, 0, SCREEN_WIDTH, SCREEN_HEIGHT))
|
42 |
|
|
|
43 |
class Ball:
|
44 |
def __init__(self):
|
45 |
self.rect = pygame.Rect(SCREEN_WIDTH // 2 - BALL_RADIUS, SCREEN_HEIGHT // 2 - BALL_RADIUS, BALL_RADIUS * 2, BALL_RADIUS * 2)
|
|
|
58 |
self.rect = pygame.Rect(SCREEN_WIDTH // 2 - BALL_RADIUS, SCREEN_HEIGHT // 2 - BALL_RADIUS, BALL_RADIUS * 2, BALL_RADIUS * 2)
|
59 |
self.velocity = [random.choice([-5, 5]), -5]
|
60 |
|
|
|
61 |
class Brick:
|
62 |
def __init__(self, x, y):
|
63 |
self.rect = pygame.Rect(x, y, BRICK_WIDTH - 5, BRICK_HEIGHT - 5)
|
64 |
|
|
|
65 |
class ArkanoidEnv(gym.Env):
|
66 |
+
def __init__(self, reward_size=1, penalty_size=-1, platform_reward=5):
|
67 |
super(ArkanoidEnv, self).__init__()
|
68 |
self.action_space = gym.spaces.Discrete(3) # 0: stay, 1: move left, 2: move right
|
69 |
self.observation_space = gym.spaces.Box(low=0, high=SCREEN_WIDTH, shape=(5 + BRICK_ROWS * BRICK_COLS * 2,), dtype=np.float32)
|
|
|
73 |
self.reset()
|
74 |
|
75 |
def reset(self, seed=None, options=None):
|
76 |
+
if seed is not None:
|
77 |
+
random.seed(seed)
|
78 |
+
np.random.seed(seed)
|
79 |
self.paddle = Paddle()
|
80 |
self.ball = Ball()
|
81 |
+
self.bricks = [Brick(x, y) for y in range(BRICK_HEIGHT, BRICK_HEIGHT * (BRICK_ROWS + 1), BRICK_HEIGHT)
|
82 |
for x in range(BRICK_WIDTH, SCREEN_WIDTH - BRICK_WIDTH, BRICK_WIDTH)]
|
83 |
self.done = False
|
84 |
self.score = 0
|
|
|
94 |
|
95 |
self.ball.move()
|
96 |
|
|
|
97 |
if self.ball.rect.colliderect(self.paddle.rect):
|
98 |
self.ball.velocity[1] = -self.ball.velocity[1]
|
99 |
+
self.score += self.platform_reward
|
100 |
|
101 |
for brick in self.bricks[:]:
|
102 |
if self.ball.rect.colliderect(brick.rect):
|
103 |
self.bricks.remove(brick)
|
104 |
self.ball.velocity[1] = -self.ball.velocity[1]
|
105 |
self.score += 1
|
106 |
+
reward = self.reward_size
|
|
|
107 |
if not self.bricks:
|
108 |
+
reward += self.reward_size * 10 # Bonus reward for breaking all bricks
|
109 |
self.done = True
|
110 |
+
truncated = False
|
111 |
+
return self._get_state(), reward, self.done, truncated, {}
|
112 |
|
113 |
if self.ball.rect.bottom >= SCREEN_HEIGHT:
|
114 |
self.done = True
|
115 |
+
reward = self.penalty_size
|
116 |
+
truncated = False
|
117 |
+
else:
|
118 |
+
reward = 0
|
119 |
+
truncated = False
|
120 |
|
121 |
+
return self._get_state(), reward, self.done, truncated, {}
|
122 |
|
123 |
def _get_state(self):
|
124 |
state = [
|
|
|
130 |
]
|
131 |
for brick in self.bricks:
|
132 |
state.extend([brick.rect.x, brick.rect.y])
|
133 |
+
state.extend([0, 0] * (BRICK_ROWS * BRICK_COLS - len(self.bricks))) # Padding for missing bricks
|
134 |
return np.array(state, dtype=np.float32)
|
135 |
|
136 |
+
def render(self, mode='rgb_array'):
|
137 |
+
surface = pygame.Surface((SCREEN_WIDTH, SCREEN_HEIGHT))
|
138 |
+
surface.fill(BLACK)
|
139 |
+
pygame.draw.rect(surface, WHITE, self.paddle.rect)
|
140 |
+
pygame.draw.ellipse(surface, WHITE, self.ball.rect)
|
|
|
|
|
|
|
141 |
for brick in self.bricks:
|
142 |
+
pygame.draw.rect(surface, RED, brick.rect)
|
143 |
+
|
144 |
+
if mode == 'rgb_array':
|
145 |
+
return pygame.surfarray.array3d(surface)
|
146 |
+
elif mode == 'human':
|
147 |
+
pygame.display.get_surface().blit(surface, (0, 0))
|
148 |
+
pygame.display.flip()
|
149 |
|
150 |
def close(self):
|
151 |
pygame.quit()
|
152 |
|
|
|
153 |
# Training and playing with custom parameters
|
154 |
def train_and_play(reward_size, penalty_size, platform_reward, iterations):
|
155 |
env = ArkanoidEnv(reward_size=reward_size, penalty_size=penalty_size, platform_reward=platform_reward)
|
|
|
165 |
|
166 |
obs, _ = env.reset()
|
167 |
done = False
|
168 |
+
while not done:
|
|
|
169 |
action, _states = model.predict(obs, deterministic=True)
|
170 |
obs, reward, done, truncated, _ = env.step(action)
|
171 |
|
172 |
+
frame = env.render(mode='rgb_array')
|
|
|
|
|
|
|
|
|
|
|
|
|
173 |
frame = np.rot90(frame)
|
174 |
frame = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR)
|
175 |
video_frames.append(frame)
|
|
|
183 |
env.close()
|
184 |
return video_path
|
185 |
|
186 |
+
# Main function with Gradio interface
|
|
|
187 |
def main():
|
188 |
iface = gr.Interface(
|
189 |
fn=train_and_play,
|
190 |
inputs=[
|
191 |
gr.Number(label="Reward Size", value=1),
|
192 |
gr.Number(label="Penalty Size", value=-1),
|
193 |
+
gr.Number(label="Platform Reward", value=5),
|
194 |
gr.Slider(label="Iterations", minimum=10, maximum=100000, step=10, value=10000)
|
195 |
],
|
196 |
outputs="video",
|
197 |
+
live=False # Disable auto-generation on slider changes
|
198 |
)
|
199 |
iface.launch()
|
200 |
|
|
|
201 |
if __name__ == "__main__":
|
202 |
main()
|