319 lines
16 KiB
Python
319 lines
16 KiB
Python
"""
|
||
PPO Trainer
|
||
实现PPO算法的训练逻辑
|
||
"""
|
||
|
||
|
||
import torch
|
||
import torch.nn as nn
|
||
import torch.optim as optim
|
||
from torch.utils.tensorboard import SummaryWriter
|
||
import numpy as np
|
||
import os
|
||
from datetime import datetime
|
||
from ppo_model import PPONetwork
|
||
|
||
|
||
class PPOTrainer:
|
||
"""PPO训练器"""
|
||
def __init__(
|
||
self,
|
||
state_dim=572,
|
||
hidden_dim=768, # 增大到768,充分利用显存
|
||
lr=3e-4,
|
||
gamma=0.99,
|
||
gae_lambda=0.95,
|
||
clip_epsilon=0.2,
|
||
value_coef=0.5,
|
||
entropy_coef=0.01,
|
||
max_grad_norm=0.5,
|
||
epochs_per_update=10,
|
||
batch_size=768, # 默认768,充分利用6GB显存
|
||
device=None,
|
||
log_dir='runs',
|
||
use_amp=True,
|
||
verbose=False,
|
||
):
|
||
"""初始化训练器"""
|
||
self.device = device if device else torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
||
self.use_amp = use_amp and torch.cuda.is_available() # 只在GPU时使用AMP
|
||
self.verbose = verbose
|
||
# 超参数
|
||
self.gamma = gamma
|
||
self.gae_lambda = gae_lambda
|
||
self.clip_epsilon = clip_epsilon
|
||
self.value_coef = value_coef
|
||
self.entropy_coef = entropy_coef
|
||
self.max_grad_norm = max_grad_norm
|
||
self.epochs_per_update = epochs_per_update
|
||
self.batch_size = batch_size
|
||
# 网络和优化器
|
||
self.policy = PPONetwork(state_dim=state_dim, hidden_dim=hidden_dim).to(self.device)
|
||
self.optimizer = optim.Adam(self.policy.parameters(), lr=lr)
|
||
# 混合精度训练(GPU加速)
|
||
if self.use_amp:
|
||
self.scaler = torch.cuda.amp.GradScaler()
|
||
if self.verbose:
|
||
print("✅ 已启用混合精度训练 (AMP) - 速度提升约30%")
|
||
else:
|
||
self.scaler = None
|
||
# TensorBoard
|
||
timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
|
||
self.writer = SummaryWriter(log_dir=f'{log_dir}/PPO_{timestamp}')
|
||
# 训练统计
|
||
self.total_steps = 0
|
||
self.total_episodes = 0
|
||
if self.verbose:
|
||
print(f"PPO Trainer 初始化完成")
|
||
print(f"设备: {self.device}")
|
||
print(f"状态维度: {state_dim}, 隐藏层维度: {hidden_dim}")
|
||
print(f"学习率: {lr}, Gamma: {gamma}, Clip ε: {clip_epsilon}")
|
||
|
||
def compute_gae(self, rewards, values, dones, next_value):
|
||
"""在 CPU 上计算广义优势估计(GAE),再把结果搬回 GPU。
|
||
|
||
Args:
|
||
rewards: numpy array or tensor
|
||
values: numpy array or tensor (可以是 GPU tensor)
|
||
dones: numpy array or tensor
|
||
next_value: float or tensor
|
||
Returns:
|
||
advantages: torch tensor (GPU)
|
||
returns: torch tensor (GPU)
|
||
"""
|
||
# 统一转为 CPU float32 numpy 数组
|
||
if isinstance(rewards, torch.Tensor):
|
||
rewards_np = rewards.detach().cpu().numpy().astype('float32')
|
||
else:
|
||
rewards_np = np.array(rewards, dtype=np.float32)
|
||
if isinstance(values, torch.Tensor):
|
||
values_np = values.detach().cpu().numpy().astype('float32')
|
||
else:
|
||
values_np = np.array(values, dtype=np.float32)
|
||
if isinstance(dones, torch.Tensor):
|
||
dones_np = dones.detach().cpu().numpy().astype('float32')
|
||
else:
|
||
dones_np = np.array(dones, dtype=np.float32)
|
||
T = len(rewards_np)
|
||
advantages_np = np.zeros_like(rewards_np, dtype=np.float32)
|
||
gae = 0.0
|
||
gae_coef = self.gamma * self.gae_lambda
|
||
# 统一 next_value 为 float
|
||
if isinstance(next_value, torch.Tensor):
|
||
next_value_f = float(next_value.detach().cpu().item())
|
||
else:
|
||
next_value_f = float(next_value)
|
||
# 计算 next_values 和 deltas
|
||
next_values_np = np.zeros_like(values_np, dtype=np.float32)
|
||
if T > 1:
|
||
next_values_np[:-1] = values_np[1:]
|
||
next_values_np[-1] = next_value_f
|
||
deltas = rewards_np + self.gamma * next_values_np * (1.0 - dones_np) - values_np
|
||
# 反向时间循环(纯 CPU,避免大量小 GPU kernel 调用)
|
||
for t in reversed(range(T)):
|
||
gae = deltas[t] + gae_coef * (1.0 - dones_np[t]) * gae
|
||
advantages_np[t] = gae
|
||
returns_np = advantages_np + values_np
|
||
# 数值裁剪
|
||
advantages_np = np.clip(advantages_np, -1e6, 1e6)
|
||
returns_np = np.clip(returns_np, -1e6, 1e6)
|
||
# 转回 GPU tensor
|
||
advantages_t = torch.from_numpy(advantages_np).to(self.device, non_blocking=True)
|
||
returns_t = torch.from_numpy(returns_np).to(self.device, non_blocking=True)
|
||
return advantages_t, returns_t
|
||
|
||
def update(self, states, actions, old_log_probs, returns, advantages, valid_actions_tensors, action_lengths):
|
||
"""PPO更新 - 直接接收 GPU tensor,减少同步与日志开销。"""
|
||
n_samples = states.shape[0]
|
||
# 数值稳定性预处理(一次性执行,避免每 batch 反复同步)
|
||
if torch.isnan(advantages).any() or torch.isinf(advantages).any():
|
||
if self.verbose:
|
||
print("⚠️ 警告: advantages tensor 包含 nan 或 inf,进行修复")
|
||
advantages = torch.nan_to_num(advantages, nan=0.0, posinf=1e6, neginf=-1e6)
|
||
if torch.isnan(returns).any() or torch.isinf(returns).any():
|
||
if self.verbose:
|
||
print("⚠️ 警告: returns tensor 包含 nan 或 inf,进行修复")
|
||
returns = torch.nan_to_num(returns, nan=0.0, posinf=1e6, neginf=-1e6)
|
||
if torch.isnan(old_log_probs).any() or torch.isinf(old_log_probs).any():
|
||
if self.verbose:
|
||
print("⚠️ 警告: old_log_probs tensor 包含 nan 或 inf,进行修复")
|
||
old_log_probs = torch.nan_to_num(old_log_probs, nan=0.0, posinf=10.0, neginf=-10.0)
|
||
# 标准化 advantages
|
||
# 标准做法:减均值除标准差(z-score normalization)
|
||
# 这对PPO训练是标准且必要的,保证数值稳定性
|
||
advantages_mean = advantages.mean()
|
||
advantages_std = advantages.std()
|
||
if advantages_std < 1e-8 or torch.isnan(advantages_std) or torch.isinf(advantages_std):
|
||
advantages_std = torch.tensor(1.0, device=self.device)
|
||
advantages = (advantages - advantages_mean) / (advantages_std + 1e-8)
|
||
# 进入训练模式
|
||
self.policy.train()
|
||
# 每个 epoch 使用新的打乱索引
|
||
for _ in range(self.epochs_per_update):
|
||
indices = torch.randperm(n_samples, device=self.device)
|
||
indices_list = indices.cpu().tolist()
|
||
shuffled_valid_actions = [valid_actions_tensors[i] for i in indices_list]
|
||
for start in range(0, n_samples, self.batch_size):
|
||
end = min(start + self.batch_size, n_samples)
|
||
batch_idx = indices[start:end]
|
||
batch_states = states[batch_idx]
|
||
batch_actions = actions[batch_idx]
|
||
batch_old_log_probs = old_log_probs[batch_idx]
|
||
batch_returns = returns[batch_idx]
|
||
batch_advantages = advantages[batch_idx]
|
||
batch_action_lengths = action_lengths[batch_idx]
|
||
batch_valid_actions = shuffled_valid_actions[start:end]
|
||
if self.use_amp:
|
||
with torch.amp.autocast('cuda'):
|
||
log_probs, state_values, entropy = self.policy.evaluate(
|
||
batch_states, batch_actions, batch_valid_actions, batch_action_lengths
|
||
)
|
||
ratio = torch.exp(log_probs - batch_old_log_probs)
|
||
surr1 = ratio * batch_advantages
|
||
surr2 = torch.clamp(ratio, 1.0 - self.clip_epsilon, 1.0 + self.clip_epsilon) * batch_advantages
|
||
policy_loss = -torch.min(surr1, surr2).mean()
|
||
value_loss = nn.MSELoss()(state_values, batch_returns)
|
||
entropy_loss = -entropy.mean()
|
||
loss = policy_loss + self.value_coef * value_loss + self.entropy_coef * entropy_loss
|
||
self.optimizer.zero_grad()
|
||
self.scaler.scale(loss).backward()
|
||
self.scaler.unscale_(self.optimizer)
|
||
nn.utils.clip_grad_norm_(self.policy.parameters(), self.max_grad_norm)
|
||
self.scaler.step(self.optimizer)
|
||
self.scaler.update()
|
||
else:
|
||
log_probs, state_values, entropy = self.policy.evaluate(
|
||
batch_states, batch_actions, batch_valid_actions, batch_action_lengths
|
||
)
|
||
ratio = torch.exp(log_probs - batch_old_log_probs)
|
||
surr1 = ratio * batch_advantages
|
||
surr2 = torch.clamp(ratio, 1.0 - self.clip_epsilon, 1.0 + self.clip_epsilon) * batch_advantages
|
||
policy_loss = -torch.min(surr1, surr2).mean()
|
||
value_loss = nn.MSELoss()(state_values, batch_returns)
|
||
entropy_loss = -entropy.mean()
|
||
loss = policy_loss + self.value_coef * value_loss + self.entropy_coef * entropy_loss
|
||
self.optimizer.zero_grad()
|
||
loss.backward()
|
||
nn.utils.clip_grad_norm_(self.policy.parameters(), self.max_grad_norm)
|
||
self.optimizer.step()
|
||
self.total_steps += batch_idx.shape[0]
|
||
|
||
# 计算梯度范数用于监控
|
||
total_grad_norm = 0.0
|
||
for p in self.policy.parameters():
|
||
if p.grad is not None:
|
||
total_grad_norm += p.grad.data.norm(2).item() ** 2
|
||
total_grad_norm = total_grad_norm ** 0.5
|
||
|
||
# TensorBoard 记录最后一次的损失值和梯度范数
|
||
self.writer.add_scalar('Loss/Policy', policy_loss.item(), self.total_steps)
|
||
self.writer.add_scalar('Loss/Value', value_loss.item(), self.total_steps)
|
||
self.writer.add_scalar('Loss/Entropy', entropy_loss.item(), self.total_steps)
|
||
self.writer.add_scalar('Loss/Total', loss.item(), self.total_steps)
|
||
self.writer.add_scalar('Training/GradNorm', total_grad_norm, self.total_steps)
|
||
self.writer.add_scalar('Training/LearningRate', self.optimizer.param_groups[0]['lr'], self.total_steps)
|
||
return loss.item()
|
||
|
||
def train_from_buffer(self, replay_buffer):
|
||
"""从回放缓冲区训练(使用动作 IDs)"""
|
||
if len(replay_buffer) == 0:
|
||
return 0.0
|
||
# 读取全部数据
|
||
states, actions, rewards, dones, valid_actions_ids_list = replay_buffer.get_all()
|
||
# 转为 tensor:状态、动作直接上 GPU;rewards/dones 保持在 CPU,供 GAE 使用
|
||
states_tensor = torch.from_numpy(states).to(self.device, non_blocking=True)
|
||
actions_tensor = torch.from_numpy(actions).to(self.device, non_blocking=True)
|
||
rewards_cpu = torch.from_numpy(rewards.astype(np.float32))
|
||
dones_cpu = torch.from_numpy(dones.astype(np.float32))
|
||
# 动作 ID 列表 -> GPU tensor 列表
|
||
valid_actions_tensors = [torch.from_numpy(ids).to(self.device, non_blocking=True)
|
||
for ids in valid_actions_ids_list]
|
||
# 预先计算每个样本的动作数量
|
||
action_lengths = torch.tensor([len(ids) for ids in valid_actions_ids_list],
|
||
dtype=torch.long, device=self.device)
|
||
# 分批计算 values 和 old_log_probs,避免显存峰值过高
|
||
n_samples = len(states)
|
||
compute_batch_size = 2048
|
||
values_list = []
|
||
old_log_probs_list = []
|
||
with torch.no_grad():
|
||
for start_idx in range(0, n_samples, compute_batch_size):
|
||
end_idx = min(start_idx + compute_batch_size, n_samples)
|
||
batch_states = states_tensor[start_idx:end_idx]
|
||
batch_actions = actions_tensor[start_idx:end_idx]
|
||
batch_valid_actions = valid_actions_tensors[start_idx:end_idx]
|
||
batch_lengths = action_lengths[start_idx:end_idx]
|
||
_, batch_values = self.policy.forward(batch_states)
|
||
values_list.append(batch_values.squeeze(-1))
|
||
batch_old_log_probs, _, _ = self.policy.evaluate(
|
||
batch_states, batch_actions, batch_valid_actions, batch_lengths
|
||
)
|
||
old_log_probs_list.append(batch_old_log_probs)
|
||
values_tensor = torch.cat(values_list)
|
||
old_log_probs_tensor = torch.cat(old_log_probs_list)
|
||
# GAE 在 CPU 上计算,再搬到 GPU
|
||
next_value = 0.0 if dones[-1] else values_tensor[-1]
|
||
advantages_tensor, returns_tensor = self.compute_gae(rewards_cpu, values_tensor, dones_cpu, next_value)
|
||
|
||
# 输出训练统计信息(每100次训练输出一次)
|
||
if self.verbose and self.total_episodes % 100 == 0:
|
||
print(f"\n[训练统计 Episode {self.total_episodes}]")
|
||
print(f" Rewards: min={rewards_cpu.min():.4f}, max={rewards_cpu.max():.4f}, mean={rewards_cpu.mean():.4f}")
|
||
print(f" Values: min={values_tensor.min():.4f}, max={values_tensor.max():.4f}, mean={values_tensor.mean():.4f}")
|
||
print(f" Advantages: min={advantages_tensor.min():.4f}, max={advantages_tensor.max():.4f}, mean={advantages_tensor.mean():.4f}, std={advantages_tensor.std():.4f}")
|
||
print(f" Returns: min={returns_tensor.min():.4f}, max={returns_tensor.max():.4f}, mean={returns_tensor.mean():.4f}")
|
||
# 更新策略
|
||
loss = self.update(states_tensor, actions_tensor, old_log_probs_tensor,
|
||
returns_tensor, advantages_tensor, valid_actions_tensors, action_lengths)
|
||
# 释放引用(让 GC 回收,避免强制 empty_cache)
|
||
del states_tensor, actions_tensor, old_log_probs_tensor, returns_tensor, advantages_tensor
|
||
del valid_actions_tensors, action_lengths, values_tensor, rewards_cpu, dones_cpu
|
||
self.total_episodes += 1
|
||
return loss
|
||
|
||
def save_model(self, path):
|
||
"""保存模型"""
|
||
# 检查模型权重是否正常
|
||
has_nan = False
|
||
has_inf = False
|
||
for name, param in self.policy.named_parameters():
|
||
if torch.isnan(param).any():
|
||
print(f"⚠️ 警告: 参数 {name} 包含 NaN!")
|
||
has_nan = True
|
||
if torch.isinf(param).any():
|
||
print(f"⚠️ 警告: 参数 {name} 包含 Inf!")
|
||
has_inf = True
|
||
|
||
if has_nan or has_inf:
|
||
print(f"❌ 模型权重异常,取消保存: {path}")
|
||
return
|
||
|
||
dir_name = os.path.dirname(path)
|
||
if dir_name:
|
||
os.makedirs(dir_name, exist_ok=True)
|
||
torch.save({
|
||
'policy_state_dict': self.policy.state_dict(),
|
||
'optimizer_state_dict': self.optimizer.state_dict(),
|
||
'total_steps': self.total_steps,
|
||
'total_episodes': self.total_episodes,
|
||
}, path)
|
||
if self.verbose:
|
||
print(f"模型已保存: {path}")
|
||
|
||
def load_model(self, path):
|
||
"""加载模型"""
|
||
if not os.path.exists(path):
|
||
print(f"模型文件不存在: {path}")
|
||
return False
|
||
checkpoint = torch.load(path, map_location=self.device)
|
||
self.policy.load_state_dict(checkpoint['policy_state_dict'])
|
||
self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
|
||
self.total_steps = checkpoint.get('total_steps', 0)
|
||
self.total_episodes = checkpoint.get('total_episodes', 0)
|
||
print(f"模型已加载: {path}")
|
||
return True
|
||
|
||
def close(self):
|
||
"""关闭TensorBoard"""
|
||
self.writer.close()
|