110 lines
3.1 KiB
Python
110 lines
3.1 KiB
Python
"""
|
|
简单完整的训练速度测试
|
|
对比不同batch_size的性能
|
|
"""
|
|
import torch
|
|
import numpy as np
|
|
import time
|
|
from ppo_trainer import PPOTrainer
|
|
from data_loader import TrainingDataLoader, ReplayBuffer
|
|
import os
|
|
|
|
print("=" * 80)
|
|
print("训练速度对比测试")
|
|
print("=" * 80)
|
|
|
|
# 检查GPU
|
|
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
|
print(f"\n设备: {device}")
|
|
if torch.cuda.is_available():
|
|
print(f"GPU: {torch.cuda.get_device_name(0)}")
|
|
torch.cuda.empty_cache()
|
|
|
|
# 加载数据
|
|
print("\n加载100个episodes...")
|
|
data_files = sorted([f for f in os.listdir('Data') if f.startswith('episode_')])[:100]
|
|
loader = TrainingDataLoader('Data')
|
|
episodes = []
|
|
for fname in data_files:
|
|
fpath = os.path.join('Data', fname)
|
|
ep_data = loader.load_episode_file(fpath)
|
|
if len(ep_data) > 0:
|
|
episodes.append(loader.parse_episode_data(ep_data))
|
|
|
|
total_samples = sum(len(ep[0]) for ep in episodes)
|
|
print(f"✅ 加载完成: {len(episodes)} episodes, {total_samples} 样本")
|
|
|
|
# 准备测试数据
|
|
buffer = ReplayBuffer(capacity=100000)
|
|
for ep in episodes:
|
|
states, actions, rewards, dones, valid_actions = ep
|
|
buffer.add_episode(states, actions, rewards, dones, valid_actions)
|
|
|
|
print(f"测试样本数: {len(buffer)}")
|
|
|
|
# 测试不同配置
|
|
configs = [
|
|
(64, 4, "当前配置"),
|
|
(512, 2, "中等优化"),
|
|
(2048, 1, "最优配置"),
|
|
]
|
|
|
|
print("\n" + "=" * 80)
|
|
print("性能对比测试")
|
|
print("=" * 80)
|
|
|
|
for batch_size, epochs, desc in configs:
|
|
print(f"\n【{desc}】")
|
|
print(f" batch_size={batch_size}, epochs_per_update={epochs}")
|
|
|
|
# 计算更新次数
|
|
updates = (len(buffer) // batch_size) * epochs
|
|
print(f" 理论更新次数: {updates}")
|
|
|
|
try:
|
|
trainer = PPOTrainer(
|
|
state_dim=572,
|
|
hidden_dim=512,
|
|
batch_size=batch_size,
|
|
epochs_per_update=epochs,
|
|
device=device,
|
|
use_amp=True
|
|
)
|
|
|
|
if torch.cuda.is_available():
|
|
torch.cuda.synchronize()
|
|
|
|
start_time = time.time()
|
|
loss = trainer.train_from_buffer(buffer)
|
|
|
|
if torch.cuda.is_available():
|
|
torch.cuda.synchronize()
|
|
|
|
train_time = time.time() - start_time
|
|
throughput = len(buffer) / train_time
|
|
|
|
if torch.cuda.is_available():
|
|
mem = torch.cuda.memory_allocated() / 1e9
|
|
print(f" ✅ 训练时间: {train_time:.1f}秒")
|
|
print(f" ✅ 吞吐量: {throughput:.0f} 样本/秒")
|
|
print(f" ✅ 显存占用: {mem:.2f} GB")
|
|
else:
|
|
print(f" ✅ 训练时间: {train_time:.1f}秒")
|
|
print(f" ✅ 吞吐量: {throughput:.0f} 样本/秒")
|
|
|
|
trainer.close()
|
|
torch.cuda.empty_cache()
|
|
|
|
except Exception as e:
|
|
print(f" ❌ 失败: {e}")
|
|
torch.cuda.empty_cache()
|
|
|
|
print("\n" + "=" * 80)
|
|
print("📊 结论")
|
|
print("=" * 80)
|
|
print("问题根源: batch_size太小导致更新次数过多")
|
|
print("解决方案: 使用 train_max_gpu.bat (batch_size=2048)")
|
|
print("预期提升: 10-20倍速度")
|
|
print("\n运行: train_max_gpu.bat")
|
|
|