Ivan000 commited on
Commit
2138c0e
·
verified ·
1 Parent(s): 68621da

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +37 -57
app.py CHANGED
@@ -32,6 +32,7 @@ pygame.init()
32
  class Paddle:
33
  def __init__(self):
34
  self.rect = pygame.Rect(SCREEN_WIDTH // 2 - PADDLE_WIDTH // 2, SCREEN_HEIGHT - PADDLE_HEIGHT - 10, PADDLE_WIDTH, PADDLE_HEIGHT)
 
35
 
36
  def move(self, direction):
37
  if direction == -1:
@@ -63,13 +64,16 @@ class Brick:
63
  self.rect = pygame.Rect(x, y, BRICK_WIDTH - 5, BRICK_HEIGHT - 5)
64
 
65
  class ArkanoidEnv(gym.Env):
66
- def __init__(self, reward_size=1, penalty_size=-1, platform_reward=5):
67
  super(ArkanoidEnv, self).__init__()
68
  self.action_space = gym.spaces.Discrete(3) # 0: stay, 1: move left, 2: move right
69
- self.observation_space = gym.spaces.Box(low=0, high=SCREEN_WIDTH, shape=(5 + BRICK_ROWS * BRICK_COLS * 2,), dtype=np.float32)
70
  self.reward_size = reward_size
71
  self.penalty_size = penalty_size
72
  self.platform_reward = platform_reward
 
 
 
73
  self.reset()
74
 
75
  def reset(self, seed=None, options=None):
@@ -82,9 +86,12 @@ class ArkanoidEnv(gym.Env):
82
  for x in range(BRICK_WIDTH, SCREEN_WIDTH - BRICK_WIDTH, BRICK_WIDTH)]
83
  self.done = False
84
  self.score = 0
 
 
85
  return self._get_state(), {}
86
 
87
  def step(self, action):
 
88
  if action == 0:
89
  self.paddle.move(0)
90
  elif action == 1:
@@ -92,46 +99,52 @@ class ArkanoidEnv(gym.Env):
92
  elif action == 2:
93
  self.paddle.move(1)
94
 
 
 
 
 
 
 
 
 
 
 
 
95
  self.ball.move()
96
 
 
97
  if self.ball.rect.colliderect(self.paddle.rect):
98
  self.ball.velocity[1] = -self.ball.velocity[1]
 
99
  self.score += self.platform_reward
100
 
 
101
  for brick in self.bricks[:]:
102
  if self.ball.rect.colliderect(brick.rect):
103
  self.bricks.remove(brick)
104
  self.ball.velocity[1] = -self.ball.velocity[1]
 
105
  self.score += 1
106
- reward = self.reward_size
107
  if not self.bricks:
108
- reward += self.reward_size * 10 # Bonus reward for breaking all bricks
109
  self.done = True
110
- truncated = False
111
- return self._get_state(), reward, self.done, truncated, {}
112
 
 
113
  if self.ball.rect.bottom >= SCREEN_HEIGHT:
114
  self.done = True
115
  reward = self.penalty_size
116
- truncated = False
117
- else:
118
- reward = 0
119
- truncated = False
120
 
121
- return self._get_state(), reward, self.done, truncated, {}
 
 
122
 
123
  def _get_state(self):
124
- state = [
125
- self.paddle.rect.x,
126
  self.ball.rect.x,
127
- self.ball.rect.y,
128
- self.ball.velocity[0],
129
- self.ball.velocity[1]
130
- ]
131
- for brick in self.bricks:
132
- state.extend([brick.rect.x, brick.rect.y])
133
- state.extend([0, 0] * (BRICK_ROWS * BRICK_COLS - len(self.bricks))) # Padding for missing bricks
134
- return np.array(state, dtype=np.float32)
135
 
136
  def render(self, mode='rgb_array'):
137
  surface = pygame.Surface((SCREEN_WIDTH, SCREEN_HEIGHT))
@@ -150,40 +163,7 @@ class ArkanoidEnv(gym.Env):
150
  def close(self):
151
  pygame.quit()
152
 
153
- # Training and playing with custom parameters
154
- def train_and_play(reward_size, penalty_size, platform_reward, iterations):
155
- env = ArkanoidEnv(reward_size=reward_size, penalty_size=penalty_size, platform_reward=platform_reward)
156
- model = DQN('MlpPolicy', env, verbose=1)
157
- timesteps_per_update = min(1000, iterations)
158
- video_frames = []
159
-
160
- completed_iterations = 0
161
- while completed_iterations < iterations:
162
- steps = min(timesteps_per_update, iterations - completed_iterations)
163
- model.learn(total_timesteps=steps)
164
- completed_iterations += steps
165
-
166
- obs, _ = env.reset()
167
- done = False
168
- while not done:
169
- action, _states = model.predict(obs, deterministic=True)
170
- obs, reward, done, truncated, _ = env.step(action)
171
-
172
- frame = env.render(mode='rgb_array')
173
- frame = np.rot90(frame)
174
- frame = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR)
175
- video_frames.append(frame)
176
-
177
- video_path = "arkanoid_training.mp4"
178
- video_writer = cv2.VideoWriter(video_path, cv2.VideoWriter_fourcc(*'mp4v'), FPS, (SCREEN_WIDTH, SCREEN_HEIGHT))
179
- for frame in video_frames:
180
- video_writer.write(frame)
181
- video_writer.release()
182
-
183
- env.close()
184
- return video_path
185
-
186
- # Main function with Gradio interface
187
  def main():
188
  iface = gr.Interface(
189
  fn=train_and_play,
@@ -191,10 +171,10 @@ def main():
191
  gr.Number(label="Reward Size", value=1),
192
  gr.Number(label="Penalty Size", value=-1),
193
  gr.Number(label="Platform Reward", value=5),
 
194
  gr.Slider(label="Iterations", minimum=10, maximum=100000, step=10, value=10000)
195
  ],
196
- outputs="video",
197
- live=False # Disable auto-generation on slider changes
198
  )
199
  iface.launch()
200
 
 
32
  class Paddle:
33
  def __init__(self):
34
  self.rect = pygame.Rect(SCREEN_WIDTH // 2 - PADDLE_WIDTH // 2, SCREEN_HEIGHT - PADDLE_HEIGHT - 10, PADDLE_WIDTH, PADDLE_HEIGHT)
35
+ self.last_position = self.rect.x
36
 
37
  def move(self, direction):
38
  if direction == -1:
 
64
  self.rect = pygame.Rect(x, y, BRICK_WIDTH - 5, BRICK_HEIGHT - 5)
65
 
66
  class ArkanoidEnv(gym.Env):
67
+ def __init__(self, reward_size=1, penalty_size=-1, platform_reward=5, inactivity_penalty=-0.5):
68
  super(ArkanoidEnv, self).__init__()
69
  self.action_space = gym.spaces.Discrete(3) # 0: stay, 1: move left, 2: move right
70
+ self.observation_space = gym.spaces.Box(low=0, high=SCREEN_WIDTH, shape=(3,), dtype=np.float32)
71
  self.reward_size = reward_size
72
  self.penalty_size = penalty_size
73
  self.platform_reward = platform_reward
74
+ self.inactivity_penalty = inactivity_penalty
75
+ self.inactivity_counter = 0
76
+ self.last_action = 0
77
  self.reset()
78
 
79
  def reset(self, seed=None, options=None):
 
86
  for x in range(BRICK_WIDTH, SCREEN_WIDTH - BRICK_WIDTH, BRICK_WIDTH)]
87
  self.done = False
88
  self.score = 0
89
+ self.inactivity_counter = 0
90
+ self.last_action = None
91
  return self._get_state(), {}
92
 
93
  def step(self, action):
94
+ # Apply action
95
  if action == 0:
96
  self.paddle.move(0)
97
  elif action == 1:
 
99
  elif action == 2:
100
  self.paddle.move(1)
101
 
102
+ # Update inactivity penalty
103
+ if action == 0:
104
+ self.inactivity_counter += 1 / FPS
105
+ else:
106
+ self.inactivity_counter = 0
107
+
108
+ if self.inactivity_counter >= 1:
109
+ reward = self.inactivity_penalty
110
+ return self._get_state(), reward, self.done, False, {}
111
+
112
+ # Update ball position
113
  self.ball.move()
114
 
115
+ # Collision with paddle
116
  if self.ball.rect.colliderect(self.paddle.rect):
117
  self.ball.velocity[1] = -self.ball.velocity[1]
118
+ self.ball.velocity[0] += random.uniform(-1, 1) # Add random offset to angle
119
  self.score += self.platform_reward
120
 
121
+ # Collision with bricks
122
  for brick in self.bricks[:]:
123
  if self.ball.rect.colliderect(brick.rect):
124
  self.bricks.remove(brick)
125
  self.ball.velocity[1] = -self.ball.velocity[1]
126
+ self.ball.velocity[0] += random.uniform(-1, 1) # Add random offset to angle
127
  self.score += 1
 
128
  if not self.bricks:
 
129
  self.done = True
130
+ return self._get_state(), self.reward_size, self.done, False, {}
 
131
 
132
+ # Check if ball is out of bounds
133
  if self.ball.rect.bottom >= SCREEN_HEIGHT:
134
  self.done = True
135
  reward = self.penalty_size
136
+ return self._get_state(), reward, self.done, False, {}
 
 
 
137
 
138
+ # Calculate reward for breaking bricks
139
+ reward = 0
140
+ return self._get_state(), reward, self.done, False, {}
141
 
142
  def _get_state(self):
143
+ return np.array([
 
144
  self.ball.rect.x,
145
+ self.paddle.rect.x,
146
+ len(self.bricks)
147
+ ], dtype=np.float32)
 
 
 
 
 
148
 
149
  def render(self, mode='rgb_array'):
150
  surface = pygame.Surface((SCREEN_WIDTH, SCREEN_HEIGHT))
 
163
  def close(self):
164
  pygame.quit()
165
 
166
+ # Main function remains unchanged
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
167
  def main():
168
  iface = gr.Interface(
169
  fn=train_and_play,
 
171
  gr.Number(label="Reward Size", value=1),
172
  gr.Number(label="Penalty Size", value=-1),
173
  gr.Number(label="Platform Reward", value=5),
174
+ gr.Number(label="Inactivity Penalty", value=-0.5),
175
  gr.Slider(label="Iterations", minimum=10, maximum=100000, step=10, value=10000)
176
  ],
177
+ outputs="video"
 
178
  )
179
  iface.launch()
180