Spaces:
Running
Running
import numpy as np | |
import torch | |
import torch.nn as nn | |
import torch.optim as optim | |
import torch.nn.functional as F | |
import random | |
import os | |
from tqdm import tqdm | |
import matplotlib.pyplot as plt | |
import warnings | |
if torch.cuda.is_available(): | |
device=torch.device("cuda") | |
elif torch.xpu.is_available(): | |
device=torch.device("xpu") | |
else: | |
device=torch.device("cpu") | |
print(f"Using device: {device}") | |
# 2048游戏环境(改进版) | |
class Game2048: | |
def __init__(self, size=4): | |
self.size = size | |
self.reset() | |
def reset(self): | |
self.board = np.zeros((self.size, self.size), dtype=np.int32) | |
self.score = 0 | |
self.prev_score = 0 | |
self.add_tile() | |
self.add_tile() | |
self.game_over = False | |
return self.get_state() | |
def add_tile(self): | |
empty_cells = [] | |
for i in range(self.size): | |
for j in range(self.size): | |
if self.board[i][j] == 0: | |
empty_cells.append((i, j)) | |
if empty_cells: | |
i, j = random.choice(empty_cells) | |
self.board[i][j] = 2 if random.random() < 0.9 else 4 | |
def move(self, direction): | |
# 0: 上, 1: 右, 2: 下, 3: 左 | |
moved = False | |
original_board = self.board.copy() | |
old_score = self.score | |
# 根据方向执行移动 | |
if direction == 0: # 上 | |
for j in range(self.size): | |
column = self.board[:, j].copy() | |
new_column, moved_col = self.slide(column) | |
if moved_col: | |
moved = True | |
self.board[:, j] = new_column | |
elif direction == 1: # 右 | |
for i in range(self.size): | |
row = self.board[i, :].copy()[::-1] | |
new_row, moved_row = self.slide(row) | |
if moved_row: | |
moved = True | |
self.board[i, :] = new_row[::-1] | |
elif direction == 2: # 下 | |
for j in range(self.size): | |
column = self.board[::-1, j].copy() | |
new_column, moved_col = self.slide(column) | |
if moved_col: | |
moved = True | |
self.board[:, j] = new_column[::-1] | |
elif direction == 3: # 左 | |
for i in range(self.size): | |
row = self.board[i, :].copy() | |
new_row, moved_row = self.slide(row) | |
if moved_row: | |
moved = True | |
self.board[i, :] = new_row | |
# 如果发生了移动,添加新方块 | |
if moved: | |
self.add_tile() | |
self.check_game_over() | |
reward = self.calculate_reward(old_score, original_board) | |
return self.get_state(), reward, self.game_over | |
def slide(self, line): | |
# 移除零并合并相同数字 | |
non_zero = line[line != 0] | |
new_line = np.zeros_like(line) | |
idx = 0 | |
score_inc = 0 | |
moved = False | |
# 检查是否移动 | |
if not np.array_equal(non_zero, line[:len(non_zero)]): | |
moved = True | |
# 合并相同数字 | |
i = 0 | |
while i < len(non_zero): | |
if i + 1 < len(non_zero) and non_zero[i] == non_zero[i+1]: | |
new_val = non_zero[i] * 2 | |
new_line[idx] = new_val | |
score_inc += new_val | |
i += 2 | |
idx += 1 | |
else: | |
new_line[idx] = non_zero[i] | |
i += 1 | |
idx += 1 | |
self.score += score_inc | |
return new_line, moved or (score_inc > 0) | |
def calculate_reward(self, old_score, original_board): | |
"""改进的奖励函数""" | |
# 1. 基本分数奖励 | |
score_reward = (self.score - old_score) * 0.1 | |
# 2. 空格子数量变化奖励 | |
empty_before = np.count_nonzero(original_board == 0) | |
empty_after = np.count_nonzero(self.board == 0) | |
empty_reward = (empty_after - empty_before) * 0.15 | |
# 3. 最大方块奖励 | |
max_before = np.max(original_board) | |
max_after = np.max(self.board) | |
max_tile_reward = 0 | |
if max_after > max_before: | |
max_tile_reward = np.log2(max_after) * 0.2 | |
# 4. 合并奖励(鼓励合并) | |
merge_reward = 0 | |
if self.score - old_score > 0: | |
merge_reward = np.log2(self.score - old_score) * 0.1 | |
# 5. 单调性惩罚(鼓励有序排列) | |
monotonicity_penalty = self.calculate_monotonicity_penalty() * 0.01 | |
# 6. 游戏结束惩罚 | |
game_over_penalty = 0 | |
if self.game_over: | |
game_over_penalty = -10 | |
# 7. 平滑度奖励(鼓励相邻方块值接近) | |
smoothness_reward = self.calculate_smoothness() * 0.01 | |
# 总奖励 | |
total_reward = ( | |
score_reward + | |
empty_reward + | |
max_tile_reward + | |
merge_reward + | |
smoothness_reward + | |
monotonicity_penalty + | |
game_over_penalty | |
) | |
return total_reward | |
def calculate_monotonicity_penalty(self): | |
"""计算单调性惩罚(值越低越好)""" | |
penalty = 0 | |
for i in range(self.size): | |
for j in range(self.size - 1): | |
if self.board[i][j] > self.board[i][j+1]: | |
penalty += self.board[i][j] - self.board[i][j+1] | |
else: | |
penalty += self.board[i][j+1] - self.board[i][j] | |
return penalty | |
def calculate_smoothness(self): | |
"""计算平滑度(值越高越好)""" | |
smoothness = 0 | |
for i in range(self.size): | |
for j in range(self.size): | |
if self.board[i][j] != 0: | |
value = np.log2(self.board[i][j]) | |
# 检查右侧邻居 | |
if j < self.size - 1 and self.board[i][j+1] != 0: | |
neighbor_value = np.log2(self.board[i][j+1]) | |
smoothness -= abs(value - neighbor_value) | |
# 检查下方邻居 | |
if i < self.size - 1 and self.board[i+1][j] != 0: | |
neighbor_value = np.log2(self.board[i+1][j]) | |
smoothness -= abs(value - neighbor_value) | |
return smoothness | |
def check_game_over(self): | |
# 检查是否还有空格子 | |
if np.any(self.board == 0): | |
self.game_over = False | |
return | |
# 检查水平和垂直方向是否有可合并的方块 | |
for i in range(self.size): | |
for j in range(self.size - 1): | |
if self.board[i][j] == self.board[i][j+1]: | |
self.game_over = False | |
return | |
for j in range(self.size): | |
for i in range(self.size - 1): | |
if self.board[i][j] == self.board[i+1][j]: | |
self.game_over = False | |
return | |
self.game_over = True | |
def get_state(self): | |
"""改进的状态表示""" | |
# 创建4个通道的状态表示 | |
state = np.zeros((4, self.size, self.size), dtype=np.float32) | |
# 通道0: 当前方块值的对数(归一化) | |
for i in range(self.size): | |
for j in range(self.size): | |
if self.board[i][j] > 0: | |
state[0, i, j] = np.log2(self.board[i][j]) / 16.0 # 支持到65536 (2^16) | |
# 通道1: 空格子指示器 | |
state[1] = (self.board == 0).astype(np.float32) | |
# 通道2: 可合并的邻居指示器 | |
for i in range(self.size): | |
for j in range(self.size): | |
if self.board[i][j] > 0: | |
# 检查右侧 | |
if j < self.size - 1 and self.board[i][j] == self.board[i][j+1]: | |
state[2, i, j] = 1.0 | |
state[2, i, j+1] = 1.0 | |
# 检查下方 | |
if i < self.size - 1 and self.board[i][j] == self.board[i+1][j]: | |
state[2, i, j] = 1.0 | |
state[2, i+1, j] = 1.0 | |
# 通道3: 最大值位置(归一化) | |
max_value = np.max(self.board) | |
if max_value > 0: | |
max_positions = np.argwhere(self.board == max_value) | |
for pos in max_positions: | |
state[3, pos[0], pos[1]] = 1.0 | |
return state | |
def get_valid_moves(self): | |
"""更高效的有效移动检测""" | |
valid_moves = [] | |
#test_board = np.zeros_like(self.board) | |
# 检查上移是否有效 | |
for j in range(self.size): | |
column = self.board[:, j].copy() | |
new_column, _ = self.slide(column) | |
if not np.array_equal(new_column, self.board[:, j]): | |
valid_moves.append(0) | |
break | |
# 检查右移是否有效 | |
for i in range(self.size): | |
row = self.board[i, :].copy()[::-1] | |
new_row, _ = self.slide(row) | |
if not np.array_equal(new_row[::-1], self.board[i, :]): | |
valid_moves.append(1) | |
break | |
# 检查下移是否有效 | |
for j in range(self.size): | |
column = self.board[::-1, j].copy() | |
new_column, _ = self.slide(column) | |
if not np.array_equal(new_column[::-1], self.board[:, j]): | |
valid_moves.append(2) | |
break | |
# 检查左移是否有效 | |
for i in range(self.size): | |
row = self.board[i, :].copy() | |
new_row, _ = self.slide(row) | |
if not np.array_equal(new_row, self.board[i, :]): | |
valid_moves.append(3) | |
break | |
return valid_moves | |
# 改进的深度Q网络(使用Dueling DQN架构) | |
class DQN(nn.Module): | |
def __init__(self, input_channels, output_size): | |
super(DQN, self).__init__() | |
self.input_channels = input_channels | |
# 卷积层 | |
self.conv1 = nn.Conv2d(input_channels, 128, kernel_size=3, padding=1) | |
self.conv2 = nn.Conv2d(128, 128, kernel_size=3, padding=1) | |
self.conv3 = nn.Conv2d(128, 128, kernel_size=3, padding=1) | |
# Dueling DQN架构 | |
# 价值流 | |
self.value_conv = nn.Conv2d(128, 4, kernel_size=1) | |
self.value_fc1 = nn.Linear(4 * 4 * 4, 128) | |
self.value_fc2 = nn.Linear(128, 1) | |
# 优势流 | |
self.advantage_conv = nn.Conv2d(128, 16, kernel_size=1) | |
self.advantage_fc1 = nn.Linear(16 * 4 * 4, 128) | |
self.advantage_fc2 = nn.Linear(128, output_size) | |
def forward(self, x): | |
x = F.relu(self.conv1(x)) | |
x = F.relu(self.conv2(x)) | |
x = F.relu(self.conv3(x)) | |
# 价值流 | |
value = F.relu(self.value_conv(x)) | |
value = value.view(value.size(0), -1) | |
value = F.relu(self.value_fc1(value)) | |
value = self.value_fc2(value) | |
# 优势流 | |
advantage = F.relu(self.advantage_conv(x)) | |
advantage = advantage.view(advantage.size(0), -1) | |
advantage = F.relu(self.advantage_fc1(advantage)) | |
advantage = self.advantage_fc2(advantage) | |
# 合并价值流和优势流 | |
q_values = value + advantage - advantage.mean(dim=1, keepdim=True) | |
return q_values | |
# 经验回放缓冲区(带优先级) | |
class PrioritizedReplayBuffer: | |
def __init__(self, capacity, alpha=0.6): | |
self.capacity = capacity | |
self.alpha = alpha | |
self.buffer = [] | |
self.priorities = np.zeros(capacity) | |
self.pos = 0 | |
self.size = 0 | |
def push(self, state, action, reward, next_state, done): | |
# 初始优先级设置为最大优先级 | |
max_priority = self.priorities.max() if self.buffer else 1.0 | |
if len(self.buffer) < self.capacity: | |
self.buffer.append((state, action, reward, next_state, done)) | |
else: | |
self.buffer[self.pos] = (state, action, reward, next_state, done) | |
self.priorities[self.pos] = max_priority | |
self.pos = (self.pos + 1) % self.capacity | |
self.size = min(self.size + 1, self.capacity) | |
def sample(self, batch_size, beta=0.4): | |
if self.size == 0: | |
return None, None, None | |
priorities = self.priorities[:self.size] | |
probs = priorities ** self.alpha | |
probs /= probs.sum() | |
indices = np.random.choice(self.size, batch_size, p=probs) | |
samples = [self.buffer[idx] for idx in indices] | |
# 计算重要性采样权重 | |
weights = (self.size * probs[indices]) ** (-beta) | |
weights /= weights.max() | |
weights = np.array(weights, dtype=np.float32) | |
states, actions, rewards, next_states, dones = zip(*samples) | |
return ( | |
torch.tensor(np.array(states)), | |
torch.tensor(actions, dtype=torch.long), | |
torch.tensor(rewards, dtype=torch.float), | |
torch.tensor(np.array(next_states)), | |
torch.tensor(dones, dtype=torch.float), | |
indices, | |
torch.tensor(weights) | |
) | |
def update_priorities(self, indices, priorities): | |
# 确保 priorities 是一个数组 | |
if isinstance(priorities, np.ndarray) and priorities.ndim == 1: | |
for idx, priority in zip(indices, priorities): | |
self.priorities[idx] = priority | |
else: | |
# 处理标量情况(虽然不应该发生) | |
if not isinstance(priorities, (list, np.ndarray)): | |
priorities = [priorities] * len(indices) | |
for idx, priority in zip(indices, priorities): | |
self.priorities[idx] = priority | |
def __len__(self): | |
return self.size | |
# 改进的DQN智能体 | |
class DQNAgent: | |
def __init__(self, input_channels, action_size, lr=3e-4, gamma=0.99, | |
epsilon_start=1.0, epsilon_end=0.01, epsilon_decay=0.999, | |
target_update_freq=1000, batch_size=128): | |
self.input_channels = input_channels | |
self.action_size = action_size | |
self.gamma = gamma | |
self.epsilon = epsilon_start | |
self.epsilon_end = epsilon_end | |
self.epsilon_decay = epsilon_decay | |
self.batch_size = batch_size | |
self.target_update_freq = target_update_freq | |
# 主网络和目标网络 | |
self.policy_net = DQN(input_channels, action_size).to(device) | |
self.target_net = DQN(input_channels, action_size).to(device) | |
self.target_net.load_state_dict(self.policy_net.state_dict()) | |
self.target_net.eval() | |
self.optimizer = optim.Adam(self.policy_net.parameters(), lr=lr, weight_decay=1e-5) | |
self.memory = PrioritizedReplayBuffer(50000) | |
self.steps_done = 0 | |
self.loss_fn = nn.SmoothL1Loss(reduction='none') | |
def select_action(self, state, valid_moves): | |
self.steps_done += 1 | |
self.epsilon = max(self.epsilon_end, self.epsilon * self.epsilon_decay) | |
if random.random() < self.epsilon: | |
# 随机选择有效动作 | |
return random.choice(valid_moves) | |
else: | |
# 使用策略网络选择动作 | |
with torch.no_grad(): | |
state_tensor = torch.tensor(state, dtype=torch.float).unsqueeze(0).to(device) | |
q_values = self.policy_net(state_tensor).cpu().numpy().flatten() | |
# 只考虑有效动作 | |
valid_q_values = np.full(self.action_size, -np.inf) | |
for move in valid_moves: | |
valid_q_values[move] = q_values[move] | |
return np.argmax(valid_q_values) | |
def optimize_model(self, beta=0.4): | |
if len(self.memory) < self.batch_size: | |
return 0 | |
# 从回放缓冲区采样 | |
sample = self.memory.sample(self.batch_size, beta) | |
if sample is None: | |
return 0 | |
states, actions, rewards, next_states, dones, indices, weights = sample | |
states = states.to(device) | |
actions = actions.to(device) | |
rewards = rewards.to(device) | |
next_states = next_states.to(device) | |
dones = dones.to(device) | |
weights = weights.to(device) | |
# 计算当前Q值 | |
current_q = self.policy_net(states).gather(1, actions.unsqueeze(1)).squeeze() | |
# 计算目标Q值(Double DQN) | |
with torch.no_grad(): | |
next_actions = self.policy_net(next_states).max(1)[1] | |
next_q = self.target_net(next_states).gather(1, next_actions.unsqueeze(1)).squeeze() | |
target_q = rewards + (1 - dones) * self.gamma * next_q | |
# 计算损失 | |
losses = self.loss_fn(current_q, target_q) | |
loss = (losses * weights).mean() | |
# 更新优先级(使用每个样本的损失绝对值) | |
with torch.no_grad(): | |
priorities = losses.abs().cpu().numpy() + 1e-5 | |
self.memory.update_priorities(indices, priorities) | |
# 优化模型 | |
self.optimizer.zero_grad() | |
loss.backward() | |
# 梯度裁剪 | |
torch.nn.utils.clip_grad_norm_(self.policy_net.parameters(), 10) | |
self.optimizer.step() | |
return loss.item() | |
def update_target_network(self): | |
self.target_net.load_state_dict(self.policy_net.state_dict()) | |
def save_model(self, path): | |
torch.save({ | |
'policy_net_state_dict': self.policy_net.state_dict(), | |
'target_net_state_dict': self.target_net.state_dict(), | |
'optimizer_state_dict': self.optimizer.state_dict(), | |
'epsilon': self.epsilon, | |
'steps_done': self.steps_done | |
}, path) | |
def load_model(self, path): | |
if not os.path.exists(path): | |
print(f"Model file not found: {path}") | |
return | |
try: | |
# 尝试使用 weights_only=False 加载模型 | |
checkpoint = torch.load(path, map_location=device, weights_only=False) | |
self.policy_net.load_state_dict(checkpoint['policy_net_state_dict']) | |
self.target_net.load_state_dict(checkpoint['target_net_state_dict']) | |
self.optimizer.load_state_dict(checkpoint['optimizer_state_dict']) | |
self.epsilon = checkpoint['epsilon'] | |
self.steps_done = checkpoint['steps_done'] | |
self.policy_net.eval() | |
self.target_net.eval() | |
print(f"Model loaded successfully from {path}") | |
except Exception as e: | |
print(f"Error loading model: {e}") | |
# 尝试使用旧版加载方式作为备选 | |
try: | |
warnings.warn("Trying legacy load method without weights_only") | |
checkpoint = torch.load(path, map_location=device) | |
self.policy_net.load_state_dict(checkpoint['policy_net_state_dict']) | |
self.target_net.load_state_dict(checkpoint['target_net_state_dict']) | |
self.optimizer.load_state_dict(checkpoint['optimizer_state_dict']) | |
self.epsilon = checkpoint['epsilon'] | |
self.steps_done = checkpoint['steps_done'] | |
self.policy_net.eval() | |
self.target_net.eval() | |
print(f"Model loaded successfully using legacy method") | |
except Exception as e2: | |
print(f"Failed to load model: {e2}") | |
# 训练函数(带进度记录) | |
def train_agent(agent, env, episodes=5000, save_path='models/dqn_2048.pth', | |
checkpoint_path='models/checkpoint.pth', resume=False, start_episode=0): | |
# 创建保存模型的目录 | |
os.makedirs(os.path.dirname(save_path), exist_ok=True) | |
# 记录训练指标 | |
scores = [] | |
max_tiles = [] | |
avg_scores = [] | |
losses = [] | |
best_score = 0 | |
best_max_tile = 0 | |
# 如果续训,加载训练状态 | |
if resume and os.path.exists(checkpoint_path): | |
try: | |
# 使用 weights_only=False 加载检查点 | |
checkpoint = torch.load(checkpoint_path, map_location=device, weights_only=False) | |
scores = checkpoint['scores'] | |
max_tiles = checkpoint['max_tiles'] | |
avg_scores = checkpoint['avg_scores'] | |
losses = checkpoint['losses'] | |
best_score = checkpoint.get('best_score', 0) | |
best_max_tile = checkpoint.get('best_max_tile', 0) | |
print(f"Resuming training from episode {start_episode}...") | |
except Exception as e: | |
print(f"Error loading checkpoint: {e}") | |
print("Starting training from scratch...") | |
resume = False | |
if not resume: | |
start_episode = 0 | |
# 使用tqdm显示进度条 | |
progress_bar = tqdm(range(start_episode, episodes), desc="Training") | |
for episode in progress_bar: | |
state = env.reset() | |
total_reward = 0 | |
done = False | |
steps = 0 | |
episode_loss = 0 | |
loss_count = 0 | |
while not done: | |
valid_moves = env.get_valid_moves() | |
if not valid_moves: | |
done = True | |
continue | |
action = agent.select_action(state, valid_moves) | |
next_state, reward, done = env.move(action) | |
total_reward += reward | |
agent.memory.push(state, action, reward, next_state, done) | |
state = next_state | |
# 优化模型 | |
loss = agent.optimize_model(beta=min(1.0, episode / 1000)) | |
if loss > 0: | |
episode_loss += loss | |
loss_count += 1 | |
# 定期更新目标网络 | |
if agent.steps_done % agent.target_update_freq == 0: | |
agent.update_target_network() | |
steps += 1 | |
# 记录分数和最大方块 | |
score = env.score | |
max_tile = np.max(env.board) | |
scores.append(score) | |
max_tiles.append(max_tile) | |
# 计算平均损失 | |
avg_loss = episode_loss / loss_count if loss_count > 0 else 0 | |
losses.append(avg_loss) | |
# 更新最佳记录 | |
if score > best_score: | |
best_score = score | |
agent.save_model(save_path.replace('.pth', '_best_score.pth')) | |
if max_tile > best_max_tile: | |
best_max_tile = max_tile | |
agent.save_model(save_path.replace('.pth', '_best_tile.pth')) | |
# 计算最近100轮平均分数 | |
recent_scores = scores[-100:] if len(scores) >= 100 else scores | |
avg_score = np.mean(recent_scores) | |
avg_scores.append(avg_score) | |
# 更新进度条描述 | |
progress_bar.set_description( | |
f"Ep {episode+1}/{episodes} | " | |
f"Score: {score} (Avg: {avg_score:.1f}) | " | |
f"Max Tile: {max_tile} | " | |
f"Loss: {avg_loss:.4f} | " | |
f"Epsilon: {agent.epsilon:.4f}" | |
) | |
# 定期保存模型和训练状态 | |
if (episode + 1) % 100 == 0: | |
agent.save_model(save_path) | |
# 保存训练状态 | |
checkpoint = { | |
'scores': scores, | |
'max_tiles': max_tiles, | |
'avg_scores': avg_scores, | |
'losses': losses, | |
'best_score': best_score, | |
'best_max_tile': best_max_tile, | |
'episode': episode + 1, | |
'steps_done': agent.steps_done, | |
'epsilon': agent.epsilon | |
} | |
try: | |
torch.save(checkpoint, checkpoint_path) | |
except Exception as e: | |
print(f"Error saving checkpoint: {e}") | |
# 绘制训练曲线 | |
if episode > 100: # 确保有足够的数据 | |
plt.figure(figsize=(12, 8)) | |
# 分数曲线 | |
plt.subplot(2, 2, 1) | |
plt.plot(scores, label='Score') | |
plt.plot(avg_scores, label='Avg Score (100 eps)') | |
plt.xlabel('Episode') | |
plt.ylabel('Score') | |
plt.title('Training Scores') | |
plt.legend() | |
# 最大方块曲线 | |
plt.subplot(2, 2, 2) | |
plt.plot(max_tiles, 'g-') | |
plt.xlabel('Episode') | |
plt.ylabel('Max Tile') | |
plt.title('Max Tile Achieved') | |
# 损失曲线 | |
plt.subplot(2, 2, 3) | |
plt.plot(losses, 'r-') | |
plt.xlabel('Episode') | |
plt.ylabel('Loss') | |
plt.title('Training Loss') | |
# 分数分布直方图 | |
plt.subplot(2, 2, 4) | |
plt.hist(scores, bins=20, alpha=0.7) | |
plt.xlabel('Score') | |
plt.ylabel('Frequency') | |
plt.title('Score Distribution') | |
plt.tight_layout() | |
plt.savefig('training_progress.png') | |
plt.close() | |
# 保存最终模型 | |
agent.save_model(save_path) | |
return scores, max_tiles, losses | |
# 推理函数(带可视化) | |
def play_with_model(agent, env, episodes=3): | |
agent.epsilon = 0.001 # 设置很小的epsilon值进行推理 | |
for episode in range(episodes): | |
state = env.reset() | |
done = False | |
steps = 0 | |
print(f"\nEpisode {episode+1}") | |
print("Initial Board:") | |
print(env.board) | |
while not done: | |
valid_moves = env.get_valid_moves() | |
if not valid_moves: | |
done = True | |
print("No valid moves left!") | |
continue | |
# 选择动作 | |
with torch.no_grad(): | |
state_tensor = torch.tensor(state, dtype=torch.float).unsqueeze(0).to(device) | |
q_values = agent.policy_net(state_tensor).cpu().numpy().flatten() | |
# 只考虑有效动作 | |
valid_q_values = np.full(env.size, -np.inf) | |
for move in valid_moves: | |
valid_q_values[move] = q_values[move] | |
action = np.argmax(valid_q_values) | |
# 执行动作 | |
next_state, reward, done = env.move(action) | |
state = next_state | |
steps += 1 | |
# 渲染游戏 | |
print(f"\nStep {steps}: Action {['Up', 'Right', 'Down', 'Left'][action]}") | |
print(env.board) | |
print(f"Score: {env.score}, Max Tile: {np.max(env.board)}") | |
#同时将结果保存至result.txt文件中 | |
with open("result.txt", "a") as f: | |
f.write(f"Episode {episode+1}, Step {steps}, Action {['Up', 'Right', 'Down', 'Left'][action]}, Score: {env.score}, Max Tile: {np.max(env.board)}\n{env.board}\n") | |
f.close() | |
print(f"\nGame Over! Final Score: {env.score}, Max Tile: {np.max(env.board)}") | |
# 主程序 | |
if __name__ == "__main__": | |
args = {"train":0, "resume":0, "play":1, "episodes":50000} | |
env = Game2048(size=4) | |
input_channels = 4 # 状态表示的通道数 | |
action_size = 4 # 上、右、下、左 | |
agent = DQNAgent( | |
input_channels, | |
action_size, | |
lr=1e-4, | |
epsilon_decay=0.999, # 更慢的衰减 | |
target_update_freq=1000, | |
batch_size=256 | |
) | |
# 训练模型 | |
if args.get('train') or args.get('resume'): | |
print("Starting training...") | |
# 如果续训,加载检查点 | |
start_episode = 0 | |
checkpoint_path = 'models/checkpoint.pth' | |
if args.get('resume') and os.path.exists(checkpoint_path): | |
try: | |
# 使用 weights_only=False 加载检查点 | |
checkpoint = torch.load(checkpoint_path, map_location=device, weights_only=False) | |
start_episode = checkpoint.get('episode', 0) | |
agent.steps_done = checkpoint.get('steps_done', 0) | |
agent.epsilon = checkpoint.get('epsilon', agent.epsilon) | |
except Exception as e: | |
print(f"Error loading checkpoint: {e}") | |
print("Starting training from scratch...") | |
start_episode = 0 | |
scores, max_tiles, losses = train_agent( | |
agent, | |
env, | |
episodes=args.get('episodes'), | |
save_path='models/dqn_2048.pth', | |
checkpoint_path=checkpoint_path, | |
resume=args.get('resume'), | |
start_episode=start_episode | |
) | |
print("Training completed!") | |
# 绘制最终训练结果 | |
plt.figure(figsize=(15, 10)) | |
plt.subplot(3, 1, 1) | |
plt.plot(scores) | |
plt.title('Scores per Episode') | |
plt.xlabel('Episode') | |
plt.ylabel('Score') | |
plt.subplot(3, 1, 2) | |
plt.plot(max_tiles) | |
plt.title('Max Tile per Episode') | |
plt.xlabel('Episode') | |
plt.ylabel('Max Tile') | |
plt.subplot(3, 1, 3) | |
plt.plot(losses) | |
plt.title('Training Loss per Episode') | |
plt.xlabel('Episode') | |
plt.ylabel('Loss') | |
plt.tight_layout() | |
plt.savefig('final_training_results.png') | |
plt.close() | |
# 加载模型并推理 | |
if args.get('play'): | |
model_path = 'models/dqn_2048_best_tile.pth' | |
if not os.path.exists(model_path): | |
model_path = 'models/dqn_2048.pth' | |
if os.path.exists(model_path): | |
agent.load_model(model_path) | |
print("Playing with trained model...") | |
if not os.path.exists("result.txt"): | |
play_with_model(agent, env, episodes=1) | |
else: | |
os.remove("result.txt") #删除之前记录 | |
play_with_model(agent, env, episodes=1) | |
else: | |
print("No trained model found. Please train the model first.") |