Gofor5's picture
Update app.py
020c81d verified
import gradio as gr
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from game2048 import Game2048
# 检测可用设备
if torch.cuda.is_available():
device = torch.device("cuda")
elif torch.xpu.is_available():
device = torch.device("xpu")
else:
device = torch.device("cpu")
# 创建游戏实例
game = Game2048(size=4)
# 方块颜色映射(根据数字值)
TILE_COLORS = {
0: "#cdc1b4", # 空白格子
2: "#eee4da", # 2
4: "#ede0c8", # 4
8: "#f2b179", # 8
16: "#f59563", # 16
32: "#f67c5f", # 32
64: "#f65e3b", # 64
128: "#edcf72", # 128
256: "#edcc61", # 256
512: "#edc850", # 512
1024: "#edc53f", # 1024
2048: "#edc22e", # 2048
4096: "#3c3a32", # 4096+,
8192: "#3c3a32", # 8192+
16384: "#3c3a32", # 16384+
}
# 文本颜色映射(根据背景深浅)
TEXT_COLORS = {
0: "#776e65", # 空白格子
2: "#776e65", # 2
4: "#776e65", # 4
8: "#f9f6f2", # 8+
16: "#f9f6f2", # 16+
32: "#f9f6f2", # 32+
64: "#f9f6f2", # 64+
128: "#f9f6f2", # 128+
256: "#f9f6f2", # 256+
512: "#f9f6f2", # 512+
1024: "#f9f6f2", # 1024+
2048: "#f9f6f2", # 2048+
4096: "#f9f6f2", # 4096+
8192: "#f9f6f2", # 8192+
16384: "#f9f6f2", # 16384+
}
# 定义DQN网络结构(与训练时相同)
class DQN(nn.Module):
def __init__(self, input_channels, output_size):
super(DQN, self).__init__()
self.input_channels = input_channels
# 卷积层
self.conv1 = nn.Conv2d(input_channels, 128, kernel_size=3, padding=1)
self.conv2 = nn.Conv2d(128, 128, kernel_size=3, padding=1)
self.conv3 = nn.Conv2d(128, 128, kernel_size=3, padding=1)
# Dueling DQN架构
# 价值流
self.value_conv = nn.Conv2d(128, 4, kernel_size=1)
self.value_fc1 = nn.Linear(4 * 4 * 4, 128)
self.value_fc2 = nn.Linear(128, 1)
# 优势流
self.advantage_conv = nn.Conv2d(128, 16, kernel_size=1)
self.advantage_fc1 = nn.Linear(16 * 4 * 4, 128)
self.advantage_fc2 = nn.Linear(128, output_size)
def forward(self, x):
x = F.relu(self.conv1(x))
x = F.relu(self.conv2(x))
x = F.relu(self.conv3(x))
# 价值流
value = F.relu(self.value_conv(x))
value = value.view(value.size(0), -1)
value = F.relu(self.value_fc1(value))
value = self.value_fc2(value)
# 优势流
advantage = F.relu(self.advantage_conv(x))
advantage = advantage.view(advantage.size(0), -1)
advantage = F.relu(self.advantage_fc1(advantage))
advantage = self.advantage_fc2(advantage)
# 合并价值流和优势流
q_values = value + advantage - advantage.mean(dim=1, keepdim=True)
return q_values
# 加载模型(支持CUDA)
def load_model(model_path, device):
model = DQN(4, 4).to(device) # 将模型移动到指定设备
try:
# 尝试加载模型
checkpoint = torch.load(model_path, map_location=device)
# 检查检查点是否包含完整的模型状态
if 'policy_net_state_dict' in checkpoint:
model.load_state_dict(checkpoint['policy_net_state_dict'])
else:
# 如果检查点不包含policy_net_state_dict,尝试直接加载
model.load_state_dict(checkpoint)
model.eval()
print(f"模型成功加载到: {device}")
return model
except Exception as e:
print(f"模型加载失败: {e}")
# 尝试备选模型路径
alt_path = model_path.replace('_best_tile', '')
try:
checkpoint = torch.load(alt_path, map_location=device)
if 'policy_net_state_dict' in checkpoint:
model.load_state_dict(checkpoint['policy_net_state_dict'])
else:
model.load_state_dict(checkpoint)
model.eval()
print(f"备选模型成功加载: {alt_path}")
return model
except Exception as e2:
print(f"备选模型加载失败: {e2}")
return None
# 尝试加载模型
model_paths = [
"models/dqn_2048_best_tile.pth",
"models/dqn_2048.pth",
"dqn_2048_best_tile.pth",
"dqn_2048.pth"
]
model = None
for path in model_paths:
model = load_model(path, device)
if model:
break
if not model:
print("警告: 未加载任何模型,AI功能将不可用")
def render_board(board):
html = "<div style='background-color:#bbada0; padding:10px; border-radius:6px;'>"
html += "<table style='border-spacing:10px; border-collapse:separate;'>"
for i in range(game.size):
html += "<tr>"
for j in range(game.size):
value = board[i][j]
color = TILE_COLORS.get(value, "#3c3a32") # 默认深色
text_color = TEXT_COLORS.get(value, "#f9f6f2") # 默认浅色
font_size = "36px" if value < 100 else "30px" if value < 1000 else "24px"
html += f"""
<td style='background-color:{color};
width:80px; height:80px;
border-radius:4px;
text-align:center;
font-weight:bold;
font-size:{font_size};
color:{text_color};'>
{value if value > 0 else ''}
</td>
"""
html += "</tr>"
html += "</table></div>"
return html
def make_move(direction):
"""执行移动操作并更新界面"""
direction_names = ["上", "右", "下", "左"]
# 执行移动
new_board, game_over = game.move(direction)
# 渲染棋盘
board_html = render_board(new_board)
# 更新状态信息
status = f"<b>移动方向:</b> {direction_names[direction]}"
status += f"<br><b>当前分数:</b> {game.score}"
status += f"<br><b>最大方块:</b> {np.max(game.board)}"
if game.game_over:
status += "<br><br><div style='color:#ff0000; font-weight:bold;'>游戏结束!</div>"
status += f"<br><b>最终分数:</b> {game.score}"
return board_html, status
def reset_game():
"""重置游戏"""
global game
game = Game2048(size=4)
board = game.reset()
# 渲染棋盘
board_html = render_board(board)
# 初始状态信息
status = "<b>游戏已重置!</b>"
status += f"<br><b>当前分数:</b> {game.score}"
status += f"<br><b>最大方块:</b> {np.max(game.board)}"
return board_html, status
def ai_move():
"""使用AI模型进行一步移动"""
if model is None:
return render_board(game.board), "<b>错误:</b> 未加载AI模型"
# 获取当前状态
state = game.get_state()
# 获取有效移动
valid_moves = game.get_valid_moves()
if not valid_moves:
return render_board(game.board), "<b>游戏结束!</b> 没有有效移动"
try:
# 转换状态为模型输入并移动到设备
state_tensor = torch.tensor(state, dtype=torch.float).unsqueeze(0).to(device)
# 模型预测
with torch.no_grad():
q_values = model(state_tensor).cpu().numpy().flatten()
# 只考虑有效动作
valid_q_values = np.full(4, -np.inf)
for move in valid_moves:
valid_q_values[move] = q_values[move]
# 选择最佳动作
action = np.argmax(valid_q_values)
# 执行移动
direction_names = ["上", "右", "下", "左"]
new_board, game_over = game.move(action)
# 渲染棋盘
board_html = render_board(new_board)
# 更新状态信息
status = f"<b>AI移动方向:</b> {direction_names[action]}"
status += f"<br><b>当前分数:</b> {game.score}"
status += f"<br><b>最大方块:</b> {np.max(game.board)}"
status += f"<br><b>设备:</b> {'GPU' if device.type == 'cuda' else 'CPU'}"
if game.game_over:
status += "<br><br><div style='color:#ff0000; font-weight:bold;'>游戏结束!</div>"
status += f"<br><b>最终分数:</b> {game.score}"
return board_html, status
except Exception as e:
print(f"AI移动出错: {e}")
return render_board(game.board), f"<b>错误:</b> AI移动失败 - {str(e)}"
# 创建Gradio界面
with gr.Blocks(title="2048游戏", theme="soft") as demo:
gr.Markdown("# 🎮 2048游戏")
gr.Markdown(f"当前运行设备: **{'GPU' if device.type == 'cuda' else 'CPU'}**")
gr.Markdown("使用方向键或下方的按钮移动方块,相同数字的方块相撞时会合并!")
with gr.Row():
with gr.Column(scale=2):
board_html = gr.HTML(render_board(game.board))
status_display = gr.HTML("<b>当前分数:</b> 0<br><b>最大方块:</b> 2")
with gr.Column():
gr.Markdown("## 手动操作")
with gr.Row():
up_btn = gr.Button("上 ↑", elem_id="up-btn")
left_btn = gr.Button("左 ←", elem_id="left-btn")
with gr.Row():
down_btn = gr.Button("下 ↓", elem_id="down-btn")
right_btn = gr.Button("右 →", elem_id="right-btn")
with gr.Row():
reset_btn = gr.Button("🔄 重置游戏", elem_id="reset-btn")
gr.Markdown("## AI操作")
with gr.Row():
ai_btn = gr.Button("🤖 AI移动一步", elem_id="ai-btn")
auto_btn = gr.Button("🚀 连续AI模式", elem_id="auto-btn")
# 连接按钮事件
up_btn.click(lambda: make_move(0), outputs=[board_html, status_display])
right_btn.click(lambda: make_move(1), outputs=[board_html, status_display])
down_btn.click(lambda: make_move(2), outputs=[board_html, status_display])
left_btn.click(lambda: make_move(3), outputs=[board_html, status_display])
reset_btn.click(reset_game, outputs=[board_html, status_display])
ai_btn.click(ai_move, outputs=[board_html, status_display])
# 连续AI模式
def auto_play():
"""连续AI移动直到游戏结束"""
if model is None:
return render_board(game.board), "<b>错误:</b> 未加载AI模型"
moves = 0
max_moves = 200 # 防止无限循环
while not game.game_over and moves < max_moves:
# 获取当前状态
state = game.get_state()
# 获取有效移动
valid_moves = game.get_valid_moves()
if not valid_moves:
break
# 转换状态为模型输入并移动到设备
state_tensor = torch.tensor(state, dtype=torch.float).unsqueeze(0).to(device)
# 模型预测
with torch.no_grad():
q_values = model(state_tensor).cpu().numpy().flatten()
# 只考虑有效动作
valid_q_values = np.full(4, -np.inf)
for move in valid_moves:
valid_q_values[move] = q_values[move]
# 选择最佳动作
action = np.argmax(valid_q_values)
# 执行移动
game.move(action)
moves += 1
# 渲染棋盘
board_html = render_board(game.board)
# 更新状态信息
status = f"<b>连续AI完成!</b>"
status += f"<br><b>移动次数:</b> {moves}"
status += f"<br><b>当前分数:</b> {game.score}"
status += f"<br><b>最大方块:</b> {np.max(game.board)}"
status += f"<br><b>设备:</b> {'GPU' if device.type == 'cuda' else 'CPU'}"
if game.game_over:
status += "<br><br><div style='color:#ff0000; font-weight:bold;'>游戏结束!</div>"
status += f"<br><b>最终分数:</b> {game.score}"
return board_html, status
auto_btn.click(auto_play, outputs=[board_html, status_display])
# 添加键盘快捷键支持
demo.load(
fn=None,
inputs=None,
outputs=None,
js="""() => {
document.addEventListener('keydown', function(e) {
if (e.key === 'ArrowUp') {
document.getElementById('up-btn').click();
} else if (e.key === 'ArrowRight') {
document.getElementById('right-btn').click();
} else if (e.key === 'ArrowDown') {
document.getElementById('down-btn').click();
} else if (e.key === 'ArrowLeft') {
document.getElementById('left-btn').click();
} else if (e.key === 'r' || e.key === 'R') {
document.getElementById('reset-btn').click();
} else if (e.key === 'a' || e.key === 'A') {
document.getElementById('ai-btn').click();
} else if (e.key === 's' || e.key === 'S') {
document.getElementById('auto-btn').click();
}
});
}"""
)
gr.Markdown("### 📚 使用说明")
gr.Markdown("1. 使用方向键或下方的按钮移动方块")
gr.Markdown("2. 相同数字的方块相撞时会合并")
gr.Markdown("3. **快捷键说明**:")
gr.Markdown(" - ↑/↓/←/→: 移动方块")
gr.Markdown(" - R: 重置游戏")
gr.Markdown(" - A: AI移动一步")
gr.Markdown(" - S: 连续AI模式(自动玩到游戏结束)")
# 启动界面
if __name__ == "__main__":
demo.launch()