Ivan000 commited on
Commit
68621da
·
verified ·
1 Parent(s): a99511a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +35 -42
app.py CHANGED
@@ -4,6 +4,7 @@ import pygame
4
  import random
5
  import gymnasium as gym
6
  from stable_baselines3 import DQN
 
7
  import gradio as gr
8
  import cv2
9
 
@@ -26,8 +27,6 @@ RED = (255, 0, 0)
26
 
27
  # Initialize Pygame
28
  pygame.init()
29
- screen = pygame.display.set_mode((SCREEN_WIDTH, SCREEN_HEIGHT))
30
- pygame.display.set_caption("Arkanoid")
31
 
32
  # Game classes
33
  class Paddle:
@@ -41,7 +40,6 @@ class Paddle:
41
  self.rect.x += 10
42
  self.rect.clamp_ip(pygame.Rect(0, 0, SCREEN_WIDTH, SCREEN_HEIGHT))
43
 
44
-
45
  class Ball:
46
  def __init__(self):
47
  self.rect = pygame.Rect(SCREEN_WIDTH // 2 - BALL_RADIUS, SCREEN_HEIGHT // 2 - BALL_RADIUS, BALL_RADIUS * 2, BALL_RADIUS * 2)
@@ -60,14 +58,12 @@ class Ball:
60
  self.rect = pygame.Rect(SCREEN_WIDTH // 2 - BALL_RADIUS, SCREEN_HEIGHT // 2 - BALL_RADIUS, BALL_RADIUS * 2, BALL_RADIUS * 2)
61
  self.velocity = [random.choice([-5, 5]), -5]
62
 
63
-
64
  class Brick:
65
  def __init__(self, x, y):
66
  self.rect = pygame.Rect(x, y, BRICK_WIDTH - 5, BRICK_HEIGHT - 5)
67
 
68
-
69
  class ArkanoidEnv(gym.Env):
70
- def __init__(self, reward_size=1, penalty_size=-1, platform_reward=2):
71
  super(ArkanoidEnv, self).__init__()
72
  self.action_space = gym.spaces.Discrete(3) # 0: stay, 1: move left, 2: move right
73
  self.observation_space = gym.spaces.Box(low=0, high=SCREEN_WIDTH, shape=(5 + BRICK_ROWS * BRICK_COLS * 2,), dtype=np.float32)
@@ -77,9 +73,12 @@ class ArkanoidEnv(gym.Env):
77
  self.reset()
78
 
79
  def reset(self, seed=None, options=None):
 
 
 
80
  self.paddle = Paddle()
81
  self.ball = Ball()
82
- self.bricks = [Brick(x, y) for y in range(BRICK_HEIGHT, BRICK_HEIGHT * (BRICK_ROWS + 1), BRICK_HEIGHT)
83
  for x in range(BRICK_WIDTH, SCREEN_WIDTH - BRICK_WIDTH, BRICK_WIDTH)]
84
  self.done = False
85
  self.score = 0
@@ -95,28 +94,31 @@ class ArkanoidEnv(gym.Env):
95
 
96
  self.ball.move()
97
 
98
- reward = 0
99
  if self.ball.rect.colliderect(self.paddle.rect):
100
  self.ball.velocity[1] = -self.ball.velocity[1]
101
- reward += self.platform_reward
102
 
103
  for brick in self.bricks[:]:
104
  if self.ball.rect.colliderect(brick.rect):
105
  self.bricks.remove(brick)
106
  self.ball.velocity[1] = -self.ball.velocity[1]
107
  self.score += 1
108
- reward += self.reward_size
109
-
110
  if not self.bricks:
111
- reward += self.reward_size * 10
112
  self.done = True
113
- return self._get_state(), reward, self.done, False, {}
 
114
 
115
  if self.ball.rect.bottom >= SCREEN_HEIGHT:
116
  self.done = True
117
- reward += self.penalty_size
 
 
 
 
118
 
119
- return self._get_state(), reward, self.done, False, {}
120
 
121
  def _get_state(self):
122
  state = [
@@ -128,26 +130,26 @@ class ArkanoidEnv(gym.Env):
128
  ]
129
  for brick in self.bricks:
130
  state.extend([brick.rect.x, brick.rect.y])
131
- state.extend([0, 0] * (BRICK_ROWS * BRICK_COLS - len(self.bricks)))
132
  return np.array(state, dtype=np.float32)
133
 
134
- def render(self, mode='human'):
135
- for event in pygame.event.get():
136
- if event.type == pygame.QUIT:
137
- pygame.quit()
138
- return
139
- screen.fill(BLACK)
140
- pygame.draw.rect(screen, WHITE, self.paddle.rect)
141
- pygame.draw.ellipse(screen, WHITE, self.ball.rect)
142
  for brick in self.bricks:
143
- pygame.draw.rect(screen, RED, brick.rect)
144
- pygame.display.flip()
145
- pygame.time.Clock().tick(FPS)
 
 
 
 
146
 
147
  def close(self):
148
  pygame.quit()
149
 
150
-
151
  # Training and playing with custom parameters
152
  def train_and_play(reward_size, penalty_size, platform_reward, iterations):
153
  env = ArkanoidEnv(reward_size=reward_size, penalty_size=penalty_size, platform_reward=platform_reward)
@@ -163,18 +165,11 @@ def train_and_play(reward_size, penalty_size, platform_reward, iterations):
163
 
164
  obs, _ = env.reset()
165
  done = False
166
- truncated = False
167
- while not done and not truncated:
168
  action, _states = model.predict(obs, deterministic=True)
169
  obs, reward, done, truncated, _ = env.step(action)
170
 
171
- try:
172
- env.render()
173
- except pygame.error:
174
- print("Pygame display was closed. Exiting render loop.")
175
- return "Training interrupted."
176
-
177
- frame = pygame.surfarray.array3d(pygame.display.get_surface())
178
  frame = np.rot90(frame)
179
  frame = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR)
180
  video_frames.append(frame)
@@ -188,22 +183,20 @@ def train_and_play(reward_size, penalty_size, platform_reward, iterations):
188
  env.close()
189
  return video_path
190
 
191
-
192
- # Main function
193
  def main():
194
  iface = gr.Interface(
195
  fn=train_and_play,
196
  inputs=[
197
  gr.Number(label="Reward Size", value=1),
198
  gr.Number(label="Penalty Size", value=-1),
199
- gr.Number(label="Platform Reward", value=2),
200
  gr.Slider(label="Iterations", minimum=10, maximum=100000, step=10, value=10000)
201
  ],
202
  outputs="video",
203
- live=False # Изменено: генерация только после нажатия кнопки
204
  )
205
  iface.launch()
206
 
207
-
208
  if __name__ == "__main__":
209
  main()
 
4
  import random
5
  import gymnasium as gym
6
  from stable_baselines3 import DQN
7
+ from stable_baselines3.common.evaluation import evaluate_policy
8
  import gradio as gr
9
  import cv2
10
 
 
27
 
28
  # Initialize Pygame
29
  pygame.init()
 
 
30
 
31
  # Game classes
32
  class Paddle:
 
40
  self.rect.x += 10
41
  self.rect.clamp_ip(pygame.Rect(0, 0, SCREEN_WIDTH, SCREEN_HEIGHT))
42
 
 
43
  class Ball:
44
  def __init__(self):
45
  self.rect = pygame.Rect(SCREEN_WIDTH // 2 - BALL_RADIUS, SCREEN_HEIGHT // 2 - BALL_RADIUS, BALL_RADIUS * 2, BALL_RADIUS * 2)
 
58
  self.rect = pygame.Rect(SCREEN_WIDTH // 2 - BALL_RADIUS, SCREEN_HEIGHT // 2 - BALL_RADIUS, BALL_RADIUS * 2, BALL_RADIUS * 2)
59
  self.velocity = [random.choice([-5, 5]), -5]
60
 
 
61
  class Brick:
62
  def __init__(self, x, y):
63
  self.rect = pygame.Rect(x, y, BRICK_WIDTH - 5, BRICK_HEIGHT - 5)
64
 
 
65
  class ArkanoidEnv(gym.Env):
66
+ def __init__(self, reward_size=1, penalty_size=-1, platform_reward=5):
67
  super(ArkanoidEnv, self).__init__()
68
  self.action_space = gym.spaces.Discrete(3) # 0: stay, 1: move left, 2: move right
69
  self.observation_space = gym.spaces.Box(low=0, high=SCREEN_WIDTH, shape=(5 + BRICK_ROWS * BRICK_COLS * 2,), dtype=np.float32)
 
73
  self.reset()
74
 
75
  def reset(self, seed=None, options=None):
76
+ if seed is not None:
77
+ random.seed(seed)
78
+ np.random.seed(seed)
79
  self.paddle = Paddle()
80
  self.ball = Ball()
81
+ self.bricks = [Brick(x, y) for y in range(BRICK_HEIGHT, BRICK_HEIGHT * (BRICK_ROWS + 1), BRICK_HEIGHT)
82
  for x in range(BRICK_WIDTH, SCREEN_WIDTH - BRICK_WIDTH, BRICK_WIDTH)]
83
  self.done = False
84
  self.score = 0
 
94
 
95
  self.ball.move()
96
 
 
97
  if self.ball.rect.colliderect(self.paddle.rect):
98
  self.ball.velocity[1] = -self.ball.velocity[1]
99
+ self.score += self.platform_reward
100
 
101
  for brick in self.bricks[:]:
102
  if self.ball.rect.colliderect(brick.rect):
103
  self.bricks.remove(brick)
104
  self.ball.velocity[1] = -self.ball.velocity[1]
105
  self.score += 1
106
+ reward = self.reward_size
 
107
  if not self.bricks:
108
+ reward += self.reward_size * 10 # Bonus reward for breaking all bricks
109
  self.done = True
110
+ truncated = False
111
+ return self._get_state(), reward, self.done, truncated, {}
112
 
113
  if self.ball.rect.bottom >= SCREEN_HEIGHT:
114
  self.done = True
115
+ reward = self.penalty_size
116
+ truncated = False
117
+ else:
118
+ reward = 0
119
+ truncated = False
120
 
121
+ return self._get_state(), reward, self.done, truncated, {}
122
 
123
  def _get_state(self):
124
  state = [
 
130
  ]
131
  for brick in self.bricks:
132
  state.extend([brick.rect.x, brick.rect.y])
133
+ state.extend([0, 0] * (BRICK_ROWS * BRICK_COLS - len(self.bricks))) # Padding for missing bricks
134
  return np.array(state, dtype=np.float32)
135
 
136
+ def render(self, mode='rgb_array'):
137
+ surface = pygame.Surface((SCREEN_WIDTH, SCREEN_HEIGHT))
138
+ surface.fill(BLACK)
139
+ pygame.draw.rect(surface, WHITE, self.paddle.rect)
140
+ pygame.draw.ellipse(surface, WHITE, self.ball.rect)
 
 
 
141
  for brick in self.bricks:
142
+ pygame.draw.rect(surface, RED, brick.rect)
143
+
144
+ if mode == 'rgb_array':
145
+ return pygame.surfarray.array3d(surface)
146
+ elif mode == 'human':
147
+ pygame.display.get_surface().blit(surface, (0, 0))
148
+ pygame.display.flip()
149
 
150
  def close(self):
151
  pygame.quit()
152
 
 
153
  # Training and playing with custom parameters
154
  def train_and_play(reward_size, penalty_size, platform_reward, iterations):
155
  env = ArkanoidEnv(reward_size=reward_size, penalty_size=penalty_size, platform_reward=platform_reward)
 
165
 
166
  obs, _ = env.reset()
167
  done = False
168
+ while not done:
 
169
  action, _states = model.predict(obs, deterministic=True)
170
  obs, reward, done, truncated, _ = env.step(action)
171
 
172
+ frame = env.render(mode='rgb_array')
 
 
 
 
 
 
173
  frame = np.rot90(frame)
174
  frame = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR)
175
  video_frames.append(frame)
 
183
  env.close()
184
  return video_path
185
 
186
+ # Main function with Gradio interface
 
187
  def main():
188
  iface = gr.Interface(
189
  fn=train_and_play,
190
  inputs=[
191
  gr.Number(label="Reward Size", value=1),
192
  gr.Number(label="Penalty Size", value=-1),
193
+ gr.Number(label="Platform Reward", value=5),
194
  gr.Slider(label="Iterations", minimum=10, maximum=100000, step=10, value=10000)
195
  ],
196
  outputs="video",
197
+ live=False # Disable auto-generation on slider changes
198
  )
199
  iface.launch()
200
 
 
201
  if __name__ == "__main__":
202
  main()