Update app.py
Browse files
app.py
CHANGED
@@ -32,7 +32,6 @@ pygame.init()
|
|
32 |
class Paddle:
|
33 |
def __init__(self):
|
34 |
self.rect = pygame.Rect(SCREEN_WIDTH // 2 - PADDLE_WIDTH // 2, SCREEN_HEIGHT - PADDLE_HEIGHT - 10, PADDLE_WIDTH, PADDLE_HEIGHT)
|
35 |
-
self.last_position = self.rect.x
|
36 |
|
37 |
def move(self, direction):
|
38 |
if direction == -1:
|
@@ -73,13 +72,9 @@ class ArkanoidEnv(gym.Env):
|
|
73 |
self.platform_reward = platform_reward
|
74 |
self.inactivity_penalty = inactivity_penalty
|
75 |
self.inactivity_counter = 0
|
76 |
-
self.last_action = 0
|
77 |
self.reset()
|
78 |
|
79 |
def reset(self, seed=None, options=None):
|
80 |
-
if seed is not None:
|
81 |
-
random.seed(seed)
|
82 |
-
np.random.seed(seed)
|
83 |
self.paddle = Paddle()
|
84 |
self.ball = Ball()
|
85 |
self.bricks = [Brick(x, y) for y in range(BRICK_HEIGHT, BRICK_HEIGHT * (BRICK_ROWS + 1), BRICK_HEIGHT)
|
@@ -87,11 +82,9 @@ class ArkanoidEnv(gym.Env):
|
|
87 |
self.done = False
|
88 |
self.score = 0
|
89 |
self.inactivity_counter = 0
|
90 |
-
self.last_action = None
|
91 |
return self._get_state(), {}
|
92 |
|
93 |
def step(self, action):
|
94 |
-
# Apply action
|
95 |
if action == 0:
|
96 |
self.paddle.move(0)
|
97 |
elif action == 1:
|
@@ -99,7 +92,6 @@ class ArkanoidEnv(gym.Env):
|
|
99 |
elif action == 2:
|
100 |
self.paddle.move(1)
|
101 |
|
102 |
-
# Update inactivity penalty
|
103 |
if action == 0:
|
104 |
self.inactivity_counter += 1 / FPS
|
105 |
else:
|
@@ -109,42 +101,31 @@ class ArkanoidEnv(gym.Env):
|
|
109 |
reward = self.inactivity_penalty
|
110 |
return self._get_state(), reward, self.done, False, {}
|
111 |
|
112 |
-
# Update ball position
|
113 |
self.ball.move()
|
114 |
|
115 |
-
# Collision with paddle
|
116 |
if self.ball.rect.colliderect(self.paddle.rect):
|
117 |
self.ball.velocity[1] = -self.ball.velocity[1]
|
118 |
-
self.ball.velocity[0] += random.uniform(-1, 1)
|
119 |
self.score += self.platform_reward
|
120 |
|
121 |
-
# Collision with bricks
|
122 |
for brick in self.bricks[:]:
|
123 |
if self.ball.rect.colliderect(brick.rect):
|
124 |
self.bricks.remove(brick)
|
125 |
self.ball.velocity[1] = -self.ball.velocity[1]
|
126 |
-
self.ball.velocity[0] += random.uniform(-1, 1)
|
127 |
self.score += 1
|
128 |
if not self.bricks:
|
129 |
self.done = True
|
130 |
return self._get_state(), self.reward_size, self.done, False, {}
|
131 |
|
132 |
-
# Check if ball is out of bounds
|
133 |
if self.ball.rect.bottom >= SCREEN_HEIGHT:
|
134 |
self.done = True
|
135 |
-
|
136 |
-
return self._get_state(), reward, self.done, False, {}
|
137 |
|
138 |
-
|
139 |
-
reward = 0
|
140 |
-
return self._get_state(), reward, self.done, False, {}
|
141 |
|
142 |
def _get_state(self):
|
143 |
-
return np.array([
|
144 |
-
self.ball.rect.x,
|
145 |
-
self.paddle.rect.x,
|
146 |
-
len(self.bricks)
|
147 |
-
], dtype=np.float32)
|
148 |
|
149 |
def render(self, mode='rgb_array'):
|
150 |
surface = pygame.Surface((SCREEN_WIDTH, SCREEN_HEIGHT))
|
@@ -163,7 +144,32 @@ class ArkanoidEnv(gym.Env):
|
|
163 |
def close(self):
|
164 |
pygame.quit()
|
165 |
|
166 |
-
#
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
167 |
def main():
|
168 |
iface = gr.Interface(
|
169 |
fn=train_and_play,
|
|
|
32 |
class Paddle:
|
33 |
def __init__(self):
|
34 |
self.rect = pygame.Rect(SCREEN_WIDTH // 2 - PADDLE_WIDTH // 2, SCREEN_HEIGHT - PADDLE_HEIGHT - 10, PADDLE_WIDTH, PADDLE_HEIGHT)
|
|
|
35 |
|
36 |
def move(self, direction):
|
37 |
if direction == -1:
|
|
|
72 |
self.platform_reward = platform_reward
|
73 |
self.inactivity_penalty = inactivity_penalty
|
74 |
self.inactivity_counter = 0
|
|
|
75 |
self.reset()
|
76 |
|
77 |
def reset(self, seed=None, options=None):
|
|
|
|
|
|
|
78 |
self.paddle = Paddle()
|
79 |
self.ball = Ball()
|
80 |
self.bricks = [Brick(x, y) for y in range(BRICK_HEIGHT, BRICK_HEIGHT * (BRICK_ROWS + 1), BRICK_HEIGHT)
|
|
|
82 |
self.done = False
|
83 |
self.score = 0
|
84 |
self.inactivity_counter = 0
|
|
|
85 |
return self._get_state(), {}
|
86 |
|
87 |
def step(self, action):
|
|
|
88 |
if action == 0:
|
89 |
self.paddle.move(0)
|
90 |
elif action == 1:
|
|
|
92 |
elif action == 2:
|
93 |
self.paddle.move(1)
|
94 |
|
|
|
95 |
if action == 0:
|
96 |
self.inactivity_counter += 1 / FPS
|
97 |
else:
|
|
|
101 |
reward = self.inactivity_penalty
|
102 |
return self._get_state(), reward, self.done, False, {}
|
103 |
|
|
|
104 |
self.ball.move()
|
105 |
|
|
|
106 |
if self.ball.rect.colliderect(self.paddle.rect):
|
107 |
self.ball.velocity[1] = -self.ball.velocity[1]
|
108 |
+
self.ball.velocity[0] += random.uniform(-1, 1)
|
109 |
self.score += self.platform_reward
|
110 |
|
|
|
111 |
for brick in self.bricks[:]:
|
112 |
if self.ball.rect.colliderect(brick.rect):
|
113 |
self.bricks.remove(brick)
|
114 |
self.ball.velocity[1] = -self.ball.velocity[1]
|
115 |
+
self.ball.velocity[0] += random.uniform(-1, 1)
|
116 |
self.score += 1
|
117 |
if not self.bricks:
|
118 |
self.done = True
|
119 |
return self._get_state(), self.reward_size, self.done, False, {}
|
120 |
|
|
|
121 |
if self.ball.rect.bottom >= SCREEN_HEIGHT:
|
122 |
self.done = True
|
123 |
+
return self._get_state(), self.penalty_size, self.done, False, {}
|
|
|
124 |
|
125 |
+
return self._get_state(), 0, self.done, False, {}
|
|
|
|
|
126 |
|
127 |
def _get_state(self):
|
128 |
+
return np.array([self.ball.rect.x, self.paddle.rect.x, len(self.bricks)], dtype=np.float32)
|
|
|
|
|
|
|
|
|
129 |
|
130 |
def render(self, mode='rgb_array'):
|
131 |
surface = pygame.Surface((SCREEN_WIDTH, SCREEN_HEIGHT))
|
|
|
144 |
def close(self):
|
145 |
pygame.quit()
|
146 |
|
147 |
+
# Training and playing function
|
148 |
+
def train_and_play(reward_size, penalty_size, platform_reward, inactivity_penalty, iterations):
|
149 |
+
env = ArkanoidEnv(reward_size, penalty_size, platform_reward, inactivity_penalty)
|
150 |
+
model = DQN("MlpPolicy", env, verbose=0)
|
151 |
+
model.learn(total_timesteps=iterations)
|
152 |
+
|
153 |
+
obs, _ = env.reset()
|
154 |
+
frames = []
|
155 |
+
while True:
|
156 |
+
action, _states = model.predict(obs)
|
157 |
+
obs, _, done, _, _ = env.step(action)
|
158 |
+
frame = env.render(mode="rgb_array")
|
159 |
+
frames.append(frame)
|
160 |
+
if done:
|
161 |
+
break
|
162 |
+
env.close()
|
163 |
+
|
164 |
+
video_path = "/tmp/arkanoid.mp4"
|
165 |
+
out = cv2.VideoWriter(video_path, cv2.VideoWriter_fourcc(*'mp4v'), FPS, (SCREEN_WIDTH, SCREEN_HEIGHT))
|
166 |
+
for frame in frames:
|
167 |
+
frame = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR)
|
168 |
+
out.write(frame)
|
169 |
+
out.release()
|
170 |
+
return video_path
|
171 |
+
|
172 |
+
# Gradio interface
|
173 |
def main():
|
174 |
iface = gr.Interface(
|
175 |
fn=train_and_play,
|