263 lines
9.0 KiB
Python
263 lines
9.0 KiB
Python
"""
|
||
快速诊断工具 - 5分钟内找出训练慢的原因
|
||
"""
|
||
import torch
|
||
import time
|
||
import os
|
||
import numpy as np
|
||
|
||
print("=" * 80)
|
||
print("训练速度快速诊断工具")
|
||
print("=" * 80)
|
||
|
||
# ============================================================================
|
||
# 1. 检查硬件环境
|
||
# ============================================================================
|
||
print("\n【1/5】硬件环境检查")
|
||
print("-" * 80)
|
||
|
||
if torch.cuda.is_available():
|
||
print(f"✅ GPU可用: {torch.cuda.get_device_name(0)}")
|
||
print(f" 显存: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB")
|
||
|
||
# 测试GPU速度
|
||
print("\n 测试GPU计算速度...")
|
||
x = torch.randn(1000, 1000, device='cuda')
|
||
start = time.time()
|
||
for _ in range(100):
|
||
y = torch.matmul(x, x)
|
||
torch.cuda.synchronize()
|
||
gpu_time = time.time() - start
|
||
print(f" GPU矩阵乘法: {gpu_time:.3f}秒 (100次 1000x1000)")
|
||
|
||
if gpu_time > 1.0:
|
||
print(" ⚠️ GPU计算较慢,可能的原因:")
|
||
print(" - 驱动版本过旧")
|
||
print(" - GPU负载过高(其他程序占用)")
|
||
print(" - 散热问题导致降频")
|
||
else:
|
||
print("❌ GPU不可用!训练将非常慢")
|
||
print(" 请检查:")
|
||
print(" - CUDA是否正确安装")
|
||
print(" - PyTorch是否为GPU版本")
|
||
|
||
# ============================================================================
|
||
# 2. 检查数据量
|
||
# ============================================================================
|
||
print("\n【2/5】数据量检查")
|
||
print("-" * 80)
|
||
|
||
if os.path.exists('Data'):
|
||
episode_files = [f for f in os.listdir('Data') if f.startswith('episode_')]
|
||
file_count = len(episode_files)
|
||
|
||
print(f" 发现 {file_count} 个episode文件")
|
||
|
||
if file_count > 5000:
|
||
print(f" ⚠️ 数据量很大!建议:")
|
||
print(f" 1. 使用 --update_frequency 10 减少更新频率")
|
||
print(f" 2. 分批训练:每次只用部分数据")
|
||
print(f" 3. 增大 batch_size 提高效率")
|
||
elif file_count > 1000:
|
||
print(f" ✅ 数据量适中")
|
||
else:
|
||
print(f" 数据量较少,训练应该很快")
|
||
|
||
# 检查单个文件大小
|
||
if file_count > 0:
|
||
sample_file = os.path.join('Data', episode_files[0])
|
||
file_size = os.path.getsize(sample_file) / 1024 # KB
|
||
print(f" 单个文件大小: ~{file_size:.1f} KB")
|
||
|
||
total_size = file_count * file_size / 1024 # MB
|
||
print(f" 总数据量估计: ~{total_size:.1f} MB")
|
||
else:
|
||
print(" ❌ Data目录不存在!")
|
||
|
||
# ============================================================================
|
||
# 3. 检查模型加载速度
|
||
# ============================================================================
|
||
print("\n【3/5】模型加载速度检查")
|
||
print("-" * 80)
|
||
|
||
try:
|
||
from ppo_model import PPONetwork
|
||
|
||
start = time.time()
|
||
model = PPONetwork(state_dim=572, hidden_dim=512)
|
||
if torch.cuda.is_available():
|
||
model = model.cuda()
|
||
load_time = time.time() - start
|
||
|
||
param_count = sum(p.numel() for p in model.parameters())
|
||
print(f" ✅ 模型加载成功")
|
||
print(f" 参数量: {param_count:,}")
|
||
print(f" 加载时间: {load_time:.3f}秒")
|
||
|
||
if load_time > 2.0:
|
||
print(" ⚠️ 模型加载较慢")
|
||
except Exception as e:
|
||
print(f" ❌ 模型加载失败: {e}")
|
||
|
||
# ============================================================================
|
||
# 4. 检查数据加载速度
|
||
# ============================================================================
|
||
print("\n【4/5】数据加载速度检查")
|
||
print("-" * 80)
|
||
|
||
if os.path.exists('Data') and file_count > 0:
|
||
try:
|
||
from data_loader import TrainingDataLoader
|
||
|
||
loader = TrainingDataLoader('Data')
|
||
|
||
# 加载10个文件测试
|
||
print(" 测试加载10个episode...")
|
||
start = time.time()
|
||
|
||
test_files = episode_files[:10]
|
||
for fname in test_files:
|
||
fpath = os.path.join('Data', fname)
|
||
ep_data = loader.load_episode_file(fpath)
|
||
_ = loader.parse_episode_data(ep_data)
|
||
|
||
load_time = time.time() - start
|
||
per_file = load_time / 10
|
||
|
||
print(f" ✅ 10个文件加载耗时: {load_time:.2f}秒")
|
||
print(f" 单文件平均: {per_file:.3f}秒")
|
||
|
||
# 估算总加载时间
|
||
total_load_time = per_file * file_count
|
||
print(f" 估算全部{file_count}个文件需要: {total_load_time:.1f}秒 ({total_load_time/60:.1f}分钟)")
|
||
|
||
if per_file > 0.1:
|
||
print(" ⚠️ 数据加载较慢!建议:")
|
||
print(" - 使用更快的存储设备(SSD)")
|
||
print(" - 预处理数据为二进制格式(.npy)")
|
||
print(" - 使用数据预加载和缓存")
|
||
else:
|
||
print(" ✅ 数据加载速度正常")
|
||
|
||
except Exception as e:
|
||
print(f" ❌ 数据加载测试失败: {e}")
|
||
|
||
# ============================================================================
|
||
# 5. 检查训练循环速度
|
||
# ============================================================================
|
||
print("\n【5/5】训练循环速度检查")
|
||
print("-" * 80)
|
||
|
||
if torch.cuda.is_available():
|
||
try:
|
||
from ppo_trainer import PPOTrainer
|
||
from data_loader import ReplayBuffer
|
||
|
||
# 创建模拟数据
|
||
n_samples = 100
|
||
state_dim = 572
|
||
|
||
states = np.random.randn(n_samples, state_dim).astype(np.float32)
|
||
actions = np.random.randint(0, 3, size=n_samples, dtype=np.int64)
|
||
rewards = np.random.randn(n_samples).astype(np.float32)
|
||
dones = np.zeros(n_samples, dtype=np.bool_)
|
||
valid_actions = [np.array([10, 20, 30], dtype=np.uint64) for _ in range(n_samples)]
|
||
|
||
# 测试训练速度
|
||
trainer = PPOTrainer(
|
||
state_dim=state_dim,
|
||
hidden_dim=512,
|
||
batch_size=64,
|
||
epochs_per_update=4,
|
||
device=torch.device('cuda'),
|
||
use_amp=True
|
||
)
|
||
|
||
buffer = ReplayBuffer(capacity=10000)
|
||
buffer.add_episode(states, actions, rewards, dones, valid_actions)
|
||
|
||
print(" 测试训练100个样本...")
|
||
start = time.time()
|
||
|
||
loss = trainer.train_from_buffer(buffer)
|
||
|
||
torch.cuda.synchronize()
|
||
train_time = time.time() - start
|
||
|
||
samples_per_sec = n_samples / train_time
|
||
|
||
print(f" ✅ 训练完成")
|
||
print(f" 训练时间: {train_time:.3f}秒")
|
||
print(f" 吞吐量: {samples_per_sec:.1f} 样本/秒")
|
||
|
||
# 估算完整训练时间
|
||
if os.path.exists('Data') and file_count > 0:
|
||
# 假设每个episode平均8步
|
||
total_samples = file_count * 8
|
||
est_time_per_epoch = total_samples / samples_per_sec
|
||
|
||
print(f"\n 估算单个epoch时间: {est_time_per_epoch:.1f}秒 ({est_time_per_epoch/60:.1f}分钟)")
|
||
print(f" 100个epoch估算: {est_time_per_epoch * 100 / 3600:.1f}小时")
|
||
|
||
if samples_per_sec < 50:
|
||
print("\n ⚠️ 训练速度较慢!")
|
||
elif samples_per_sec < 100:
|
||
print("\n ⚙️ 训练速度一般,有优化空间")
|
||
else:
|
||
print("\n ✅ 训练速度良好")
|
||
|
||
trainer.close()
|
||
|
||
except Exception as e:
|
||
print(f" ❌ 训练速度测试失败: {e}")
|
||
import traceback
|
||
traceback.print_exc()
|
||
|
||
# ============================================================================
|
||
# 诊断结果总结
|
||
# ============================================================================
|
||
print("\n" + "=" * 80)
|
||
print("诊断结果总结")
|
||
print("=" * 80)
|
||
|
||
print("\n🔍 主要瓶颈分析:")
|
||
|
||
bottlenecks = []
|
||
|
||
if not torch.cuda.is_available():
|
||
bottlenecks.append("❌ 【严重】GPU不可用 - 必须解决")
|
||
elif gpu_time > 1.0:
|
||
bottlenecks.append("⚠️ 【中等】GPU计算慢 - 检查驱动和散热")
|
||
|
||
if os.path.exists('Data'):
|
||
if file_count > 5000:
|
||
bottlenecks.append("⚠️ 【中等】数据量过大 - 建议分批或增大batch")
|
||
|
||
if per_file > 0.1:
|
||
bottlenecks.append("⚠️ 【中等】数据加载慢 - 考虑使用SSD或预处理")
|
||
|
||
if 'samples_per_sec' in locals() and samples_per_sec < 50:
|
||
bottlenecks.append("❌ 【严重】训练速度慢 - CPU-GPU同步问题")
|
||
|
||
if len(bottlenecks) == 0:
|
||
print("✅ 未发现明显瓶颈,性能应该正常")
|
||
else:
|
||
for i, msg in enumerate(bottlenecks, 1):
|
||
print(f"{i}. {msg}")
|
||
|
||
print("\n💡 优化建议:")
|
||
print("1. 使用优化后的训练脚本: train_fast.bat")
|
||
print("2. 增大batch_size到128或256")
|
||
print("3. 启用混合精度训练 --use_amp")
|
||
print("4. 减少更新频率 --update_frequency 10")
|
||
print("5. 监控GPU利用率: nvidia-smi -l 1")
|
||
|
||
print("\n📊 详细性能分析:")
|
||
print("运行: python profile_performance.py")
|
||
|
||
print("\n🚀 开始训练:")
|
||
print("运行: train_fast.bat")
|
||
|
||
print("\n" + "=" * 80)
|
||
|