Ivan000 commited on
Commit
a14aa44
·
verified ·
1 Parent(s): 2138c0e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +31 -25
app.py CHANGED
@@ -32,7 +32,6 @@ pygame.init()
32
  class Paddle:
33
  def __init__(self):
34
  self.rect = pygame.Rect(SCREEN_WIDTH // 2 - PADDLE_WIDTH // 2, SCREEN_HEIGHT - PADDLE_HEIGHT - 10, PADDLE_WIDTH, PADDLE_HEIGHT)
35
- self.last_position = self.rect.x
36
 
37
  def move(self, direction):
38
  if direction == -1:
@@ -73,13 +72,9 @@ class ArkanoidEnv(gym.Env):
73
  self.platform_reward = platform_reward
74
  self.inactivity_penalty = inactivity_penalty
75
  self.inactivity_counter = 0
76
- self.last_action = 0
77
  self.reset()
78
 
79
  def reset(self, seed=None, options=None):
80
- if seed is not None:
81
- random.seed(seed)
82
- np.random.seed(seed)
83
  self.paddle = Paddle()
84
  self.ball = Ball()
85
  self.bricks = [Brick(x, y) for y in range(BRICK_HEIGHT, BRICK_HEIGHT * (BRICK_ROWS + 1), BRICK_HEIGHT)
@@ -87,11 +82,9 @@ class ArkanoidEnv(gym.Env):
87
  self.done = False
88
  self.score = 0
89
  self.inactivity_counter = 0
90
- self.last_action = None
91
  return self._get_state(), {}
92
 
93
  def step(self, action):
94
- # Apply action
95
  if action == 0:
96
  self.paddle.move(0)
97
  elif action == 1:
@@ -99,7 +92,6 @@ class ArkanoidEnv(gym.Env):
99
  elif action == 2:
100
  self.paddle.move(1)
101
 
102
- # Update inactivity penalty
103
  if action == 0:
104
  self.inactivity_counter += 1 / FPS
105
  else:
@@ -109,42 +101,31 @@ class ArkanoidEnv(gym.Env):
109
  reward = self.inactivity_penalty
110
  return self._get_state(), reward, self.done, False, {}
111
 
112
- # Update ball position
113
  self.ball.move()
114
 
115
- # Collision with paddle
116
  if self.ball.rect.colliderect(self.paddle.rect):
117
  self.ball.velocity[1] = -self.ball.velocity[1]
118
- self.ball.velocity[0] += random.uniform(-1, 1) # Add random offset to angle
119
  self.score += self.platform_reward
120
 
121
- # Collision with bricks
122
  for brick in self.bricks[:]:
123
  if self.ball.rect.colliderect(brick.rect):
124
  self.bricks.remove(brick)
125
  self.ball.velocity[1] = -self.ball.velocity[1]
126
- self.ball.velocity[0] += random.uniform(-1, 1) # Add random offset to angle
127
  self.score += 1
128
  if not self.bricks:
129
  self.done = True
130
  return self._get_state(), self.reward_size, self.done, False, {}
131
 
132
- # Check if ball is out of bounds
133
  if self.ball.rect.bottom >= SCREEN_HEIGHT:
134
  self.done = True
135
- reward = self.penalty_size
136
- return self._get_state(), reward, self.done, False, {}
137
 
138
- # Calculate reward for breaking bricks
139
- reward = 0
140
- return self._get_state(), reward, self.done, False, {}
141
 
142
  def _get_state(self):
143
- return np.array([
144
- self.ball.rect.x,
145
- self.paddle.rect.x,
146
- len(self.bricks)
147
- ], dtype=np.float32)
148
 
149
  def render(self, mode='rgb_array'):
150
  surface = pygame.Surface((SCREEN_WIDTH, SCREEN_HEIGHT))
@@ -163,7 +144,32 @@ class ArkanoidEnv(gym.Env):
163
  def close(self):
164
  pygame.quit()
165
 
166
- # Main function remains unchanged
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
167
  def main():
168
  iface = gr.Interface(
169
  fn=train_and_play,
 
32
  class Paddle:
33
  def __init__(self):
34
  self.rect = pygame.Rect(SCREEN_WIDTH // 2 - PADDLE_WIDTH // 2, SCREEN_HEIGHT - PADDLE_HEIGHT - 10, PADDLE_WIDTH, PADDLE_HEIGHT)
 
35
 
36
  def move(self, direction):
37
  if direction == -1:
 
72
  self.platform_reward = platform_reward
73
  self.inactivity_penalty = inactivity_penalty
74
  self.inactivity_counter = 0
 
75
  self.reset()
76
 
77
  def reset(self, seed=None, options=None):
 
 
 
78
  self.paddle = Paddle()
79
  self.ball = Ball()
80
  self.bricks = [Brick(x, y) for y in range(BRICK_HEIGHT, BRICK_HEIGHT * (BRICK_ROWS + 1), BRICK_HEIGHT)
 
82
  self.done = False
83
  self.score = 0
84
  self.inactivity_counter = 0
 
85
  return self._get_state(), {}
86
 
87
  def step(self, action):
 
88
  if action == 0:
89
  self.paddle.move(0)
90
  elif action == 1:
 
92
  elif action == 2:
93
  self.paddle.move(1)
94
 
 
95
  if action == 0:
96
  self.inactivity_counter += 1 / FPS
97
  else:
 
101
  reward = self.inactivity_penalty
102
  return self._get_state(), reward, self.done, False, {}
103
 
 
104
  self.ball.move()
105
 
 
106
  if self.ball.rect.colliderect(self.paddle.rect):
107
  self.ball.velocity[1] = -self.ball.velocity[1]
108
+ self.ball.velocity[0] += random.uniform(-1, 1)
109
  self.score += self.platform_reward
110
 
 
111
  for brick in self.bricks[:]:
112
  if self.ball.rect.colliderect(brick.rect):
113
  self.bricks.remove(brick)
114
  self.ball.velocity[1] = -self.ball.velocity[1]
115
+ self.ball.velocity[0] += random.uniform(-1, 1)
116
  self.score += 1
117
  if not self.bricks:
118
  self.done = True
119
  return self._get_state(), self.reward_size, self.done, False, {}
120
 
 
121
  if self.ball.rect.bottom >= SCREEN_HEIGHT:
122
  self.done = True
123
+ return self._get_state(), self.penalty_size, self.done, False, {}
 
124
 
125
+ return self._get_state(), 0, self.done, False, {}
 
 
126
 
127
  def _get_state(self):
128
+ return np.array([self.ball.rect.x, self.paddle.rect.x, len(self.bricks)], dtype=np.float32)
 
 
 
 
129
 
130
  def render(self, mode='rgb_array'):
131
  surface = pygame.Surface((SCREEN_WIDTH, SCREEN_HEIGHT))
 
144
  def close(self):
145
  pygame.quit()
146
 
147
+ # Training and playing function
148
+ def train_and_play(reward_size, penalty_size, platform_reward, inactivity_penalty, iterations):
149
+ env = ArkanoidEnv(reward_size, penalty_size, platform_reward, inactivity_penalty)
150
+ model = DQN("MlpPolicy", env, verbose=0)
151
+ model.learn(total_timesteps=iterations)
152
+
153
+ obs, _ = env.reset()
154
+ frames = []
155
+ while True:
156
+ action, _states = model.predict(obs)
157
+ obs, _, done, _, _ = env.step(action)
158
+ frame = env.render(mode="rgb_array")
159
+ frames.append(frame)
160
+ if done:
161
+ break
162
+ env.close()
163
+
164
+ video_path = "/tmp/arkanoid.mp4"
165
+ out = cv2.VideoWriter(video_path, cv2.VideoWriter_fourcc(*'mp4v'), FPS, (SCREEN_WIDTH, SCREEN_HEIGHT))
166
+ for frame in frames:
167
+ frame = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR)
168
+ out.write(frame)
169
+ out.release()
170
+ return video_path
171
+
172
+ # Gradio interface
173
  def main():
174
  iface = gr.Interface(
175
  fn=train_and_play,