TH1/AITrainPython/train.py
2025-12-06 11:44:43 +08:00

218 lines
8.5 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
PPO训练主程序
从Data目录读取训练数据训练PPO模型
"""
import argparse
import os
import torch
import numpy as np
from tqdm import tqdm
from data_loader import TrainingDataLoader, ReplayBuffer
from ppo_trainer import PPOTrainer
def train(args):
"""
训练主函数
"""
print("=" * 60)
print("PPO AI训练程序")
print("=" * 60)
# 设置随机种子
torch.manual_seed(args.seed)
np.random.seed(args.seed)
# GPU优化设置
if torch.cuda.is_available():
# 启用TF32以加速矩阵乘法Ampere及以上GPU
torch.set_float32_matmul_precision('high')
# 启用cudnn benchmark以自动选择最优卷积算法
torch.backends.cudnn.benchmark = True
# 设置显存分配策略
torch.cuda.empty_cache()
print("✅ GPU优化已启用: TF32 + cuDNN benchmark + 显存优化")
# 初始化数据加载器
print(f"\n正在加载训练数据...")
data_loader = TrainingDataLoader(data_dir=args.data_dir)
all_episodes = data_loader.load_all_episodes(max_episodes=args.max_episodes)
if len(all_episodes) == 0:
print("错误: 没有找到训练数据!")
print(f"请确保 {args.data_dir} 目录下有 episode_*.jsonl* 文件")
return
print(f"成功加载 {len(all_episodes)} 个episode")
# 自动获取状态维度(从实际数据中检测)
state_dim = all_episodes[0][0].shape[1]
print(f"检测到状态维度: {state_dim}")
# 计算数据统计
total_steps = sum(len(ep[0]) for ep in all_episodes)
avg_episode_length = total_steps / len(all_episodes)
total_reward = sum(ep[2].sum() for ep in all_episodes)
avg_reward = total_reward / len(all_episodes)
print(f"总步数: {total_steps}")
print(f"平均episode长度: {avg_episode_length:.2f}")
print(f"平均episode奖励: {avg_reward:.2f}")
# 初始化训练器
print(f"\n初始化PPO训练器...")
trainer = PPOTrainer(
state_dim=state_dim,
hidden_dim=args.hidden_dim,
lr=args.lr,
gamma=args.gamma,
gae_lambda=args.gae_lambda,
clip_epsilon=args.clip_epsilon,
value_coef=args.value_coef,
entropy_coef=args.entropy_coef,
max_grad_norm=args.max_grad_norm,
epochs_per_update=args.epochs_per_update,
batch_size=args.batch_size,
device=torch.device(args.device),
log_dir=args.log_dir,
use_amp=args.use_amp
)
# 如果指定了加载路径,加载预训练模型
if args.load_model:
trainer.load_model(args.load_model)
# 初始化回放缓冲区
replay_buffer = ReplayBuffer(capacity=args.buffer_capacity)
# 开始训练
print(f"\n开始训练 {args.num_epochs} 轮...")
print("=" * 60)
best_loss = float('inf')
for epoch in range(args.num_epochs):
print(f"\nEpoch {epoch + 1}/{args.num_epochs}")
# 随机打乱episodes
np.random.shuffle(all_episodes)
epoch_losses = []
# 优化累积大量数据再训练充分利用GPU和显存
replay_buffer.clear()
num_episodes = len(all_episodes)
# 合理的数据累积量,避免显存溢出
# 6GB显存的GPU累积5-10个episodes比较合适
effective_update_freq = max(args.update_frequency, 5)
num_batches = (num_episodes + effective_update_freq - 1) // effective_update_freq
for batch_idx in tqdm(range(num_batches), desc=f"训练进度"):
start_idx = batch_idx * effective_update_freq
end_idx = min(start_idx + effective_update_freq, num_episodes)
# 批量加载大量episodes到buffer
for i in range(start_idx, end_idx):
states, actions, rewards, dones, valid_actions_list = all_episodes[i]
replay_buffer.add_episode(states, actions, rewards, dones, valid_actions_list)
# 训练
if len(replay_buffer) > 0:
loss = trainer.train_from_buffer(replay_buffer)
epoch_losses.append(loss)
# 每次训练后清空buffer避免显存累积
if not args.keep_buffer:
replay_buffer.clear()
# 计算epoch平均loss
avg_loss = np.mean(epoch_losses) if epoch_losses else 0.0
print(f"Epoch {epoch + 1} 平均Loss: {avg_loss:.4f}")
# 清理GPU缓存防止内存碎片
if torch.cuda.is_available():
torch.cuda.empty_cache()
# 保存模型
if (epoch + 1) % args.save_frequency == 0:
model_path = os.path.join(args.model_dir, f"ppo_model_epoch_{epoch + 1}.pth")
trainer.save_model(model_path)
# 保存最佳模型
if avg_loss < best_loss:
best_loss = avg_loss
model_path = os.path.join(args.model_dir, "ppo_model_best.pth")
trainer.save_model(model_path)
print(f"保存最佳模型 (loss: {best_loss:.4f})")
# 保存最终模型
final_model_path = os.path.join(args.model_dir, "ppo_model_final.pth")
trainer.save_model(final_model_path)
print("\n" + "=" * 60)
print("训练完成!")
print(f"最终模型保存在: {final_model_path}")
print(f"最佳模型保存在: {os.path.join(args.model_dir, 'ppo_model_best.pth')}")
print(f"TensorBoard日志: {args.log_dir}")
print("=" * 60)
trainer.close()
def main():
parser = argparse.ArgumentParser(description='PPO AI训练程序')
# 数据相关
parser.add_argument('--data_dir', type=str, default='Data', help='训练数据目录')
parser.add_argument('--model_dir', type=str, default='models', help='模型保存目录')
parser.add_argument('--log_dir', type=str, default='runs', help='TensorBoard日志目录')
parser.add_argument('--load_model', type=str, default=None, help='加载预训练模型路径')
parser.add_argument('--max_episodes', type=int, default=None, help='最大加载episode数量用于快速测试默认None加载全部')
# 训练超参数
parser.add_argument('--num_epochs', type=int, default=100, help='训练轮数')
parser.add_argument('--batch_size', type=int, default=64, help='批大小')
parser.add_argument('--buffer_capacity', type=int, default=10000, help='回放缓冲区容量')
parser.add_argument('--update_frequency', type=int, default=10, help='每N个episode更新一次')
parser.add_argument('--keep_buffer', action='store_true', help='训练后保留缓冲区数据')
parser.add_argument('--save_frequency', type=int, default=10, help='每N轮保存一次模型')
# PPO超参数
parser.add_argument('--hidden_dim', type=int, default=512, help='隐藏层维度')
parser.add_argument('--lr', type=float, default=3e-4, help='学习率')
parser.add_argument('--gamma', type=float, default=0.99, help='折扣因子')
parser.add_argument('--gae_lambda', type=float, default=0.95, help='GAE lambda')
parser.add_argument('--clip_epsilon', type=float, default=0.2, help='PPO裁剪参数')
parser.add_argument('--value_coef', type=float, default=0.5, help='Value loss权重')
parser.add_argument('--entropy_coef', type=float, default=0.01, help='Entropy bonus权重')
parser.add_argument('--max_grad_norm', type=float, default=0.5, help='梯度裁剪阈值')
parser.add_argument('--epochs_per_update', type=int, default=10, help='每次更新的训练轮数')
parser.add_argument('--gradient_accumulation_steps', type=int, default=1,
help='梯度累积步数增大有效batch size提升GPU利用率')
# 其他
parser.add_argument('--device', type=str, default='cuda' if torch.cuda.is_available() else 'cpu',
help='训练设备 (cuda/cpu)')
parser.add_argument('--seed', type=int, default=42, help='随机种子')
parser.add_argument('--use_amp', action='store_true', default=True,
help='使用混合精度训练GPU时推荐可提速30%')
parser.add_argument('--no_amp', action='store_false', dest='use_amp',
help='禁用混合精度训练')
args = parser.parse_args()
# 创建必要的目录
os.makedirs(args.model_dir, exist_ok=True)
os.makedirs(args.log_dir, exist_ok=True)
# 开始训练
train(args)
if __name__ == '__main__':
main()