Spaces:
Running
Running
Upload 3 files
Browse files- 2048的网页实现.py +293 -0
- game2048.py +200 -0
- main.py +837 -0
2048的网页实现.py
ADDED
@@ -0,0 +1,293 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
import numpy as np
|
3 |
+
import random
|
4 |
+
import torch
|
5 |
+
import torch.nn as nn
|
6 |
+
import torch.nn.functional as F
|
7 |
+
from game2048 import Game2048
|
8 |
+
|
9 |
+
# 创建游戏实例
|
10 |
+
game = Game2048(size=4)
|
11 |
+
|
12 |
+
# 方块颜色映射(根据数字值)
|
13 |
+
TILE_COLORS = {
|
14 |
+
0: "#cdc1b4", # 空白格子
|
15 |
+
2: "#eee4da", # 2
|
16 |
+
4: "#ede0c8", # 4
|
17 |
+
8: "#f2b179", # 8
|
18 |
+
16: "#f59563", # 16
|
19 |
+
32: "#f67c5f", # 32
|
20 |
+
64: "#f65e3b", # 64
|
21 |
+
128: "#edcf72", # 128
|
22 |
+
256: "#edcc61", # 256
|
23 |
+
512: "#edc850", # 512
|
24 |
+
1024: "#edc53f", # 1024
|
25 |
+
2048: "#edc22e", # 2048
|
26 |
+
4096: "#3c3a32", # 4096+
|
27 |
+
}
|
28 |
+
|
29 |
+
# 文本颜色映射(根据背景深浅)
|
30 |
+
TEXT_COLORS = {
|
31 |
+
0: "#776e65", # 空白格子
|
32 |
+
2: "#776e65", # 2
|
33 |
+
4: "#776e65", # 4
|
34 |
+
8: "#f9f6f2", # 8+
|
35 |
+
16: "#f9f6f2", # 16+
|
36 |
+
32: "#f9f6f2", # 32+
|
37 |
+
64: "#f9f6f2", # 64+
|
38 |
+
128: "#f9f6f2", # 128+
|
39 |
+
256: "#f9f6f2", # 256+
|
40 |
+
512: "#f9f6f2", # 512+
|
41 |
+
1024: "#f9f6f2", # 1024+
|
42 |
+
2048: "#f9f6f2", # 2048+
|
43 |
+
4096: "#f9f6f2", # 4096+
|
44 |
+
}
|
45 |
+
|
46 |
+
# 定义DQN网络结构(与训练时相同)
|
47 |
+
class DQN(nn.Module):
|
48 |
+
def __init__(self, input_channels, output_size):
|
49 |
+
super(DQN, self).__init__()
|
50 |
+
self.input_channels = input_channels
|
51 |
+
|
52 |
+
# 卷积层
|
53 |
+
self.conv1 = nn.Conv2d(input_channels, 128, kernel_size=3, padding=1)
|
54 |
+
self.conv2 = nn.Conv2d(128, 128, kernel_size=3, padding=1)
|
55 |
+
self.conv3 = nn.Conv2d(128, 128, kernel_size=3, padding=1)
|
56 |
+
|
57 |
+
# Dueling DQN架构
|
58 |
+
# 价值流
|
59 |
+
self.value_conv = nn.Conv2d(128, 4, kernel_size=1)
|
60 |
+
self.value_fc1 = nn.Linear(4 * 4 * 4, 128)
|
61 |
+
self.value_fc2 = nn.Linear(128, 1)
|
62 |
+
|
63 |
+
# 优势流
|
64 |
+
self.advantage_conv = nn.Conv2d(128, 16, kernel_size=1)
|
65 |
+
self.advantage_fc1 = nn.Linear(16 * 4 * 4, 128)
|
66 |
+
self.advantage_fc2 = nn.Linear(128, output_size)
|
67 |
+
|
68 |
+
def forward(self, x):
|
69 |
+
x = F.relu(self.conv1(x))
|
70 |
+
x = F.relu(self.conv2(x))
|
71 |
+
x = F.relu(self.conv3(x))
|
72 |
+
|
73 |
+
# 价值流
|
74 |
+
value = F.relu(self.value_conv(x))
|
75 |
+
value = value.view(value.size(0), -1)
|
76 |
+
value = F.relu(self.value_fc1(value))
|
77 |
+
value = self.value_fc2(value)
|
78 |
+
|
79 |
+
# 优势流
|
80 |
+
advantage = F.relu(self.advantage_conv(x))
|
81 |
+
advantage = advantage.view(advantage.size(0), -1)
|
82 |
+
advantage = F.relu(self.advantage_fc1(advantage))
|
83 |
+
advantage = self.advantage_fc2(advantage)
|
84 |
+
|
85 |
+
# 合并价值流和优势流
|
86 |
+
q_values = value + advantage - advantage.mean(dim=1, keepdim=True)
|
87 |
+
return q_values
|
88 |
+
|
89 |
+
# 加载模型
|
90 |
+
def load_model(model_path):
|
91 |
+
model = DQN(4, 4) # 输入通道4,输出动作4个
|
92 |
+
try:
|
93 |
+
checkpoint = torch.load(model_path, map_location=torch.device('cpu'))
|
94 |
+
model.load_state_dict(checkpoint['policy_net_state_dict'])
|
95 |
+
model.eval()
|
96 |
+
print("模型加载成功")
|
97 |
+
return model
|
98 |
+
except Exception as e:
|
99 |
+
print(f"模型加载失败: {e}")
|
100 |
+
return None
|
101 |
+
|
102 |
+
# 尝试加载模型
|
103 |
+
model_path = "models/dqn_2048_best_tile.pth"
|
104 |
+
model = load_model(model_path)
|
105 |
+
|
106 |
+
def render_board(board):
|
107 |
+
html = "<div style='background-color:#bbada0; padding:10px; border-radius:6px;'>"
|
108 |
+
html += "<table style='border-spacing:10px; border-collapse:separate;'>"
|
109 |
+
|
110 |
+
for i in range(game.size):
|
111 |
+
html += "<tr>"
|
112 |
+
for j in range(game.size):
|
113 |
+
value = board[i][j]
|
114 |
+
color = TILE_COLORS.get(value, "#3c3a32") # 默认深色
|
115 |
+
text_color = TEXT_COLORS.get(value, "#f9f6f2") # 默认浅色
|
116 |
+
font_size = "36px" if value < 100 else "30px" if value < 1000 else "24px"
|
117 |
+
|
118 |
+
html += f"""
|
119 |
+
<td style='background-color:{color};
|
120 |
+
width:80px; height:80px;
|
121 |
+
border-radius:4px;
|
122 |
+
text-align:center;
|
123 |
+
font-weight:bold;
|
124 |
+
font-size:{font_size};
|
125 |
+
color:{text_color};'>
|
126 |
+
{value if value > 0 else ''}
|
127 |
+
</td>
|
128 |
+
"""
|
129 |
+
html += "</tr>"
|
130 |
+
|
131 |
+
html += "</table></div>"
|
132 |
+
return html
|
133 |
+
|
134 |
+
def make_move(direction):
|
135 |
+
"""执行移动操作并更新界面"""
|
136 |
+
direction_names = ["上", "右", "下", "左"]
|
137 |
+
|
138 |
+
# 执行移动
|
139 |
+
new_board, game_over = game.move(direction)
|
140 |
+
|
141 |
+
# 渲染棋盘
|
142 |
+
board_html = render_board(new_board)
|
143 |
+
|
144 |
+
# 更新状态信息
|
145 |
+
status = f"<b>移动方向:</b> {direction_names[direction]}"
|
146 |
+
status += f"<br><b>当前分数:</b> {game.score}"
|
147 |
+
status += f"<br><b>最大方块:</b> {np.max(game.board)}"
|
148 |
+
|
149 |
+
if game.game_over:
|
150 |
+
status += "<br><br><div style='color:#ff0000; font-weight:bold;'>游戏结束!</div>"
|
151 |
+
status += f"<br><b>最终分数:</b> {game.score}"
|
152 |
+
|
153 |
+
return board_html, status
|
154 |
+
|
155 |
+
def reset_game():
|
156 |
+
"""重置游戏"""
|
157 |
+
global game
|
158 |
+
game = Game2048(size=4)
|
159 |
+
board = game.reset()
|
160 |
+
|
161 |
+
# 渲染棋盘
|
162 |
+
board_html = render_board(board)
|
163 |
+
|
164 |
+
# 初始状态信息
|
165 |
+
status = "<b>游戏已重置!</b>"
|
166 |
+
status += f"<br><b>当前分数:</b> {game.score}"
|
167 |
+
status += f"<br><b>最大方块:</b> {np.max(game.board)}"
|
168 |
+
|
169 |
+
return board_html, status
|
170 |
+
|
171 |
+
def ai_move():
|
172 |
+
"""使用AI模型进行一步移动"""
|
173 |
+
if model is None:
|
174 |
+
return render_board(game.board), "<b>错误:</b> 未加载AI模型"
|
175 |
+
|
176 |
+
# 获取当前状态
|
177 |
+
state = game.get_state()
|
178 |
+
|
179 |
+
# 获取有效移动
|
180 |
+
valid_moves = game.get_valid_moves()
|
181 |
+
if not valid_moves:
|
182 |
+
return render_board(game.board), "<b>游戏结束!</b> 没有有效移动"
|
183 |
+
|
184 |
+
# 转换状态为模型输入
|
185 |
+
state_tensor = torch.tensor(state, dtype=torch.float).unsqueeze(0)
|
186 |
+
|
187 |
+
# 模型预测
|
188 |
+
with torch.no_grad():
|
189 |
+
q_values = model(state_tensor).numpy().flatten()
|
190 |
+
|
191 |
+
# 只考虑有效动作
|
192 |
+
valid_q_values = np.full(4, -np.inf)
|
193 |
+
for move in valid_moves:
|
194 |
+
valid_q_values[move] = q_values[move]
|
195 |
+
|
196 |
+
# 选择最佳动作
|
197 |
+
action = np.argmax(valid_q_values)
|
198 |
+
|
199 |
+
# 执行移动
|
200 |
+
direction_names = ["上", "右", "下", "左"]
|
201 |
+
new_board, game_over = game.move(action)
|
202 |
+
|
203 |
+
# 渲染棋盘
|
204 |
+
board_html = render_board(new_board)
|
205 |
+
|
206 |
+
# 更新状态信息
|
207 |
+
status = f"<b>AI移动方向:</b> {direction_names[action]}"
|
208 |
+
status += f"<br><b>当前分数:</b> {game.score}"
|
209 |
+
status += f"<br><b>最大方块:</b> {np.max(game.board)}"
|
210 |
+
|
211 |
+
if game.game_over:
|
212 |
+
status += "<br><br><div style='color:#ff0000; font-weight:bold;'>游戏结束!</div>"
|
213 |
+
status += f"<br><b>最终分数:</b> {game.score}"
|
214 |
+
|
215 |
+
return board_html, status
|
216 |
+
|
217 |
+
# 创建Gradio界面
|
218 |
+
with gr.Blocks(title="2048游戏", theme="soft") as demo:
|
219 |
+
gr.Markdown("# 🎮 2048游戏")
|
220 |
+
gr.Markdown("使用方向键或下方的按钮移动方块,相同数字的方块相撞时会合并!")
|
221 |
+
with gr.Row():
|
222 |
+
with gr.Column(scale=2):
|
223 |
+
board_html = gr.HTML(render_board(game.board))
|
224 |
+
with gr.Row(visible=False):
|
225 |
+
status_display = gr.HTML("<b>当前分数:</b> 0<br><b>最大方块:</b> 2")
|
226 |
+
with gr.Column():
|
227 |
+
gr.Markdown("## 手动操作")
|
228 |
+
with gr.Row():
|
229 |
+
gr.Button("上 ↑", elem_id="up-btn").click(
|
230 |
+
fn=lambda: make_move(0),
|
231 |
+
outputs=[board_html, status_display]
|
232 |
+
)
|
233 |
+
gr.Button("左 ←", elem_id="left-btn").click(
|
234 |
+
fn=lambda: make_move(3),
|
235 |
+
outputs=[board_html, status_display]
|
236 |
+
)
|
237 |
+
with gr.Row():
|
238 |
+
gr.Button("下 ↓", elem_id="down-btn").click(
|
239 |
+
fn=lambda: make_move(2),
|
240 |
+
outputs=[board_html, status_display]
|
241 |
+
)
|
242 |
+
gr.Button("右 →", elem_id="right-btn").click(
|
243 |
+
fn=lambda: make_move(1),
|
244 |
+
outputs=[board_html, status_display]
|
245 |
+
)
|
246 |
+
with gr.Row():
|
247 |
+
gr.Button("🔄 重置游戏", elem_id="reset-btn").click(
|
248 |
+
fn=reset_game,
|
249 |
+
outputs=[board_html, status_display]
|
250 |
+
)
|
251 |
+
with gr.Row():
|
252 |
+
status_display = gr.HTML("<b>当前分数:</b> 0<br><b>最大方块:</b> 2")
|
253 |
+
with gr.Column():
|
254 |
+
gr.Markdown("## AI操作")
|
255 |
+
gr.Button("🤖 AI移动一步", elem_id="ai-btn").click(
|
256 |
+
fn=ai_move,
|
257 |
+
outputs=[board_html, status_display]
|
258 |
+
)
|
259 |
+
gr.Markdown("基于DQN神经网络提供支持")
|
260 |
+
|
261 |
+
# 添加键盘快捷键支持
|
262 |
+
demo.load(
|
263 |
+
fn=None,
|
264 |
+
inputs=None,
|
265 |
+
outputs=None,
|
266 |
+
js="""() => {
|
267 |
+
document.addEventListener('keydown', function(e) {
|
268 |
+
if (e.key === 'ArrowUp') {
|
269 |
+
document.getElementById('up-btn').click();
|
270 |
+
} else if (e.key === 'ArrowRight') {
|
271 |
+
document.getElementById('right-btn').click();
|
272 |
+
} else if (e.key === 'ArrowDown') {
|
273 |
+
document.getElementById('down-btn').click();
|
274 |
+
} else if (e.key === 'ArrowLeft') {
|
275 |
+
document.getElementById('left-btn').click();
|
276 |
+
} else if (e.key === 'r' || e.key === 'R') {
|
277 |
+
document.getElementById('reset-btn').click();
|
278 |
+
} else if (e.key === 'a' || e.key === 'A') {
|
279 |
+
document.getElementById('ai-btn').click();
|
280 |
+
}
|
281 |
+
});
|
282 |
+
}"""
|
283 |
+
)
|
284 |
+
gr.Markdown("### 📚 使用说明")
|
285 |
+
gr.Markdown("1. 使用方向键或下方的按钮移动方块。")
|
286 |
+
gr.Markdown("2. 相同数字的方块相撞时会合并。")
|
287 |
+
gr.Markdown("3. 快捷键说明:上/下/左/右键移动方块,R键重置游戏,A键AI移动一步。")
|
288 |
+
gr.Markdown("4. 点击 '🤖 AI移动一步' 按钮可以使用AI模型进行一步移动。")
|
289 |
+
gr.Markdown("5. 游戏结束后,会显示最终分数和最大方块。")
|
290 |
+
|
291 |
+
# 启动界面
|
292 |
+
if __name__ == "__main__":
|
293 |
+
demo.launch(share=True)
|
game2048.py
ADDED
@@ -0,0 +1,200 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import random
|
3 |
+
|
4 |
+
class Game2048:
|
5 |
+
def __init__(self, size=4):
|
6 |
+
self.size = size
|
7 |
+
self.reset()
|
8 |
+
|
9 |
+
def reset(self):
|
10 |
+
"""重置游戏状态"""
|
11 |
+
self.board = np.zeros((self.size, self.size), dtype=np.int32)
|
12 |
+
self.score = 0
|
13 |
+
self.add_tile()
|
14 |
+
self.add_tile()
|
15 |
+
self.game_over = False
|
16 |
+
return self.board.copy()
|
17 |
+
|
18 |
+
def add_tile(self):
|
19 |
+
"""在随机空位置添加新方块(90%概率为2,10%概率为4)"""
|
20 |
+
empty_cells = []
|
21 |
+
for i in range(self.size):
|
22 |
+
for j in range(self.size):
|
23 |
+
if self.board[i][j] == 0:
|
24 |
+
empty_cells.append((i, j))
|
25 |
+
|
26 |
+
if empty_cells:
|
27 |
+
i, j = random.choice(empty_cells)
|
28 |
+
self.board[i][j] = 2 if random.random() < 0.9 else 4
|
29 |
+
|
30 |
+
def move(self, direction):
|
31 |
+
"""
|
32 |
+
执行移动操作
|
33 |
+
0: 上, 1: 右, 2: 下, 3: 左
|
34 |
+
返回: (新棋盘状态, 游戏是否结束)
|
35 |
+
"""
|
36 |
+
moved = False
|
37 |
+
# 根据方向执行移动
|
38 |
+
if direction == 0: # 上
|
39 |
+
for j in range(self.size):
|
40 |
+
column = self.board[:, j].copy()
|
41 |
+
new_column, moved_col = self.slide(column)
|
42 |
+
if moved_col:
|
43 |
+
moved = True
|
44 |
+
self.board[:, j] = new_column
|
45 |
+
|
46 |
+
elif direction == 1: # 右
|
47 |
+
for i in range(self.size):
|
48 |
+
row = self.board[i, :].copy()[::-1]
|
49 |
+
new_row, moved_row = self.slide(row)
|
50 |
+
if moved_row:
|
51 |
+
moved = True
|
52 |
+
self.board[i, :] = new_row[::-1]
|
53 |
+
|
54 |
+
elif direction == 2: # 下
|
55 |
+
for j in range(self.size):
|
56 |
+
column = self.board[::-1, j].copy()
|
57 |
+
new_column, moved_col = self.slide(column)
|
58 |
+
if moved_col:
|
59 |
+
moved = True
|
60 |
+
self.board[:, j] = new_column[::-1]
|
61 |
+
|
62 |
+
elif direction == 3: # 左
|
63 |
+
for i in range(self.size):
|
64 |
+
row = self.board[i, :].copy()
|
65 |
+
new_row, moved_row = self.slide(row)
|
66 |
+
if moved_row:
|
67 |
+
moved = True
|
68 |
+
self.board[i, :] = new_row
|
69 |
+
|
70 |
+
# 如果发生了移动,添加新方块并检查游戏结束
|
71 |
+
if moved:
|
72 |
+
self.add_tile()
|
73 |
+
self.check_game_over()
|
74 |
+
|
75 |
+
return self.board.copy(), self.game_over
|
76 |
+
|
77 |
+
def slide(self, line):
|
78 |
+
"""处理单行/列的移动和合并逻辑"""
|
79 |
+
non_zero = line[line != 0]
|
80 |
+
new_line = np.zeros_like(line)
|
81 |
+
idx = 0
|
82 |
+
score_inc = 0
|
83 |
+
moved = False
|
84 |
+
|
85 |
+
# 检查是否移动
|
86 |
+
if not np.array_equal(non_zero, line[:len(non_zero)]):
|
87 |
+
moved = True
|
88 |
+
|
89 |
+
# 合并相同数字
|
90 |
+
i = 0
|
91 |
+
while i < len(non_zero):
|
92 |
+
if i + 1 < len(non_zero) and non_zero[i] == non_zero[i+1]:
|
93 |
+
new_val = non_zero[i] * 2
|
94 |
+
new_line[idx] = new_val
|
95 |
+
score_inc += new_val
|
96 |
+
i += 2
|
97 |
+
idx += 1
|
98 |
+
else:
|
99 |
+
new_line[idx] = non_zero[i]
|
100 |
+
i += 1
|
101 |
+
idx += 1
|
102 |
+
|
103 |
+
self.score += score_inc
|
104 |
+
return new_line, moved or (score_inc > 0)
|
105 |
+
|
106 |
+
def check_game_over(self):
|
107 |
+
"""检查游戏是否结束"""
|
108 |
+
# 检查是否还有空格子
|
109 |
+
if np.any(self.board == 0):
|
110 |
+
self.game_over = False
|
111 |
+
return
|
112 |
+
|
113 |
+
# 检查水平和垂直方向是否有可合并的方块
|
114 |
+
for i in range(self.size):
|
115 |
+
for j in range(self.size - 1):
|
116 |
+
if self.board[i][j] == self.board[i][j+1]:
|
117 |
+
self.game_over = False
|
118 |
+
return
|
119 |
+
|
120 |
+
for j in range(self.size):
|
121 |
+
for i in range(self.size - 1):
|
122 |
+
if self.board[i][j] == self.board[i+1][j]:
|
123 |
+
self.game_over = False
|
124 |
+
return
|
125 |
+
|
126 |
+
self.game_over = True
|
127 |
+
|
128 |
+
def get_valid_moves(self):
|
129 |
+
"""获取当前所有有效移动方向"""
|
130 |
+
valid_moves = []
|
131 |
+
|
132 |
+
# 检查上移是否有效
|
133 |
+
for j in range(self.size):
|
134 |
+
column = self.board[:, j].copy()
|
135 |
+
new_column, _ = self.slide(column)
|
136 |
+
if not np.array_equal(new_column, self.board[:, j]):
|
137 |
+
valid_moves.append(0)
|
138 |
+
break
|
139 |
+
|
140 |
+
# 检查右移是否有效
|
141 |
+
for i in range(self.size):
|
142 |
+
row = self.board[i, :].copy()[::-1]
|
143 |
+
new_row, _ = self.slide(row)
|
144 |
+
if not np.array_equal(new_row[::-1], self.board[i, :]):
|
145 |
+
valid_moves.append(1)
|
146 |
+
break
|
147 |
+
|
148 |
+
# 检查下移是否有效
|
149 |
+
for j in range(self.size):
|
150 |
+
column = self.board[::-1, j].copy()
|
151 |
+
new_column, _ = self.slide(column)
|
152 |
+
if not np.array_equal(new_column[::-1], self.board[:, j]):
|
153 |
+
valid_moves.append(2)
|
154 |
+
break
|
155 |
+
|
156 |
+
# 检查左移是否有效
|
157 |
+
for i in range(self.size):
|
158 |
+
row = self.board[i, :].copy()
|
159 |
+
new_row, _ = self.slide(row)
|
160 |
+
if not np.array_equal(new_row, self.board[i, :]):
|
161 |
+
valid_moves.append(3)
|
162 |
+
break
|
163 |
+
|
164 |
+
return valid_moves
|
165 |
+
|
166 |
+
def get_state(self):
|
167 |
+
"""获取当前游戏状态表示(用于AI模型)"""
|
168 |
+
# 创建4个通道的状态表示
|
169 |
+
state = np.zeros((4, self.size, self.size), dtype=np.float32)
|
170 |
+
|
171 |
+
# 通道0: 当前方块值的对数(归一化)
|
172 |
+
for i in range(self.size):
|
173 |
+
for j in range(self.size):
|
174 |
+
if self.board[i][j] > 0:
|
175 |
+
state[0, i, j] = np.log2(self.board[i][j]) / 16.0 # 支持到65536 (2^16)
|
176 |
+
|
177 |
+
# 通道1: 空格子指示器
|
178 |
+
state[1] = (self.board == 0).astype(np.float32)
|
179 |
+
|
180 |
+
# 通道2: 可合并的邻居指示器
|
181 |
+
for i in range(self.size):
|
182 |
+
for j in range(self.size):
|
183 |
+
if self.board[i][j] > 0:
|
184 |
+
# 检查右侧
|
185 |
+
if j < self.size - 1 and self.board[i][j] == self.board[i][j+1]:
|
186 |
+
state[2, i, j] = 1.0
|
187 |
+
state[2, i, j+1] = 1.0
|
188 |
+
# 检查下方
|
189 |
+
if i < self.size - 1 and self.board[i][j] == self.board[i+1][j]:
|
190 |
+
state[2, i, j] = 1.0
|
191 |
+
state[2, i+1, j] = 1.0
|
192 |
+
|
193 |
+
# 通道3: 最大值位置(归一化)
|
194 |
+
max_value = np.max(self.board)
|
195 |
+
if max_value > 0:
|
196 |
+
max_positions = np.argwhere(self.board == max_value)
|
197 |
+
for pos in max_positions:
|
198 |
+
state[3, pos[0], pos[1]] = 1.0
|
199 |
+
|
200 |
+
return state
|
main.py
ADDED
@@ -0,0 +1,837 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import torch
|
3 |
+
import torch.nn as nn
|
4 |
+
import torch.optim as optim
|
5 |
+
import torch.nn.functional as F
|
6 |
+
import random
|
7 |
+
import os
|
8 |
+
from tqdm import tqdm
|
9 |
+
import matplotlib.pyplot as plt
|
10 |
+
import warnings
|
11 |
+
|
12 |
+
if torch.cuda.is_available():
|
13 |
+
device=torch.device("cuda")
|
14 |
+
elif torch.xpu.is_available():
|
15 |
+
device=torch.device("xpu")
|
16 |
+
else:
|
17 |
+
device=torch.device("cpu")
|
18 |
+
print(f"Using device: {device}")
|
19 |
+
|
20 |
+
# 2048游戏环境(改进版)
|
21 |
+
class Game2048:
|
22 |
+
def __init__(self, size=4):
|
23 |
+
self.size = size
|
24 |
+
self.reset()
|
25 |
+
|
26 |
+
def reset(self):
|
27 |
+
self.board = np.zeros((self.size, self.size), dtype=np.int32)
|
28 |
+
self.score = 0
|
29 |
+
self.prev_score = 0
|
30 |
+
self.add_tile()
|
31 |
+
self.add_tile()
|
32 |
+
self.game_over = False
|
33 |
+
return self.get_state()
|
34 |
+
|
35 |
+
def add_tile(self):
|
36 |
+
empty_cells = []
|
37 |
+
for i in range(self.size):
|
38 |
+
for j in range(self.size):
|
39 |
+
if self.board[i][j] == 0:
|
40 |
+
empty_cells.append((i, j))
|
41 |
+
|
42 |
+
if empty_cells:
|
43 |
+
i, j = random.choice(empty_cells)
|
44 |
+
self.board[i][j] = 2 if random.random() < 0.9 else 4
|
45 |
+
|
46 |
+
def move(self, direction):
|
47 |
+
# 0: 上, 1: 右, 2: 下, 3: 左
|
48 |
+
moved = False
|
49 |
+
original_board = self.board.copy()
|
50 |
+
old_score = self.score
|
51 |
+
|
52 |
+
# 根据方向执行移动
|
53 |
+
if direction == 0: # 上
|
54 |
+
for j in range(self.size):
|
55 |
+
column = self.board[:, j].copy()
|
56 |
+
new_column, moved_col = self.slide(column)
|
57 |
+
if moved_col:
|
58 |
+
moved = True
|
59 |
+
self.board[:, j] = new_column
|
60 |
+
|
61 |
+
elif direction == 1: # 右
|
62 |
+
for i in range(self.size):
|
63 |
+
row = self.board[i, :].copy()[::-1]
|
64 |
+
new_row, moved_row = self.slide(row)
|
65 |
+
if moved_row:
|
66 |
+
moved = True
|
67 |
+
self.board[i, :] = new_row[::-1]
|
68 |
+
|
69 |
+
elif direction == 2: # 下
|
70 |
+
for j in range(self.size):
|
71 |
+
column = self.board[::-1, j].copy()
|
72 |
+
new_column, moved_col = self.slide(column)
|
73 |
+
if moved_col:
|
74 |
+
moved = True
|
75 |
+
self.board[:, j] = new_column[::-1]
|
76 |
+
|
77 |
+
elif direction == 3: # 左
|
78 |
+
for i in range(self.size):
|
79 |
+
row = self.board[i, :].copy()
|
80 |
+
new_row, moved_row = self.slide(row)
|
81 |
+
if moved_row:
|
82 |
+
moved = True
|
83 |
+
self.board[i, :] = new_row
|
84 |
+
|
85 |
+
# 如果发生了移动,添加新方块
|
86 |
+
if moved:
|
87 |
+
self.add_tile()
|
88 |
+
self.check_game_over()
|
89 |
+
|
90 |
+
reward = self.calculate_reward(old_score, original_board)
|
91 |
+
return self.get_state(), reward, self.game_over
|
92 |
+
|
93 |
+
def slide(self, line):
|
94 |
+
# 移除零并合并相同数字
|
95 |
+
non_zero = line[line != 0]
|
96 |
+
new_line = np.zeros_like(line)
|
97 |
+
idx = 0
|
98 |
+
score_inc = 0
|
99 |
+
moved = False
|
100 |
+
|
101 |
+
# 检查是否移动
|
102 |
+
if not np.array_equal(non_zero, line[:len(non_zero)]):
|
103 |
+
moved = True
|
104 |
+
|
105 |
+
# 合并相同数字
|
106 |
+
i = 0
|
107 |
+
while i < len(non_zero):
|
108 |
+
if i + 1 < len(non_zero) and non_zero[i] == non_zero[i+1]:
|
109 |
+
new_val = non_zero[i] * 2
|
110 |
+
new_line[idx] = new_val
|
111 |
+
score_inc += new_val
|
112 |
+
i += 2
|
113 |
+
idx += 1
|
114 |
+
else:
|
115 |
+
new_line[idx] = non_zero[i]
|
116 |
+
i += 1
|
117 |
+
idx += 1
|
118 |
+
|
119 |
+
self.score += score_inc
|
120 |
+
return new_line, moved or (score_inc > 0)
|
121 |
+
|
122 |
+
def calculate_reward(self, old_score, original_board):
|
123 |
+
"""改进的奖励函数"""
|
124 |
+
# 1. 基本分数奖励
|
125 |
+
score_reward = (self.score - old_score) * 0.1
|
126 |
+
|
127 |
+
# 2. 空格子数量变化奖励
|
128 |
+
empty_before = np.count_nonzero(original_board == 0)
|
129 |
+
empty_after = np.count_nonzero(self.board == 0)
|
130 |
+
empty_reward = (empty_after - empty_before) * 0.15
|
131 |
+
|
132 |
+
# 3. 最大方块奖励
|
133 |
+
max_before = np.max(original_board)
|
134 |
+
max_after = np.max(self.board)
|
135 |
+
max_tile_reward = 0
|
136 |
+
if max_after > max_before:
|
137 |
+
max_tile_reward = np.log2(max_after) * 0.2
|
138 |
+
|
139 |
+
# 4. 合并奖励(鼓励合并)
|
140 |
+
merge_reward = 0
|
141 |
+
if self.score - old_score > 0:
|
142 |
+
merge_reward = np.log2(self.score - old_score) * 0.1
|
143 |
+
|
144 |
+
# 5. 单调性惩罚(鼓励有序排列)
|
145 |
+
monotonicity_penalty = self.calculate_monotonicity_penalty() * 0.01
|
146 |
+
|
147 |
+
# 6. 游戏结束惩罚
|
148 |
+
game_over_penalty = 0
|
149 |
+
if self.game_over:
|
150 |
+
game_over_penalty = -10
|
151 |
+
|
152 |
+
# 7. 平滑度奖励(鼓励相邻方块值接近)
|
153 |
+
smoothness_reward = self.calculate_smoothness() * 0.01
|
154 |
+
|
155 |
+
# 总奖励
|
156 |
+
total_reward = (
|
157 |
+
score_reward +
|
158 |
+
empty_reward +
|
159 |
+
max_tile_reward +
|
160 |
+
merge_reward +
|
161 |
+
smoothness_reward +
|
162 |
+
monotonicity_penalty +
|
163 |
+
game_over_penalty
|
164 |
+
)
|
165 |
+
|
166 |
+
return total_reward
|
167 |
+
|
168 |
+
def calculate_monotonicity_penalty(self):
|
169 |
+
"""计算单调性惩罚(值越低越好)"""
|
170 |
+
penalty = 0
|
171 |
+
for i in range(self.size):
|
172 |
+
for j in range(self.size - 1):
|
173 |
+
if self.board[i][j] > self.board[i][j+1]:
|
174 |
+
penalty += self.board[i][j] - self.board[i][j+1]
|
175 |
+
else:
|
176 |
+
penalty += self.board[i][j+1] - self.board[i][j]
|
177 |
+
return penalty
|
178 |
+
|
179 |
+
def calculate_smoothness(self):
|
180 |
+
"""计算平滑度(值越高越好)"""
|
181 |
+
smoothness = 0
|
182 |
+
for i in range(self.size):
|
183 |
+
for j in range(self.size):
|
184 |
+
if self.board[i][j] != 0:
|
185 |
+
value = np.log2(self.board[i][j])
|
186 |
+
# 检查右侧邻居
|
187 |
+
if j < self.size - 1 and self.board[i][j+1] != 0:
|
188 |
+
neighbor_value = np.log2(self.board[i][j+1])
|
189 |
+
smoothness -= abs(value - neighbor_value)
|
190 |
+
# 检查下方邻居
|
191 |
+
if i < self.size - 1 and self.board[i+1][j] != 0:
|
192 |
+
neighbor_value = np.log2(self.board[i+1][j])
|
193 |
+
smoothness -= abs(value - neighbor_value)
|
194 |
+
return smoothness
|
195 |
+
|
196 |
+
def check_game_over(self):
|
197 |
+
# 检查是否还有空格子
|
198 |
+
if np.any(self.board == 0):
|
199 |
+
self.game_over = False
|
200 |
+
return
|
201 |
+
|
202 |
+
# 检查水平和垂直方向是否有可合并的方块
|
203 |
+
for i in range(self.size):
|
204 |
+
for j in range(self.size - 1):
|
205 |
+
if self.board[i][j] == self.board[i][j+1]:
|
206 |
+
self.game_over = False
|
207 |
+
return
|
208 |
+
|
209 |
+
for j in range(self.size):
|
210 |
+
for i in range(self.size - 1):
|
211 |
+
if self.board[i][j] == self.board[i+1][j]:
|
212 |
+
self.game_over = False
|
213 |
+
return
|
214 |
+
|
215 |
+
self.game_over = True
|
216 |
+
|
217 |
+
def get_state(self):
|
218 |
+
"""改进的状态表示"""
|
219 |
+
# 创建4个通道的状态表示
|
220 |
+
state = np.zeros((4, self.size, self.size), dtype=np.float32)
|
221 |
+
|
222 |
+
# 通道0: 当前方块值的对数(归一化)
|
223 |
+
for i in range(self.size):
|
224 |
+
for j in range(self.size):
|
225 |
+
if self.board[i][j] > 0:
|
226 |
+
state[0, i, j] = np.log2(self.board[i][j]) / 16.0 # 支持到65536 (2^16)
|
227 |
+
|
228 |
+
# 通道1: 空格子指示器
|
229 |
+
state[1] = (self.board == 0).astype(np.float32)
|
230 |
+
|
231 |
+
# 通道2: 可合并的邻居指示器
|
232 |
+
for i in range(self.size):
|
233 |
+
for j in range(self.size):
|
234 |
+
if self.board[i][j] > 0:
|
235 |
+
# 检查右侧
|
236 |
+
if j < self.size - 1 and self.board[i][j] == self.board[i][j+1]:
|
237 |
+
state[2, i, j] = 1.0
|
238 |
+
state[2, i, j+1] = 1.0
|
239 |
+
# 检查下方
|
240 |
+
if i < self.size - 1 and self.board[i][j] == self.board[i+1][j]:
|
241 |
+
state[2, i, j] = 1.0
|
242 |
+
state[2, i+1, j] = 1.0
|
243 |
+
|
244 |
+
# 通道3: 最大值位置(归一化)
|
245 |
+
max_value = np.max(self.board)
|
246 |
+
if max_value > 0:
|
247 |
+
max_positions = np.argwhere(self.board == max_value)
|
248 |
+
for pos in max_positions:
|
249 |
+
state[3, pos[0], pos[1]] = 1.0
|
250 |
+
|
251 |
+
return state
|
252 |
+
|
253 |
+
def get_valid_moves(self):
|
254 |
+
"""更高效的有效移动检测"""
|
255 |
+
valid_moves = []
|
256 |
+
#test_board = np.zeros_like(self.board)
|
257 |
+
|
258 |
+
# 检查上移是否有效
|
259 |
+
for j in range(self.size):
|
260 |
+
column = self.board[:, j].copy()
|
261 |
+
new_column, _ = self.slide(column)
|
262 |
+
if not np.array_equal(new_column, self.board[:, j]):
|
263 |
+
valid_moves.append(0)
|
264 |
+
break
|
265 |
+
|
266 |
+
# 检查右移是否有效
|
267 |
+
for i in range(self.size):
|
268 |
+
row = self.board[i, :].copy()[::-1]
|
269 |
+
new_row, _ = self.slide(row)
|
270 |
+
if not np.array_equal(new_row[::-1], self.board[i, :]):
|
271 |
+
valid_moves.append(1)
|
272 |
+
break
|
273 |
+
|
274 |
+
# 检查下移是否有效
|
275 |
+
for j in range(self.size):
|
276 |
+
column = self.board[::-1, j].copy()
|
277 |
+
new_column, _ = self.slide(column)
|
278 |
+
if not np.array_equal(new_column[::-1], self.board[:, j]):
|
279 |
+
valid_moves.append(2)
|
280 |
+
break
|
281 |
+
|
282 |
+
# 检查左移是否有效
|
283 |
+
for i in range(self.size):
|
284 |
+
row = self.board[i, :].copy()
|
285 |
+
new_row, _ = self.slide(row)
|
286 |
+
if not np.array_equal(new_row, self.board[i, :]):
|
287 |
+
valid_moves.append(3)
|
288 |
+
break
|
289 |
+
|
290 |
+
return valid_moves
|
291 |
+
|
292 |
+
# 改进的深度Q网络(使用Dueling DQN架构)
|
293 |
+
class DQN(nn.Module):
|
294 |
+
def __init__(self, input_channels, output_size):
|
295 |
+
super(DQN, self).__init__()
|
296 |
+
self.input_channels = input_channels
|
297 |
+
|
298 |
+
# 卷积层
|
299 |
+
self.conv1 = nn.Conv2d(input_channels, 128, kernel_size=3, padding=1)
|
300 |
+
self.conv2 = nn.Conv2d(128, 128, kernel_size=3, padding=1)
|
301 |
+
self.conv3 = nn.Conv2d(128, 128, kernel_size=3, padding=1)
|
302 |
+
|
303 |
+
# Dueling DQN架构
|
304 |
+
# 价值流
|
305 |
+
self.value_conv = nn.Conv2d(128, 4, kernel_size=1)
|
306 |
+
self.value_fc1 = nn.Linear(4 * 4 * 4, 128)
|
307 |
+
self.value_fc2 = nn.Linear(128, 1)
|
308 |
+
|
309 |
+
# 优势流
|
310 |
+
self.advantage_conv = nn.Conv2d(128, 16, kernel_size=1)
|
311 |
+
self.advantage_fc1 = nn.Linear(16 * 4 * 4, 128)
|
312 |
+
self.advantage_fc2 = nn.Linear(128, output_size)
|
313 |
+
|
314 |
+
def forward(self, x):
|
315 |
+
x = F.relu(self.conv1(x))
|
316 |
+
x = F.relu(self.conv2(x))
|
317 |
+
x = F.relu(self.conv3(x))
|
318 |
+
|
319 |
+
# 价值流
|
320 |
+
value = F.relu(self.value_conv(x))
|
321 |
+
value = value.view(value.size(0), -1)
|
322 |
+
value = F.relu(self.value_fc1(value))
|
323 |
+
value = self.value_fc2(value)
|
324 |
+
|
325 |
+
# 优势流
|
326 |
+
advantage = F.relu(self.advantage_conv(x))
|
327 |
+
advantage = advantage.view(advantage.size(0), -1)
|
328 |
+
advantage = F.relu(self.advantage_fc1(advantage))
|
329 |
+
advantage = self.advantage_fc2(advantage)
|
330 |
+
|
331 |
+
# 合并价值流和优势流
|
332 |
+
q_values = value + advantage - advantage.mean(dim=1, keepdim=True)
|
333 |
+
return q_values
|
334 |
+
|
335 |
+
# 经验回放缓冲区(带优先级)
|
336 |
+
class PrioritizedReplayBuffer:
|
337 |
+
def __init__(self, capacity, alpha=0.6):
|
338 |
+
self.capacity = capacity
|
339 |
+
self.alpha = alpha
|
340 |
+
self.buffer = []
|
341 |
+
self.priorities = np.zeros(capacity)
|
342 |
+
self.pos = 0
|
343 |
+
self.size = 0
|
344 |
+
|
345 |
+
def push(self, state, action, reward, next_state, done):
|
346 |
+
# 初始优先级设置为最大优先级
|
347 |
+
max_priority = self.priorities.max() if self.buffer else 1.0
|
348 |
+
|
349 |
+
if len(self.buffer) < self.capacity:
|
350 |
+
self.buffer.append((state, action, reward, next_state, done))
|
351 |
+
else:
|
352 |
+
self.buffer[self.pos] = (state, action, reward, next_state, done)
|
353 |
+
|
354 |
+
self.priorities[self.pos] = max_priority
|
355 |
+
self.pos = (self.pos + 1) % self.capacity
|
356 |
+
self.size = min(self.size + 1, self.capacity)
|
357 |
+
|
358 |
+
def sample(self, batch_size, beta=0.4):
|
359 |
+
if self.size == 0:
|
360 |
+
return None, None, None
|
361 |
+
|
362 |
+
priorities = self.priorities[:self.size]
|
363 |
+
probs = priorities ** self.alpha
|
364 |
+
probs /= probs.sum()
|
365 |
+
|
366 |
+
indices = np.random.choice(self.size, batch_size, p=probs)
|
367 |
+
samples = [self.buffer[idx] for idx in indices]
|
368 |
+
|
369 |
+
# 计算重要性采样权重
|
370 |
+
weights = (self.size * probs[indices]) ** (-beta)
|
371 |
+
weights /= weights.max()
|
372 |
+
weights = np.array(weights, dtype=np.float32)
|
373 |
+
|
374 |
+
states, actions, rewards, next_states, dones = zip(*samples)
|
375 |
+
return (
|
376 |
+
torch.tensor(np.array(states)),
|
377 |
+
torch.tensor(actions, dtype=torch.long),
|
378 |
+
torch.tensor(rewards, dtype=torch.float),
|
379 |
+
torch.tensor(np.array(next_states)),
|
380 |
+
torch.tensor(dones, dtype=torch.float),
|
381 |
+
indices,
|
382 |
+
torch.tensor(weights)
|
383 |
+
)
|
384 |
+
|
385 |
+
def update_priorities(self, indices, priorities):
|
386 |
+
# 确保 priorities 是一个数组
|
387 |
+
if isinstance(priorities, np.ndarray) and priorities.ndim == 1:
|
388 |
+
for idx, priority in zip(indices, priorities):
|
389 |
+
self.priorities[idx] = priority
|
390 |
+
else:
|
391 |
+
# 处理标量情况(虽然不应该发生)
|
392 |
+
if not isinstance(priorities, (list, np.ndarray)):
|
393 |
+
priorities = [priorities] * len(indices)
|
394 |
+
for idx, priority in zip(indices, priorities):
|
395 |
+
self.priorities[idx] = priority
|
396 |
+
|
397 |
+
def __len__(self):
|
398 |
+
return self.size
|
399 |
+
|
400 |
+
# 改进的DQN智能体
|
401 |
+
class DQNAgent:
|
402 |
+
def __init__(self, input_channels, action_size, lr=3e-4, gamma=0.99,
|
403 |
+
epsilon_start=1.0, epsilon_end=0.01, epsilon_decay=0.999,
|
404 |
+
target_update_freq=1000, batch_size=128):
|
405 |
+
self.input_channels = input_channels
|
406 |
+
self.action_size = action_size
|
407 |
+
self.gamma = gamma
|
408 |
+
self.epsilon = epsilon_start
|
409 |
+
self.epsilon_end = epsilon_end
|
410 |
+
self.epsilon_decay = epsilon_decay
|
411 |
+
self.batch_size = batch_size
|
412 |
+
self.target_update_freq = target_update_freq
|
413 |
+
|
414 |
+
# 主网络和目标网络
|
415 |
+
self.policy_net = DQN(input_channels, action_size).to(device)
|
416 |
+
self.target_net = DQN(input_channels, action_size).to(device)
|
417 |
+
self.target_net.load_state_dict(self.policy_net.state_dict())
|
418 |
+
self.target_net.eval()
|
419 |
+
|
420 |
+
self.optimizer = optim.Adam(self.policy_net.parameters(), lr=lr, weight_decay=1e-5)
|
421 |
+
self.memory = PrioritizedReplayBuffer(50000)
|
422 |
+
self.steps_done = 0
|
423 |
+
self.loss_fn = nn.SmoothL1Loss(reduction='none')
|
424 |
+
|
425 |
+
def select_action(self, state, valid_moves):
|
426 |
+
self.steps_done += 1
|
427 |
+
self.epsilon = max(self.epsilon_end, self.epsilon * self.epsilon_decay)
|
428 |
+
|
429 |
+
if random.random() < self.epsilon:
|
430 |
+
# 随机选择有效动作
|
431 |
+
return random.choice(valid_moves)
|
432 |
+
else:
|
433 |
+
# 使用策略网络选择动作
|
434 |
+
with torch.no_grad():
|
435 |
+
state_tensor = torch.tensor(state, dtype=torch.float).unsqueeze(0).to(device)
|
436 |
+
q_values = self.policy_net(state_tensor).cpu().numpy().flatten()
|
437 |
+
|
438 |
+
# 只考虑有效动作
|
439 |
+
valid_q_values = np.full(self.action_size, -np.inf)
|
440 |
+
for move in valid_moves:
|
441 |
+
valid_q_values[move] = q_values[move]
|
442 |
+
|
443 |
+
return np.argmax(valid_q_values)
|
444 |
+
|
445 |
+
def optimize_model(self, beta=0.4):
|
446 |
+
if len(self.memory) < self.batch_size:
|
447 |
+
return 0
|
448 |
+
|
449 |
+
# 从回放缓冲区采样
|
450 |
+
sample = self.memory.sample(self.batch_size, beta)
|
451 |
+
if sample is None:
|
452 |
+
return 0
|
453 |
+
|
454 |
+
states, actions, rewards, next_states, dones, indices, weights = sample
|
455 |
+
|
456 |
+
states = states.to(device)
|
457 |
+
actions = actions.to(device)
|
458 |
+
rewards = rewards.to(device)
|
459 |
+
next_states = next_states.to(device)
|
460 |
+
dones = dones.to(device)
|
461 |
+
weights = weights.to(device)
|
462 |
+
|
463 |
+
# 计算当前Q值
|
464 |
+
current_q = self.policy_net(states).gather(1, actions.unsqueeze(1)).squeeze()
|
465 |
+
|
466 |
+
# 计算目标Q值(Double DQN)
|
467 |
+
with torch.no_grad():
|
468 |
+
next_actions = self.policy_net(next_states).max(1)[1]
|
469 |
+
next_q = self.target_net(next_states).gather(1, next_actions.unsqueeze(1)).squeeze()
|
470 |
+
target_q = rewards + (1 - dones) * self.gamma * next_q
|
471 |
+
|
472 |
+
# 计算损失
|
473 |
+
losses = self.loss_fn(current_q, target_q)
|
474 |
+
loss = (losses * weights).mean()
|
475 |
+
|
476 |
+
# 更新优先级(使用每个样本的损失绝对值)
|
477 |
+
with torch.no_grad():
|
478 |
+
priorities = losses.abs().cpu().numpy() + 1e-5
|
479 |
+
self.memory.update_priorities(indices, priorities)
|
480 |
+
|
481 |
+
# 优化模型
|
482 |
+
self.optimizer.zero_grad()
|
483 |
+
loss.backward()
|
484 |
+
|
485 |
+
# 梯度裁剪
|
486 |
+
torch.nn.utils.clip_grad_norm_(self.policy_net.parameters(), 10)
|
487 |
+
|
488 |
+
self.optimizer.step()
|
489 |
+
|
490 |
+
return loss.item()
|
491 |
+
|
492 |
+
def update_target_network(self):
|
493 |
+
self.target_net.load_state_dict(self.policy_net.state_dict())
|
494 |
+
|
495 |
+
def save_model(self, path):
|
496 |
+
torch.save({
|
497 |
+
'policy_net_state_dict': self.policy_net.state_dict(),
|
498 |
+
'target_net_state_dict': self.target_net.state_dict(),
|
499 |
+
'optimizer_state_dict': self.optimizer.state_dict(),
|
500 |
+
'epsilon': self.epsilon,
|
501 |
+
'steps_done': self.steps_done
|
502 |
+
}, path)
|
503 |
+
|
504 |
+
def load_model(self, path):
|
505 |
+
if not os.path.exists(path):
|
506 |
+
print(f"Model file not found: {path}")
|
507 |
+
return
|
508 |
+
|
509 |
+
try:
|
510 |
+
# 尝试使用 weights_only=False 加载模型
|
511 |
+
checkpoint = torch.load(path, map_location=device, weights_only=False)
|
512 |
+
self.policy_net.load_state_dict(checkpoint['policy_net_state_dict'])
|
513 |
+
self.target_net.load_state_dict(checkpoint['target_net_state_dict'])
|
514 |
+
self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
|
515 |
+
self.epsilon = checkpoint['epsilon']
|
516 |
+
self.steps_done = checkpoint['steps_done']
|
517 |
+
self.policy_net.eval()
|
518 |
+
self.target_net.eval()
|
519 |
+
print(f"Model loaded successfully from {path}")
|
520 |
+
except Exception as e:
|
521 |
+
print(f"Error loading model: {e}")
|
522 |
+
# 尝试使用旧版加载方式作为备选
|
523 |
+
try:
|
524 |
+
warnings.warn("Trying legacy load method without weights_only")
|
525 |
+
checkpoint = torch.load(path, map_location=device)
|
526 |
+
self.policy_net.load_state_dict(checkpoint['policy_net_state_dict'])
|
527 |
+
self.target_net.load_state_dict(checkpoint['target_net_state_dict'])
|
528 |
+
self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
|
529 |
+
self.epsilon = checkpoint['epsilon']
|
530 |
+
self.steps_done = checkpoint['steps_done']
|
531 |
+
self.policy_net.eval()
|
532 |
+
self.target_net.eval()
|
533 |
+
print(f"Model loaded successfully using legacy method")
|
534 |
+
except Exception as e2:
|
535 |
+
print(f"Failed to load model: {e2}")
|
536 |
+
# 训练函数(带进度记录)
|
537 |
+
def train_agent(agent, env, episodes=5000, save_path='models/dqn_2048.pth',
|
538 |
+
checkpoint_path='models/checkpoint.pth', resume=False, start_episode=0):
|
539 |
+
# 创建保存模型的目录
|
540 |
+
os.makedirs(os.path.dirname(save_path), exist_ok=True)
|
541 |
+
|
542 |
+
# 记录训练指标
|
543 |
+
scores = []
|
544 |
+
max_tiles = []
|
545 |
+
avg_scores = []
|
546 |
+
losses = []
|
547 |
+
best_score = 0
|
548 |
+
best_max_tile = 0
|
549 |
+
|
550 |
+
# 如果续训,加载训练状态
|
551 |
+
if resume and os.path.exists(checkpoint_path):
|
552 |
+
try:
|
553 |
+
# 使用 weights_only=False 加载检查点
|
554 |
+
checkpoint = torch.load(checkpoint_path, map_location=device, weights_only=False)
|
555 |
+
scores = checkpoint['scores']
|
556 |
+
max_tiles = checkpoint['max_tiles']
|
557 |
+
avg_scores = checkpoint['avg_scores']
|
558 |
+
losses = checkpoint['losses']
|
559 |
+
best_score = checkpoint.get('best_score', 0)
|
560 |
+
best_max_tile = checkpoint.get('best_max_tile', 0)
|
561 |
+
print(f"Resuming training from episode {start_episode}...")
|
562 |
+
except Exception as e:
|
563 |
+
print(f"Error loading checkpoint: {e}")
|
564 |
+
print("Starting training from scratch...")
|
565 |
+
resume = False
|
566 |
+
|
567 |
+
if not resume:
|
568 |
+
start_episode = 0
|
569 |
+
|
570 |
+
# 使用tqdm显示进度条
|
571 |
+
progress_bar = tqdm(range(start_episode, episodes), desc="Training")
|
572 |
+
|
573 |
+
for episode in progress_bar:
|
574 |
+
state = env.reset()
|
575 |
+
total_reward = 0
|
576 |
+
done = False
|
577 |
+
steps = 0
|
578 |
+
episode_loss = 0
|
579 |
+
loss_count = 0
|
580 |
+
|
581 |
+
while not done:
|
582 |
+
valid_moves = env.get_valid_moves()
|
583 |
+
if not valid_moves:
|
584 |
+
done = True
|
585 |
+
continue
|
586 |
+
|
587 |
+
action = agent.select_action(state, valid_moves)
|
588 |
+
next_state, reward, done = env.move(action)
|
589 |
+
total_reward += reward
|
590 |
+
|
591 |
+
agent.memory.push(state, action, reward, next_state, done)
|
592 |
+
state = next_state
|
593 |
+
|
594 |
+
# 优化模型
|
595 |
+
loss = agent.optimize_model(beta=min(1.0, episode / 1000))
|
596 |
+
if loss > 0:
|
597 |
+
episode_loss += loss
|
598 |
+
loss_count += 1
|
599 |
+
|
600 |
+
# 定期更新目标网络
|
601 |
+
if agent.steps_done % agent.target_update_freq == 0:
|
602 |
+
agent.update_target_network()
|
603 |
+
|
604 |
+
steps += 1
|
605 |
+
|
606 |
+
# 记录分数和最大方块
|
607 |
+
score = env.score
|
608 |
+
max_tile = np.max(env.board)
|
609 |
+
scores.append(score)
|
610 |
+
max_tiles.append(max_tile)
|
611 |
+
|
612 |
+
# 计算平均损失
|
613 |
+
avg_loss = episode_loss / loss_count if loss_count > 0 else 0
|
614 |
+
losses.append(avg_loss)
|
615 |
+
|
616 |
+
# 更新最佳记录
|
617 |
+
if score > best_score:
|
618 |
+
best_score = score
|
619 |
+
agent.save_model(save_path.replace('.pth', '_best_score.pth'))
|
620 |
+
if max_tile > best_max_tile:
|
621 |
+
best_max_tile = max_tile
|
622 |
+
agent.save_model(save_path.replace('.pth', '_best_tile.pth'))
|
623 |
+
|
624 |
+
# 计算最近100轮平均分数
|
625 |
+
recent_scores = scores[-100:] if len(scores) >= 100 else scores
|
626 |
+
avg_score = np.mean(recent_scores)
|
627 |
+
avg_scores.append(avg_score)
|
628 |
+
|
629 |
+
# 更新进度条描述
|
630 |
+
progress_bar.set_description(
|
631 |
+
f"Ep {episode+1}/{episodes} | "
|
632 |
+
f"Score: {score} (Avg: {avg_score:.1f}) | "
|
633 |
+
f"Max Tile: {max_tile} | "
|
634 |
+
f"Loss: {avg_loss:.4f} | "
|
635 |
+
f"Epsilon: {agent.epsilon:.4f}"
|
636 |
+
)
|
637 |
+
|
638 |
+
# 定期保存模型和训练状态
|
639 |
+
if (episode + 1) % 100 == 0:
|
640 |
+
agent.save_model(save_path)
|
641 |
+
|
642 |
+
# 保存训练状态
|
643 |
+
checkpoint = {
|
644 |
+
'scores': scores,
|
645 |
+
'max_tiles': max_tiles,
|
646 |
+
'avg_scores': avg_scores,
|
647 |
+
'losses': losses,
|
648 |
+
'best_score': best_score,
|
649 |
+
'best_max_tile': best_max_tile,
|
650 |
+
'episode': episode + 1,
|
651 |
+
'steps_done': agent.steps_done,
|
652 |
+
'epsilon': agent.epsilon
|
653 |
+
}
|
654 |
+
try:
|
655 |
+
torch.save(checkpoint, checkpoint_path)
|
656 |
+
except Exception as e:
|
657 |
+
print(f"Error saving checkpoint: {e}")
|
658 |
+
|
659 |
+
# 绘制训练曲线
|
660 |
+
if episode > 100: # 确保有足够的数据
|
661 |
+
plt.figure(figsize=(12, 8))
|
662 |
+
|
663 |
+
# 分数曲线
|
664 |
+
plt.subplot(2, 2, 1)
|
665 |
+
plt.plot(scores, label='Score')
|
666 |
+
plt.plot(avg_scores, label='Avg Score (100 eps)')
|
667 |
+
plt.xlabel('Episode')
|
668 |
+
plt.ylabel('Score')
|
669 |
+
plt.title('Training Scores')
|
670 |
+
plt.legend()
|
671 |
+
|
672 |
+
# 最大方块曲线
|
673 |
+
plt.subplot(2, 2, 2)
|
674 |
+
plt.plot(max_tiles, 'g-')
|
675 |
+
plt.xlabel('Episode')
|
676 |
+
plt.ylabel('Max Tile')
|
677 |
+
plt.title('Max Tile Achieved')
|
678 |
+
|
679 |
+
# 损失曲线
|
680 |
+
plt.subplot(2, 2, 3)
|
681 |
+
plt.plot(losses, 'r-')
|
682 |
+
plt.xlabel('Episode')
|
683 |
+
plt.ylabel('Loss')
|
684 |
+
plt.title('Training Loss')
|
685 |
+
|
686 |
+
# 分数分布直方图
|
687 |
+
plt.subplot(2, 2, 4)
|
688 |
+
plt.hist(scores, bins=20, alpha=0.7)
|
689 |
+
plt.xlabel('Score')
|
690 |
+
plt.ylabel('Frequency')
|
691 |
+
plt.title('Score Distribution')
|
692 |
+
|
693 |
+
plt.tight_layout()
|
694 |
+
plt.savefig('training_progress.png')
|
695 |
+
plt.close()
|
696 |
+
|
697 |
+
# 保存最终模型
|
698 |
+
agent.save_model(save_path)
|
699 |
+
|
700 |
+
return scores, max_tiles, losses
|
701 |
+
# 推理函数(带可视化)
|
702 |
+
def play_with_model(agent, env, episodes=3):
|
703 |
+
agent.epsilon = 0.001 # 设置很小的epsilon值进行推理
|
704 |
+
|
705 |
+
for episode in range(episodes):
|
706 |
+
state = env.reset()
|
707 |
+
done = False
|
708 |
+
steps = 0
|
709 |
+
|
710 |
+
print(f"\nEpisode {episode+1}")
|
711 |
+
print("Initial Board:")
|
712 |
+
print(env.board)
|
713 |
+
|
714 |
+
while not done:
|
715 |
+
valid_moves = env.get_valid_moves()
|
716 |
+
if not valid_moves:
|
717 |
+
done = True
|
718 |
+
print("No valid moves left!")
|
719 |
+
continue
|
720 |
+
|
721 |
+
# 选择动作
|
722 |
+
with torch.no_grad():
|
723 |
+
state_tensor = torch.tensor(state, dtype=torch.float).unsqueeze(0).to(device)
|
724 |
+
q_values = agent.policy_net(state_tensor).cpu().numpy().flatten()
|
725 |
+
|
726 |
+
# 只考虑有效动作
|
727 |
+
valid_q_values = np.full(env.size, -np.inf)
|
728 |
+
for move in valid_moves:
|
729 |
+
valid_q_values[move] = q_values[move]
|
730 |
+
|
731 |
+
action = np.argmax(valid_q_values)
|
732 |
+
|
733 |
+
# 执行动作
|
734 |
+
next_state, reward, done = env.move(action)
|
735 |
+
state = next_state
|
736 |
+
steps += 1
|
737 |
+
|
738 |
+
# 渲染游戏
|
739 |
+
print(f"\nStep {steps}: Action {['Up', 'Right', 'Down', 'Left'][action]}")
|
740 |
+
print(env.board)
|
741 |
+
print(f"Score: {env.score}, Max Tile: {np.max(env.board)}")
|
742 |
+
#同时将结果保存至result.txt文件中
|
743 |
+
with open("result.txt", "a") as f:
|
744 |
+
f.write(f"Episode {episode+1}, Step {steps}, Action {['Up', 'Right', 'Down', 'Left'][action]}, Score: {env.score}, Max Tile: {np.max(env.board)}\n{env.board}\n")
|
745 |
+
f.close()
|
746 |
+
|
747 |
+
|
748 |
+
print(f"\nGame Over! Final Score: {env.score}, Max Tile: {np.max(env.board)}")
|
749 |
+
|
750 |
+
# 主程序
|
751 |
+
if __name__ == "__main__":
|
752 |
+
args = {"train":0, "resume":0, "play":1, "episodes":50000}
|
753 |
+
env = Game2048(size=4)
|
754 |
+
input_channels = 4 # 状态表示的通道数
|
755 |
+
action_size = 4 # 上、右、下、左
|
756 |
+
|
757 |
+
agent = DQNAgent(
|
758 |
+
input_channels,
|
759 |
+
action_size,
|
760 |
+
lr=1e-4,
|
761 |
+
epsilon_decay=0.999, # 更慢的衰减
|
762 |
+
target_update_freq=1000,
|
763 |
+
batch_size=256
|
764 |
+
)
|
765 |
+
|
766 |
+
# 训练模型
|
767 |
+
if args.get('train') or args.get('resume'):
|
768 |
+
print("Starting training...")
|
769 |
+
|
770 |
+
# 如果续训,加载检查点
|
771 |
+
start_episode = 0
|
772 |
+
checkpoint_path = 'models/checkpoint.pth'
|
773 |
+
if args.get('resume') and os.path.exists(checkpoint_path):
|
774 |
+
try:
|
775 |
+
# 使用 weights_only=False 加载检查点
|
776 |
+
checkpoint = torch.load(checkpoint_path, map_location=device, weights_only=False)
|
777 |
+
start_episode = checkpoint.get('episode', 0)
|
778 |
+
agent.steps_done = checkpoint.get('steps_done', 0)
|
779 |
+
agent.epsilon = checkpoint.get('epsilon', agent.epsilon)
|
780 |
+
except Exception as e:
|
781 |
+
print(f"Error loading checkpoint: {e}")
|
782 |
+
print("Starting training from scratch...")
|
783 |
+
start_episode = 0
|
784 |
+
|
785 |
+
scores, max_tiles, losses = train_agent(
|
786 |
+
agent,
|
787 |
+
env,
|
788 |
+
episodes=args.get('episodes'),
|
789 |
+
save_path='models/dqn_2048.pth',
|
790 |
+
checkpoint_path=checkpoint_path,
|
791 |
+
resume=args.get('resume'),
|
792 |
+
start_episode=start_episode
|
793 |
+
)
|
794 |
+
print("Training completed!")
|
795 |
+
|
796 |
+
# 绘制最终训练结果
|
797 |
+
plt.figure(figsize=(15, 10))
|
798 |
+
|
799 |
+
plt.subplot(3, 1, 1)
|
800 |
+
plt.plot(scores)
|
801 |
+
plt.title('Scores per Episode')
|
802 |
+
plt.xlabel('Episode')
|
803 |
+
plt.ylabel('Score')
|
804 |
+
|
805 |
+
plt.subplot(3, 1, 2)
|
806 |
+
plt.plot(max_tiles)
|
807 |
+
plt.title('Max Tile per Episode')
|
808 |
+
plt.xlabel('Episode')
|
809 |
+
plt.ylabel('Max Tile')
|
810 |
+
|
811 |
+
plt.subplot(3, 1, 3)
|
812 |
+
plt.plot(losses)
|
813 |
+
plt.title('Training Loss per Episode')
|
814 |
+
plt.xlabel('Episode')
|
815 |
+
plt.ylabel('Loss')
|
816 |
+
|
817 |
+
plt.tight_layout()
|
818 |
+
plt.savefig('final_training_results.png')
|
819 |
+
plt.close()
|
820 |
+
|
821 |
+
# 加载模型并推理
|
822 |
+
if args.get('play'):
|
823 |
+
model_path = 'models/dqn_2048_best_tile.pth'
|
824 |
+
if not os.path.exists(model_path):
|
825 |
+
model_path = 'models/dqn_2048.pth'
|
826 |
+
|
827 |
+
if os.path.exists(model_path):
|
828 |
+
agent.load_model(model_path)
|
829 |
+
print("Playing with trained model...")
|
830 |
+
if not os.path.exists("result.txt"):
|
831 |
+
play_with_model(agent, env, episodes=1)
|
832 |
+
else:
|
833 |
+
os.remove("result.txt") #删除之前记录
|
834 |
+
play_with_model(agent, env, episodes=1)
|
835 |
+
|
836 |
+
else:
|
837 |
+
print("No trained model found. Please train the model first.")
|