TH1/AITrainPython/quick_diagnose.py
2025-12-06 11:44:43 +08:00

263 lines
9.0 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
快速诊断工具 - 5分钟内找出训练慢的原因
"""
import torch
import time
import os
import numpy as np
print("=" * 80)
print("训练速度快速诊断工具")
print("=" * 80)
# ============================================================================
# 1. 检查硬件环境
# ============================================================================
print("\n【1/5】硬件环境检查")
print("-" * 80)
if torch.cuda.is_available():
print(f"✅ GPU可用: {torch.cuda.get_device_name(0)}")
print(f" 显存: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB")
# 测试GPU速度
print("\n 测试GPU计算速度...")
x = torch.randn(1000, 1000, device='cuda')
start = time.time()
for _ in range(100):
y = torch.matmul(x, x)
torch.cuda.synchronize()
gpu_time = time.time() - start
print(f" GPU矩阵乘法: {gpu_time:.3f}秒 (100次 1000x1000)")
if gpu_time > 1.0:
print(" ⚠️ GPU计算较慢可能的原因:")
print(" - 驱动版本过旧")
print(" - GPU负载过高其他程序占用")
print(" - 散热问题导致降频")
else:
print("❌ GPU不可用训练将非常慢")
print(" 请检查:")
print(" - CUDA是否正确安装")
print(" - PyTorch是否为GPU版本")
# ============================================================================
# 2. 检查数据量
# ============================================================================
print("\n【2/5】数据量检查")
print("-" * 80)
if os.path.exists('Data'):
episode_files = [f for f in os.listdir('Data') if f.startswith('episode_')]
file_count = len(episode_files)
print(f" 发现 {file_count} 个episode文件")
if file_count > 5000:
print(f" ⚠️ 数据量很大!建议:")
print(f" 1. 使用 --update_frequency 10 减少更新频率")
print(f" 2. 分批训练:每次只用部分数据")
print(f" 3. 增大 batch_size 提高效率")
elif file_count > 1000:
print(f" ✅ 数据量适中")
else:
print(f" 数据量较少,训练应该很快")
# 检查单个文件大小
if file_count > 0:
sample_file = os.path.join('Data', episode_files[0])
file_size = os.path.getsize(sample_file) / 1024 # KB
print(f" 单个文件大小: ~{file_size:.1f} KB")
total_size = file_count * file_size / 1024 # MB
print(f" 总数据量估计: ~{total_size:.1f} MB")
else:
print(" ❌ Data目录不存在")
# ============================================================================
# 3. 检查模型加载速度
# ============================================================================
print("\n【3/5】模型加载速度检查")
print("-" * 80)
try:
from ppo_model import PPONetwork
start = time.time()
model = PPONetwork(state_dim=572, hidden_dim=512)
if torch.cuda.is_available():
model = model.cuda()
load_time = time.time() - start
param_count = sum(p.numel() for p in model.parameters())
print(f" ✅ 模型加载成功")
print(f" 参数量: {param_count:,}")
print(f" 加载时间: {load_time:.3f}")
if load_time > 2.0:
print(" ⚠️ 模型加载较慢")
except Exception as e:
print(f" ❌ 模型加载失败: {e}")
# ============================================================================
# 4. 检查数据加载速度
# ============================================================================
print("\n【4/5】数据加载速度检查")
print("-" * 80)
if os.path.exists('Data') and file_count > 0:
try:
from data_loader import TrainingDataLoader
loader = TrainingDataLoader('Data')
# 加载10个文件测试
print(" 测试加载10个episode...")
start = time.time()
test_files = episode_files[:10]
for fname in test_files:
fpath = os.path.join('Data', fname)
ep_data = loader.load_episode_file(fpath)
_ = loader.parse_episode_data(ep_data)
load_time = time.time() - start
per_file = load_time / 10
print(f" ✅ 10个文件加载耗时: {load_time:.2f}")
print(f" 单文件平均: {per_file:.3f}")
# 估算总加载时间
total_load_time = per_file * file_count
print(f" 估算全部{file_count}个文件需要: {total_load_time:.1f}秒 ({total_load_time/60:.1f}分钟)")
if per_file > 0.1:
print(" ⚠️ 数据加载较慢!建议:")
print(" - 使用更快的存储设备SSD")
print(" - 预处理数据为二进制格式(.npy")
print(" - 使用数据预加载和缓存")
else:
print(" ✅ 数据加载速度正常")
except Exception as e:
print(f" ❌ 数据加载测试失败: {e}")
# ============================================================================
# 5. 检查训练循环速度
# ============================================================================
print("\n【5/5】训练循环速度检查")
print("-" * 80)
if torch.cuda.is_available():
try:
from ppo_trainer import PPOTrainer
from data_loader import ReplayBuffer
# 创建模拟数据
n_samples = 100
state_dim = 572
states = np.random.randn(n_samples, state_dim).astype(np.float32)
actions = np.random.randint(0, 3, size=n_samples, dtype=np.int64)
rewards = np.random.randn(n_samples).astype(np.float32)
dones = np.zeros(n_samples, dtype=np.bool_)
valid_actions = [np.array([10, 20, 30], dtype=np.uint64) for _ in range(n_samples)]
# 测试训练速度
trainer = PPOTrainer(
state_dim=state_dim,
hidden_dim=512,
batch_size=64,
epochs_per_update=4,
device=torch.device('cuda'),
use_amp=True
)
buffer = ReplayBuffer(capacity=10000)
buffer.add_episode(states, actions, rewards, dones, valid_actions)
print(" 测试训练100个样本...")
start = time.time()
loss = trainer.train_from_buffer(buffer)
torch.cuda.synchronize()
train_time = time.time() - start
samples_per_sec = n_samples / train_time
print(f" ✅ 训练完成")
print(f" 训练时间: {train_time:.3f}")
print(f" 吞吐量: {samples_per_sec:.1f} 样本/秒")
# 估算完整训练时间
if os.path.exists('Data') and file_count > 0:
# 假设每个episode平均8步
total_samples = file_count * 8
est_time_per_epoch = total_samples / samples_per_sec
print(f"\n 估算单个epoch时间: {est_time_per_epoch:.1f}秒 ({est_time_per_epoch/60:.1f}分钟)")
print(f" 100个epoch估算: {est_time_per_epoch * 100 / 3600:.1f}小时")
if samples_per_sec < 50:
print("\n ⚠️ 训练速度较慢!")
elif samples_per_sec < 100:
print("\n ⚙️ 训练速度一般,有优化空间")
else:
print("\n ✅ 训练速度良好")
trainer.close()
except Exception as e:
print(f" ❌ 训练速度测试失败: {e}")
import traceback
traceback.print_exc()
# ============================================================================
# 诊断结果总结
# ============================================================================
print("\n" + "=" * 80)
print("诊断结果总结")
print("=" * 80)
print("\n🔍 主要瓶颈分析:")
bottlenecks = []
if not torch.cuda.is_available():
bottlenecks.append("❌ 【严重】GPU不可用 - 必须解决")
elif gpu_time > 1.0:
bottlenecks.append("⚠️ 【中等】GPU计算慢 - 检查驱动和散热")
if os.path.exists('Data'):
if file_count > 5000:
bottlenecks.append("⚠️ 【中等】数据量过大 - 建议分批或增大batch")
if per_file > 0.1:
bottlenecks.append("⚠️ 【中等】数据加载慢 - 考虑使用SSD或预处理")
if 'samples_per_sec' in locals() and samples_per_sec < 50:
bottlenecks.append("❌ 【严重】训练速度慢 - CPU-GPU同步问题")
if len(bottlenecks) == 0:
print("✅ 未发现明显瓶颈,性能应该正常")
else:
for i, msg in enumerate(bottlenecks, 1):
print(f"{i}. {msg}")
print("\n💡 优化建议:")
print("1. 使用优化后的训练脚本: train_fast.bat")
print("2. 增大batch_size到128或256")
print("3. 启用混合精度训练 --use_amp")
print("4. 减少更新频率 --update_frequency 10")
print("5. 监控GPU利用率: nvidia-smi -l 1")
print("\n📊 详细性能分析:")
print("运行: python profile_performance.py")
print("\n🚀 开始训练:")
print("运行: train_fast.bat")
print("\n" + "=" * 80)