575 lines
22 KiB
Python
575 lines
22 KiB
Python
"""
|
||
模型一致性测试脚本
|
||
测试 PTH 和 ONNX 模型的输出是否一致,以及概率分布是否合理
|
||
"""
|
||
|
||
import os
|
||
import json
|
||
import numpy as np
|
||
import torch
|
||
import onnxruntime as ort
|
||
from ppo_model import PPONetwork
|
||
from data_loader import TrainingDataLoader
|
||
from colorama import init, Fore, Style
|
||
import sys
|
||
|
||
# 初始化 colorama(用于彩色输出)
|
||
init(autoreset=True)
|
||
|
||
|
||
class ModelConsistencyTester:
|
||
"""模型一致性测试器"""
|
||
|
||
def __init__(self, pth_path='models/ppo_model_best.pth',
|
||
onnx_path='models/ppo_model_best.onnx',
|
||
state_dim=572):
|
||
"""
|
||
Args:
|
||
pth_path: PyTorch模型路径
|
||
onnx_path: ONNX模型路径
|
||
state_dim: 状态维度
|
||
"""
|
||
self.pth_path = pth_path
|
||
self.onnx_path = onnx_path
|
||
self.state_dim = state_dim
|
||
self.device = torch.device('cpu') # 统一使用CPU进行比较
|
||
|
||
print(f"{Fore.CYAN}{'='*80}")
|
||
print(f"{Fore.CYAN}模型一致性测试")
|
||
print(f"{Fore.CYAN}{'='*80}\n")
|
||
|
||
# 加载模型
|
||
self._load_pth_model()
|
||
self._load_onnx_model()
|
||
|
||
def _load_pth_model(self):
|
||
"""加载PyTorch模型"""
|
||
print(f"{Fore.YELLOW}[1/2] 加载 PyTorch 模型...")
|
||
|
||
if not os.path.exists(self.pth_path):
|
||
raise FileNotFoundError(f"PTH模型不存在: {self.pth_path}")
|
||
|
||
# 加载checkpoint并检测hidden_dim
|
||
checkpoint = torch.load(self.pth_path, map_location=self.device, weights_only=False)
|
||
|
||
if 'policy_state_dict' in checkpoint:
|
||
state_dict = checkpoint['policy_state_dict']
|
||
else:
|
||
state_dict = checkpoint
|
||
|
||
# 自动检测hidden_dim
|
||
hidden_dim = state_dict['shared_net.0.weight'].shape[0]
|
||
|
||
# 创建模型
|
||
self.pth_model = PPONetwork(state_dim=self.state_dim, hidden_dim=hidden_dim)
|
||
|
||
# 加载权重(兼容旧模型)
|
||
if 'policy_state_dict' in checkpoint:
|
||
try:
|
||
self.pth_model.load_state_dict(checkpoint['policy_state_dict'])
|
||
except RuntimeError as e:
|
||
if '_freqs' in str(e):
|
||
self.pth_model.load_state_dict(checkpoint['policy_state_dict'], strict=False)
|
||
else:
|
||
raise
|
||
else:
|
||
try:
|
||
self.pth_model.load_state_dict(checkpoint)
|
||
except RuntimeError as e:
|
||
if '_freqs' in str(e):
|
||
self.pth_model.load_state_dict(checkpoint, strict=False)
|
||
else:
|
||
raise
|
||
|
||
self.pth_model.eval()
|
||
|
||
print(f"{Fore.GREEN} [OK] PyTorch 模型加载成功")
|
||
print(f" 路径: {self.pth_path}")
|
||
print(f" Hidden Dim: {hidden_dim}")
|
||
if 'total_steps' in checkpoint:
|
||
print(f" 训练步数: {checkpoint['total_steps']}")
|
||
if 'total_episodes' in checkpoint:
|
||
print(f" 训练轮数: {checkpoint['total_episodes']}")
|
||
print()
|
||
|
||
def _load_onnx_model(self):
|
||
"""加载ONNX模型"""
|
||
print(f"{Fore.YELLOW}[2/2] 加载 ONNX 模型...")
|
||
|
||
if not os.path.exists(self.onnx_path):
|
||
raise FileNotFoundError(f"ONNX模型不存在: {self.onnx_path}")
|
||
|
||
# 创建ONNX推理会话
|
||
self.onnx_session = ort.InferenceSession(self.onnx_path)
|
||
|
||
print(f"{Fore.GREEN} [OK] ONNX 模型加载成功")
|
||
print(f" 路径: {self.onnx_path}")
|
||
|
||
# 打印输入输出信息
|
||
print(f" 输入:")
|
||
for input_meta in self.onnx_session.get_inputs():
|
||
print(f" - {input_meta.name}: {input_meta.shape} ({input_meta.type})")
|
||
|
||
print(f" 输出:")
|
||
for output_meta in self.onnx_session.get_outputs():
|
||
print(f" - {output_meta.name}: {output_meta.shape} ({output_meta.type})")
|
||
print()
|
||
|
||
def predict_pth(self, state, valid_actions):
|
||
"""
|
||
使用PyTorch模型预测
|
||
Args:
|
||
state: (state_dim,) numpy array
|
||
valid_actions: (n_actions,) numpy array of uint64
|
||
Returns:
|
||
action_probs: (n_actions,) numpy array,概率分布
|
||
selected_action: uint64,选择的动作
|
||
"""
|
||
with torch.no_grad():
|
||
# 转换为tensor
|
||
state_tensor = torch.FloatTensor(state).unsqueeze(0).to(self.device)
|
||
action_ids_tensor = torch.from_numpy(valid_actions).to(self.device)
|
||
|
||
# 获取动作logits(与ONNX导出逻辑一致)
|
||
shared_features, _ = self.pth_model.forward(state_tensor)
|
||
action_logits = self.pth_model.get_action_logits(shared_features, action_ids_tensor)
|
||
action_logits = action_logits.cpu().numpy()
|
||
|
||
# 确保是 1D 数组
|
||
if action_logits.ndim == 2:
|
||
action_logits = action_logits.squeeze(0)
|
||
|
||
# 转换为概率(使用与ONNX相同的numpy softmax)
|
||
logits_max = np.max(action_logits)
|
||
exp_logits = np.exp(action_logits - logits_max)
|
||
action_probs = exp_logits / np.sum(exp_logits)
|
||
|
||
# 选择概率最大的动作(deterministic)
|
||
action_index = np.argmax(action_probs)
|
||
selected_action = valid_actions[action_index]
|
||
|
||
return action_probs, selected_action
|
||
|
||
def predict_onnx(self, state, valid_actions):
|
||
"""
|
||
使用ONNX模型预测
|
||
Args:
|
||
state: (state_dim,) numpy array
|
||
valid_actions: (n_actions,) numpy array of uint64
|
||
Returns:
|
||
action_probs: (n_actions,) numpy array,概率分布
|
||
selected_action: uint64,选择的动作
|
||
"""
|
||
# 准备输入
|
||
state_input = state.reshape(1, -1).astype(np.float32)
|
||
actions_input = valid_actions.astype(np.int64)
|
||
|
||
# 运行ONNX推理(输出是logits)
|
||
outputs = self.onnx_session.run(
|
||
None,
|
||
{
|
||
'state': state_input,
|
||
'valid_actions': actions_input
|
||
}
|
||
)
|
||
|
||
action_logits = outputs[0]
|
||
|
||
# 转换为概率(softmax)- 使用数值稳定的方式
|
||
logits_max = np.max(action_logits)
|
||
exp_logits = np.exp(action_logits - logits_max)
|
||
action_probs = exp_logits / np.sum(exp_logits)
|
||
|
||
# 选择概率最大的动作
|
||
action_index = np.argmax(action_probs)
|
||
selected_action = valid_actions[action_index]
|
||
|
||
return action_probs, selected_action
|
||
|
||
def check_probability_distribution(self, probs, threshold=0.05):
|
||
"""
|
||
检查概率分布是否合理
|
||
Args:
|
||
probs: 概率分布
|
||
threshold: 判断"太平均"的阈值
|
||
Returns:
|
||
is_valid: bool,是否合理
|
||
issues: list,问题列表
|
||
"""
|
||
issues = []
|
||
|
||
# 1. 检查概率和是否接近1
|
||
prob_sum = np.sum(probs)
|
||
if not np.isclose(prob_sum, 1.0, atol=1e-5):
|
||
issues.append(f"概率和不为1: {prob_sum:.6f}")
|
||
|
||
# 2. 检查是否有NaN或Inf
|
||
if np.any(np.isnan(probs)):
|
||
issues.append("存在NaN值")
|
||
if np.any(np.isinf(probs)):
|
||
issues.append("存在Inf值")
|
||
|
||
# 3. 检查是否所有概率都接近0(除了一个)
|
||
zero_count = np.sum(probs < 1e-6)
|
||
if zero_count == len(probs) - 1:
|
||
issues.append("几乎所有概率都为0(过度自信)")
|
||
|
||
# 4. 检查是否太平均(熵太高)
|
||
n = len(probs)
|
||
if n > 1:
|
||
# 计算与均匀分布的差异
|
||
uniform_prob = 1.0 / n
|
||
max_deviation = np.max(np.abs(probs - uniform_prob))
|
||
|
||
if max_deviation < threshold:
|
||
issues.append(f"概率分布过于平均(最大偏差: {max_deviation:.4f})")
|
||
|
||
# 5. 检查最大概率
|
||
max_prob = np.max(probs)
|
||
if max_prob < 0.1 and n > 5:
|
||
issues.append(f"最大概率过低: {max_prob:.4f}")
|
||
|
||
is_valid = len(issues) == 0
|
||
return is_valid, issues
|
||
|
||
def test_single_sample(self, state, valid_actions, expected_action):
|
||
"""
|
||
测试单个样本
|
||
Args:
|
||
state: (state_dim,) numpy array
|
||
valid_actions: (n_actions,) numpy array
|
||
expected_action: uint64,期望的动作
|
||
Returns:
|
||
result: dict,测试结果
|
||
"""
|
||
result = {
|
||
'pth_correct': False,
|
||
'onnx_correct': False,
|
||
'models_match': False,
|
||
'pth_prob_valid': False,
|
||
'onnx_prob_valid': False,
|
||
'pth_probs': None,
|
||
'onnx_probs': None,
|
||
'pth_action': None,
|
||
'onnx_action': None,
|
||
'expected_action': expected_action,
|
||
'issues': []
|
||
}
|
||
|
||
try:
|
||
# PTH模型预测
|
||
pth_probs, pth_action = self.predict_pth(state, valid_actions)
|
||
result['pth_probs'] = pth_probs
|
||
result['pth_action'] = pth_action
|
||
result['pth_correct'] = (pth_action == expected_action)
|
||
|
||
# 检查PTH概率分布
|
||
pth_valid, pth_issues = self.check_probability_distribution(pth_probs)
|
||
result['pth_prob_valid'] = pth_valid
|
||
if not pth_valid:
|
||
result['issues'].extend([f"PTH: {issue}" for issue in pth_issues])
|
||
|
||
# ONNX模型预测
|
||
onnx_probs, onnx_action = self.predict_onnx(state, valid_actions)
|
||
result['onnx_probs'] = onnx_probs
|
||
result['onnx_action'] = onnx_action
|
||
result['onnx_correct'] = (onnx_action == expected_action)
|
||
|
||
# 检查ONNX概率分布
|
||
onnx_valid, onnx_issues = self.check_probability_distribution(onnx_probs)
|
||
result['onnx_prob_valid'] = onnx_valid
|
||
if not onnx_valid:
|
||
result['issues'].extend([f"ONNX: {issue}" for issue in onnx_issues])
|
||
|
||
# 检查两个模型是否一致
|
||
# 1. 动作是否一致
|
||
result['models_match'] = (pth_action == onnx_action)
|
||
|
||
# 2. 概率分布是否接近
|
||
prob_diff = np.abs(pth_probs - onnx_probs)
|
||
max_prob_diff = np.max(prob_diff)
|
||
result['max_prob_diff'] = max_prob_diff
|
||
|
||
# 允许5%的概率差异(由于浮点精度和实现差异)
|
||
if max_prob_diff > 0.05:
|
||
result['issues'].append(f"PTH和ONNX概率差异过大: {max_prob_diff:.6f}")
|
||
|
||
except Exception as e:
|
||
result['issues'].append(f"预测异常: {str(e)}")
|
||
|
||
return result
|
||
|
||
def test_from_jsonl(self, jsonl_path, max_samples=None):
|
||
"""
|
||
从JSONL文件测试
|
||
Args:
|
||
jsonl_path: JSONL文件路径
|
||
max_samples: 最多测试样本数,None表示全部
|
||
Returns:
|
||
results: list of dict,所有测试结果
|
||
"""
|
||
print(f"{Fore.CYAN}{'='*80}")
|
||
print(f"{Fore.CYAN}从 JSONL 文件测试: {os.path.basename(jsonl_path)}")
|
||
print(f"{Fore.CYAN}{'='*80}\n")
|
||
|
||
if not os.path.exists(jsonl_path):
|
||
raise FileNotFoundError(f"JSONL文件不存在: {jsonl_path}")
|
||
|
||
# 加载数据
|
||
loader = TrainingDataLoader()
|
||
episode_data = loader.load_episode_file(jsonl_path)
|
||
|
||
if len(episode_data) == 0:
|
||
print(f"{Fore.RED}错误: 无法加载数据")
|
||
return []
|
||
|
||
if max_samples is not None:
|
||
episode_data = episode_data[:max_samples]
|
||
|
||
print(f"加载了 {len(episode_data)} 个样本\n")
|
||
|
||
# 测试每个样本
|
||
results = []
|
||
|
||
for i, step_data in enumerate(episode_data):
|
||
state = np.array(step_data['State'], dtype=np.float32)
|
||
valid_actions = np.array(step_data['Actions'], dtype=np.uint64)
|
||
expected_action = int(step_data['SelectedAction'])
|
||
|
||
result = self.test_single_sample(state, valid_actions, expected_action)
|
||
results.append(result)
|
||
|
||
# 打印进度
|
||
if (i + 1) % 10 == 0:
|
||
print(f" 已测试 {i + 1}/{len(episode_data)} 个样本...")
|
||
|
||
print(f"\n{Fore.GREEN}测试完成!\n")
|
||
|
||
return results
|
||
|
||
def print_summary(self, results, skip_action_check=False):
|
||
"""
|
||
打印测试结果摘要
|
||
Args:
|
||
results: 测试结果列表
|
||
skip_action_check: 是否跳过动作检查
|
||
"""
|
||
if len(results) == 0:
|
||
print(f"{Fore.RED}无测试结果")
|
||
return
|
||
|
||
total = len(results)
|
||
pth_correct = sum(1 for r in results if r['pth_correct'])
|
||
onnx_correct = sum(1 for r in results if r['onnx_correct'])
|
||
models_match = sum(1 for r in results if r['models_match'])
|
||
pth_prob_valid = sum(1 for r in results if r['pth_prob_valid'])
|
||
onnx_prob_valid = sum(1 for r in results if r['onnx_prob_valid'])
|
||
|
||
# 统计完全通过的样本(所有条件都满足)
|
||
if skip_action_check:
|
||
# 跳过动作检查时,只要求模型一致性和概率合理
|
||
all_pass = sum(1 for r in results if (
|
||
r['models_match'] and
|
||
r['pth_prob_valid'] and
|
||
r['onnx_prob_valid'] and
|
||
len(r['issues']) == 0
|
||
))
|
||
else:
|
||
all_pass = sum(1 for r in results if (
|
||
r['pth_correct'] and
|
||
r['onnx_correct'] and
|
||
r['models_match'] and
|
||
r['pth_prob_valid'] and
|
||
r['onnx_prob_valid']
|
||
))
|
||
|
||
print(f"{Fore.CYAN}{'='*80}")
|
||
print(f"{Fore.CYAN}测试结果摘要")
|
||
print(f"{Fore.CYAN}{'='*80}\n")
|
||
|
||
print(f"总样本数: {total}\n")
|
||
|
||
# 动作正确率(如果不跳过)
|
||
if not skip_action_check:
|
||
print(f"{Fore.YELLOW}【动作选择(与训练数据对比)】")
|
||
self._print_stat("PTH模型正确率", pth_correct, total)
|
||
self._print_stat("ONNX模型正确率", onnx_correct, total)
|
||
print(f"{Fore.CYAN} 注意: 模型已通过PPO学习,可能选择与行为树不同的动作{Style.RESET_ALL}")
|
||
print()
|
||
|
||
# 模型一致性
|
||
print(f"{Fore.YELLOW}【模型一致性】")
|
||
self._print_stat("PTH和ONNX输出一致", models_match, total)
|
||
|
||
# 计算平均概率差异
|
||
avg_prob_diff = np.mean([r.get('max_prob_diff', 0) for r in results])
|
||
print(f" 平均概率差异: {avg_prob_diff:.6f}")
|
||
print()
|
||
|
||
# 概率分布合理性
|
||
print(f"{Fore.YELLOW}【概率分布合理性】")
|
||
self._print_stat("PTH概率分布合理", pth_prob_valid, total)
|
||
self._print_stat("ONNX概率分布合理", onnx_prob_valid, total)
|
||
print()
|
||
|
||
# 总体通过率
|
||
print(f"{Fore.YELLOW}【综合评估】")
|
||
self._print_stat("所有检测项通过", all_pass, total, critical=True)
|
||
print()
|
||
|
||
# 列出问题样本
|
||
problem_samples = [i for i, r in enumerate(results) if len(r['issues']) > 0]
|
||
if len(problem_samples) > 0:
|
||
print(f"{Fore.RED}【问题样本】")
|
||
print(f" 发现 {len(problem_samples)} 个问题样本\n")
|
||
|
||
# 只显示前10个
|
||
show_count = min(10, len(problem_samples))
|
||
for idx in problem_samples[:show_count]:
|
||
r = results[idx]
|
||
print(f" 样本 #{idx + 1}:")
|
||
print(f" 期望动作: {r['expected_action']}")
|
||
print(f" PTH预测: {r['pth_action']} {'[OK]' if r['pth_correct'] else '[X]'}")
|
||
print(f" ONNX预测: {r['onnx_action']} {'[OK]' if r['onnx_correct'] else '[X]'}")
|
||
|
||
for issue in r['issues']:
|
||
print(f" ⚠ {issue}")
|
||
print()
|
||
|
||
if len(problem_samples) > show_count:
|
||
print(f" ... 还有 {len(problem_samples) - show_count} 个问题样本未显示\n")
|
||
|
||
# 判断是否通过测试
|
||
print(f"{Fore.CYAN}{'='*80}")
|
||
|
||
# 通过标准:
|
||
if skip_action_check:
|
||
# 跳过动作检查模式:只要求模型一致性和概率合理
|
||
# 1. PTH和ONNX模型一致性至少95%
|
||
# 2. 概率分布合理性至少95%
|
||
consistency_rate = models_match / total
|
||
prob_valid_rate = min(pth_prob_valid, onnx_prob_valid) / total
|
||
|
||
if consistency_rate >= 0.95 and prob_valid_rate >= 0.95:
|
||
print(f"{Fore.GREEN}{'[OK] 测试通过!':^80}")
|
||
print(f"{Fore.GREEN}{'PTH和ONNX模型输出一致,概率分布合理':^80}")
|
||
print(f"{Fore.GREEN}{'模型可以投入使用':^80}")
|
||
else:
|
||
print(f"{Fore.RED}{'[FAIL] 测试失败!':^80}")
|
||
if consistency_rate < 0.95:
|
||
print(f"{Fore.RED}{'PTH和ONNX一致性不足95%':^80}")
|
||
if prob_valid_rate < 0.95:
|
||
print(f"{Fore.RED}{'概率分布合理性不足95%':^80}")
|
||
print(f"{Fore.RED}{'建议检查模型训练或导出过程':^80}")
|
||
else:
|
||
# 完整检查模式:要求动作匹配
|
||
# 1. 至少80%的样本所有检测项都通过
|
||
# 2. PTH和ONNX模型一致性至少95%
|
||
pass_rate = all_pass / total
|
||
consistency_rate = models_match / total
|
||
|
||
if pass_rate >= 0.8 and consistency_rate >= 0.95:
|
||
print(f"{Fore.GREEN}{'[OK] 测试通过!':^80}")
|
||
print(f"{Fore.GREEN}{'所有检测项达标,模型可以投入使用':^80}")
|
||
else:
|
||
print(f"{Fore.RED}{'[FAIL] 测试失败!':^80}")
|
||
if pass_rate < 0.8:
|
||
print(f"{Fore.RED}{'综合通过率不足80%':^80}")
|
||
if consistency_rate < 0.95:
|
||
print(f"{Fore.RED}{'PTH和ONNX一致性不足95%':^80}")
|
||
print(f"{Fore.RED}{'建议检查模型训练或导出过程':^80}")
|
||
|
||
print(f"{Fore.CYAN}{'='*80}\n")
|
||
|
||
def _print_stat(self, label, count, total, critical=False):
|
||
"""打印统计信息"""
|
||
percentage = (count / total) * 100
|
||
|
||
# 根据百分比选择颜色
|
||
if percentage >= 95:
|
||
color = Fore.GREEN
|
||
elif percentage >= 80:
|
||
color = Fore.YELLOW
|
||
else:
|
||
color = Fore.RED
|
||
|
||
if critical and percentage >= 80:
|
||
color = Fore.GREEN
|
||
|
||
bar_length = 40
|
||
filled = int(bar_length * count / total)
|
||
bar = '█' * filled + '░' * (bar_length - filled)
|
||
|
||
print(f" {label:30} {color}{count:4d}/{total:4d} ({percentage:5.1f}%) {bar}{Style.RESET_ALL}")
|
||
|
||
|
||
def main():
|
||
"""主函数"""
|
||
import argparse
|
||
|
||
parser = argparse.ArgumentParser(description='测试PTH和ONNX模型的一致性')
|
||
parser.add_argument('--pth', type=str, default='models/ppo_model_best.pth',
|
||
help='PyTorch模型路径')
|
||
parser.add_argument('--onnx', type=str, default='models/ppo_model_best.onnx',
|
||
help='ONNX模型路径')
|
||
parser.add_argument('--data', type=str, default=None,
|
||
help='JSONL数据文件路径(默认随机选择一个)')
|
||
parser.add_argument('--max-samples', type=int, default=100,
|
||
help='最多测试样本数(默认100)')
|
||
parser.add_argument('--state-dim', type=int, default=572,
|
||
help='状态维度(默认572)')
|
||
parser.add_argument('--skip-action-check', action='store_true',
|
||
help='跳过动作正确性检查(只检查模型一致性)')
|
||
|
||
args = parser.parse_args()
|
||
|
||
# 如果没有指定数据文件,随机选择一个
|
||
if args.data is None:
|
||
data_dir = 'Data'
|
||
if os.path.exists(data_dir):
|
||
jsonl_files = [f for f in os.listdir(data_dir) if f.endswith('.jsonl')]
|
||
if len(jsonl_files) > 0:
|
||
# 选择中间的一个文件(避免选到开头或结尾的极端情况)
|
||
args.data = os.path.join(data_dir, jsonl_files[len(jsonl_files) // 2])
|
||
print(f"{Fore.YELLOW}自动选择数据文件: {os.path.basename(args.data)}\n")
|
||
else:
|
||
print(f"{Fore.RED}错误: Data目录中没有JSONL文件")
|
||
return
|
||
else:
|
||
print(f"{Fore.RED}错误: Data目录不存在")
|
||
return
|
||
|
||
# 检查文件是否存在
|
||
if not os.path.exists(args.pth):
|
||
print(f"{Fore.RED}错误: PTH模型不存在: {args.pth}")
|
||
return
|
||
|
||
if not os.path.exists(args.onnx):
|
||
print(f"{Fore.RED}错误: ONNX模型不存在: {args.onnx}")
|
||
print(f"{Fore.YELLOW}提示: 请先运行 export_to_onnx.py 导出ONNX模型")
|
||
return
|
||
|
||
try:
|
||
# 创建测试器
|
||
tester = ModelConsistencyTester(
|
||
pth_path=args.pth,
|
||
onnx_path=args.onnx,
|
||
state_dim=args.state_dim
|
||
)
|
||
|
||
# 运行测试
|
||
results = tester.test_from_jsonl(args.data, max_samples=args.max_samples)
|
||
|
||
# 打印结果
|
||
tester.print_summary(results, skip_action_check=args.skip_action_check)
|
||
|
||
except Exception as e:
|
||
print(f"{Fore.RED}错误: {e}")
|
||
import traceback
|
||
traceback.print_exc()
|
||
|
||
|
||
if __name__ == '__main__':
|
||
main()
|
||
|