Update app.py
Browse files
app.py
CHANGED
@@ -32,6 +32,7 @@ pygame.init()
|
|
32 |
class Paddle:
|
33 |
def __init__(self):
|
34 |
self.rect = pygame.Rect(SCREEN_WIDTH // 2 - PADDLE_WIDTH // 2, SCREEN_HEIGHT - PADDLE_HEIGHT - 10, PADDLE_WIDTH, PADDLE_HEIGHT)
|
|
|
35 |
|
36 |
def move(self, direction):
|
37 |
if direction == -1:
|
@@ -63,13 +64,16 @@ class Brick:
|
|
63 |
self.rect = pygame.Rect(x, y, BRICK_WIDTH - 5, BRICK_HEIGHT - 5)
|
64 |
|
65 |
class ArkanoidEnv(gym.Env):
|
66 |
-
def __init__(self, reward_size=1, penalty_size=-1, platform_reward=5):
|
67 |
super(ArkanoidEnv, self).__init__()
|
68 |
self.action_space = gym.spaces.Discrete(3) # 0: stay, 1: move left, 2: move right
|
69 |
-
self.observation_space = gym.spaces.Box(low=0, high=SCREEN_WIDTH, shape=(
|
70 |
self.reward_size = reward_size
|
71 |
self.penalty_size = penalty_size
|
72 |
self.platform_reward = platform_reward
|
|
|
|
|
|
|
73 |
self.reset()
|
74 |
|
75 |
def reset(self, seed=None, options=None):
|
@@ -82,9 +86,12 @@ class ArkanoidEnv(gym.Env):
|
|
82 |
for x in range(BRICK_WIDTH, SCREEN_WIDTH - BRICK_WIDTH, BRICK_WIDTH)]
|
83 |
self.done = False
|
84 |
self.score = 0
|
|
|
|
|
85 |
return self._get_state(), {}
|
86 |
|
87 |
def step(self, action):
|
|
|
88 |
if action == 0:
|
89 |
self.paddle.move(0)
|
90 |
elif action == 1:
|
@@ -92,46 +99,52 @@ class ArkanoidEnv(gym.Env):
|
|
92 |
elif action == 2:
|
93 |
self.paddle.move(1)
|
94 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
95 |
self.ball.move()
|
96 |
|
|
|
97 |
if self.ball.rect.colliderect(self.paddle.rect):
|
98 |
self.ball.velocity[1] = -self.ball.velocity[1]
|
|
|
99 |
self.score += self.platform_reward
|
100 |
|
|
|
101 |
for brick in self.bricks[:]:
|
102 |
if self.ball.rect.colliderect(brick.rect):
|
103 |
self.bricks.remove(brick)
|
104 |
self.ball.velocity[1] = -self.ball.velocity[1]
|
|
|
105 |
self.score += 1
|
106 |
-
reward = self.reward_size
|
107 |
if not self.bricks:
|
108 |
-
reward += self.reward_size * 10 # Bonus reward for breaking all bricks
|
109 |
self.done = True
|
110 |
-
|
111 |
-
return self._get_state(), reward, self.done, truncated, {}
|
112 |
|
|
|
113 |
if self.ball.rect.bottom >= SCREEN_HEIGHT:
|
114 |
self.done = True
|
115 |
reward = self.penalty_size
|
116 |
-
|
117 |
-
else:
|
118 |
-
reward = 0
|
119 |
-
truncated = False
|
120 |
|
121 |
-
|
|
|
|
|
122 |
|
123 |
def _get_state(self):
|
124 |
-
|
125 |
-
self.paddle.rect.x,
|
126 |
self.ball.rect.x,
|
127 |
-
self.
|
128 |
-
self.
|
129 |
-
|
130 |
-
]
|
131 |
-
for brick in self.bricks:
|
132 |
-
state.extend([brick.rect.x, brick.rect.y])
|
133 |
-
state.extend([0, 0] * (BRICK_ROWS * BRICK_COLS - len(self.bricks))) # Padding for missing bricks
|
134 |
-
return np.array(state, dtype=np.float32)
|
135 |
|
136 |
def render(self, mode='rgb_array'):
|
137 |
surface = pygame.Surface((SCREEN_WIDTH, SCREEN_HEIGHT))
|
@@ -150,40 +163,7 @@ class ArkanoidEnv(gym.Env):
|
|
150 |
def close(self):
|
151 |
pygame.quit()
|
152 |
|
153 |
-
#
|
154 |
-
def train_and_play(reward_size, penalty_size, platform_reward, iterations):
|
155 |
-
env = ArkanoidEnv(reward_size=reward_size, penalty_size=penalty_size, platform_reward=platform_reward)
|
156 |
-
model = DQN('MlpPolicy', env, verbose=1)
|
157 |
-
timesteps_per_update = min(1000, iterations)
|
158 |
-
video_frames = []
|
159 |
-
|
160 |
-
completed_iterations = 0
|
161 |
-
while completed_iterations < iterations:
|
162 |
-
steps = min(timesteps_per_update, iterations - completed_iterations)
|
163 |
-
model.learn(total_timesteps=steps)
|
164 |
-
completed_iterations += steps
|
165 |
-
|
166 |
-
obs, _ = env.reset()
|
167 |
-
done = False
|
168 |
-
while not done:
|
169 |
-
action, _states = model.predict(obs, deterministic=True)
|
170 |
-
obs, reward, done, truncated, _ = env.step(action)
|
171 |
-
|
172 |
-
frame = env.render(mode='rgb_array')
|
173 |
-
frame = np.rot90(frame)
|
174 |
-
frame = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR)
|
175 |
-
video_frames.append(frame)
|
176 |
-
|
177 |
-
video_path = "arkanoid_training.mp4"
|
178 |
-
video_writer = cv2.VideoWriter(video_path, cv2.VideoWriter_fourcc(*'mp4v'), FPS, (SCREEN_WIDTH, SCREEN_HEIGHT))
|
179 |
-
for frame in video_frames:
|
180 |
-
video_writer.write(frame)
|
181 |
-
video_writer.release()
|
182 |
-
|
183 |
-
env.close()
|
184 |
-
return video_path
|
185 |
-
|
186 |
-
# Main function with Gradio interface
|
187 |
def main():
|
188 |
iface = gr.Interface(
|
189 |
fn=train_and_play,
|
@@ -191,10 +171,10 @@ def main():
|
|
191 |
gr.Number(label="Reward Size", value=1),
|
192 |
gr.Number(label="Penalty Size", value=-1),
|
193 |
gr.Number(label="Platform Reward", value=5),
|
|
|
194 |
gr.Slider(label="Iterations", minimum=10, maximum=100000, step=10, value=10000)
|
195 |
],
|
196 |
-
outputs="video"
|
197 |
-
live=False # Disable auto-generation on slider changes
|
198 |
)
|
199 |
iface.launch()
|
200 |
|
|
|
32 |
class Paddle:
|
33 |
def __init__(self):
|
34 |
self.rect = pygame.Rect(SCREEN_WIDTH // 2 - PADDLE_WIDTH // 2, SCREEN_HEIGHT - PADDLE_HEIGHT - 10, PADDLE_WIDTH, PADDLE_HEIGHT)
|
35 |
+
self.last_position = self.rect.x
|
36 |
|
37 |
def move(self, direction):
|
38 |
if direction == -1:
|
|
|
64 |
self.rect = pygame.Rect(x, y, BRICK_WIDTH - 5, BRICK_HEIGHT - 5)
|
65 |
|
66 |
class ArkanoidEnv(gym.Env):
|
67 |
+
def __init__(self, reward_size=1, penalty_size=-1, platform_reward=5, inactivity_penalty=-0.5):
|
68 |
super(ArkanoidEnv, self).__init__()
|
69 |
self.action_space = gym.spaces.Discrete(3) # 0: stay, 1: move left, 2: move right
|
70 |
+
self.observation_space = gym.spaces.Box(low=0, high=SCREEN_WIDTH, shape=(3,), dtype=np.float32)
|
71 |
self.reward_size = reward_size
|
72 |
self.penalty_size = penalty_size
|
73 |
self.platform_reward = platform_reward
|
74 |
+
self.inactivity_penalty = inactivity_penalty
|
75 |
+
self.inactivity_counter = 0
|
76 |
+
self.last_action = 0
|
77 |
self.reset()
|
78 |
|
79 |
def reset(self, seed=None, options=None):
|
|
|
86 |
for x in range(BRICK_WIDTH, SCREEN_WIDTH - BRICK_WIDTH, BRICK_WIDTH)]
|
87 |
self.done = False
|
88 |
self.score = 0
|
89 |
+
self.inactivity_counter = 0
|
90 |
+
self.last_action = None
|
91 |
return self._get_state(), {}
|
92 |
|
93 |
def step(self, action):
|
94 |
+
# Apply action
|
95 |
if action == 0:
|
96 |
self.paddle.move(0)
|
97 |
elif action == 1:
|
|
|
99 |
elif action == 2:
|
100 |
self.paddle.move(1)
|
101 |
|
102 |
+
# Update inactivity penalty
|
103 |
+
if action == 0:
|
104 |
+
self.inactivity_counter += 1 / FPS
|
105 |
+
else:
|
106 |
+
self.inactivity_counter = 0
|
107 |
+
|
108 |
+
if self.inactivity_counter >= 1:
|
109 |
+
reward = self.inactivity_penalty
|
110 |
+
return self._get_state(), reward, self.done, False, {}
|
111 |
+
|
112 |
+
# Update ball position
|
113 |
self.ball.move()
|
114 |
|
115 |
+
# Collision with paddle
|
116 |
if self.ball.rect.colliderect(self.paddle.rect):
|
117 |
self.ball.velocity[1] = -self.ball.velocity[1]
|
118 |
+
self.ball.velocity[0] += random.uniform(-1, 1) # Add random offset to angle
|
119 |
self.score += self.platform_reward
|
120 |
|
121 |
+
# Collision with bricks
|
122 |
for brick in self.bricks[:]:
|
123 |
if self.ball.rect.colliderect(brick.rect):
|
124 |
self.bricks.remove(brick)
|
125 |
self.ball.velocity[1] = -self.ball.velocity[1]
|
126 |
+
self.ball.velocity[0] += random.uniform(-1, 1) # Add random offset to angle
|
127 |
self.score += 1
|
|
|
128 |
if not self.bricks:
|
|
|
129 |
self.done = True
|
130 |
+
return self._get_state(), self.reward_size, self.done, False, {}
|
|
|
131 |
|
132 |
+
# Check if ball is out of bounds
|
133 |
if self.ball.rect.bottom >= SCREEN_HEIGHT:
|
134 |
self.done = True
|
135 |
reward = self.penalty_size
|
136 |
+
return self._get_state(), reward, self.done, False, {}
|
|
|
|
|
|
|
137 |
|
138 |
+
# Calculate reward for breaking bricks
|
139 |
+
reward = 0
|
140 |
+
return self._get_state(), reward, self.done, False, {}
|
141 |
|
142 |
def _get_state(self):
|
143 |
+
return np.array([
|
|
|
144 |
self.ball.rect.x,
|
145 |
+
self.paddle.rect.x,
|
146 |
+
len(self.bricks)
|
147 |
+
], dtype=np.float32)
|
|
|
|
|
|
|
|
|
|
|
148 |
|
149 |
def render(self, mode='rgb_array'):
|
150 |
surface = pygame.Surface((SCREEN_WIDTH, SCREEN_HEIGHT))
|
|
|
163 |
def close(self):
|
164 |
pygame.quit()
|
165 |
|
166 |
+
# Main function remains unchanged
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
167 |
def main():
|
168 |
iface = gr.Interface(
|
169 |
fn=train_and_play,
|
|
|
171 |
gr.Number(label="Reward Size", value=1),
|
172 |
gr.Number(label="Penalty Size", value=-1),
|
173 |
gr.Number(label="Platform Reward", value=5),
|
174 |
+
gr.Number(label="Inactivity Penalty", value=-0.5),
|
175 |
gr.Slider(label="Iterations", minimum=10, maximum=100000, step=10, value=10000)
|
176 |
],
|
177 |
+
outputs="video"
|
|
|
178 |
)
|
179 |
iface.launch()
|
180 |
|