Spaces:
Running
Running
import gradio as gr | |
import numpy as np | |
import random | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from game2048 import Game2048 | |
# 创建游戏实例 | |
game = Game2048(size=4) | |
# 方块颜色映射(根据数字值) | |
TILE_COLORS = { | |
0: "#cdc1b4", # 空白格子 | |
2: "#eee4da", # 2 | |
4: "#ede0c8", # 4 | |
8: "#f2b179", # 8 | |
16: "#f59563", # 16 | |
32: "#f67c5f", # 32 | |
64: "#f65e3b", # 64 | |
128: "#edcf72", # 128 | |
256: "#edcc61", # 256 | |
512: "#edc850", # 512 | |
1024: "#edc53f", # 1024 | |
2048: "#edc22e", # 2048 | |
4096: "#3c3a32", # 4096+ | |
} | |
# 文本颜色映射(根据背景深浅) | |
TEXT_COLORS = { | |
0: "#776e65", # 空白格子 | |
2: "#776e65", # 2 | |
4: "#776e65", # 4 | |
8: "#f9f6f2", # 8+ | |
16: "#f9f6f2", # 16+ | |
32: "#f9f6f2", # 32+ | |
64: "#f9f6f2", # 64+ | |
128: "#f9f6f2", # 128+ | |
256: "#f9f6f2", # 256+ | |
512: "#f9f6f2", # 512+ | |
1024: "#f9f6f2", # 1024+ | |
2048: "#f9f6f2", # 2048+ | |
4096: "#f9f6f2", # 4096+ | |
} | |
# 定义DQN网络结构(与训练时相同) | |
class DQN(nn.Module): | |
def __init__(self, input_channels, output_size): | |
super(DQN, self).__init__() | |
self.input_channels = input_channels | |
# 卷积层 | |
self.conv1 = nn.Conv2d(input_channels, 128, kernel_size=3, padding=1) | |
self.conv2 = nn.Conv2d(128, 128, kernel_size=3, padding=1) | |
self.conv3 = nn.Conv2d(128, 128, kernel_size=3, padding=1) | |
# Dueling DQN架构 | |
# 价值流 | |
self.value_conv = nn.Conv2d(128, 4, kernel_size=1) | |
self.value_fc1 = nn.Linear(4 * 4 * 4, 128) | |
self.value_fc2 = nn.Linear(128, 1) | |
# 优势流 | |
self.advantage_conv = nn.Conv2d(128, 16, kernel_size=1) | |
self.advantage_fc1 = nn.Linear(16 * 4 * 4, 128) | |
self.advantage_fc2 = nn.Linear(128, output_size) | |
def forward(self, x): | |
x = F.relu(self.conv1(x)) | |
x = F.relu(self.conv2(x)) | |
x = F.relu(self.conv3(x)) | |
# 价值流 | |
value = F.relu(self.value_conv(x)) | |
value = value.view(value.size(0), -1) | |
value = F.relu(self.value_fc1(value)) | |
value = self.value_fc2(value) | |
# 优势流 | |
advantage = F.relu(self.advantage_conv(x)) | |
advantage = advantage.view(advantage.size(0), -1) | |
advantage = F.relu(self.advantage_fc1(advantage)) | |
advantage = self.advantage_fc2(advantage) | |
# 合并价值流和优势流 | |
q_values = value + advantage - advantage.mean(dim=1, keepdim=True) | |
return q_values | |
# 加载模型 | |
def load_model(model_path): | |
model = DQN(4, 4) # 输入通道4,输出动作4个 | |
try: | |
checkpoint = torch.load(model_path, map_location=torch.device('cpu')) | |
model.load_state_dict(checkpoint['policy_net_state_dict']) | |
model.eval() | |
print("模型加载成功") | |
return model | |
except Exception as e: | |
print(f"模型加载失败: {e}") | |
return None | |
# 尝试加载模型 | |
model_path = "dqn_2048_best_tile.pth" | |
model = load_model(model_path) | |
def render_board(board): | |
html = "<div style='background-color:#bbada0; padding:10px; border-radius:6px;'>" | |
html += "<table style='border-spacing:10px; border-collapse:separate;'>" | |
for i in range(game.size): | |
html += "<tr>" | |
for j in range(game.size): | |
value = board[i][j] | |
color = TILE_COLORS.get(value, "#3c3a32") # 默认深色 | |
text_color = TEXT_COLORS.get(value, "#f9f6f2") # 默认浅色 | |
font_size = "36px" if value < 100 else "30px" if value < 1000 else "24px" | |
html += f""" | |
<td style='background-color:{color}; | |
width:80px; height:80px; | |
border-radius:4px; | |
text-align:center; | |
font-weight:bold; | |
font-size:{font_size}; | |
color:{text_color};'> | |
{value if value > 0 else ''} | |
</td> | |
""" | |
html += "</tr>" | |
html += "</table></div>" | |
return html | |
def make_move(direction): | |
"""执行移动操作并更新界面""" | |
direction_names = ["上", "右", "下", "左"] | |
# 执行移动 | |
new_board, game_over = game.move(direction) | |
# 渲染棋盘 | |
board_html = render_board(new_board) | |
# 更新状态信息 | |
status = f"<b>移动方向:</b> {direction_names[direction]}" | |
status += f"<br><b>当前分数:</b> {game.score}" | |
status += f"<br><b>最大方块:</b> {np.max(game.board)}" | |
if game.game_over: | |
status += "<br><br><div style='color:#ff0000; font-weight:bold;'>游戏结束!</div>" | |
status += f"<br><b>最终分数:</b> {game.score}" | |
return board_html, status | |
def reset_game(): | |
"""重置游戏""" | |
global game | |
game = Game2048(size=4) | |
board = game.reset() | |
# 渲染棋盘 | |
board_html = render_board(board) | |
# 初始状态信息 | |
status = "<b>游戏已重置!</b>" | |
status += f"<br><b>当前分数:</b> {game.score}" | |
status += f"<br><b>最大方块:</b> {np.max(game.board)}" | |
return board_html, status | |
def ai_move(): | |
"""使用AI模型进行一步移动""" | |
if model is None: | |
return render_board(game.board), "<b>错误:</b> 未加载AI模型" | |
# 获取当前状态 | |
state = game.get_state() | |
# 获取有效移动 | |
valid_moves = game.get_valid_moves() | |
if not valid_moves: | |
return render_board(game.board), "<b>游戏结束!</b> 没有有效移动" | |
# 转换状态为模型输入 | |
state_tensor = torch.tensor(state, dtype=torch.float).unsqueeze(0) | |
# 模型预测 | |
with torch.no_grad(): | |
q_values = model(state_tensor).numpy().flatten() | |
# 只考虑有效动作 | |
valid_q_values = np.full(4, -np.inf) | |
for move in valid_moves: | |
valid_q_values[move] = q_values[move] | |
# 选择最佳动作 | |
action = np.argmax(valid_q_values) | |
# 执行移动 | |
direction_names = ["上", "右", "下", "左"] | |
new_board, game_over = game.move(action) | |
# 渲染棋盘 | |
board_html = render_board(new_board) | |
# 更新状态信息 | |
status = f"<b>AI移动方向:</b> {direction_names[action]}" | |
status += f"<br><b>当前分数:</b> {game.score}" | |
status += f"<br><b>最大方块:</b> {np.max(game.board)}" | |
if game.game_over: | |
status += "<br><br><div style='color:#ff0000; font-weight:bold;'>游戏结束!</div>" | |
status += f"<br><b>最终分数:</b> {game.score}" | |
return board_html, status | |
# 创建Gradio界面 | |
with gr.Blocks(title="2048游戏", theme="soft") as demo: | |
gr.Markdown("# 🎮 2048游戏") | |
gr.Markdown("使用方向键或下方的按钮移动方块,相同数字的方块相撞时会合并!") | |
with gr.Row(): | |
with gr.Column(scale=2): | |
board_html = gr.HTML(render_board(game.board)) | |
with gr.Row(visible=False): | |
status_display = gr.HTML("<b>当前分数:</b> 0<br><b>最大方块:</b> 2") | |
with gr.Column(): | |
gr.Markdown("## 手动操作") | |
with gr.Row(): | |
gr.Button("上 ↑", elem_id="up-btn").click( | |
fn=lambda: make_move(0), | |
outputs=[board_html, status_display] | |
) | |
gr.Button("左 ←", elem_id="left-btn").click( | |
fn=lambda: make_move(3), | |
outputs=[board_html, status_display] | |
) | |
with gr.Row(): | |
gr.Button("下 ↓", elem_id="down-btn").click( | |
fn=lambda: make_move(2), | |
outputs=[board_html, status_display] | |
) | |
gr.Button("右 →", elem_id="right-btn").click( | |
fn=lambda: make_move(1), | |
outputs=[board_html, status_display] | |
) | |
with gr.Row(): | |
gr.Button("🔄 重置游戏", elem_id="reset-btn").click( | |
fn=reset_game, | |
outputs=[board_html, status_display] | |
) | |
with gr.Row(): | |
status_display = gr.HTML("<b>当前分数:</b> 0<br><b>最大方块:</b> 2") | |
with gr.Column(): | |
gr.Markdown("## AI操作") | |
gr.Button("🤖 AI移动一步", elem_id="ai-btn").click( | |
fn=ai_move, | |
outputs=[board_html, status_display] | |
) | |
gr.Markdown("基于DQN神经网络提供支持") | |
# 添加键盘快捷键支持 | |
demo.load( | |
fn=None, | |
inputs=None, | |
outputs=None, | |
js="""() => { | |
document.addEventListener('keydown', function(e) { | |
if (e.key === 'ArrowUp') { | |
document.getElementById('up-btn').click(); | |
} else if (e.key === 'ArrowRight') { | |
document.getElementById('right-btn').click(); | |
} else if (e.key === 'ArrowDown') { | |
document.getElementById('down-btn').click(); | |
} else if (e.key === 'ArrowLeft') { | |
document.getElementById('left-btn').click(); | |
} else if (e.key === 'r' || e.key === 'R') { | |
document.getElementById('reset-btn').click(); | |
} else if (e.key === 'a' || e.key === 'A') { | |
document.getElementById('ai-btn').click(); | |
} | |
}); | |
}""" | |
) | |
gr.Markdown("### 📚 使用说明") | |
gr.Markdown("1. 使用方向键或下方的按钮移动方块。") | |
gr.Markdown("2. 相同数字的方块相撞时会合并。") | |
gr.Markdown("3. 快捷键说明:上/下/左/右键移动方块,R键重置游戏,A键AI移动一步。") | |
gr.Markdown("4. 点击 '🤖 AI移动一步' 按钮可以使用AI模型进行一步移动。") | |
gr.Markdown("5. 游戏结束后,会显示最终分数和最大方块。") | |
# 启动界面 | |
if __name__ == "__main__": | |
demo.launch(share=True) |