Gofor5 commited on
Commit
ca0a8ed
·
verified ·
1 Parent(s): b592a60

Upload 3 files

Browse files
Files changed (3) hide show
  1. 2048的网页实现.py +293 -0
  2. game2048.py +200 -0
  3. main.py +837 -0
2048的网页实现.py ADDED
@@ -0,0 +1,293 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import numpy as np
3
+ import random
4
+ import torch
5
+ import torch.nn as nn
6
+ import torch.nn.functional as F
7
+ from game2048 import Game2048
8
+
9
+ # 创建游戏实例
10
+ game = Game2048(size=4)
11
+
12
+ # 方块颜色映射(根据数字值)
13
+ TILE_COLORS = {
14
+ 0: "#cdc1b4", # 空白格子
15
+ 2: "#eee4da", # 2
16
+ 4: "#ede0c8", # 4
17
+ 8: "#f2b179", # 8
18
+ 16: "#f59563", # 16
19
+ 32: "#f67c5f", # 32
20
+ 64: "#f65e3b", # 64
21
+ 128: "#edcf72", # 128
22
+ 256: "#edcc61", # 256
23
+ 512: "#edc850", # 512
24
+ 1024: "#edc53f", # 1024
25
+ 2048: "#edc22e", # 2048
26
+ 4096: "#3c3a32", # 4096+
27
+ }
28
+
29
+ # 文本颜色映射(根据背景深浅)
30
+ TEXT_COLORS = {
31
+ 0: "#776e65", # 空白格子
32
+ 2: "#776e65", # 2
33
+ 4: "#776e65", # 4
34
+ 8: "#f9f6f2", # 8+
35
+ 16: "#f9f6f2", # 16+
36
+ 32: "#f9f6f2", # 32+
37
+ 64: "#f9f6f2", # 64+
38
+ 128: "#f9f6f2", # 128+
39
+ 256: "#f9f6f2", # 256+
40
+ 512: "#f9f6f2", # 512+
41
+ 1024: "#f9f6f2", # 1024+
42
+ 2048: "#f9f6f2", # 2048+
43
+ 4096: "#f9f6f2", # 4096+
44
+ }
45
+
46
+ # 定义DQN网络结构(与训练时相同)
47
+ class DQN(nn.Module):
48
+ def __init__(self, input_channels, output_size):
49
+ super(DQN, self).__init__()
50
+ self.input_channels = input_channels
51
+
52
+ # 卷积层
53
+ self.conv1 = nn.Conv2d(input_channels, 128, kernel_size=3, padding=1)
54
+ self.conv2 = nn.Conv2d(128, 128, kernel_size=3, padding=1)
55
+ self.conv3 = nn.Conv2d(128, 128, kernel_size=3, padding=1)
56
+
57
+ # Dueling DQN架构
58
+ # 价值流
59
+ self.value_conv = nn.Conv2d(128, 4, kernel_size=1)
60
+ self.value_fc1 = nn.Linear(4 * 4 * 4, 128)
61
+ self.value_fc2 = nn.Linear(128, 1)
62
+
63
+ # 优势流
64
+ self.advantage_conv = nn.Conv2d(128, 16, kernel_size=1)
65
+ self.advantage_fc1 = nn.Linear(16 * 4 * 4, 128)
66
+ self.advantage_fc2 = nn.Linear(128, output_size)
67
+
68
+ def forward(self, x):
69
+ x = F.relu(self.conv1(x))
70
+ x = F.relu(self.conv2(x))
71
+ x = F.relu(self.conv3(x))
72
+
73
+ # 价值流
74
+ value = F.relu(self.value_conv(x))
75
+ value = value.view(value.size(0), -1)
76
+ value = F.relu(self.value_fc1(value))
77
+ value = self.value_fc2(value)
78
+
79
+ # 优势流
80
+ advantage = F.relu(self.advantage_conv(x))
81
+ advantage = advantage.view(advantage.size(0), -1)
82
+ advantage = F.relu(self.advantage_fc1(advantage))
83
+ advantage = self.advantage_fc2(advantage)
84
+
85
+ # 合并价值流和优势流
86
+ q_values = value + advantage - advantage.mean(dim=1, keepdim=True)
87
+ return q_values
88
+
89
+ # 加载模型
90
+ def load_model(model_path):
91
+ model = DQN(4, 4) # 输入通道4,输出动作4个
92
+ try:
93
+ checkpoint = torch.load(model_path, map_location=torch.device('cpu'))
94
+ model.load_state_dict(checkpoint['policy_net_state_dict'])
95
+ model.eval()
96
+ print("模型加载成功")
97
+ return model
98
+ except Exception as e:
99
+ print(f"模型加载失败: {e}")
100
+ return None
101
+
102
+ # 尝试加载模型
103
+ model_path = "models/dqn_2048_best_tile.pth"
104
+ model = load_model(model_path)
105
+
106
+ def render_board(board):
107
+ html = "<div style='background-color:#bbada0; padding:10px; border-radius:6px;'>"
108
+ html += "<table style='border-spacing:10px; border-collapse:separate;'>"
109
+
110
+ for i in range(game.size):
111
+ html += "<tr>"
112
+ for j in range(game.size):
113
+ value = board[i][j]
114
+ color = TILE_COLORS.get(value, "#3c3a32") # 默认深色
115
+ text_color = TEXT_COLORS.get(value, "#f9f6f2") # 默认浅色
116
+ font_size = "36px" if value < 100 else "30px" if value < 1000 else "24px"
117
+
118
+ html += f"""
119
+ <td style='background-color:{color};
120
+ width:80px; height:80px;
121
+ border-radius:4px;
122
+ text-align:center;
123
+ font-weight:bold;
124
+ font-size:{font_size};
125
+ color:{text_color};'>
126
+ {value if value > 0 else ''}
127
+ </td>
128
+ """
129
+ html += "</tr>"
130
+
131
+ html += "</table></div>"
132
+ return html
133
+
134
+ def make_move(direction):
135
+ """执行移动操作并更新界面"""
136
+ direction_names = ["上", "右", "下", "左"]
137
+
138
+ # 执行移动
139
+ new_board, game_over = game.move(direction)
140
+
141
+ # 渲染棋盘
142
+ board_html = render_board(new_board)
143
+
144
+ # 更新状态信息
145
+ status = f"<b>移动方向:</b> {direction_names[direction]}"
146
+ status += f"<br><b>当前分数:</b> {game.score}"
147
+ status += f"<br><b>最大方块:</b> {np.max(game.board)}"
148
+
149
+ if game.game_over:
150
+ status += "<br><br><div style='color:#ff0000; font-weight:bold;'>游戏结束!</div>"
151
+ status += f"<br><b>最终分数:</b> {game.score}"
152
+
153
+ return board_html, status
154
+
155
+ def reset_game():
156
+ """重置游戏"""
157
+ global game
158
+ game = Game2048(size=4)
159
+ board = game.reset()
160
+
161
+ # 渲染棋盘
162
+ board_html = render_board(board)
163
+
164
+ # 初始状态信息
165
+ status = "<b>游戏已重置!</b>"
166
+ status += f"<br><b>当前分数:</b> {game.score}"
167
+ status += f"<br><b>最大方块:</b> {np.max(game.board)}"
168
+
169
+ return board_html, status
170
+
171
+ def ai_move():
172
+ """使用AI模型进行一步移动"""
173
+ if model is None:
174
+ return render_board(game.board), "<b>错误:</b> 未加载AI模型"
175
+
176
+ # 获取当前状态
177
+ state = game.get_state()
178
+
179
+ # 获取有效移动
180
+ valid_moves = game.get_valid_moves()
181
+ if not valid_moves:
182
+ return render_board(game.board), "<b>游戏结束!</b> 没有有效移动"
183
+
184
+ # 转换状态为模型输入
185
+ state_tensor = torch.tensor(state, dtype=torch.float).unsqueeze(0)
186
+
187
+ # 模型预测
188
+ with torch.no_grad():
189
+ q_values = model(state_tensor).numpy().flatten()
190
+
191
+ # 只考虑有效动作
192
+ valid_q_values = np.full(4, -np.inf)
193
+ for move in valid_moves:
194
+ valid_q_values[move] = q_values[move]
195
+
196
+ # 选择最佳动作
197
+ action = np.argmax(valid_q_values)
198
+
199
+ # 执行移动
200
+ direction_names = ["上", "右", "下", "左"]
201
+ new_board, game_over = game.move(action)
202
+
203
+ # 渲染棋盘
204
+ board_html = render_board(new_board)
205
+
206
+ # 更新状态信息
207
+ status = f"<b>AI移动方向:</b> {direction_names[action]}"
208
+ status += f"<br><b>当前分数:</b> {game.score}"
209
+ status += f"<br><b>最大方块:</b> {np.max(game.board)}"
210
+
211
+ if game.game_over:
212
+ status += "<br><br><div style='color:#ff0000; font-weight:bold;'>游戏结束!</div>"
213
+ status += f"<br><b>最终分数:</b> {game.score}"
214
+
215
+ return board_html, status
216
+
217
+ # 创建Gradio界面
218
+ with gr.Blocks(title="2048游戏", theme="soft") as demo:
219
+ gr.Markdown("# 🎮 2048游戏")
220
+ gr.Markdown("使用方向键或下方的按钮移动方块,相同数字的方块相撞时会合并!")
221
+ with gr.Row():
222
+ with gr.Column(scale=2):
223
+ board_html = gr.HTML(render_board(game.board))
224
+ with gr.Row(visible=False):
225
+ status_display = gr.HTML("<b>当前分数:</b> 0<br><b>最大方块:</b> 2")
226
+ with gr.Column():
227
+ gr.Markdown("## 手动操作")
228
+ with gr.Row():
229
+ gr.Button("上 ↑", elem_id="up-btn").click(
230
+ fn=lambda: make_move(0),
231
+ outputs=[board_html, status_display]
232
+ )
233
+ gr.Button("左 ←", elem_id="left-btn").click(
234
+ fn=lambda: make_move(3),
235
+ outputs=[board_html, status_display]
236
+ )
237
+ with gr.Row():
238
+ gr.Button("下 ↓", elem_id="down-btn").click(
239
+ fn=lambda: make_move(2),
240
+ outputs=[board_html, status_display]
241
+ )
242
+ gr.Button("右 →", elem_id="right-btn").click(
243
+ fn=lambda: make_move(1),
244
+ outputs=[board_html, status_display]
245
+ )
246
+ with gr.Row():
247
+ gr.Button("🔄 重置游戏", elem_id="reset-btn").click(
248
+ fn=reset_game,
249
+ outputs=[board_html, status_display]
250
+ )
251
+ with gr.Row():
252
+ status_display = gr.HTML("<b>当前分数:</b> 0<br><b>最大方块:</b> 2")
253
+ with gr.Column():
254
+ gr.Markdown("## AI操作")
255
+ gr.Button("🤖 AI移动一步", elem_id="ai-btn").click(
256
+ fn=ai_move,
257
+ outputs=[board_html, status_display]
258
+ )
259
+ gr.Markdown("基于DQN神经网络提供支持")
260
+
261
+ # 添加键盘快捷键支持
262
+ demo.load(
263
+ fn=None,
264
+ inputs=None,
265
+ outputs=None,
266
+ js="""() => {
267
+ document.addEventListener('keydown', function(e) {
268
+ if (e.key === 'ArrowUp') {
269
+ document.getElementById('up-btn').click();
270
+ } else if (e.key === 'ArrowRight') {
271
+ document.getElementById('right-btn').click();
272
+ } else if (e.key === 'ArrowDown') {
273
+ document.getElementById('down-btn').click();
274
+ } else if (e.key === 'ArrowLeft') {
275
+ document.getElementById('left-btn').click();
276
+ } else if (e.key === 'r' || e.key === 'R') {
277
+ document.getElementById('reset-btn').click();
278
+ } else if (e.key === 'a' || e.key === 'A') {
279
+ document.getElementById('ai-btn').click();
280
+ }
281
+ });
282
+ }"""
283
+ )
284
+ gr.Markdown("### 📚 使用说明")
285
+ gr.Markdown("1. 使用方向键或下方的按钮移动方块。")
286
+ gr.Markdown("2. 相同数字的方块相撞时会合并。")
287
+ gr.Markdown("3. 快捷键说明:上/下/左/右键移动方块,R键重置游戏,A键AI移动一步。")
288
+ gr.Markdown("4. 点击 '🤖 AI移动一步' 按钮可以使用AI模型进行一步移动。")
289
+ gr.Markdown("5. 游戏结束后,会显示最终分数和最大方块。")
290
+
291
+ # 启动界面
292
+ if __name__ == "__main__":
293
+ demo.launch(share=True)
game2048.py ADDED
@@ -0,0 +1,200 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import random
3
+
4
+ class Game2048:
5
+ def __init__(self, size=4):
6
+ self.size = size
7
+ self.reset()
8
+
9
+ def reset(self):
10
+ """重置游戏状态"""
11
+ self.board = np.zeros((self.size, self.size), dtype=np.int32)
12
+ self.score = 0
13
+ self.add_tile()
14
+ self.add_tile()
15
+ self.game_over = False
16
+ return self.board.copy()
17
+
18
+ def add_tile(self):
19
+ """在随机空位置添加新方块(90%概率为2,10%概率为4)"""
20
+ empty_cells = []
21
+ for i in range(self.size):
22
+ for j in range(self.size):
23
+ if self.board[i][j] == 0:
24
+ empty_cells.append((i, j))
25
+
26
+ if empty_cells:
27
+ i, j = random.choice(empty_cells)
28
+ self.board[i][j] = 2 if random.random() < 0.9 else 4
29
+
30
+ def move(self, direction):
31
+ """
32
+ 执行移动操作
33
+ 0: 上, 1: 右, 2: 下, 3: 左
34
+ 返回: (新棋盘状态, 游戏是否结束)
35
+ """
36
+ moved = False
37
+ # 根据方向执行移动
38
+ if direction == 0: # 上
39
+ for j in range(self.size):
40
+ column = self.board[:, j].copy()
41
+ new_column, moved_col = self.slide(column)
42
+ if moved_col:
43
+ moved = True
44
+ self.board[:, j] = new_column
45
+
46
+ elif direction == 1: # 右
47
+ for i in range(self.size):
48
+ row = self.board[i, :].copy()[::-1]
49
+ new_row, moved_row = self.slide(row)
50
+ if moved_row:
51
+ moved = True
52
+ self.board[i, :] = new_row[::-1]
53
+
54
+ elif direction == 2: # 下
55
+ for j in range(self.size):
56
+ column = self.board[::-1, j].copy()
57
+ new_column, moved_col = self.slide(column)
58
+ if moved_col:
59
+ moved = True
60
+ self.board[:, j] = new_column[::-1]
61
+
62
+ elif direction == 3: # 左
63
+ for i in range(self.size):
64
+ row = self.board[i, :].copy()
65
+ new_row, moved_row = self.slide(row)
66
+ if moved_row:
67
+ moved = True
68
+ self.board[i, :] = new_row
69
+
70
+ # 如果发生了移动,添加新方块并检查游戏结束
71
+ if moved:
72
+ self.add_tile()
73
+ self.check_game_over()
74
+
75
+ return self.board.copy(), self.game_over
76
+
77
+ def slide(self, line):
78
+ """处理单行/列的移动和合并逻辑"""
79
+ non_zero = line[line != 0]
80
+ new_line = np.zeros_like(line)
81
+ idx = 0
82
+ score_inc = 0
83
+ moved = False
84
+
85
+ # 检查是否移动
86
+ if not np.array_equal(non_zero, line[:len(non_zero)]):
87
+ moved = True
88
+
89
+ # 合并相同数字
90
+ i = 0
91
+ while i < len(non_zero):
92
+ if i + 1 < len(non_zero) and non_zero[i] == non_zero[i+1]:
93
+ new_val = non_zero[i] * 2
94
+ new_line[idx] = new_val
95
+ score_inc += new_val
96
+ i += 2
97
+ idx += 1
98
+ else:
99
+ new_line[idx] = non_zero[i]
100
+ i += 1
101
+ idx += 1
102
+
103
+ self.score += score_inc
104
+ return new_line, moved or (score_inc > 0)
105
+
106
+ def check_game_over(self):
107
+ """检查游戏是否结束"""
108
+ # 检查是否还有空格子
109
+ if np.any(self.board == 0):
110
+ self.game_over = False
111
+ return
112
+
113
+ # 检查水平和垂直方向是否有可合并的方块
114
+ for i in range(self.size):
115
+ for j in range(self.size - 1):
116
+ if self.board[i][j] == self.board[i][j+1]:
117
+ self.game_over = False
118
+ return
119
+
120
+ for j in range(self.size):
121
+ for i in range(self.size - 1):
122
+ if self.board[i][j] == self.board[i+1][j]:
123
+ self.game_over = False
124
+ return
125
+
126
+ self.game_over = True
127
+
128
+ def get_valid_moves(self):
129
+ """获取当前所有有效移动方向"""
130
+ valid_moves = []
131
+
132
+ # 检查上移是否有效
133
+ for j in range(self.size):
134
+ column = self.board[:, j].copy()
135
+ new_column, _ = self.slide(column)
136
+ if not np.array_equal(new_column, self.board[:, j]):
137
+ valid_moves.append(0)
138
+ break
139
+
140
+ # 检查右移是否有效
141
+ for i in range(self.size):
142
+ row = self.board[i, :].copy()[::-1]
143
+ new_row, _ = self.slide(row)
144
+ if not np.array_equal(new_row[::-1], self.board[i, :]):
145
+ valid_moves.append(1)
146
+ break
147
+
148
+ # 检查下移是否有效
149
+ for j in range(self.size):
150
+ column = self.board[::-1, j].copy()
151
+ new_column, _ = self.slide(column)
152
+ if not np.array_equal(new_column[::-1], self.board[:, j]):
153
+ valid_moves.append(2)
154
+ break
155
+
156
+ # 检查左移是否有效
157
+ for i in range(self.size):
158
+ row = self.board[i, :].copy()
159
+ new_row, _ = self.slide(row)
160
+ if not np.array_equal(new_row, self.board[i, :]):
161
+ valid_moves.append(3)
162
+ break
163
+
164
+ return valid_moves
165
+
166
+ def get_state(self):
167
+ """获取当前游戏状态表示(用于AI模型)"""
168
+ # 创建4个通道的状态表示
169
+ state = np.zeros((4, self.size, self.size), dtype=np.float32)
170
+
171
+ # 通道0: 当前方块值的对数(归一化)
172
+ for i in range(self.size):
173
+ for j in range(self.size):
174
+ if self.board[i][j] > 0:
175
+ state[0, i, j] = np.log2(self.board[i][j]) / 16.0 # 支持到65536 (2^16)
176
+
177
+ # 通道1: 空格子指示器
178
+ state[1] = (self.board == 0).astype(np.float32)
179
+
180
+ # 通道2: 可合并的邻居指示器
181
+ for i in range(self.size):
182
+ for j in range(self.size):
183
+ if self.board[i][j] > 0:
184
+ # 检查右侧
185
+ if j < self.size - 1 and self.board[i][j] == self.board[i][j+1]:
186
+ state[2, i, j] = 1.0
187
+ state[2, i, j+1] = 1.0
188
+ # 检查下方
189
+ if i < self.size - 1 and self.board[i][j] == self.board[i+1][j]:
190
+ state[2, i, j] = 1.0
191
+ state[2, i+1, j] = 1.0
192
+
193
+ # 通道3: 最大值位置(归一化)
194
+ max_value = np.max(self.board)
195
+ if max_value > 0:
196
+ max_positions = np.argwhere(self.board == max_value)
197
+ for pos in max_positions:
198
+ state[3, pos[0], pos[1]] = 1.0
199
+
200
+ return state
main.py ADDED
@@ -0,0 +1,837 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.optim as optim
5
+ import torch.nn.functional as F
6
+ import random
7
+ import os
8
+ from tqdm import tqdm
9
+ import matplotlib.pyplot as plt
10
+ import warnings
11
+
12
+ if torch.cuda.is_available():
13
+ device=torch.device("cuda")
14
+ elif torch.xpu.is_available():
15
+ device=torch.device("xpu")
16
+ else:
17
+ device=torch.device("cpu")
18
+ print(f"Using device: {device}")
19
+
20
+ # 2048游戏环境(改进版)
21
+ class Game2048:
22
+ def __init__(self, size=4):
23
+ self.size = size
24
+ self.reset()
25
+
26
+ def reset(self):
27
+ self.board = np.zeros((self.size, self.size), dtype=np.int32)
28
+ self.score = 0
29
+ self.prev_score = 0
30
+ self.add_tile()
31
+ self.add_tile()
32
+ self.game_over = False
33
+ return self.get_state()
34
+
35
+ def add_tile(self):
36
+ empty_cells = []
37
+ for i in range(self.size):
38
+ for j in range(self.size):
39
+ if self.board[i][j] == 0:
40
+ empty_cells.append((i, j))
41
+
42
+ if empty_cells:
43
+ i, j = random.choice(empty_cells)
44
+ self.board[i][j] = 2 if random.random() < 0.9 else 4
45
+
46
+ def move(self, direction):
47
+ # 0: 上, 1: 右, 2: 下, 3: 左
48
+ moved = False
49
+ original_board = self.board.copy()
50
+ old_score = self.score
51
+
52
+ # 根据方向执行移动
53
+ if direction == 0: # 上
54
+ for j in range(self.size):
55
+ column = self.board[:, j].copy()
56
+ new_column, moved_col = self.slide(column)
57
+ if moved_col:
58
+ moved = True
59
+ self.board[:, j] = new_column
60
+
61
+ elif direction == 1: # 右
62
+ for i in range(self.size):
63
+ row = self.board[i, :].copy()[::-1]
64
+ new_row, moved_row = self.slide(row)
65
+ if moved_row:
66
+ moved = True
67
+ self.board[i, :] = new_row[::-1]
68
+
69
+ elif direction == 2: # 下
70
+ for j in range(self.size):
71
+ column = self.board[::-1, j].copy()
72
+ new_column, moved_col = self.slide(column)
73
+ if moved_col:
74
+ moved = True
75
+ self.board[:, j] = new_column[::-1]
76
+
77
+ elif direction == 3: # 左
78
+ for i in range(self.size):
79
+ row = self.board[i, :].copy()
80
+ new_row, moved_row = self.slide(row)
81
+ if moved_row:
82
+ moved = True
83
+ self.board[i, :] = new_row
84
+
85
+ # 如果发生了移动,添加新方块
86
+ if moved:
87
+ self.add_tile()
88
+ self.check_game_over()
89
+
90
+ reward = self.calculate_reward(old_score, original_board)
91
+ return self.get_state(), reward, self.game_over
92
+
93
+ def slide(self, line):
94
+ # 移除零并合并相同数字
95
+ non_zero = line[line != 0]
96
+ new_line = np.zeros_like(line)
97
+ idx = 0
98
+ score_inc = 0
99
+ moved = False
100
+
101
+ # 检查是否移动
102
+ if not np.array_equal(non_zero, line[:len(non_zero)]):
103
+ moved = True
104
+
105
+ # 合并相同数字
106
+ i = 0
107
+ while i < len(non_zero):
108
+ if i + 1 < len(non_zero) and non_zero[i] == non_zero[i+1]:
109
+ new_val = non_zero[i] * 2
110
+ new_line[idx] = new_val
111
+ score_inc += new_val
112
+ i += 2
113
+ idx += 1
114
+ else:
115
+ new_line[idx] = non_zero[i]
116
+ i += 1
117
+ idx += 1
118
+
119
+ self.score += score_inc
120
+ return new_line, moved or (score_inc > 0)
121
+
122
+ def calculate_reward(self, old_score, original_board):
123
+ """改进的奖励函数"""
124
+ # 1. 基本分数奖励
125
+ score_reward = (self.score - old_score) * 0.1
126
+
127
+ # 2. 空格子数量变化奖励
128
+ empty_before = np.count_nonzero(original_board == 0)
129
+ empty_after = np.count_nonzero(self.board == 0)
130
+ empty_reward = (empty_after - empty_before) * 0.15
131
+
132
+ # 3. 最大方块奖励
133
+ max_before = np.max(original_board)
134
+ max_after = np.max(self.board)
135
+ max_tile_reward = 0
136
+ if max_after > max_before:
137
+ max_tile_reward = np.log2(max_after) * 0.2
138
+
139
+ # 4. 合并奖励(鼓励合并)
140
+ merge_reward = 0
141
+ if self.score - old_score > 0:
142
+ merge_reward = np.log2(self.score - old_score) * 0.1
143
+
144
+ # 5. 单调性惩罚(鼓励有序排列)
145
+ monotonicity_penalty = self.calculate_monotonicity_penalty() * 0.01
146
+
147
+ # 6. 游戏结束惩罚
148
+ game_over_penalty = 0
149
+ if self.game_over:
150
+ game_over_penalty = -10
151
+
152
+ # 7. 平滑度奖励(鼓励相邻方块值接近)
153
+ smoothness_reward = self.calculate_smoothness() * 0.01
154
+
155
+ # 总奖励
156
+ total_reward = (
157
+ score_reward +
158
+ empty_reward +
159
+ max_tile_reward +
160
+ merge_reward +
161
+ smoothness_reward +
162
+ monotonicity_penalty +
163
+ game_over_penalty
164
+ )
165
+
166
+ return total_reward
167
+
168
+ def calculate_monotonicity_penalty(self):
169
+ """计算单调性惩罚(值越低越好)"""
170
+ penalty = 0
171
+ for i in range(self.size):
172
+ for j in range(self.size - 1):
173
+ if self.board[i][j] > self.board[i][j+1]:
174
+ penalty += self.board[i][j] - self.board[i][j+1]
175
+ else:
176
+ penalty += self.board[i][j+1] - self.board[i][j]
177
+ return penalty
178
+
179
+ def calculate_smoothness(self):
180
+ """计算平滑度(值越高越好)"""
181
+ smoothness = 0
182
+ for i in range(self.size):
183
+ for j in range(self.size):
184
+ if self.board[i][j] != 0:
185
+ value = np.log2(self.board[i][j])
186
+ # 检查右侧邻居
187
+ if j < self.size - 1 and self.board[i][j+1] != 0:
188
+ neighbor_value = np.log2(self.board[i][j+1])
189
+ smoothness -= abs(value - neighbor_value)
190
+ # 检查下方邻居
191
+ if i < self.size - 1 and self.board[i+1][j] != 0:
192
+ neighbor_value = np.log2(self.board[i+1][j])
193
+ smoothness -= abs(value - neighbor_value)
194
+ return smoothness
195
+
196
+ def check_game_over(self):
197
+ # 检查是否还有空格子
198
+ if np.any(self.board == 0):
199
+ self.game_over = False
200
+ return
201
+
202
+ # 检查水平和垂直方向是否有可合并的方块
203
+ for i in range(self.size):
204
+ for j in range(self.size - 1):
205
+ if self.board[i][j] == self.board[i][j+1]:
206
+ self.game_over = False
207
+ return
208
+
209
+ for j in range(self.size):
210
+ for i in range(self.size - 1):
211
+ if self.board[i][j] == self.board[i+1][j]:
212
+ self.game_over = False
213
+ return
214
+
215
+ self.game_over = True
216
+
217
+ def get_state(self):
218
+ """改进的状态表示"""
219
+ # 创建4个通道的状态表示
220
+ state = np.zeros((4, self.size, self.size), dtype=np.float32)
221
+
222
+ # 通道0: 当前方块值的对数(归一化)
223
+ for i in range(self.size):
224
+ for j in range(self.size):
225
+ if self.board[i][j] > 0:
226
+ state[0, i, j] = np.log2(self.board[i][j]) / 16.0 # 支持到65536 (2^16)
227
+
228
+ # 通道1: 空格子指示器
229
+ state[1] = (self.board == 0).astype(np.float32)
230
+
231
+ # 通道2: 可合并的邻居指示器
232
+ for i in range(self.size):
233
+ for j in range(self.size):
234
+ if self.board[i][j] > 0:
235
+ # 检查右侧
236
+ if j < self.size - 1 and self.board[i][j] == self.board[i][j+1]:
237
+ state[2, i, j] = 1.0
238
+ state[2, i, j+1] = 1.0
239
+ # 检查下方
240
+ if i < self.size - 1 and self.board[i][j] == self.board[i+1][j]:
241
+ state[2, i, j] = 1.0
242
+ state[2, i+1, j] = 1.0
243
+
244
+ # 通道3: 最大值位置(归一化)
245
+ max_value = np.max(self.board)
246
+ if max_value > 0:
247
+ max_positions = np.argwhere(self.board == max_value)
248
+ for pos in max_positions:
249
+ state[3, pos[0], pos[1]] = 1.0
250
+
251
+ return state
252
+
253
+ def get_valid_moves(self):
254
+ """更高效的有效移动检测"""
255
+ valid_moves = []
256
+ #test_board = np.zeros_like(self.board)
257
+
258
+ # 检查上移是否有效
259
+ for j in range(self.size):
260
+ column = self.board[:, j].copy()
261
+ new_column, _ = self.slide(column)
262
+ if not np.array_equal(new_column, self.board[:, j]):
263
+ valid_moves.append(0)
264
+ break
265
+
266
+ # 检查右移是否有效
267
+ for i in range(self.size):
268
+ row = self.board[i, :].copy()[::-1]
269
+ new_row, _ = self.slide(row)
270
+ if not np.array_equal(new_row[::-1], self.board[i, :]):
271
+ valid_moves.append(1)
272
+ break
273
+
274
+ # 检查下移是否有效
275
+ for j in range(self.size):
276
+ column = self.board[::-1, j].copy()
277
+ new_column, _ = self.slide(column)
278
+ if not np.array_equal(new_column[::-1], self.board[:, j]):
279
+ valid_moves.append(2)
280
+ break
281
+
282
+ # 检查左移是否有效
283
+ for i in range(self.size):
284
+ row = self.board[i, :].copy()
285
+ new_row, _ = self.slide(row)
286
+ if not np.array_equal(new_row, self.board[i, :]):
287
+ valid_moves.append(3)
288
+ break
289
+
290
+ return valid_moves
291
+
292
+ # 改进的深度Q网络(使用Dueling DQN架构)
293
+ class DQN(nn.Module):
294
+ def __init__(self, input_channels, output_size):
295
+ super(DQN, self).__init__()
296
+ self.input_channels = input_channels
297
+
298
+ # 卷积层
299
+ self.conv1 = nn.Conv2d(input_channels, 128, kernel_size=3, padding=1)
300
+ self.conv2 = nn.Conv2d(128, 128, kernel_size=3, padding=1)
301
+ self.conv3 = nn.Conv2d(128, 128, kernel_size=3, padding=1)
302
+
303
+ # Dueling DQN架构
304
+ # 价值流
305
+ self.value_conv = nn.Conv2d(128, 4, kernel_size=1)
306
+ self.value_fc1 = nn.Linear(4 * 4 * 4, 128)
307
+ self.value_fc2 = nn.Linear(128, 1)
308
+
309
+ # 优势流
310
+ self.advantage_conv = nn.Conv2d(128, 16, kernel_size=1)
311
+ self.advantage_fc1 = nn.Linear(16 * 4 * 4, 128)
312
+ self.advantage_fc2 = nn.Linear(128, output_size)
313
+
314
+ def forward(self, x):
315
+ x = F.relu(self.conv1(x))
316
+ x = F.relu(self.conv2(x))
317
+ x = F.relu(self.conv3(x))
318
+
319
+ # 价值流
320
+ value = F.relu(self.value_conv(x))
321
+ value = value.view(value.size(0), -1)
322
+ value = F.relu(self.value_fc1(value))
323
+ value = self.value_fc2(value)
324
+
325
+ # 优势流
326
+ advantage = F.relu(self.advantage_conv(x))
327
+ advantage = advantage.view(advantage.size(0), -1)
328
+ advantage = F.relu(self.advantage_fc1(advantage))
329
+ advantage = self.advantage_fc2(advantage)
330
+
331
+ # 合并价值流和优势流
332
+ q_values = value + advantage - advantage.mean(dim=1, keepdim=True)
333
+ return q_values
334
+
335
+ # 经验回放缓冲区(带优先级)
336
+ class PrioritizedReplayBuffer:
337
+ def __init__(self, capacity, alpha=0.6):
338
+ self.capacity = capacity
339
+ self.alpha = alpha
340
+ self.buffer = []
341
+ self.priorities = np.zeros(capacity)
342
+ self.pos = 0
343
+ self.size = 0
344
+
345
+ def push(self, state, action, reward, next_state, done):
346
+ # 初始优先级设置为最大优先级
347
+ max_priority = self.priorities.max() if self.buffer else 1.0
348
+
349
+ if len(self.buffer) < self.capacity:
350
+ self.buffer.append((state, action, reward, next_state, done))
351
+ else:
352
+ self.buffer[self.pos] = (state, action, reward, next_state, done)
353
+
354
+ self.priorities[self.pos] = max_priority
355
+ self.pos = (self.pos + 1) % self.capacity
356
+ self.size = min(self.size + 1, self.capacity)
357
+
358
+ def sample(self, batch_size, beta=0.4):
359
+ if self.size == 0:
360
+ return None, None, None
361
+
362
+ priorities = self.priorities[:self.size]
363
+ probs = priorities ** self.alpha
364
+ probs /= probs.sum()
365
+
366
+ indices = np.random.choice(self.size, batch_size, p=probs)
367
+ samples = [self.buffer[idx] for idx in indices]
368
+
369
+ # 计算重要性采样权重
370
+ weights = (self.size * probs[indices]) ** (-beta)
371
+ weights /= weights.max()
372
+ weights = np.array(weights, dtype=np.float32)
373
+
374
+ states, actions, rewards, next_states, dones = zip(*samples)
375
+ return (
376
+ torch.tensor(np.array(states)),
377
+ torch.tensor(actions, dtype=torch.long),
378
+ torch.tensor(rewards, dtype=torch.float),
379
+ torch.tensor(np.array(next_states)),
380
+ torch.tensor(dones, dtype=torch.float),
381
+ indices,
382
+ torch.tensor(weights)
383
+ )
384
+
385
+ def update_priorities(self, indices, priorities):
386
+ # 确保 priorities 是一个数组
387
+ if isinstance(priorities, np.ndarray) and priorities.ndim == 1:
388
+ for idx, priority in zip(indices, priorities):
389
+ self.priorities[idx] = priority
390
+ else:
391
+ # 处理标量情况(虽然不应该发生)
392
+ if not isinstance(priorities, (list, np.ndarray)):
393
+ priorities = [priorities] * len(indices)
394
+ for idx, priority in zip(indices, priorities):
395
+ self.priorities[idx] = priority
396
+
397
+ def __len__(self):
398
+ return self.size
399
+
400
+ # 改进的DQN智能体
401
+ class DQNAgent:
402
+ def __init__(self, input_channels, action_size, lr=3e-4, gamma=0.99,
403
+ epsilon_start=1.0, epsilon_end=0.01, epsilon_decay=0.999,
404
+ target_update_freq=1000, batch_size=128):
405
+ self.input_channels = input_channels
406
+ self.action_size = action_size
407
+ self.gamma = gamma
408
+ self.epsilon = epsilon_start
409
+ self.epsilon_end = epsilon_end
410
+ self.epsilon_decay = epsilon_decay
411
+ self.batch_size = batch_size
412
+ self.target_update_freq = target_update_freq
413
+
414
+ # 主网络和目标网络
415
+ self.policy_net = DQN(input_channels, action_size).to(device)
416
+ self.target_net = DQN(input_channels, action_size).to(device)
417
+ self.target_net.load_state_dict(self.policy_net.state_dict())
418
+ self.target_net.eval()
419
+
420
+ self.optimizer = optim.Adam(self.policy_net.parameters(), lr=lr, weight_decay=1e-5)
421
+ self.memory = PrioritizedReplayBuffer(50000)
422
+ self.steps_done = 0
423
+ self.loss_fn = nn.SmoothL1Loss(reduction='none')
424
+
425
+ def select_action(self, state, valid_moves):
426
+ self.steps_done += 1
427
+ self.epsilon = max(self.epsilon_end, self.epsilon * self.epsilon_decay)
428
+
429
+ if random.random() < self.epsilon:
430
+ # 随机选择有效动作
431
+ return random.choice(valid_moves)
432
+ else:
433
+ # 使用策略网络选择动作
434
+ with torch.no_grad():
435
+ state_tensor = torch.tensor(state, dtype=torch.float).unsqueeze(0).to(device)
436
+ q_values = self.policy_net(state_tensor).cpu().numpy().flatten()
437
+
438
+ # 只考虑有效动作
439
+ valid_q_values = np.full(self.action_size, -np.inf)
440
+ for move in valid_moves:
441
+ valid_q_values[move] = q_values[move]
442
+
443
+ return np.argmax(valid_q_values)
444
+
445
+ def optimize_model(self, beta=0.4):
446
+ if len(self.memory) < self.batch_size:
447
+ return 0
448
+
449
+ # 从回放缓冲区采样
450
+ sample = self.memory.sample(self.batch_size, beta)
451
+ if sample is None:
452
+ return 0
453
+
454
+ states, actions, rewards, next_states, dones, indices, weights = sample
455
+
456
+ states = states.to(device)
457
+ actions = actions.to(device)
458
+ rewards = rewards.to(device)
459
+ next_states = next_states.to(device)
460
+ dones = dones.to(device)
461
+ weights = weights.to(device)
462
+
463
+ # 计算当前Q值
464
+ current_q = self.policy_net(states).gather(1, actions.unsqueeze(1)).squeeze()
465
+
466
+ # 计算目标Q值(Double DQN)
467
+ with torch.no_grad():
468
+ next_actions = self.policy_net(next_states).max(1)[1]
469
+ next_q = self.target_net(next_states).gather(1, next_actions.unsqueeze(1)).squeeze()
470
+ target_q = rewards + (1 - dones) * self.gamma * next_q
471
+
472
+ # 计算损失
473
+ losses = self.loss_fn(current_q, target_q)
474
+ loss = (losses * weights).mean()
475
+
476
+ # 更新优先级(使用每个样本的损失绝对值)
477
+ with torch.no_grad():
478
+ priorities = losses.abs().cpu().numpy() + 1e-5
479
+ self.memory.update_priorities(indices, priorities)
480
+
481
+ # 优化模型
482
+ self.optimizer.zero_grad()
483
+ loss.backward()
484
+
485
+ # 梯度裁剪
486
+ torch.nn.utils.clip_grad_norm_(self.policy_net.parameters(), 10)
487
+
488
+ self.optimizer.step()
489
+
490
+ return loss.item()
491
+
492
+ def update_target_network(self):
493
+ self.target_net.load_state_dict(self.policy_net.state_dict())
494
+
495
+ def save_model(self, path):
496
+ torch.save({
497
+ 'policy_net_state_dict': self.policy_net.state_dict(),
498
+ 'target_net_state_dict': self.target_net.state_dict(),
499
+ 'optimizer_state_dict': self.optimizer.state_dict(),
500
+ 'epsilon': self.epsilon,
501
+ 'steps_done': self.steps_done
502
+ }, path)
503
+
504
+ def load_model(self, path):
505
+ if not os.path.exists(path):
506
+ print(f"Model file not found: {path}")
507
+ return
508
+
509
+ try:
510
+ # 尝试使用 weights_only=False 加载模型
511
+ checkpoint = torch.load(path, map_location=device, weights_only=False)
512
+ self.policy_net.load_state_dict(checkpoint['policy_net_state_dict'])
513
+ self.target_net.load_state_dict(checkpoint['target_net_state_dict'])
514
+ self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
515
+ self.epsilon = checkpoint['epsilon']
516
+ self.steps_done = checkpoint['steps_done']
517
+ self.policy_net.eval()
518
+ self.target_net.eval()
519
+ print(f"Model loaded successfully from {path}")
520
+ except Exception as e:
521
+ print(f"Error loading model: {e}")
522
+ # 尝试使用旧版加载方式作为备选
523
+ try:
524
+ warnings.warn("Trying legacy load method without weights_only")
525
+ checkpoint = torch.load(path, map_location=device)
526
+ self.policy_net.load_state_dict(checkpoint['policy_net_state_dict'])
527
+ self.target_net.load_state_dict(checkpoint['target_net_state_dict'])
528
+ self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
529
+ self.epsilon = checkpoint['epsilon']
530
+ self.steps_done = checkpoint['steps_done']
531
+ self.policy_net.eval()
532
+ self.target_net.eval()
533
+ print(f"Model loaded successfully using legacy method")
534
+ except Exception as e2:
535
+ print(f"Failed to load model: {e2}")
536
+ # 训练函数(带进度记录)
537
+ def train_agent(agent, env, episodes=5000, save_path='models/dqn_2048.pth',
538
+ checkpoint_path='models/checkpoint.pth', resume=False, start_episode=0):
539
+ # 创建保存模型的目录
540
+ os.makedirs(os.path.dirname(save_path), exist_ok=True)
541
+
542
+ # 记录训练指标
543
+ scores = []
544
+ max_tiles = []
545
+ avg_scores = []
546
+ losses = []
547
+ best_score = 0
548
+ best_max_tile = 0
549
+
550
+ # 如果续训,加载训练状态
551
+ if resume and os.path.exists(checkpoint_path):
552
+ try:
553
+ # 使用 weights_only=False 加载检查点
554
+ checkpoint = torch.load(checkpoint_path, map_location=device, weights_only=False)
555
+ scores = checkpoint['scores']
556
+ max_tiles = checkpoint['max_tiles']
557
+ avg_scores = checkpoint['avg_scores']
558
+ losses = checkpoint['losses']
559
+ best_score = checkpoint.get('best_score', 0)
560
+ best_max_tile = checkpoint.get('best_max_tile', 0)
561
+ print(f"Resuming training from episode {start_episode}...")
562
+ except Exception as e:
563
+ print(f"Error loading checkpoint: {e}")
564
+ print("Starting training from scratch...")
565
+ resume = False
566
+
567
+ if not resume:
568
+ start_episode = 0
569
+
570
+ # 使用tqdm显示进度条
571
+ progress_bar = tqdm(range(start_episode, episodes), desc="Training")
572
+
573
+ for episode in progress_bar:
574
+ state = env.reset()
575
+ total_reward = 0
576
+ done = False
577
+ steps = 0
578
+ episode_loss = 0
579
+ loss_count = 0
580
+
581
+ while not done:
582
+ valid_moves = env.get_valid_moves()
583
+ if not valid_moves:
584
+ done = True
585
+ continue
586
+
587
+ action = agent.select_action(state, valid_moves)
588
+ next_state, reward, done = env.move(action)
589
+ total_reward += reward
590
+
591
+ agent.memory.push(state, action, reward, next_state, done)
592
+ state = next_state
593
+
594
+ # 优化模型
595
+ loss = agent.optimize_model(beta=min(1.0, episode / 1000))
596
+ if loss > 0:
597
+ episode_loss += loss
598
+ loss_count += 1
599
+
600
+ # 定期更新目标网络
601
+ if agent.steps_done % agent.target_update_freq == 0:
602
+ agent.update_target_network()
603
+
604
+ steps += 1
605
+
606
+ # 记录分数和最大方块
607
+ score = env.score
608
+ max_tile = np.max(env.board)
609
+ scores.append(score)
610
+ max_tiles.append(max_tile)
611
+
612
+ # 计算平均损失
613
+ avg_loss = episode_loss / loss_count if loss_count > 0 else 0
614
+ losses.append(avg_loss)
615
+
616
+ # 更新最佳记录
617
+ if score > best_score:
618
+ best_score = score
619
+ agent.save_model(save_path.replace('.pth', '_best_score.pth'))
620
+ if max_tile > best_max_tile:
621
+ best_max_tile = max_tile
622
+ agent.save_model(save_path.replace('.pth', '_best_tile.pth'))
623
+
624
+ # 计算最近100轮平均分数
625
+ recent_scores = scores[-100:] if len(scores) >= 100 else scores
626
+ avg_score = np.mean(recent_scores)
627
+ avg_scores.append(avg_score)
628
+
629
+ # 更新进度条描述
630
+ progress_bar.set_description(
631
+ f"Ep {episode+1}/{episodes} | "
632
+ f"Score: {score} (Avg: {avg_score:.1f}) | "
633
+ f"Max Tile: {max_tile} | "
634
+ f"Loss: {avg_loss:.4f} | "
635
+ f"Epsilon: {agent.epsilon:.4f}"
636
+ )
637
+
638
+ # 定期保存模型和训练状态
639
+ if (episode + 1) % 100 == 0:
640
+ agent.save_model(save_path)
641
+
642
+ # 保存训练状态
643
+ checkpoint = {
644
+ 'scores': scores,
645
+ 'max_tiles': max_tiles,
646
+ 'avg_scores': avg_scores,
647
+ 'losses': losses,
648
+ 'best_score': best_score,
649
+ 'best_max_tile': best_max_tile,
650
+ 'episode': episode + 1,
651
+ 'steps_done': agent.steps_done,
652
+ 'epsilon': agent.epsilon
653
+ }
654
+ try:
655
+ torch.save(checkpoint, checkpoint_path)
656
+ except Exception as e:
657
+ print(f"Error saving checkpoint: {e}")
658
+
659
+ # 绘制训练曲线
660
+ if episode > 100: # 确保有足够的数据
661
+ plt.figure(figsize=(12, 8))
662
+
663
+ # 分数曲线
664
+ plt.subplot(2, 2, 1)
665
+ plt.plot(scores, label='Score')
666
+ plt.plot(avg_scores, label='Avg Score (100 eps)')
667
+ plt.xlabel('Episode')
668
+ plt.ylabel('Score')
669
+ plt.title('Training Scores')
670
+ plt.legend()
671
+
672
+ # 最大方块曲线
673
+ plt.subplot(2, 2, 2)
674
+ plt.plot(max_tiles, 'g-')
675
+ plt.xlabel('Episode')
676
+ plt.ylabel('Max Tile')
677
+ plt.title('Max Tile Achieved')
678
+
679
+ # 损失曲线
680
+ plt.subplot(2, 2, 3)
681
+ plt.plot(losses, 'r-')
682
+ plt.xlabel('Episode')
683
+ plt.ylabel('Loss')
684
+ plt.title('Training Loss')
685
+
686
+ # 分数分布直方图
687
+ plt.subplot(2, 2, 4)
688
+ plt.hist(scores, bins=20, alpha=0.7)
689
+ plt.xlabel('Score')
690
+ plt.ylabel('Frequency')
691
+ plt.title('Score Distribution')
692
+
693
+ plt.tight_layout()
694
+ plt.savefig('training_progress.png')
695
+ plt.close()
696
+
697
+ # 保存最终模型
698
+ agent.save_model(save_path)
699
+
700
+ return scores, max_tiles, losses
701
+ # 推理函数(带可视化)
702
+ def play_with_model(agent, env, episodes=3):
703
+ agent.epsilon = 0.001 # 设置很小的epsilon值进行推理
704
+
705
+ for episode in range(episodes):
706
+ state = env.reset()
707
+ done = False
708
+ steps = 0
709
+
710
+ print(f"\nEpisode {episode+1}")
711
+ print("Initial Board:")
712
+ print(env.board)
713
+
714
+ while not done:
715
+ valid_moves = env.get_valid_moves()
716
+ if not valid_moves:
717
+ done = True
718
+ print("No valid moves left!")
719
+ continue
720
+
721
+ # 选择动作
722
+ with torch.no_grad():
723
+ state_tensor = torch.tensor(state, dtype=torch.float).unsqueeze(0).to(device)
724
+ q_values = agent.policy_net(state_tensor).cpu().numpy().flatten()
725
+
726
+ # 只考虑有效动作
727
+ valid_q_values = np.full(env.size, -np.inf)
728
+ for move in valid_moves:
729
+ valid_q_values[move] = q_values[move]
730
+
731
+ action = np.argmax(valid_q_values)
732
+
733
+ # 执行动作
734
+ next_state, reward, done = env.move(action)
735
+ state = next_state
736
+ steps += 1
737
+
738
+ # 渲染游戏
739
+ print(f"\nStep {steps}: Action {['Up', 'Right', 'Down', 'Left'][action]}")
740
+ print(env.board)
741
+ print(f"Score: {env.score}, Max Tile: {np.max(env.board)}")
742
+ #同时将结果保存至result.txt文件中
743
+ with open("result.txt", "a") as f:
744
+ f.write(f"Episode {episode+1}, Step {steps}, Action {['Up', 'Right', 'Down', 'Left'][action]}, Score: {env.score}, Max Tile: {np.max(env.board)}\n{env.board}\n")
745
+ f.close()
746
+
747
+
748
+ print(f"\nGame Over! Final Score: {env.score}, Max Tile: {np.max(env.board)}")
749
+
750
+ # 主程序
751
+ if __name__ == "__main__":
752
+ args = {"train":0, "resume":0, "play":1, "episodes":50000}
753
+ env = Game2048(size=4)
754
+ input_channels = 4 # 状态表示的通道数
755
+ action_size = 4 # 上、右、下、左
756
+
757
+ agent = DQNAgent(
758
+ input_channels,
759
+ action_size,
760
+ lr=1e-4,
761
+ epsilon_decay=0.999, # 更慢的衰减
762
+ target_update_freq=1000,
763
+ batch_size=256
764
+ )
765
+
766
+ # 训练模型
767
+ if args.get('train') or args.get('resume'):
768
+ print("Starting training...")
769
+
770
+ # 如果续训,加载检查点
771
+ start_episode = 0
772
+ checkpoint_path = 'models/checkpoint.pth'
773
+ if args.get('resume') and os.path.exists(checkpoint_path):
774
+ try:
775
+ # 使用 weights_only=False 加载检查点
776
+ checkpoint = torch.load(checkpoint_path, map_location=device, weights_only=False)
777
+ start_episode = checkpoint.get('episode', 0)
778
+ agent.steps_done = checkpoint.get('steps_done', 0)
779
+ agent.epsilon = checkpoint.get('epsilon', agent.epsilon)
780
+ except Exception as e:
781
+ print(f"Error loading checkpoint: {e}")
782
+ print("Starting training from scratch...")
783
+ start_episode = 0
784
+
785
+ scores, max_tiles, losses = train_agent(
786
+ agent,
787
+ env,
788
+ episodes=args.get('episodes'),
789
+ save_path='models/dqn_2048.pth',
790
+ checkpoint_path=checkpoint_path,
791
+ resume=args.get('resume'),
792
+ start_episode=start_episode
793
+ )
794
+ print("Training completed!")
795
+
796
+ # 绘制最终训练结果
797
+ plt.figure(figsize=(15, 10))
798
+
799
+ plt.subplot(3, 1, 1)
800
+ plt.plot(scores)
801
+ plt.title('Scores per Episode')
802
+ plt.xlabel('Episode')
803
+ plt.ylabel('Score')
804
+
805
+ plt.subplot(3, 1, 2)
806
+ plt.plot(max_tiles)
807
+ plt.title('Max Tile per Episode')
808
+ plt.xlabel('Episode')
809
+ plt.ylabel('Max Tile')
810
+
811
+ plt.subplot(3, 1, 3)
812
+ plt.plot(losses)
813
+ plt.title('Training Loss per Episode')
814
+ plt.xlabel('Episode')
815
+ plt.ylabel('Loss')
816
+
817
+ plt.tight_layout()
818
+ plt.savefig('final_training_results.png')
819
+ plt.close()
820
+
821
+ # 加载模型并推理
822
+ if args.get('play'):
823
+ model_path = 'models/dqn_2048_best_tile.pth'
824
+ if not os.path.exists(model_path):
825
+ model_path = 'models/dqn_2048.pth'
826
+
827
+ if os.path.exists(model_path):
828
+ agent.load_model(model_path)
829
+ print("Playing with trained model...")
830
+ if not os.path.exists("result.txt"):
831
+ play_with_model(agent, env, episodes=1)
832
+ else:
833
+ os.remove("result.txt") #删除之前记录
834
+ play_with_model(agent, env, episodes=1)
835
+
836
+ else:
837
+ print("No trained model found. Please train the model first.")