Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
@@ -1,11 +1,17 @@
|
|
1 |
import gradio as gr
|
2 |
import numpy as np
|
3 |
-
import random
|
4 |
import torch
|
5 |
import torch.nn as nn
|
6 |
import torch.nn.functional as F
|
7 |
from game2048 import Game2048
|
8 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
9 |
# 创建游戏实例
|
10 |
game = Game2048(size=4)
|
11 |
|
@@ -23,7 +29,9 @@ TILE_COLORS = {
|
|
23 |
512: "#edc850", # 512
|
24 |
1024: "#edc53f", # 1024
|
25 |
2048: "#edc22e", # 2048
|
26 |
-
4096: "#3c3a32", # 4096
|
|
|
|
|
27 |
}
|
28 |
|
29 |
# 文本颜色映射(根据背景深浅)
|
@@ -41,6 +49,8 @@ TEXT_COLORS = {
|
|
41 |
1024: "#f9f6f2", # 1024+
|
42 |
2048: "#f9f6f2", # 2048+
|
43 |
4096: "#f9f6f2", # 4096+
|
|
|
|
|
44 |
}
|
45 |
|
46 |
# 定义DQN网络结构(与训练时相同)
|
@@ -86,22 +96,56 @@ class DQN(nn.Module):
|
|
86 |
q_values = value + advantage - advantage.mean(dim=1, keepdim=True)
|
87 |
return q_values
|
88 |
|
89 |
-
#
|
90 |
-
def load_model(model_path):
|
91 |
-
model = DQN(4, 4) #
|
92 |
try:
|
93 |
-
|
94 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
95 |
model.eval()
|
96 |
-
print("
|
97 |
return model
|
98 |
except Exception as e:
|
99 |
print(f"模型加载失败: {e}")
|
100 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
101 |
|
102 |
# 尝试加载模型
|
103 |
-
|
104 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
105 |
|
106 |
def render_board(board):
|
107 |
html = "<div style='background-color:#bbada0; padding:10px; border-radius:6px;'>"
|
@@ -181,82 +225,134 @@ def ai_move():
|
|
181 |
if not valid_moves:
|
182 |
return render_board(game.board), "<b>游戏结束!</b> 没有有效移动"
|
183 |
|
184 |
-
|
185 |
-
|
186 |
-
|
187 |
-
|
188 |
-
|
189 |
-
|
190 |
-
|
191 |
-
|
192 |
-
|
193 |
-
|
194 |
-
|
195 |
-
|
196 |
-
|
197 |
-
|
198 |
-
|
199 |
-
|
200 |
-
|
201 |
-
|
202 |
-
|
203 |
-
|
204 |
-
|
205 |
-
|
206 |
-
|
207 |
-
|
208 |
-
|
209 |
-
|
210 |
-
|
211 |
-
|
212 |
-
|
213 |
-
|
|
|
|
|
|
|
|
|
214 |
|
215 |
-
|
|
|
|
|
216 |
|
217 |
# 创建Gradio界面
|
218 |
with gr.Blocks(title="2048游戏", theme="soft") as demo:
|
219 |
gr.Markdown("# 🎮 2048游戏")
|
|
|
220 |
gr.Markdown("使用方向键或下方的按钮移动方块,相同数字的方块相撞时会合并!")
|
|
|
221 |
with gr.Row():
|
222 |
with gr.Column(scale=2):
|
223 |
board_html = gr.HTML(render_board(game.board))
|
224 |
-
|
225 |
-
|
226 |
with gr.Column():
|
227 |
gr.Markdown("## 手动操作")
|
228 |
with gr.Row():
|
229 |
-
gr.Button("上 ↑", elem_id="up-btn")
|
230 |
-
|
231 |
-
outputs=[board_html, status_display]
|
232 |
-
)
|
233 |
-
gr.Button("左 ←", elem_id="left-btn").click(
|
234 |
-
fn=lambda: make_move(3),
|
235 |
-
outputs=[board_html, status_display]
|
236 |
-
)
|
237 |
-
with gr.Row():
|
238 |
-
gr.Button("下 ↓", elem_id="down-btn").click(
|
239 |
-
fn=lambda: make_move(2),
|
240 |
-
outputs=[board_html, status_display]
|
241 |
-
)
|
242 |
-
gr.Button("右 →", elem_id="right-btn").click(
|
243 |
-
fn=lambda: make_move(1),
|
244 |
-
outputs=[board_html, status_display]
|
245 |
-
)
|
246 |
with gr.Row():
|
247 |
-
gr.Button("
|
248 |
-
|
249 |
-
outputs=[board_html, status_display]
|
250 |
-
)
|
251 |
with gr.Row():
|
252 |
-
|
253 |
-
|
254 |
gr.Markdown("## AI操作")
|
255 |
-
gr.
|
256 |
-
|
257 |
-
|
258 |
-
|
259 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
260 |
|
261 |
# 添加键盘快捷键支持
|
262 |
demo.load(
|
@@ -277,17 +373,22 @@ with gr.Blocks(title="2048游戏", theme="soft") as demo:
|
|
277 |
document.getElementById('reset-btn').click();
|
278 |
} else if (e.key === 'a' || e.key === 'A') {
|
279 |
document.getElementById('ai-btn').click();
|
|
|
|
|
280 |
}
|
281 |
});
|
282 |
}"""
|
283 |
)
|
|
|
284 |
gr.Markdown("### 📚 使用说明")
|
285 |
-
gr.Markdown("1.
|
286 |
-
gr.Markdown("2.
|
287 |
-
gr.Markdown("3.
|
288 |
-
gr.Markdown("
|
289 |
-
gr.Markdown("
|
|
|
|
|
290 |
|
291 |
# 启动界面
|
292 |
if __name__ == "__main__":
|
293 |
-
demo.launch(
|
|
|
1 |
import gradio as gr
|
2 |
import numpy as np
|
|
|
3 |
import torch
|
4 |
import torch.nn as nn
|
5 |
import torch.nn.functional as F
|
6 |
from game2048 import Game2048
|
7 |
|
8 |
+
# 检测可用设备
|
9 |
+
if torch.cuda.is_available():
|
10 |
+
device = torch.device("cuda")
|
11 |
+
elif torch.xpu.is_available():
|
12 |
+
device = torch.device("xpu")
|
13 |
+
else:
|
14 |
+
device = torch.device("cpu")
|
15 |
# 创建游戏实例
|
16 |
game = Game2048(size=4)
|
17 |
|
|
|
29 |
512: "#edc850", # 512
|
30 |
1024: "#edc53f", # 1024
|
31 |
2048: "#edc22e", # 2048
|
32 |
+
4096: "#3c3a32", # 4096+,
|
33 |
+
8192: "#3c3a32", # 8192+
|
34 |
+
16384: "#3c3a32", # 16384+
|
35 |
}
|
36 |
|
37 |
# 文本颜色映射(根据背景深浅)
|
|
|
49 |
1024: "#f9f6f2", # 1024+
|
50 |
2048: "#f9f6f2", # 2048+
|
51 |
4096: "#f9f6f2", # 4096+
|
52 |
+
8192: "#f9f6f2", # 8192+
|
53 |
+
16384: "#f9f6f2", # 16384+
|
54 |
}
|
55 |
|
56 |
# 定义DQN网络结构(与训练时相同)
|
|
|
96 |
q_values = value + advantage - advantage.mean(dim=1, keepdim=True)
|
97 |
return q_values
|
98 |
|
99 |
+
# 加载模型(支持CUDA)
|
100 |
+
def load_model(model_path, device):
|
101 |
+
model = DQN(4, 4).to(device) # 将模型移动到指定设备
|
102 |
try:
|
103 |
+
# 尝试加载模型
|
104 |
+
checkpoint = torch.load(model_path, map_location=device)
|
105 |
+
|
106 |
+
# 检查检查点是否包含完整的模型状态
|
107 |
+
if 'policy_net_state_dict' in checkpoint:
|
108 |
+
model.load_state_dict(checkpoint['policy_net_state_dict'])
|
109 |
+
else:
|
110 |
+
# 如果检查点不包含policy_net_state_dict,尝试直接加载
|
111 |
+
model.load_state_dict(checkpoint)
|
112 |
+
|
113 |
model.eval()
|
114 |
+
print(f"模型成功加载到: {device}")
|
115 |
return model
|
116 |
except Exception as e:
|
117 |
print(f"模型加载失败: {e}")
|
118 |
+
# 尝试备选模型路径
|
119 |
+
alt_path = model_path.replace('_best_tile', '')
|
120 |
+
try:
|
121 |
+
checkpoint = torch.load(alt_path, map_location=device)
|
122 |
+
if 'policy_net_state_dict' in checkpoint:
|
123 |
+
model.load_state_dict(checkpoint['policy_net_state_dict'])
|
124 |
+
else:
|
125 |
+
model.load_state_dict(checkpoint)
|
126 |
+
model.eval()
|
127 |
+
print(f"备选模型成功加载: {alt_path}")
|
128 |
+
return model
|
129 |
+
except Exception as e2:
|
130 |
+
print(f"备选模型加载失败: {e2}")
|
131 |
+
return None
|
132 |
|
133 |
# 尝试加载模型
|
134 |
+
model_paths = [
|
135 |
+
"models/dqn_2048_best_tile.pth",
|
136 |
+
"models/dqn_2048.pth",
|
137 |
+
"dqn_2048_best_tile.pth",
|
138 |
+
"dqn_2048.pth"
|
139 |
+
]
|
140 |
+
|
141 |
+
model = None
|
142 |
+
for path in model_paths:
|
143 |
+
model = load_model(path, device)
|
144 |
+
if model:
|
145 |
+
break
|
146 |
+
|
147 |
+
if not model:
|
148 |
+
print("警告: 未加载任何模型,AI功能将不可用")
|
149 |
|
150 |
def render_board(board):
|
151 |
html = "<div style='background-color:#bbada0; padding:10px; border-radius:6px;'>"
|
|
|
225 |
if not valid_moves:
|
226 |
return render_board(game.board), "<b>游戏结束!</b> 没有有效移动"
|
227 |
|
228 |
+
try:
|
229 |
+
# 转换状态为模型输入并移动到设备
|
230 |
+
state_tensor = torch.tensor(state, dtype=torch.float).unsqueeze(0).to(device)
|
231 |
+
|
232 |
+
# 模型预测
|
233 |
+
with torch.no_grad():
|
234 |
+
q_values = model(state_tensor).cpu().numpy().flatten()
|
235 |
+
|
236 |
+
# 只考虑有效动作
|
237 |
+
valid_q_values = np.full(4, -np.inf)
|
238 |
+
for move in valid_moves:
|
239 |
+
valid_q_values[move] = q_values[move]
|
240 |
+
|
241 |
+
# 选择最佳动作
|
242 |
+
action = np.argmax(valid_q_values)
|
243 |
+
|
244 |
+
# 执行移动
|
245 |
+
direction_names = ["上", "右", "下", "左"]
|
246 |
+
new_board, game_over = game.move(action)
|
247 |
+
|
248 |
+
# 渲染棋盘
|
249 |
+
board_html = render_board(new_board)
|
250 |
+
|
251 |
+
# 更新状态信息
|
252 |
+
status = f"<b>AI移动方向:</b> {direction_names[action]}"
|
253 |
+
status += f"<br><b>当前分数:</b> {game.score}"
|
254 |
+
status += f"<br><b>最大方块:</b> {np.max(game.board)}"
|
255 |
+
status += f"<br><b>设备:</b> {'GPU' if device.type == 'cuda' else 'CPU'}"
|
256 |
+
|
257 |
+
if game.game_over:
|
258 |
+
status += "<br><br><div style='color:#ff0000; font-weight:bold;'>游戏结束!</div>"
|
259 |
+
status += f"<br><b>最终分数:</b> {game.score}"
|
260 |
+
|
261 |
+
return board_html, status
|
262 |
|
263 |
+
except Exception as e:
|
264 |
+
print(f"AI移动出错: {e}")
|
265 |
+
return render_board(game.board), f"<b>错误:</b> AI移动失败 - {str(e)}"
|
266 |
|
267 |
# 创建Gradio界面
|
268 |
with gr.Blocks(title="2048游戏", theme="soft") as demo:
|
269 |
gr.Markdown("# 🎮 2048游戏")
|
270 |
+
gr.Markdown(f"当前运行设备: **{'GPU' if device.type == 'cuda' else 'CPU'}**")
|
271 |
gr.Markdown("使用方向键或下方的按钮移动方块,相同数字的方块相撞时会合并!")
|
272 |
+
|
273 |
with gr.Row():
|
274 |
with gr.Column(scale=2):
|
275 |
board_html = gr.HTML(render_board(game.board))
|
276 |
+
status_display = gr.HTML("<b>当前分数:</b> 0<br><b>最大方块:</b> 2")
|
277 |
+
|
278 |
with gr.Column():
|
279 |
gr.Markdown("## 手动操作")
|
280 |
with gr.Row():
|
281 |
+
up_btn = gr.Button("上 ↑", elem_id="up-btn")
|
282 |
+
left_btn = gr.Button("左 ←", elem_id="left-btn")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
283 |
with gr.Row():
|
284 |
+
down_btn = gr.Button("下 ↓", elem_id="down-btn")
|
285 |
+
right_btn = gr.Button("右 →", elem_id="right-btn")
|
|
|
|
|
286 |
with gr.Row():
|
287 |
+
reset_btn = gr.Button("🔄 重置游戏", elem_id="reset-btn")
|
288 |
+
|
289 |
gr.Markdown("## AI操作")
|
290 |
+
with gr.Row():
|
291 |
+
ai_btn = gr.Button("🤖 AI移动一步", elem_id="ai-btn")
|
292 |
+
auto_btn = gr.Button("🚀 连续AI模式", elem_id="auto-btn")
|
293 |
+
|
294 |
+
# 连接按钮事件
|
295 |
+
up_btn.click(lambda: make_move(0), outputs=[board_html, status_display])
|
296 |
+
right_btn.click(lambda: make_move(1), outputs=[board_html, status_display])
|
297 |
+
down_btn.click(lambda: make_move(2), outputs=[board_html, status_display])
|
298 |
+
left_btn.click(lambda: make_move(3), outputs=[board_html, status_display])
|
299 |
+
reset_btn.click(reset_game, outputs=[board_html, status_display])
|
300 |
+
ai_btn.click(ai_move, outputs=[board_html, status_display])
|
301 |
+
|
302 |
+
# 连续AI模式
|
303 |
+
def auto_play():
|
304 |
+
"""连续AI移动直到游戏结束"""
|
305 |
+
if model is None:
|
306 |
+
return render_board(game.board), "<b>错误:</b> 未加载AI模型"
|
307 |
+
|
308 |
+
moves = 0
|
309 |
+
max_moves = 200 # 防止无限循环
|
310 |
+
|
311 |
+
while not game.game_over and moves < max_moves:
|
312 |
+
# 获取当前状态
|
313 |
+
state = game.get_state()
|
314 |
+
|
315 |
+
# 获取有效移动
|
316 |
+
valid_moves = game.get_valid_moves()
|
317 |
+
if not valid_moves:
|
318 |
+
break
|
319 |
+
|
320 |
+
# 转换状态为模型输入并移动到设备
|
321 |
+
state_tensor = torch.tensor(state, dtype=torch.float).unsqueeze(0).to(device)
|
322 |
+
|
323 |
+
# 模型预测
|
324 |
+
with torch.no_grad():
|
325 |
+
q_values = model(state_tensor).cpu().numpy().flatten()
|
326 |
+
|
327 |
+
# 只考虑有效动作
|
328 |
+
valid_q_values = np.full(4, -np.inf)
|
329 |
+
for move in valid_moves:
|
330 |
+
valid_q_values[move] = q_values[move]
|
331 |
+
|
332 |
+
# 选择最佳动作
|
333 |
+
action = np.argmax(valid_q_values)
|
334 |
+
|
335 |
+
# 执行移动
|
336 |
+
game.move(action)
|
337 |
+
moves += 1
|
338 |
+
|
339 |
+
# 渲染棋盘
|
340 |
+
board_html = render_board(game.board)
|
341 |
+
|
342 |
+
# 更新状态信息
|
343 |
+
status = f"<b>连续AI完成!</b>"
|
344 |
+
status += f"<br><b>移动次数:</b> {moves}"
|
345 |
+
status += f"<br><b>当前分数:</b> {game.score}"
|
346 |
+
status += f"<br><b>最大方块:</b> {np.max(game.board)}"
|
347 |
+
status += f"<br><b>设备:</b> {'GPU' if device.type == 'cuda' else 'CPU'}"
|
348 |
+
|
349 |
+
if game.game_over:
|
350 |
+
status += "<br><br><div style='color:#ff0000; font-weight:bold;'>游戏结束!</div>"
|
351 |
+
status += f"<br><b>最终分数:</b> {game.score}"
|
352 |
+
|
353 |
+
return board_html, status
|
354 |
+
|
355 |
+
auto_btn.click(auto_play, outputs=[board_html, status_display])
|
356 |
|
357 |
# 添加键盘快捷键支持
|
358 |
demo.load(
|
|
|
373 |
document.getElementById('reset-btn').click();
|
374 |
} else if (e.key === 'a' || e.key === 'A') {
|
375 |
document.getElementById('ai-btn').click();
|
376 |
+
} else if (e.key === 's' || e.key === 'S') {
|
377 |
+
document.getElementById('auto-btn').click();
|
378 |
}
|
379 |
});
|
380 |
}"""
|
381 |
)
|
382 |
+
|
383 |
gr.Markdown("### 📚 使用说明")
|
384 |
+
gr.Markdown("1. 使用方向键或下方的按钮移动方块")
|
385 |
+
gr.Markdown("2. 相同数字的方块相撞时会合并")
|
386 |
+
gr.Markdown("3. **快捷键说明**:")
|
387 |
+
gr.Markdown(" - ↑/↓/←/→: 移动方块")
|
388 |
+
gr.Markdown(" - R: 重置游戏")
|
389 |
+
gr.Markdown(" - A: AI移动一步")
|
390 |
+
gr.Markdown(" - S: 连续AI模式(自动玩到游戏结束)")
|
391 |
|
392 |
# 启动界面
|
393 |
if __name__ == "__main__":
|
394 |
+
demo.launch()
|