import gradio as gr import numpy as np import torch import torch.nn as nn import torch.nn.functional as F from game2048 import Game2048 # 检测可用设备 if torch.cuda.is_available(): device = torch.device("cuda") elif torch.xpu.is_available(): device = torch.device("xpu") else: device = torch.device("cpu") # 创建游戏实例 game = Game2048(size=4) # 方块颜色映射(根据数字值) TILE_COLORS = { 0: "#cdc1b4", # 空白格子 2: "#eee4da", # 2 4: "#ede0c8", # 4 8: "#f2b179", # 8 16: "#f59563", # 16 32: "#f67c5f", # 32 64: "#f65e3b", # 64 128: "#edcf72", # 128 256: "#edcc61", # 256 512: "#edc850", # 512 1024: "#edc53f", # 1024 2048: "#edc22e", # 2048 4096: "#3c3a32", # 4096+, 8192: "#3c3a32", # 8192+ 16384: "#3c3a32", # 16384+ } # 文本颜色映射(根据背景深浅) TEXT_COLORS = { 0: "#776e65", # 空白格子 2: "#776e65", # 2 4: "#776e65", # 4 8: "#f9f6f2", # 8+ 16: "#f9f6f2", # 16+ 32: "#f9f6f2", # 32+ 64: "#f9f6f2", # 64+ 128: "#f9f6f2", # 128+ 256: "#f9f6f2", # 256+ 512: "#f9f6f2", # 512+ 1024: "#f9f6f2", # 1024+ 2048: "#f9f6f2", # 2048+ 4096: "#f9f6f2", # 4096+ 8192: "#f9f6f2", # 8192+ 16384: "#f9f6f2", # 16384+ } # 定义DQN网络结构(与训练时相同) class DQN(nn.Module): def __init__(self, input_channels, output_size): super(DQN, self).__init__() self.input_channels = input_channels # 卷积层 self.conv1 = nn.Conv2d(input_channels, 128, kernel_size=3, padding=1) self.conv2 = nn.Conv2d(128, 128, kernel_size=3, padding=1) self.conv3 = nn.Conv2d(128, 128, kernel_size=3, padding=1) # Dueling DQN架构 # 价值流 self.value_conv = nn.Conv2d(128, 4, kernel_size=1) self.value_fc1 = nn.Linear(4 * 4 * 4, 128) self.value_fc2 = nn.Linear(128, 1) # 优势流 self.advantage_conv = nn.Conv2d(128, 16, kernel_size=1) self.advantage_fc1 = nn.Linear(16 * 4 * 4, 128) self.advantage_fc2 = nn.Linear(128, output_size) def forward(self, x): x = F.relu(self.conv1(x)) x = F.relu(self.conv2(x)) x = F.relu(self.conv3(x)) # 价值流 value = F.relu(self.value_conv(x)) value = value.view(value.size(0), -1) value = F.relu(self.value_fc1(value)) value = self.value_fc2(value) # 优势流 advantage = F.relu(self.advantage_conv(x)) advantage = advantage.view(advantage.size(0), -1) advantage = F.relu(self.advantage_fc1(advantage)) advantage = self.advantage_fc2(advantage) # 合并价值流和优势流 q_values = value + advantage - advantage.mean(dim=1, keepdim=True) return q_values # 加载模型(支持CUDA) def load_model(model_path, device): model = DQN(4, 4).to(device) # 将模型移动到指定设备 try: # 尝试加载模型 checkpoint = torch.load(model_path, map_location=device) # 检查检查点是否包含完整的模型状态 if 'policy_net_state_dict' in checkpoint: model.load_state_dict(checkpoint['policy_net_state_dict']) else: # 如果检查点不包含policy_net_state_dict,尝试直接加载 model.load_state_dict(checkpoint) model.eval() print(f"模型成功加载到: {device}") return model except Exception as e: print(f"模型加载失败: {e}") # 尝试备选模型路径 alt_path = model_path.replace('_best_tile', '') try: checkpoint = torch.load(alt_path, map_location=device) if 'policy_net_state_dict' in checkpoint: model.load_state_dict(checkpoint['policy_net_state_dict']) else: model.load_state_dict(checkpoint) model.eval() print(f"备选模型成功加载: {alt_path}") return model except Exception as e2: print(f"备选模型加载失败: {e2}") return None # 尝试加载模型 model_paths = [ "models/dqn_2048_best_tile.pth", "models/dqn_2048.pth", "dqn_2048_best_tile.pth", "dqn_2048.pth" ] model = None for path in model_paths: model = load_model(path, device) if model: break if not model: print("警告: 未加载任何模型,AI功能将不可用") def render_board(board): html = "
" html += "" for i in range(game.size): html += "" for j in range(game.size): value = board[i][j] color = TILE_COLORS.get(value, "#3c3a32") # 默认深色 text_color = TEXT_COLORS.get(value, "#f9f6f2") # 默认浅色 font_size = "36px" if value < 100 else "30px" if value < 1000 else "24px" html += f""" """ html += "" html += "
{value if value > 0 else ''}
" return html def make_move(direction): """执行移动操作并更新界面""" direction_names = ["上", "右", "下", "左"] # 执行移动 new_board, game_over = game.move(direction) # 渲染棋盘 board_html = render_board(new_board) # 更新状态信息 status = f"移动方向: {direction_names[direction]}" status += f"
当前分数: {game.score}" status += f"
最大方块: {np.max(game.board)}" if game.game_over: status += "

游戏结束!
" status += f"
最终分数: {game.score}" return board_html, status def reset_game(): """重置游戏""" global game game = Game2048(size=4) board = game.reset() # 渲染棋盘 board_html = render_board(board) # 初始状态信息 status = "游戏已重置!" status += f"
当前分数: {game.score}" status += f"
最大方块: {np.max(game.board)}" return board_html, status def ai_move(): """使用AI模型进行一步移动""" if model is None: return render_board(game.board), "错误: 未加载AI模型" # 获取当前状态 state = game.get_state() # 获取有效移动 valid_moves = game.get_valid_moves() if not valid_moves: return render_board(game.board), "游戏结束! 没有有效移动" try: # 转换状态为模型输入并移动到设备 state_tensor = torch.tensor(state, dtype=torch.float).unsqueeze(0).to(device) # 模型预测 with torch.no_grad(): q_values = model(state_tensor).cpu().numpy().flatten() # 只考虑有效动作 valid_q_values = np.full(4, -np.inf) for move in valid_moves: valid_q_values[move] = q_values[move] # 选择最佳动作 action = np.argmax(valid_q_values) # 执行移动 direction_names = ["上", "右", "下", "左"] new_board, game_over = game.move(action) # 渲染棋盘 board_html = render_board(new_board) # 更新状态信息 status = f"AI移动方向: {direction_names[action]}" status += f"
当前分数: {game.score}" status += f"
最大方块: {np.max(game.board)}" status += f"
设备: {'GPU' if device.type == 'cuda' else 'CPU'}" if game.game_over: status += "

游戏结束!
" status += f"
最终分数: {game.score}" return board_html, status except Exception as e: print(f"AI移动出错: {e}") return render_board(game.board), f"错误: AI移动失败 - {str(e)}" # 创建Gradio界面 with gr.Blocks(title="2048游戏", theme="soft") as demo: gr.Markdown("# 🎮 2048游戏") gr.Markdown(f"当前运行设备: **{'GPU' if device.type == 'cuda' else 'CPU'}**") gr.Markdown("使用方向键或下方的按钮移动方块,相同数字的方块相撞时会合并!") with gr.Row(): with gr.Column(scale=2): board_html = gr.HTML(render_board(game.board)) status_display = gr.HTML("当前分数: 0
最大方块: 2") with gr.Column(): gr.Markdown("## 手动操作") with gr.Row(): up_btn = gr.Button("上 ↑", elem_id="up-btn") left_btn = gr.Button("左 ←", elem_id="left-btn") with gr.Row(): down_btn = gr.Button("下 ↓", elem_id="down-btn") right_btn = gr.Button("右 →", elem_id="right-btn") with gr.Row(): reset_btn = gr.Button("🔄 重置游戏", elem_id="reset-btn") gr.Markdown("## AI操作") with gr.Row(): ai_btn = gr.Button("🤖 AI移动一步", elem_id="ai-btn") auto_btn = gr.Button("🚀 连续AI模式", elem_id="auto-btn") # 连接按钮事件 up_btn.click(lambda: make_move(0), outputs=[board_html, status_display]) right_btn.click(lambda: make_move(1), outputs=[board_html, status_display]) down_btn.click(lambda: make_move(2), outputs=[board_html, status_display]) left_btn.click(lambda: make_move(3), outputs=[board_html, status_display]) reset_btn.click(reset_game, outputs=[board_html, status_display]) ai_btn.click(ai_move, outputs=[board_html, status_display]) # 连续AI模式 def auto_play(): """连续AI移动直到游戏结束""" if model is None: return render_board(game.board), "错误: 未加载AI模型" moves = 0 max_moves = 200 # 防止无限循环 while not game.game_over and moves < max_moves: # 获取当前状态 state = game.get_state() # 获取有效移动 valid_moves = game.get_valid_moves() if not valid_moves: break # 转换状态为模型输入并移动到设备 state_tensor = torch.tensor(state, dtype=torch.float).unsqueeze(0).to(device) # 模型预测 with torch.no_grad(): q_values = model(state_tensor).cpu().numpy().flatten() # 只考虑有效动作 valid_q_values = np.full(4, -np.inf) for move in valid_moves: valid_q_values[move] = q_values[move] # 选择最佳动作 action = np.argmax(valid_q_values) # 执行移动 game.move(action) moves += 1 # 渲染棋盘 board_html = render_board(game.board) # 更新状态信息 status = f"连续AI完成!" status += f"
移动次数: {moves}" status += f"
当前分数: {game.score}" status += f"
最大方块: {np.max(game.board)}" status += f"
设备: {'GPU' if device.type == 'cuda' else 'CPU'}" if game.game_over: status += "

游戏结束!
" status += f"
最终分数: {game.score}" return board_html, status auto_btn.click(auto_play, outputs=[board_html, status_display]) # 添加键盘快捷键支持 demo.load( fn=None, inputs=None, outputs=None, js="""() => { document.addEventListener('keydown', function(e) { if (e.key === 'ArrowUp') { document.getElementById('up-btn').click(); } else if (e.key === 'ArrowRight') { document.getElementById('right-btn').click(); } else if (e.key === 'ArrowDown') { document.getElementById('down-btn').click(); } else if (e.key === 'ArrowLeft') { document.getElementById('left-btn').click(); } else if (e.key === 'r' || e.key === 'R') { document.getElementById('reset-btn').click(); } else if (e.key === 'a' || e.key === 'A') { document.getElementById('ai-btn').click(); } else if (e.key === 's' || e.key === 'S') { document.getElementById('auto-btn').click(); } }); }""" ) gr.Markdown("### 📚 使用说明") gr.Markdown("1. 使用方向键或下方的按钮移动方块") gr.Markdown("2. 相同数字的方块相撞时会合并") gr.Markdown("3. **快捷键说明**:") gr.Markdown(" - ↑/↓/←/→: 移动方块") gr.Markdown(" - R: 重置游戏") gr.Markdown(" - A: AI移动一步") gr.Markdown(" - S: 连续AI模式(自动玩到游戏结束)") # 启动界面 if __name__ == "__main__": demo.launch()