Gofor5 commited on
Commit
020c81d
·
verified ·
1 Parent(s): bb627c0

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +179 -78
app.py CHANGED
@@ -1,11 +1,17 @@
1
  import gradio as gr
2
  import numpy as np
3
- import random
4
  import torch
5
  import torch.nn as nn
6
  import torch.nn.functional as F
7
  from game2048 import Game2048
8
 
 
 
 
 
 
 
 
9
  # 创建游戏实例
10
  game = Game2048(size=4)
11
 
@@ -23,7 +29,9 @@ TILE_COLORS = {
23
  512: "#edc850", # 512
24
  1024: "#edc53f", # 1024
25
  2048: "#edc22e", # 2048
26
- 4096: "#3c3a32", # 4096+
 
 
27
  }
28
 
29
  # 文本颜色映射(根据背景深浅)
@@ -41,6 +49,8 @@ TEXT_COLORS = {
41
  1024: "#f9f6f2", # 1024+
42
  2048: "#f9f6f2", # 2048+
43
  4096: "#f9f6f2", # 4096+
 
 
44
  }
45
 
46
  # 定义DQN网络结构(与训练时相同)
@@ -86,22 +96,56 @@ class DQN(nn.Module):
86
  q_values = value + advantage - advantage.mean(dim=1, keepdim=True)
87
  return q_values
88
 
89
- # 加载模型
90
- def load_model(model_path):
91
- model = DQN(4, 4) # 输入通道4,输出动作4个
92
  try:
93
- checkpoint = torch.load(model_path, map_location=torch.device('cpu'))
94
- model.load_state_dict(checkpoint['policy_net_state_dict'])
 
 
 
 
 
 
 
 
95
  model.eval()
96
- print("模型加载成功")
97
  return model
98
  except Exception as e:
99
  print(f"模型加载失败: {e}")
100
- return None
 
 
 
 
 
 
 
 
 
 
 
 
 
101
 
102
  # 尝试加载模型
103
- model_path = "dqn_2048_best_tile.pth"
104
- model = load_model(model_path)
 
 
 
 
 
 
 
 
 
 
 
 
 
105
 
106
  def render_board(board):
107
  html = "<div style='background-color:#bbada0; padding:10px; border-radius:6px;'>"
@@ -181,82 +225,134 @@ def ai_move():
181
  if not valid_moves:
182
  return render_board(game.board), "<b>游戏结束!</b> 没有有效移动"
183
 
184
- # 转换状态为模型输入
185
- state_tensor = torch.tensor(state, dtype=torch.float).unsqueeze(0)
186
-
187
- # 模型预测
188
- with torch.no_grad():
189
- q_values = model(state_tensor).numpy().flatten()
190
-
191
- # 只考虑有效动作
192
- valid_q_values = np.full(4, -np.inf)
193
- for move in valid_moves:
194
- valid_q_values[move] = q_values[move]
195
-
196
- # 选择最佳动作
197
- action = np.argmax(valid_q_values)
198
-
199
- # 执行移动
200
- direction_names = ["上", "右", "下", "左"]
201
- new_board, game_over = game.move(action)
202
-
203
- # 渲染棋盘
204
- board_html = render_board(new_board)
205
-
206
- # 更新状态信息
207
- status = f"<b>AI移动方向:</b> {direction_names[action]}"
208
- status += f"<br><b>当前分数:</b> {game.score}"
209
- status += f"<br><b>最大方块:</b> {np.max(game.board)}"
210
-
211
- if game.game_over:
212
- status += "<br><br><div style='color:#ff0000; font-weight:bold;'>游戏结束!</div>"
213
- status += f"<br><b>最终分数:</b> {game.score}"
 
 
 
 
214
 
215
- return board_html, status
 
 
216
 
217
  # 创建Gradio界面
218
  with gr.Blocks(title="2048游戏", theme="soft") as demo:
219
  gr.Markdown("# 🎮 2048游戏")
 
220
  gr.Markdown("使用方向键或下方的按钮移动方块,相同数字的方块相撞时会合并!")
 
221
  with gr.Row():
222
  with gr.Column(scale=2):
223
  board_html = gr.HTML(render_board(game.board))
224
- with gr.Row(visible=False):
225
- status_display = gr.HTML("<b>当前分数:</b> 0<br><b>最大方块:</b> 2")
226
  with gr.Column():
227
  gr.Markdown("## 手动操作")
228
  with gr.Row():
229
- gr.Button("上 ↑", elem_id="up-btn").click(
230
- fn=lambda: make_move(0),
231
- outputs=[board_html, status_display]
232
- )
233
- gr.Button("左 ←", elem_id="left-btn").click(
234
- fn=lambda: make_move(3),
235
- outputs=[board_html, status_display]
236
- )
237
- with gr.Row():
238
- gr.Button("下 ↓", elem_id="down-btn").click(
239
- fn=lambda: make_move(2),
240
- outputs=[board_html, status_display]
241
- )
242
- gr.Button("右 →", elem_id="right-btn").click(
243
- fn=lambda: make_move(1),
244
- outputs=[board_html, status_display]
245
- )
246
  with gr.Row():
247
- gr.Button("🔄 重置游戏", elem_id="reset-btn").click(
248
- fn=reset_game,
249
- outputs=[board_html, status_display]
250
- )
251
  with gr.Row():
252
- status_display = gr.HTML("<b>当前分数:</b> 0<br><b>最大方块:</b> 2")
253
- with gr.Column():
254
  gr.Markdown("## AI操作")
255
- gr.Button("🤖 AI移动一步", elem_id="ai-btn").click(
256
- fn=ai_move,
257
- outputs=[board_html, status_display]
258
- )
259
- gr.Markdown("基于DQN神经网络提供支持")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
260
 
261
  # 添加键盘快捷键支持
262
  demo.load(
@@ -277,17 +373,22 @@ with gr.Blocks(title="2048游戏", theme="soft") as demo:
277
  document.getElementById('reset-btn').click();
278
  } else if (e.key === 'a' || e.key === 'A') {
279
  document.getElementById('ai-btn').click();
 
 
280
  }
281
  });
282
  }"""
283
  )
 
284
  gr.Markdown("### 📚 使用说明")
285
- gr.Markdown("1. 使用方向键或下方的按钮移动方块。")
286
- gr.Markdown("2. 相同数字的方块相撞时会合并。")
287
- gr.Markdown("3. 快捷键说明:上/下/左/右键移动方块,R键重置游戏,A键AI移动一步。")
288
- gr.Markdown("4. 点击 '🤖 AI移动一步' 按钮可以使用AI模型进行一步移动。")
289
- gr.Markdown("5. 游戏结束后,会显示最终分数和最大方块。")
 
 
290
 
291
  # 启动界面
292
  if __name__ == "__main__":
293
- demo.launch(share=True)
 
1
  import gradio as gr
2
  import numpy as np
 
3
  import torch
4
  import torch.nn as nn
5
  import torch.nn.functional as F
6
  from game2048 import Game2048
7
 
8
+ # 检测可用设备
9
+ if torch.cuda.is_available():
10
+ device = torch.device("cuda")
11
+ elif torch.xpu.is_available():
12
+ device = torch.device("xpu")
13
+ else:
14
+ device = torch.device("cpu")
15
  # 创建游戏实例
16
  game = Game2048(size=4)
17
 
 
29
  512: "#edc850", # 512
30
  1024: "#edc53f", # 1024
31
  2048: "#edc22e", # 2048
32
+ 4096: "#3c3a32", # 4096+,
33
+ 8192: "#3c3a32", # 8192+
34
+ 16384: "#3c3a32", # 16384+
35
  }
36
 
37
  # 文本颜色映射(根据背景深浅)
 
49
  1024: "#f9f6f2", # 1024+
50
  2048: "#f9f6f2", # 2048+
51
  4096: "#f9f6f2", # 4096+
52
+ 8192: "#f9f6f2", # 8192+
53
+ 16384: "#f9f6f2", # 16384+
54
  }
55
 
56
  # 定义DQN网络结构(与训练时相同)
 
96
  q_values = value + advantage - advantage.mean(dim=1, keepdim=True)
97
  return q_values
98
 
99
+ # 加载模型(支持CUDA)
100
+ def load_model(model_path, device):
101
+ model = DQN(4, 4).to(device) # 将模型移动到指定设备
102
  try:
103
+ # 尝试加载模型
104
+ checkpoint = torch.load(model_path, map_location=device)
105
+
106
+ # 检查检查点是否包含完整的模型状态
107
+ if 'policy_net_state_dict' in checkpoint:
108
+ model.load_state_dict(checkpoint['policy_net_state_dict'])
109
+ else:
110
+ # 如果检查点不包含policy_net_state_dict,尝试直接加载
111
+ model.load_state_dict(checkpoint)
112
+
113
  model.eval()
114
+ print(f"模型成功加载到: {device}")
115
  return model
116
  except Exception as e:
117
  print(f"模型加载失败: {e}")
118
+ # 尝试备选模型路径
119
+ alt_path = model_path.replace('_best_tile', '')
120
+ try:
121
+ checkpoint = torch.load(alt_path, map_location=device)
122
+ if 'policy_net_state_dict' in checkpoint:
123
+ model.load_state_dict(checkpoint['policy_net_state_dict'])
124
+ else:
125
+ model.load_state_dict(checkpoint)
126
+ model.eval()
127
+ print(f"备选模型成功加载: {alt_path}")
128
+ return model
129
+ except Exception as e2:
130
+ print(f"备选模型加载失败: {e2}")
131
+ return None
132
 
133
  # 尝试加载模型
134
+ model_paths = [
135
+ "models/dqn_2048_best_tile.pth",
136
+ "models/dqn_2048.pth",
137
+ "dqn_2048_best_tile.pth",
138
+ "dqn_2048.pth"
139
+ ]
140
+
141
+ model = None
142
+ for path in model_paths:
143
+ model = load_model(path, device)
144
+ if model:
145
+ break
146
+
147
+ if not model:
148
+ print("警告: 未加载任何模型,AI功能将不可用")
149
 
150
  def render_board(board):
151
  html = "<div style='background-color:#bbada0; padding:10px; border-radius:6px;'>"
 
225
  if not valid_moves:
226
  return render_board(game.board), "<b>游戏结束!</b> 没有有效移动"
227
 
228
+ try:
229
+ # 转换状态为模型输入并移动到设备
230
+ state_tensor = torch.tensor(state, dtype=torch.float).unsqueeze(0).to(device)
231
+
232
+ # 模型预测
233
+ with torch.no_grad():
234
+ q_values = model(state_tensor).cpu().numpy().flatten()
235
+
236
+ # 只考虑有效动作
237
+ valid_q_values = np.full(4, -np.inf)
238
+ for move in valid_moves:
239
+ valid_q_values[move] = q_values[move]
240
+
241
+ # 选择最佳动作
242
+ action = np.argmax(valid_q_values)
243
+
244
+ # 执行移动
245
+ direction_names = ["上", "右", "下", "左"]
246
+ new_board, game_over = game.move(action)
247
+
248
+ # 渲染棋盘
249
+ board_html = render_board(new_board)
250
+
251
+ # 更新状态信息
252
+ status = f"<b>AI移动方向:</b> {direction_names[action]}"
253
+ status += f"<br><b>当前分数:</b> {game.score}"
254
+ status += f"<br><b>最大方块:</b> {np.max(game.board)}"
255
+ status += f"<br><b>设备:</b> {'GPU' if device.type == 'cuda' else 'CPU'}"
256
+
257
+ if game.game_over:
258
+ status += "<br><br><div style='color:#ff0000; font-weight:bold;'>游戏结束!</div>"
259
+ status += f"<br><b>最终分数:</b> {game.score}"
260
+
261
+ return board_html, status
262
 
263
+ except Exception as e:
264
+ print(f"AI移动出错: {e}")
265
+ return render_board(game.board), f"<b>错误:</b> AI移动失败 - {str(e)}"
266
 
267
  # 创建Gradio界面
268
  with gr.Blocks(title="2048游戏", theme="soft") as demo:
269
  gr.Markdown("# 🎮 2048游戏")
270
+ gr.Markdown(f"当前运行设备: **{'GPU' if device.type == 'cuda' else 'CPU'}**")
271
  gr.Markdown("使用方向键或下方的按钮移动方块,相同数字的方块相撞时会合并!")
272
+
273
  with gr.Row():
274
  with gr.Column(scale=2):
275
  board_html = gr.HTML(render_board(game.board))
276
+ status_display = gr.HTML("<b>当前分数:</b> 0<br><b>最大方块:</b> 2")
277
+
278
  with gr.Column():
279
  gr.Markdown("## 手动操作")
280
  with gr.Row():
281
+ up_btn = gr.Button("上 ↑", elem_id="up-btn")
282
+ left_btn = gr.Button("左 ←", elem_id="left-btn")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
283
  with gr.Row():
284
+ down_btn = gr.Button(" ", elem_id="down-btn")
285
+ right_btn = gr.Button("右 →", elem_id="right-btn")
 
 
286
  with gr.Row():
287
+ reset_btn = gr.Button("🔄 重置游戏", elem_id="reset-btn")
288
+
289
  gr.Markdown("## AI操作")
290
+ with gr.Row():
291
+ ai_btn = gr.Button("🤖 AI移动一步", elem_id="ai-btn")
292
+ auto_btn = gr.Button("🚀 连续AI模式", elem_id="auto-btn")
293
+
294
+ # 连接按钮事件
295
+ up_btn.click(lambda: make_move(0), outputs=[board_html, status_display])
296
+ right_btn.click(lambda: make_move(1), outputs=[board_html, status_display])
297
+ down_btn.click(lambda: make_move(2), outputs=[board_html, status_display])
298
+ left_btn.click(lambda: make_move(3), outputs=[board_html, status_display])
299
+ reset_btn.click(reset_game, outputs=[board_html, status_display])
300
+ ai_btn.click(ai_move, outputs=[board_html, status_display])
301
+
302
+ # 连续AI模式
303
+ def auto_play():
304
+ """连续AI移动直到游戏结束"""
305
+ if model is None:
306
+ return render_board(game.board), "<b>错误:</b> 未加载AI模型"
307
+
308
+ moves = 0
309
+ max_moves = 200 # 防止无限循环
310
+
311
+ while not game.game_over and moves < max_moves:
312
+ # 获取当前状态
313
+ state = game.get_state()
314
+
315
+ # 获取有效移动
316
+ valid_moves = game.get_valid_moves()
317
+ if not valid_moves:
318
+ break
319
+
320
+ # 转换状态为模型输入并移动到设备
321
+ state_tensor = torch.tensor(state, dtype=torch.float).unsqueeze(0).to(device)
322
+
323
+ # 模型预测
324
+ with torch.no_grad():
325
+ q_values = model(state_tensor).cpu().numpy().flatten()
326
+
327
+ # 只考虑有效动作
328
+ valid_q_values = np.full(4, -np.inf)
329
+ for move in valid_moves:
330
+ valid_q_values[move] = q_values[move]
331
+
332
+ # 选择最佳动作
333
+ action = np.argmax(valid_q_values)
334
+
335
+ # 执行移动
336
+ game.move(action)
337
+ moves += 1
338
+
339
+ # 渲染棋盘
340
+ board_html = render_board(game.board)
341
+
342
+ # 更新状态信息
343
+ status = f"<b>连续AI完成!</b>"
344
+ status += f"<br><b>移动次数:</b> {moves}"
345
+ status += f"<br><b>当前分数:</b> {game.score}"
346
+ status += f"<br><b>最大方块:</b> {np.max(game.board)}"
347
+ status += f"<br><b>设备:</b> {'GPU' if device.type == 'cuda' else 'CPU'}"
348
+
349
+ if game.game_over:
350
+ status += "<br><br><div style='color:#ff0000; font-weight:bold;'>游戏结束!</div>"
351
+ status += f"<br><b>最终分数:</b> {game.score}"
352
+
353
+ return board_html, status
354
+
355
+ auto_btn.click(auto_play, outputs=[board_html, status_display])
356
 
357
  # 添加键盘快捷键支持
358
  demo.load(
 
373
  document.getElementById('reset-btn').click();
374
  } else if (e.key === 'a' || e.key === 'A') {
375
  document.getElementById('ai-btn').click();
376
+ } else if (e.key === 's' || e.key === 'S') {
377
+ document.getElementById('auto-btn').click();
378
  }
379
  });
380
  }"""
381
  )
382
+
383
  gr.Markdown("### 📚 使用说明")
384
+ gr.Markdown("1. 使用方向键或下方的按钮移动方块")
385
+ gr.Markdown("2. 相同数字的方块相撞时会合并")
386
+ gr.Markdown("3. **快捷键说明**:")
387
+ gr.Markdown(" - ↑/↓/←/→: 移动方块")
388
+ gr.Markdown(" - R: 重置游戏")
389
+ gr.Markdown(" - A: AI移动一步")
390
+ gr.Markdown(" - S: 连续AI模式(自动玩到游戏结束)")
391
 
392
  # 启动界面
393
  if __name__ == "__main__":
394
+ demo.launch()