TH1/AITrainPython/test_speed.py
2025-12-06 11:44:43 +08:00

110 lines
3.1 KiB
Python

"""
简单完整的训练速度测试
对比不同batch_size的性能
"""
import torch
import numpy as np
import time
from ppo_trainer import PPOTrainer
from data_loader import TrainingDataLoader, ReplayBuffer
import os
print("=" * 80)
print("训练速度对比测试")
print("=" * 80)
# 检查GPU
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"\n设备: {device}")
if torch.cuda.is_available():
print(f"GPU: {torch.cuda.get_device_name(0)}")
torch.cuda.empty_cache()
# 加载数据
print("\n加载100个episodes...")
data_files = sorted([f for f in os.listdir('Data') if f.startswith('episode_')])[:100]
loader = TrainingDataLoader('Data')
episodes = []
for fname in data_files:
fpath = os.path.join('Data', fname)
ep_data = loader.load_episode_file(fpath)
if len(ep_data) > 0:
episodes.append(loader.parse_episode_data(ep_data))
total_samples = sum(len(ep[0]) for ep in episodes)
print(f"✅ 加载完成: {len(episodes)} episodes, {total_samples} 样本")
# 准备测试数据
buffer = ReplayBuffer(capacity=100000)
for ep in episodes:
states, actions, rewards, dones, valid_actions = ep
buffer.add_episode(states, actions, rewards, dones, valid_actions)
print(f"测试样本数: {len(buffer)}")
# 测试不同配置
configs = [
(64, 4, "当前配置"),
(512, 2, "中等优化"),
(2048, 1, "最优配置"),
]
print("\n" + "=" * 80)
print("性能对比测试")
print("=" * 80)
for batch_size, epochs, desc in configs:
print(f"\n{desc}")
print(f" batch_size={batch_size}, epochs_per_update={epochs}")
# 计算更新次数
updates = (len(buffer) // batch_size) * epochs
print(f" 理论更新次数: {updates}")
try:
trainer = PPOTrainer(
state_dim=572,
hidden_dim=512,
batch_size=batch_size,
epochs_per_update=epochs,
device=device,
use_amp=True
)
if torch.cuda.is_available():
torch.cuda.synchronize()
start_time = time.time()
loss = trainer.train_from_buffer(buffer)
if torch.cuda.is_available():
torch.cuda.synchronize()
train_time = time.time() - start_time
throughput = len(buffer) / train_time
if torch.cuda.is_available():
mem = torch.cuda.memory_allocated() / 1e9
print(f" ✅ 训练时间: {train_time:.1f}")
print(f" ✅ 吞吐量: {throughput:.0f} 样本/秒")
print(f" ✅ 显存占用: {mem:.2f} GB")
else:
print(f" ✅ 训练时间: {train_time:.1f}")
print(f" ✅ 吞吐量: {throughput:.0f} 样本/秒")
trainer.close()
torch.cuda.empty_cache()
except Exception as e:
print(f" ❌ 失败: {e}")
torch.cuda.empty_cache()
print("\n" + "=" * 80)
print("📊 结论")
print("=" * 80)
print("问题根源: batch_size太小导致更新次数过多")
print("解决方案: 使用 train_max_gpu.bat (batch_size=2048)")
print("预期提升: 10-20倍速度")
print("\n运行: train_max_gpu.bat")