TH1/AITrainPython/ppo_trainer.py
2025-12-06 11:44:43 +08:00

319 lines
16 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
PPO Trainer
实现PPO算法的训练逻辑
"""
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.tensorboard import SummaryWriter
import numpy as np
import os
from datetime import datetime
from ppo_model import PPONetwork
class PPOTrainer:
"""PPO训练器"""
def __init__(
self,
state_dim=572,
hidden_dim=768, # 增大到768充分利用显存
lr=3e-4,
gamma=0.99,
gae_lambda=0.95,
clip_epsilon=0.2,
value_coef=0.5,
entropy_coef=0.01,
max_grad_norm=0.5,
epochs_per_update=10,
batch_size=768, # 默认768充分利用6GB显存
device=None,
log_dir='runs',
use_amp=True,
verbose=False,
):
"""初始化训练器"""
self.device = device if device else torch.device('cuda' if torch.cuda.is_available() else 'cpu')
self.use_amp = use_amp and torch.cuda.is_available() # 只在GPU时使用AMP
self.verbose = verbose
# 超参数
self.gamma = gamma
self.gae_lambda = gae_lambda
self.clip_epsilon = clip_epsilon
self.value_coef = value_coef
self.entropy_coef = entropy_coef
self.max_grad_norm = max_grad_norm
self.epochs_per_update = epochs_per_update
self.batch_size = batch_size
# 网络和优化器
self.policy = PPONetwork(state_dim=state_dim, hidden_dim=hidden_dim).to(self.device)
self.optimizer = optim.Adam(self.policy.parameters(), lr=lr)
# 混合精度训练GPU加速
if self.use_amp:
self.scaler = torch.cuda.amp.GradScaler()
if self.verbose:
print("✅ 已启用混合精度训练 (AMP) - 速度提升约30%")
else:
self.scaler = None
# TensorBoard
timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
self.writer = SummaryWriter(log_dir=f'{log_dir}/PPO_{timestamp}')
# 训练统计
self.total_steps = 0
self.total_episodes = 0
if self.verbose:
print(f"PPO Trainer 初始化完成")
print(f"设备: {self.device}")
print(f"状态维度: {state_dim}, 隐藏层维度: {hidden_dim}")
print(f"学习率: {lr}, Gamma: {gamma}, Clip ε: {clip_epsilon}")
def compute_gae(self, rewards, values, dones, next_value):
"""在 CPU 上计算广义优势估计(GAE),再把结果搬回 GPU。
Args:
rewards: numpy array or tensor
values: numpy array or tensor (可以是 GPU tensor)
dones: numpy array or tensor
next_value: float or tensor
Returns:
advantages: torch tensor (GPU)
returns: torch tensor (GPU)
"""
# 统一转为 CPU float32 numpy 数组
if isinstance(rewards, torch.Tensor):
rewards_np = rewards.detach().cpu().numpy().astype('float32')
else:
rewards_np = np.array(rewards, dtype=np.float32)
if isinstance(values, torch.Tensor):
values_np = values.detach().cpu().numpy().astype('float32')
else:
values_np = np.array(values, dtype=np.float32)
if isinstance(dones, torch.Tensor):
dones_np = dones.detach().cpu().numpy().astype('float32')
else:
dones_np = np.array(dones, dtype=np.float32)
T = len(rewards_np)
advantages_np = np.zeros_like(rewards_np, dtype=np.float32)
gae = 0.0
gae_coef = self.gamma * self.gae_lambda
# 统一 next_value 为 float
if isinstance(next_value, torch.Tensor):
next_value_f = float(next_value.detach().cpu().item())
else:
next_value_f = float(next_value)
# 计算 next_values 和 deltas
next_values_np = np.zeros_like(values_np, dtype=np.float32)
if T > 1:
next_values_np[:-1] = values_np[1:]
next_values_np[-1] = next_value_f
deltas = rewards_np + self.gamma * next_values_np * (1.0 - dones_np) - values_np
# 反向时间循环(纯 CPU避免大量小 GPU kernel 调用)
for t in reversed(range(T)):
gae = deltas[t] + gae_coef * (1.0 - dones_np[t]) * gae
advantages_np[t] = gae
returns_np = advantages_np + values_np
# 数值裁剪
advantages_np = np.clip(advantages_np, -1e6, 1e6)
returns_np = np.clip(returns_np, -1e6, 1e6)
# 转回 GPU tensor
advantages_t = torch.from_numpy(advantages_np).to(self.device, non_blocking=True)
returns_t = torch.from_numpy(returns_np).to(self.device, non_blocking=True)
return advantages_t, returns_t
def update(self, states, actions, old_log_probs, returns, advantages, valid_actions_tensors, action_lengths):
"""PPO更新 - 直接接收 GPU tensor减少同步与日志开销。"""
n_samples = states.shape[0]
# 数值稳定性预处理(一次性执行,避免每 batch 反复同步)
if torch.isnan(advantages).any() or torch.isinf(advantages).any():
if self.verbose:
print("⚠️ 警告: advantages tensor 包含 nan 或 inf进行修复")
advantages = torch.nan_to_num(advantages, nan=0.0, posinf=1e6, neginf=-1e6)
if torch.isnan(returns).any() or torch.isinf(returns).any():
if self.verbose:
print("⚠️ 警告: returns tensor 包含 nan 或 inf进行修复")
returns = torch.nan_to_num(returns, nan=0.0, posinf=1e6, neginf=-1e6)
if torch.isnan(old_log_probs).any() or torch.isinf(old_log_probs).any():
if self.verbose:
print("⚠️ 警告: old_log_probs tensor 包含 nan 或 inf进行修复")
old_log_probs = torch.nan_to_num(old_log_probs, nan=0.0, posinf=10.0, neginf=-10.0)
# 标准化 advantages
# 标准做法减均值除标准差z-score normalization
# 这对PPO训练是标准且必要的保证数值稳定性
advantages_mean = advantages.mean()
advantages_std = advantages.std()
if advantages_std < 1e-8 or torch.isnan(advantages_std) or torch.isinf(advantages_std):
advantages_std = torch.tensor(1.0, device=self.device)
advantages = (advantages - advantages_mean) / (advantages_std + 1e-8)
# 进入训练模式
self.policy.train()
# 每个 epoch 使用新的打乱索引
for _ in range(self.epochs_per_update):
indices = torch.randperm(n_samples, device=self.device)
indices_list = indices.cpu().tolist()
shuffled_valid_actions = [valid_actions_tensors[i] for i in indices_list]
for start in range(0, n_samples, self.batch_size):
end = min(start + self.batch_size, n_samples)
batch_idx = indices[start:end]
batch_states = states[batch_idx]
batch_actions = actions[batch_idx]
batch_old_log_probs = old_log_probs[batch_idx]
batch_returns = returns[batch_idx]
batch_advantages = advantages[batch_idx]
batch_action_lengths = action_lengths[batch_idx]
batch_valid_actions = shuffled_valid_actions[start:end]
if self.use_amp:
with torch.amp.autocast('cuda'):
log_probs, state_values, entropy = self.policy.evaluate(
batch_states, batch_actions, batch_valid_actions, batch_action_lengths
)
ratio = torch.exp(log_probs - batch_old_log_probs)
surr1 = ratio * batch_advantages
surr2 = torch.clamp(ratio, 1.0 - self.clip_epsilon, 1.0 + self.clip_epsilon) * batch_advantages
policy_loss = -torch.min(surr1, surr2).mean()
value_loss = nn.MSELoss()(state_values, batch_returns)
entropy_loss = -entropy.mean()
loss = policy_loss + self.value_coef * value_loss + self.entropy_coef * entropy_loss
self.optimizer.zero_grad()
self.scaler.scale(loss).backward()
self.scaler.unscale_(self.optimizer)
nn.utils.clip_grad_norm_(self.policy.parameters(), self.max_grad_norm)
self.scaler.step(self.optimizer)
self.scaler.update()
else:
log_probs, state_values, entropy = self.policy.evaluate(
batch_states, batch_actions, batch_valid_actions, batch_action_lengths
)
ratio = torch.exp(log_probs - batch_old_log_probs)
surr1 = ratio * batch_advantages
surr2 = torch.clamp(ratio, 1.0 - self.clip_epsilon, 1.0 + self.clip_epsilon) * batch_advantages
policy_loss = -torch.min(surr1, surr2).mean()
value_loss = nn.MSELoss()(state_values, batch_returns)
entropy_loss = -entropy.mean()
loss = policy_loss + self.value_coef * value_loss + self.entropy_coef * entropy_loss
self.optimizer.zero_grad()
loss.backward()
nn.utils.clip_grad_norm_(self.policy.parameters(), self.max_grad_norm)
self.optimizer.step()
self.total_steps += batch_idx.shape[0]
# 计算梯度范数用于监控
total_grad_norm = 0.0
for p in self.policy.parameters():
if p.grad is not None:
total_grad_norm += p.grad.data.norm(2).item() ** 2
total_grad_norm = total_grad_norm ** 0.5
# TensorBoard 记录最后一次的损失值和梯度范数
self.writer.add_scalar('Loss/Policy', policy_loss.item(), self.total_steps)
self.writer.add_scalar('Loss/Value', value_loss.item(), self.total_steps)
self.writer.add_scalar('Loss/Entropy', entropy_loss.item(), self.total_steps)
self.writer.add_scalar('Loss/Total', loss.item(), self.total_steps)
self.writer.add_scalar('Training/GradNorm', total_grad_norm, self.total_steps)
self.writer.add_scalar('Training/LearningRate', self.optimizer.param_groups[0]['lr'], self.total_steps)
return loss.item()
def train_from_buffer(self, replay_buffer):
"""从回放缓冲区训练(使用动作 IDs"""
if len(replay_buffer) == 0:
return 0.0
# 读取全部数据
states, actions, rewards, dones, valid_actions_ids_list = replay_buffer.get_all()
# 转为 tensor状态、动作直接上 GPUrewards/dones 保持在 CPU供 GAE 使用
states_tensor = torch.from_numpy(states).to(self.device, non_blocking=True)
actions_tensor = torch.from_numpy(actions).to(self.device, non_blocking=True)
rewards_cpu = torch.from_numpy(rewards.astype(np.float32))
dones_cpu = torch.from_numpy(dones.astype(np.float32))
# 动作 ID 列表 -> GPU tensor 列表
valid_actions_tensors = [torch.from_numpy(ids).to(self.device, non_blocking=True)
for ids in valid_actions_ids_list]
# 预先计算每个样本的动作数量
action_lengths = torch.tensor([len(ids) for ids in valid_actions_ids_list],
dtype=torch.long, device=self.device)
# 分批计算 values 和 old_log_probs避免显存峰值过高
n_samples = len(states)
compute_batch_size = 2048
values_list = []
old_log_probs_list = []
with torch.no_grad():
for start_idx in range(0, n_samples, compute_batch_size):
end_idx = min(start_idx + compute_batch_size, n_samples)
batch_states = states_tensor[start_idx:end_idx]
batch_actions = actions_tensor[start_idx:end_idx]
batch_valid_actions = valid_actions_tensors[start_idx:end_idx]
batch_lengths = action_lengths[start_idx:end_idx]
_, batch_values = self.policy.forward(batch_states)
values_list.append(batch_values.squeeze(-1))
batch_old_log_probs, _, _ = self.policy.evaluate(
batch_states, batch_actions, batch_valid_actions, batch_lengths
)
old_log_probs_list.append(batch_old_log_probs)
values_tensor = torch.cat(values_list)
old_log_probs_tensor = torch.cat(old_log_probs_list)
# GAE 在 CPU 上计算,再搬到 GPU
next_value = 0.0 if dones[-1] else values_tensor[-1]
advantages_tensor, returns_tensor = self.compute_gae(rewards_cpu, values_tensor, dones_cpu, next_value)
# 输出训练统计信息每100次训练输出一次
if self.verbose and self.total_episodes % 100 == 0:
print(f"\n[训练统计 Episode {self.total_episodes}]")
print(f" Rewards: min={rewards_cpu.min():.4f}, max={rewards_cpu.max():.4f}, mean={rewards_cpu.mean():.4f}")
print(f" Values: min={values_tensor.min():.4f}, max={values_tensor.max():.4f}, mean={values_tensor.mean():.4f}")
print(f" Advantages: min={advantages_tensor.min():.4f}, max={advantages_tensor.max():.4f}, mean={advantages_tensor.mean():.4f}, std={advantages_tensor.std():.4f}")
print(f" Returns: min={returns_tensor.min():.4f}, max={returns_tensor.max():.4f}, mean={returns_tensor.mean():.4f}")
# 更新策略
loss = self.update(states_tensor, actions_tensor, old_log_probs_tensor,
returns_tensor, advantages_tensor, valid_actions_tensors, action_lengths)
# 释放引用(让 GC 回收,避免强制 empty_cache
del states_tensor, actions_tensor, old_log_probs_tensor, returns_tensor, advantages_tensor
del valid_actions_tensors, action_lengths, values_tensor, rewards_cpu, dones_cpu
self.total_episodes += 1
return loss
def save_model(self, path):
"""保存模型"""
# 检查模型权重是否正常
has_nan = False
has_inf = False
for name, param in self.policy.named_parameters():
if torch.isnan(param).any():
print(f"⚠️ 警告: 参数 {name} 包含 NaN!")
has_nan = True
if torch.isinf(param).any():
print(f"⚠️ 警告: 参数 {name} 包含 Inf!")
has_inf = True
if has_nan or has_inf:
print(f"❌ 模型权重异常,取消保存: {path}")
return
dir_name = os.path.dirname(path)
if dir_name:
os.makedirs(dir_name, exist_ok=True)
torch.save({
'policy_state_dict': self.policy.state_dict(),
'optimizer_state_dict': self.optimizer.state_dict(),
'total_steps': self.total_steps,
'total_episodes': self.total_episodes,
}, path)
if self.verbose:
print(f"模型已保存: {path}")
def load_model(self, path):
"""加载模型"""
if not os.path.exists(path):
print(f"模型文件不存在: {path}")
return False
checkpoint = torch.load(path, map_location=self.device)
self.policy.load_state_dict(checkpoint['policy_state_dict'])
self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
self.total_steps = checkpoint.get('total_steps', 0)
self.total_episodes = checkpoint.get('total_episodes', 0)
print(f"模型已加载: {path}")
return True
def close(self):
"""关闭TensorBoard"""
self.writer.close()